folly/folly/io/async/test/SSLSessionTest.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <folly/ssl/SSLSession.h>

#include <memory>

#include <folly/io/async/test/AsyncSSLSocketTest.h>
#include <folly/net/NetOps.h>
#include <folly/net/NetworkSocket.h>
#include <folly/portability/GTest.h>
#include <folly/portability/OpenSSL.h>
#include <folly/portability/Sockets.h>
#include <folly/ssl/detail/OpenSSLSession.h>
#include <folly/testing/TestUtil.h>

using folly::ssl::SSLSession;
using folly::ssl::detail::OpenSSLSession;

using namespace folly;
using namespace folly::test;

class SSLSessionTest : public testing::Test {
 public:
  void SetUp() override {
    clientCtx_.reset(new folly::SSLContext());
    dfServerCtx_.reset(new folly::SSLContext());
    hskServerCtx_.reset(new folly::SSLContext());
    serverName_ = "xyz.newdev.facebook.com";
    getctx(clientCtx_, dfServerCtx_);
  }

  void TearDown() override {}

  void getfds(NetworkSocket fds[2]) {
    if (netops::socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
      FAIL() << "failed to create socketpair: " << errnoStr(errno);
    }
    for (int idx = 0; idx < 2; ++idx) {
      if (netops::set_socket_non_blocking(fds[idx]) != 0) {
        FAIL() << "failed to put socket " << idx
               << " in non-blocking mode: " << errnoStr(errno);
      }
    }
  }

  void getctx(
      std::shared_ptr<folly::SSLContext> clientCtx,
      std::shared_ptr<folly::SSLContext> serverCtx) {
    clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
    clientCtx->loadTrustedCertificates(find_resource(kTestCA).string().c_str());

    serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
    serverCtx->loadCertificate(find_resource(kTestCert).string().c_str());
    serverCtx->loadPrivateKey(find_resource(kTestKey).string().c_str());
  }

  folly::EventBase eventBase_;
  std::shared_ptr<SSLContext> clientCtx_;
  std::shared_ptr<SSLContext> dfServerCtx_;
  // Use the same SSLContext to continue the handshake after
  // tlsext_hostname match.
  std::shared_ptr<SSLContext> hskServerCtx_;
  std::string serverName_;
};

// TLS 1.2 and TLS 1.3 deliver session tickets in different ways, but we can use
// SSLContext::SessionLifecycleCallbacks to receive them in a similar manner so
// tests can work regardless of version.
class SimpleSessionLifecycleCallback
    : public SSLContext::SessionLifecycleCallbacks {
 public:
  void onNewSession(SSL*, ssl::SSLSessionUniquePtr session) override {
    // This can be called multiple times. OpenSSL sends two session tickets by
    // default). Grab the last one.
    session_ = std::move(session);
    ASSERT_TRUE(socket_ != nullptr);
    // At this point we have what we need to resume a session. Detach the
    // ReadCallback, allowing the socket's EventBase to stop looping.
    socket_->setReadCB(nullptr);
  }

  // set when session is available
  ssl::SSLSessionUniquePtr session_;
  // set after object construction
  folly::AsyncSSLSocket* socket_;
};

class SimpleReadCallback : public AsyncTransport::ReadCallback {
 public:
  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
    *bufReturn = buffer_;
    *lenReturn = sizeof(buffer_);
  }

  void readDataAvailable(size_t) noexcept override {
    // this callback should only be used to read session tickets, which
    // aren't delivered to callbacks
    FAIL();
  }

  void readEOF() noexcept override { FAIL(); }

  void readErr(const AsyncSocketException& ex) noexcept override {
    FAIL() << ex;
  }

  char buffer_[1024];
};

TEST_F(SSLSessionTest, BasicTest) {
  ssl::SSLSessionUniquePtr sslSession;
  // Full handshake
  {
    NetworkSocket fds[2];
    getfds(fds);
    auto sessionCb = std::make_unique<SimpleSessionLifecycleCallback>();
    auto sessionCbPtr = sessionCb.get();
    clientCtx_->setSessionLifecycleCallbacks(std::move(sessionCb));

    AsyncSSLSocket::UniquePtr clientSock(
        new AsyncSSLSocket(clientCtx_, &eventBase_, fds[0], serverName_));
    auto clientPtr = clientSock.get();

    AsyncSSLSocket::UniquePtr serverSock(
        new AsyncSSLSocket(dfServerCtx_, &eventBase_, fds[1], true));
    SSLHandshakeClient client(std::move(clientSock), false, false);
    SSLHandshakeServerParseClientHello server(
        std::move(serverSock), false, false);
    sessionCbPtr->socket_ = clientPtr;
    SimpleReadCallback readCb;
    // register read callback to read incoming session tickets (for TLS 1.3)
    clientPtr->setReadCB(&readCb);
    // should stop when the session ticket is received
    eventBase_.loop();
    ASSERT_TRUE(client.handshakeSuccess_);
    sslSession = std::move(sessionCbPtr->session_);
    ASSERT_TRUE(sslSession != nullptr);
    ASSERT_FALSE(clientPtr->getSSLSessionReused());
  }

  // Session resumption
  {
    NetworkSocket fds[2];
    getfds(fds);
    AsyncSSLSocket::UniquePtr clientSock(
        new AsyncSSLSocket(clientCtx_, &eventBase_, fds[0], serverName_));
    auto clientPtr = clientSock.get();

    clientPtr->setRawSSLSession(std::move(sslSession));

    AsyncSSLSocket::UniquePtr serverSock(
        new AsyncSSLSocket(dfServerCtx_, &eventBase_, fds[1], true));
    SSLHandshakeClient client(std::move(clientSock), false, false);
    SSLHandshakeServerParseClientHello server(
        std::move(serverSock), false, false);

    eventBase_.loop();
    ASSERT_TRUE(client.handshakeSuccess_);
    ASSERT_TRUE(clientPtr->getSSLSessionReused());
  }
}

TEST_F(SSLSessionTest, NullSessionResumptionTest) {
  // Set null session, should result in full handshake
  {
    NetworkSocket fds[2];
    getfds(fds);
    AsyncSSLSocket::UniquePtr clientSock(
        new AsyncSSLSocket(clientCtx_, &eventBase_, fds[0], serverName_));
    auto clientPtr = clientSock.get();

    clientPtr->setSSLSession(nullptr);

    AsyncSSLSocket::UniquePtr serverSock(
        new AsyncSSLSocket(dfServerCtx_, &eventBase_, fds[1], true));
    SSLHandshakeClient client(std::move(clientSock), false, false);
    SSLHandshakeServerParseClientHello server(
        std::move(serverSock), false, false);

    eventBase_.loop();
    ASSERT_TRUE(client.handshakeSuccess_);
    ASSERT_FALSE(clientPtr->getSSLSessionReused());
  }
}