folly/folly/io/async/test/AsyncSSLSocketTest.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <folly/io/async/test/AsyncSSLSocketTest.h>

#include <fcntl.h>
#include <signal.h>
#include <sys/types.h>

#include <openssl/async.h>

#include <fstream>
#include <iostream>
#include <list>
#include <set>
#include <thread>

#include <folly/SocketAddress.h>
#include <folly/String.h>
#include <folly/io/Cursor.h>
#include <folly/io/async/AsyncPipe.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/EventBase.h>
#include <folly/io/async/EventBaseThread.h>
#include <folly/io/async/SSLOptions.h>
#include <folly/io/async/ScopedEventBaseThread.h>
#include <folly/io/async/ssl/BasicTransportCertificate.h>
#include <folly/io/async/ssl/OpenSSLTransportCertificate.h>
#include <folly/io/async/test/BlockingSocket.h>
#include <folly/io/async/test/MockAsyncSocketLegacyObserver.h>
#include <folly/io/async/test/TFOUtil.h>
#include <folly/io/async/test/TestSSLServer.h>
#include <folly/net/NetOps.h>
#include <folly/net/NetworkSocket.h>
#include <folly/net/test/MockNetOpsDispatcher.h>
#include <folly/portability/GMock.h>
#include <folly/portability/GTest.h>
#include <folly/portability/OpenSSL.h>
#include <folly/portability/Unistd.h>
#include <folly/testing/TestUtil.h>

#ifdef __linux__
#include <dlfcn.h>
#endif

using std::cerr;
using std::endl;
using std::string;

using namespace folly;
using namespace folly::test;
using namespace testing;

namespace {

#if defined __linux__
// to store libc's original setsockopt()
typedef int (*setsockopt_ptr)(int, int, int, const void*, socklen_t);
setsockopt_ptr real_setsockopt_ = nullptr;

// global struct to initialize before main runs. we can init within a test,
// or in main, but this method seems to be least intrsive and universal
struct GlobalStatic {
  GlobalStatic() {
    real_setsockopt_ = (setsockopt_ptr)dlsym(RTLD_NEXT, "setsockopt");
  }
  void reset() noexcept { ttlsDisabledSet.clear(); }
  // for each fd, tracks whether TTLS is disabled or not
  std::unordered_set<folly::NetworkSocket /* fd */> ttlsDisabledSet;
};

// the constructor will be called before main() which is all we care about
GlobalStatic globalStatic;

} // namespace

// we intercept setsoctopt to test setting NO_TRANSPARENT_TLS opt
// this name has to be global
int setsockopt(
    int sockfd, int level, int optname, const void* optval, socklen_t optlen) {
  if (optname == SO_NO_TRANSPARENT_TLS) {
    globalStatic.ttlsDisabledSet.insert(folly::NetworkSocket::fromFd(sockfd));
    return 0;
  }
  return real_setsockopt_(sockfd, level, optname, optval, optlen);
}
#endif

namespace {

void getfds(NetworkSocket fds[2]) {
  if (netops::socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
    FAIL() << "failed to create socketpair: " << errnoStr(errno);
  }
  for (int idx = 0; idx < 2; ++idx) {
    if (netops::set_socket_non_blocking(fds[idx]) != 0) {
      FAIL() << "failed to put socket " << idx
             << " in non-blocking mode: " << errnoStr(errno);
    }
  }
}

void getctx(
    std::shared_ptr<folly::SSLContext> clientCtx,
    std::shared_ptr<folly::SSLContext> serverCtx) {
  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());
  serverCtx->loadPrivateKey(find_resource(kTestKey).c_str());
}

std::string getFileAsBuf(const char* fileName) {
  std::string buffer;
  folly::readFile(find_resource(fileName).c_str(), buffer);
  return buffer;
}

folly::ssl::X509UniquePtr readCertFromFile(const std::string& filename) {
  auto path = find_resource(filename);
  folly::ssl::BioUniquePtr bio(BIO_new(BIO_s_file()));
  if (!bio) {
    throw std::runtime_error("Couldn't create BIO");
  }

  if (BIO_read_filename(bio.get(), path.c_str()) != 1) {
    throw std::runtime_error("Couldn't read cert file: " + filename);
  }
  return folly::ssl::X509UniquePtr(
      PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
}

void connectWriteReadClose(
    folly::AsyncReader::ReadCallback::ReadMode readMode) {
  // Start listening on a local port
  folly::test::WriteCallbackBase writeCallback;
  folly::test::ReadCallback readCallback(&writeCallback);
  folly::test::HandshakeCallback handshakeCallback(&readCallback);
  folly::test::SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  folly::test::TestSSLServer server(&acceptCallback);

  // Set up SSL context.
  auto sslContext = std::make_shared<folly::SSLContext>();
  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
  // sslContext->authenticate(true, false);

  // connect
  auto socket = std::make_shared<folly::test::BlockingSocket>(
      server.getAddress(), sslContext);
  socket->setReadMode(readMode);
  socket->open(std::chrono::milliseconds(10000));

  // write()
  uint8_t buf[128];
  memset(buf, 'a', sizeof(buf));
  socket->write(buf, sizeof(buf));

  // read()
  uint8_t readbuf[128];
  uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
  EXPECT_EQ(bytesRead, 128);
  EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);

  // close()
  socket->close();

  cerr << "ConnectWriteReadClose test completed" << endl;
  EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
}
} // namespace

/**
 * Test connecting to, writing to, reading from, and closing the
 * connection to the SSL server.
 */
TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
  connectWriteReadClose(folly::AsyncReader::ReadCallback::ReadMode::ReadBuffer);
}

TEST(AsyncSSLSocketTest, ConnectWriteReadvClose) {
  connectWriteReadClose(folly::AsyncReader::ReadCallback::ReadMode::ReadVec);
}

TEST(AsyncSSLSocketTest, ConnectWriteReadCloseReadable) {
  // Same as above, but test AsyncSSLSocket::readable along the way

  // Start listening on a local port
  WriteCallbackBase writeCallback;
  ReadCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  // Set up SSL context.
  std::shared_ptr<SSLContext> sslContext(new SSLContext());
  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
  // sslContext->authenticate(true, false);

  // connect
  auto socket =
      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
  socket->open(std::chrono::milliseconds(10000));

  // write()
  uint8_t buf[128];
  memset(buf, 'a', sizeof(buf));
  socket->write(buf, sizeof(buf));

  // read()
  uint8_t readbuf[128];
  // The TLS record includes the full 128 bytes.  Even though we only read 1
  // byte out of the socket, the rest of the full record decrypted and buffered
  // in the underlying SSL session.
  uint32_t bytesRead = socket->readAll(readbuf, 1);
  EXPECT_EQ(bytesRead, 1);
  // The socket has no data pending in the kernel
  EXPECT_FALSE(socket->getSocket()->AsyncSocket::readable());
  // But the socket is readable
  EXPECT_TRUE(socket->getSocket()->readable());
  bytesRead += socket->readAll(readbuf + 1, sizeof(readbuf) - 1);
  EXPECT_EQ(bytesRead, 128);
  EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);

  // close()
  socket->close();

  cerr << "ConnectWriteReadClose test completed" << endl;
  EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
}

/**
 * Check that zero copy options are a noop under AsyncSSLSocket since they
 * aren't supported.
 */
TEST(AsyncSSLSocketTest, ZeroCopy) {
  // Set up SSL context.
  std::shared_ptr<SSLContext> sslContext(new SSLContext());
  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

  auto socket = AsyncSSLSocket::newSocket(sslContext, /*evb=*/nullptr);
  EXPECT_FALSE(socket->setZeroCopy(true));
  EXPECT_FALSE(socket->getZeroCopy());
}

/**
 * Same as above simple test, but with a large read len to test
 * clamping behavior.
 */
TEST(AsyncSSLSocketTest, ConnectWriteReadLargeClose) {
  // Start listening on a local port
  WriteCallbackBase writeCallback;
  ReadCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  std::unique_ptr<SSLContext> serverSslContext =
      TestSSLServer::getDefaultSSLContext();
  // With TLS 1.3, OpenSSL will send session tickets. These will arrive after
  // the handshake has completed and can interfere with the BlockingSocket
  // client's reads. Disable the session tickets so that doesn't happen.
  SSL_CTX_set_num_tickets(serverSslContext->getSSLCtx(), 0);
  TestSSLServer server(&acceptCallback, std::move(serverSslContext));

  // Set up SSL context.
  std::shared_ptr<SSLContext> sslContext(new SSLContext());
  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
  // sslContext->authenticate(true, false);

  // connect
  auto socket =
      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
  socket->open(std::chrono::milliseconds(10000));

  // write()
  uint8_t buf[128];
  memset(buf, 'a', sizeof(buf));
  socket->write(buf, sizeof(buf));

  // read()
  uint8_t readbuf[128];
  // we will fake the read len but that should be fine
  size_t readLen = 1LL << 33;
  uint32_t bytesRead = socket->read(readbuf, readLen);
  EXPECT_EQ(bytesRead, 128);
  EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);

  // close()
  socket->close();

  cerr << "ConnectWriteReadClose test completed" << endl;
  EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
}

/**
 * Test reading after server close.
 */
TEST(AsyncSSLSocketTest, ReadAfterClose) {
  // Start listening on a local port
  WriteCallbackBase writeCallback;
  ReadEOFCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  auto server = std::make_unique<TestSSLServer>(&acceptCallback);

  // Set up SSL context.
  auto sslContext = std::make_shared<SSLContext>();
  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

  auto socket =
      std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
  socket->open();

  // This should trigger an EOF on the client.
  auto evb = handshakeCallback.getSocket()->getEventBase();
  evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
  std::array<uint8_t, 128> readbuf;
  auto bytesRead = socket->read(readbuf.data(), readbuf.size());
  EXPECT_EQ(0, bytesRead);
}

/**
 * Test bad renegotiation
 */
#if !defined(OPENSSL_IS_BORINGSSL)
TEST(AsyncSSLSocketTest, Renegotiate) {
  EventBase eventBase;
  // renegotation in TLS 1.2 only
  auto clientCtx =
      std::make_shared<SSLContext>(SSLContext::SSLVersion::TLSv1_2);
  clientCtx->disableTLS13();
  auto dfServerCtx =
      std::make_shared<SSLContext>(SSLContext::SSLVersion::TLSv1_2);
  dfServerCtx->disableTLS13();
  std::array<NetworkSocket, 2> fds;
  getfds(fds.data());
  getctx(clientCtx, dfServerCtx);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
  SSLHandshakeClient client(std::move(clientSock), true, true);
  RenegotiatingServer server(std::move(serverSock));

  while (!client.handshakeSuccess_ && !client.handshakeError_) {
    eventBase.loopOnce();
  }

  ASSERT_TRUE(client.handshakeSuccess_);

  auto sslSock = std::move(client).moveSocket();
  sslSock->detachEventBase();
  // This is nasty, however we don't want to add support for
  // renegotiation in AsyncSSLSocket.
  SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));

  auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));

  std::thread t([&]() { eventBase.loopForever(); });

  // Trigger the renegotiation.
  std::array<uint8_t, 128> buf;
  memset(buf.data(), 'a', buf.size());
  try {
    socket->write(buf.data(), buf.size());
  } catch (AsyncSocketException& e) {
    LOG(INFO) << "client got error " << e.what();
  }
  eventBase.terminateLoopSoon();
  t.join();

  eventBase.loop();
  ASSERT_TRUE(server.renegotiationError_);
}
#endif

/**
 * Negative test for handshakeError().
 */
TEST(AsyncSSLSocketTest, HandshakeError) {
  // Start listening on a local port
  WriteCallbackBase writeCallback;
  WriteErrorCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  HandshakeErrorCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  // Set up SSL context.
  std::shared_ptr<SSLContext> sslContext(new SSLContext());
  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

  // connect
  auto socket =
      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
  // read()
  bool ex = false;
  try {
    socket->open();

    uint8_t readbuf[128];
    uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
    LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
  } catch (AsyncSocketException&) {
    ex = true;
  }
  EXPECT_TRUE(ex);

  // close()
  socket->close();
  cerr << "HandshakeError test completed" << endl;
}

/**
 * Negative test for readError().
 */
TEST(AsyncSSLSocketTest, ReadError) {
  // Start listening on a local port
  WriteCallbackBase writeCallback;
  ReadErrorCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  // Set up SSL context.
  std::shared_ptr<SSLContext> sslContext(new SSLContext());
  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

  // connect
  auto socket =
      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
  socket->open();

  // write something to trigger ssl handshake
  uint8_t buf[128];
  memset(buf, 'a', sizeof(buf));
  socket->write(buf, sizeof(buf));

  socket->close();
  cerr << "ReadError test completed" << endl;
}

/**
 * Negative test for writeError().
 */
TEST(AsyncSSLSocketTest, WriteError) {
  // Start listening on a local port
  WriteCallbackBase writeCallback;
  WriteErrorCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  // Set up SSL context.
  std::shared_ptr<SSLContext> sslContext(new SSLContext());
  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

  // connect
  auto socket =
      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
  socket->open();

  // write something to trigger ssl handshake
  uint8_t buf[128];
  memset(buf, 'a', sizeof(buf));
  socket->write(buf, sizeof(buf));

  socket->close();
  cerr << "WriteError test completed" << endl;
}

/**
 * Test a socket with TCP_NODELAY unset.
 */
TEST(AsyncSSLSocketTest, SocketWithDelay) {
  // Start listening on a local port
  WriteCallbackBase writeCallback;
  ReadCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  // Set up SSL context.
  std::shared_ptr<SSLContext> sslContext(new SSLContext());
  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

  // connect
  auto socket =
      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
  socket->open();

  // write()
  uint8_t buf[128];
  memset(buf, 'a', sizeof(buf));
  socket->write(buf, sizeof(buf));

  // read()
  uint8_t readbuf[128];
  uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
  EXPECT_EQ(bytesRead, 128);
  EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);

  // close()
  socket->close();

  cerr << "SocketWithDelay test completed" << endl;
}

class NextProtocolTest : public Test {
  // For matching protos
 public:
  void SetUp() override { getctx(clientCtx, serverCtx); }

  void connect(bool unset = false) {
    getfds(fds);

    if (unset) {
      // unsetting NPN for any of [client, server] is enough to make NPN not
      // work
      clientCtx->unsetNextProtocols();
    }

    AsyncSSLSocket::UniquePtr clientSock(
        new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
    AsyncSSLSocket::UniquePtr serverSock(
        new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
    client = std::make_unique<AlpnClient>(std::move(clientSock));
    server = std::make_unique<AlpnServer>(std::move(serverSock));

    eventBase.loop();
  }

  void expectProtocol(const std::string& proto) {
    expectHandshakeSuccess();
    EXPECT_NE(client->nextProtoLength, 0);
    EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
    EXPECT_EQ(
        memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
        0);
    string selected((const char*)client->nextProto, client->nextProtoLength);
    EXPECT_EQ(proto, selected);
  }

  void expectNoProtocol() {
    expectHandshakeSuccess();
    EXPECT_EQ(client->nextProtoLength, 0);
    EXPECT_EQ(server->nextProtoLength, 0);
    EXPECT_EQ(client->nextProto, nullptr);
    EXPECT_EQ(server->nextProto, nullptr);
  }

  void expectHandshakeSuccess() {
    EXPECT_FALSE(client->except.has_value())
        << "client handshake error: " << client->except->what();
    EXPECT_FALSE(server->except.has_value())
        << "server handshake error: " << server->except->what();
  }

  void expectHandshakeError() {
    EXPECT_TRUE(client->except.has_value())
        << "Expected client handshake error!";
    EXPECT_TRUE(server->except.has_value())
        << "Expected server handshake error!";
  }

  EventBase eventBase;
  std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
  std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
  NetworkSocket fds[2];
  std::unique_ptr<AlpnClient> client;
  std::unique_ptr<AlpnServer> server;
};

TEST_F(NextProtocolTest, AlpnTestOverlap) {
  clientCtx->setAdvertisedNextProtocols({"blub", "baz"});
  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});

  connect();

  expectProtocol("baz");
}

TEST_F(NextProtocolTest, AlpnTestUnset) {
  // Identical to above test, except that we want unset NPN before
  // looping.
  clientCtx->setAdvertisedNextProtocols({"blub", "baz"});
  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});

  connect(true /* unset */);

  expectNoProtocol();
}

TEST_F(NextProtocolTest, AlpnTestNoOverlap) {
  clientCtx->setAdvertisedNextProtocols({"blub"});
  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
  connect();

  expectNoProtocol();
}

TEST_F(NextProtocolTest, RandomizedAlpnTest) {
  // Probability that this test will fail is 2^-64, which could be considered
  // as negligible.
  const int kTries = 64;

  clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
  serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}});

  std::set<string> selectedProtocols;
  for (int i = 0; i < kTries; ++i) {
    connect();

    EXPECT_NE(client->nextProtoLength, 0);
    EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
    EXPECT_EQ(
        memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
        0);
    string selected((const char*)client->nextProto, client->nextProtoLength);
    selectedProtocols.insert(selected);
    expectHandshakeSuccess();
  }
  EXPECT_EQ(selectedProtocols.size(), 2);
}

TEST_F(NextProtocolTest, AlpnNotAllowMismatchNoClientProtocol) {
  clientCtx->setAdvertisedNextProtocols({});
  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
  serverCtx->setAlpnAllowMismatch(false);

  connect();

  expectHandshakeSuccess();
  expectNoProtocol();
  EXPECT_EQ(server->getClientAlpns(), std::vector<std::string>({}));
}

TEST_F(NextProtocolTest, AlpnNotAllowMismatchWithOverlap) {
  clientCtx->setAdvertisedNextProtocols({"blub", "baz"});
  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
  serverCtx->setAlpnAllowMismatch(false);

  connect();

  expectProtocol("baz");
  EXPECT_EQ(
      server->getClientAlpns(), std::vector<std::string>({"blub", "baz"}));
}

TEST_F(NextProtocolTest, AlpnNotAllowMismatchWithoutOverlap) {
  clientCtx->setAdvertisedNextProtocols({"blub"});
  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
  serverCtx->setAlpnAllowMismatch(false);

  connect();

  expectHandshakeError();
  EXPECT_EQ(server->getClientAlpns(), std::vector<std::string>({"blub"}));
}

/**
 * 1. Client sends TLSEXT_HOSTNAME in client hello.
 * 2. Server found a match SSL_CTX and use this SSL_CTX to
 *    continue the SSL handshake.
 * 3. Server sends back TLSEXT_HOSTNAME in server hello.
 */
TEST(AsyncSSLSocketTest, SNITestMatch) {
  EventBase eventBase;
  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
  std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
  // Use the same SSLContext to continue the handshake after
  // tlsext_hostname match.
  std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
  const std::string serverName("xyz.newdev.facebook.com");
  NetworkSocket fds[2];
  getfds(fds);
  getctx(clientCtx, dfServerCtx);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
  SNIClient client(std::move(clientSock));
  SNIServer server(
      std::move(serverSock), dfServerCtx, hskServerCtx, serverName);

  eventBase.loop();

  EXPECT_TRUE(client.serverNameMatch);
  EXPECT_TRUE(server.serverNameMatch);
}

/**
 * 1. Client sends TLSEXT_HOSTNAME in client hello.
 * 2. Server cannot find a matching SSL_CTX and continue to use
 *    the current SSL_CTX to do the handshake.
 * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
 */
TEST(AsyncSSLSocketTest, SNITestNotMatch) {
  EventBase eventBase;
  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
  std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
  // Use the same SSLContext to continue the handshake after
  // tlsext_hostname match.
  std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
  const std::string clientRequestingServerName("foo.com");
  const std::string serverExpectedServerName("xyz.newdev.facebook.com");

  NetworkSocket fds[2];
  getfds(fds);
  getctx(clientCtx, dfServerCtx);

  AsyncSSLSocket::UniquePtr clientSock(new AsyncSSLSocket(
      clientCtx, &eventBase, fds[0], clientRequestingServerName));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
  SNIClient client(std::move(clientSock));
  SNIServer server(
      std::move(serverSock),
      dfServerCtx,
      hskServerCtx,
      serverExpectedServerName);

  eventBase.loop();

  EXPECT_TRUE(!client.serverNameMatch);
  EXPECT_TRUE(!server.serverNameMatch);
}
/**
 * 1. Client sends TLSEXT_HOSTNAME in client hello.
 * 2. We then change the serverName.
 * 3. We expect that we get 'false' as the result for serNameMatch.
 */

TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
  EventBase eventBase;
  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
  std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
  // Use the same SSLContext to continue the handshake after
  // tlsext_hostname match.
  std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
  const std::string serverName("xyz.newdev.facebook.com");
  NetworkSocket fds[2];
  getfds(fds);
  getctx(clientCtx, dfServerCtx);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
  // Change the server name
  std::string newName("new.com");
  clientSock->setServerName(newName);
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
  SNIClient client(std::move(clientSock));
  SNIServer server(
      std::move(serverSock), dfServerCtx, hskServerCtx, serverName);

  eventBase.loop();

  EXPECT_TRUE(!client.serverNameMatch);
}

/**
 * 1. Client does not send TLSEXT_HOSTNAME in client hello.
 * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
 */
TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
  EventBase eventBase;
  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
  std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
  // Use the same SSLContext to continue the handshake after
  // tlsext_hostname match.
  std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
  const std::string serverExpectedServerName("xyz.newdev.facebook.com");

  NetworkSocket fds[2];
  getfds(fds);
  getctx(clientCtx, dfServerCtx);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
  SNIClient client(std::move(clientSock));
  SNIServer server(
      std::move(serverSock),
      dfServerCtx,
      hskServerCtx,
      serverExpectedServerName);

  eventBase.loop();

  EXPECT_TRUE(!client.serverNameMatch);
  EXPECT_TRUE(!server.serverNameMatch);
}

/**
 * 1. Create an SSLContext that does not have an ALPN
 * 2. Use AsyncSSLSocket::setSupportedApplicationProtocols on the client and
 * server, and assert that a common ALPN was negotiated.
 */
TEST(AsyncSSLSocketTest, SetSupportedApplicationProtocols) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto dfServerCtx = std::make_shared<SSLContext>();
  // Use the same SSLContext to continue the handshake after
  // tlsext_hostname match.
  auto hskServerCtx = std::make_shared<SSLContext>();
  const std::string serverExpectedServerName("xyz.newdev.facebook.com");

  NetworkSocket fds[2];
  getfds(fds);
  getctx(clientCtx, dfServerCtx);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));

  std::vector<std::string> protocols;
  protocols.push_back("rs");

  clientSock->setSupportedApplicationProtocols(protocols);
  serverSock->setSupportedApplicationProtocols(protocols);

  SNIClient client(std::move(clientSock));
  SNIServer server(
      std::move(serverSock),
      dfServerCtx,
      hskServerCtx,
      serverExpectedServerName);

  eventBase.loop();

  EXPECT_TRUE(
      client.getApplicationProtocol().compare(
          server.getApplicationProtocol()) == 0);
}

/**
 * Test SSL client socket
 */
TEST(AsyncSSLSocketTest, SSLClientTest) {
  // Start listening on a local port
  WriteCallbackBase writeCallback;
  ReadCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  // Set up SSL client
  EventBase eventBase;
  auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
  client->setSSLOptions(SSL_OP_NO_TICKET);

  client->connect();
  EventBaseAborter eba(&eventBase, 3000);
  eventBase.loop();

  EXPECT_EQ(client->getMiss(), 1);
  EXPECT_EQ(client->getHit(), 0);

  cerr << "SSLClientTest test completed" << endl;
}

/**
 * Test SSL client socket session re-use
 */
TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
  // Start listening on a local port
  WriteCallbackBase writeCallback;
  ReadCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  // Set up SSL client
  EventBase eventBase;
  auto client =
      std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
  client->setSSLOptions(SSL_OP_NO_TICKET);

  client->connect();
  EventBaseAborter eba(&eventBase, 3000);
  eventBase.loop();

  EXPECT_EQ(client->getMiss(), 1);
  EXPECT_EQ(client->getHit(), 9);

  cerr << "SSLClientTestReuse test completed" << endl;
}

/**
 * Test SSL client socket timeout
 */
TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
  // Start listening on a local port
  EmptyReadCallback readCallback;
  HandshakeCallback handshakeCallback(
      &readCallback, HandshakeCallback::EXPECT_ERROR);
  HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  // Set up SSL client
  EventBase eventBase;
  auto client =
      std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
  client->setSSLOptions(SSL_OP_NO_TICKET);

  client->connect(true /* write before connect completes */);
  EventBaseAborter eba(&eventBase, 3000);
  eventBase.loop();

  usleep(100000);
  // This is checking that the connectError callback precedes any queued
  // writeError callbacks.  This matches AsyncSocket's behavior
  EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
  EXPECT_EQ(client->getErrors(), 1);
  EXPECT_EQ(client->getMiss(), 0);
  EXPECT_EQ(client->getHit(), 0);

  cerr << "SSLClientTimeoutTest test completed" << endl;
}

class PerLoopReadCallback : public AsyncTransport::ReadCallback {
 public:
  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
    *bufReturn = buf_.data();
    *lenReturn = buf_.size();
  }

  void readDataAvailable(size_t len) noexcept override {
    VLOG(3) << "Read of size: " << len;
    s_->setReadCB(nullptr);
    s_->getEventBase()->runInLoop([this]() { s_->setReadCB(this); });
  }

  void readErr(const AsyncSocketException&) noexcept override {}

  void readEOF() noexcept override {}

  void setSocket(AsyncSocket* s) { s_ = s; }

 private:
  AsyncSocket* s_;
  std::array<uint8_t, 1000> buf_;
};

class CloseNotifyConnector : public AsyncSocket::ConnectCallback {
 public:
  CloseNotifyConnector(EventBase* evb, const SocketAddress& addr) {
    evb_ = evb;
    ssl_ = AsyncSSLSocket::newSocket(std::make_shared<SSLContext>(), evb_);
    ssl_->connect(this, addr);
  }

  void connectSuccess() noexcept override {
    ssl_->writeChain(nullptr, IOBuf::copyBuffer("hi"));
    auto ssl = const_cast<SSL*>(ssl_->getSSL());
    SSL_shutdown(ssl);
    auto fd = ssl_->detachNetworkSocket();
    tcp_.reset(new AsyncSocket(evb_, fd), AsyncSocket::Destructor());
    evb_->runAfterDelay(
        [this]() {
          perLoopReads_.setSocket(tcp_.get());
          tcp_->setReadCB(&perLoopReads_);
          evb_->runAfterDelay([this]() { tcp_->closeNow(); }, 10);
        },
        100);
  }

  void connectErr(const AsyncSocketException& ex) noexcept override {
    FAIL() << ex.what();
  }

 private:
  EventBase* evb_;
  std::shared_ptr<AsyncSSLSocket> ssl_;
  std::shared_ptr<AsyncSocket> tcp_;
  PerLoopReadCallback perLoopReads_;
};

class ErrorCheckingWriteCallback : public AsyncSocket::WriteCallback {
 public:
  void writeSuccess() noexcept override {}

  void writeErr(size_t, const AsyncSocketException& ex) noexcept override {
    LOG(ERROR) << "write error: " << ex.what();
    EXPECT_NE(
        ex.getType(),
        AsyncSocketException::AsyncSocketExceptionType::SSL_ERROR);
  }
};

class WriteOnEofReadCallback : public ReadCallback {
 public:
  using ReadCallback::ReadCallback;

  void readEOF() noexcept override {
    LOG(INFO) << "Got EOF";
    auto chain = IOBuf::create(0);
    for (size_t i = 0; i < 1000 * 1000; i++) {
      auto buf = IOBuf::create(10);
      buf->append(10);
      memset(buf->writableData(), 'x', 10);
      chain->prependChain(std::move(buf));
    }
    socket_->writeChain(&writeCallback_, std::move(chain));
  }

  void readErr(const AsyncSocketException& ex) noexcept override {
    LOG(ERROR) << ex.what();
  }

 private:
  ErrorCheckingWriteCallback writeCallback_;
};

TEST(AsyncSSLSocketTest, EarlyCloseNotify) {
  WriteOnEofReadCallback readCallback(nullptr);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  EventBase eventBase;
  CloseNotifyConnector cnc(&eventBase, server.getAddress());

  eventBase.loop();
}

/**
 * Verify Client Ciphers obtained using SSL MSG Callback.
 */
TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
  EventBase eventBase;
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->setVerificationOption(SSLContext::VerifyClientCertificate::ALWAYS);
  serverCtx->setCiphersuitesOrThrow(
      "TLS_AES_128_GCM_SHA256:TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256");
  serverCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());
  serverCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());
  serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(
      find_resource(kTestCA).c_str());

  auto clientCtx = std::make_shared<SSLContext>();
  clientCtx->setVerificationOption(
      SSLContext::VerifyServerCertificate::IF_PRESENTED);
  clientCtx->setCiphersuitesOrThrow(
      "TLS_CHACHA20_POLY1305_SHA256:TLS_AES_256_GCM_SHA384");
  // With SSLContext min version set to TLS 1.2, TLS 1.2 ciphers are appended to
  // the list that's sent to the server, and those would appear in the
  // clientCiphers_ captured and verified below. Remove all of them by setting
  // eNULL.
  clientCtx->setCiphersOrThrow("eNULL");
  clientCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  clientCtx->loadCertificate(find_resource(kTestCert).c_str());
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  NetworkSocket fds[2];
  getfds(fds);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  SSLHandshakeClient client(std::move(clientSock), true, true);
  SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);

  eventBase.loop();

#if defined(OPENSSL_IS_BORINGSSL)
  EXPECT_EQ(
      server.clientCiphers_,
      "TLS_CHACHA20_POLY1305_SHA256:TLS_AES_256_GCM_SHA384");
#else
  EXPECT_EQ(
      server.clientCiphers_,
      "TLS_CHACHA20_POLY1305_SHA256:TLS_AES_256_GCM_SHA384:00ff");
#endif
  EXPECT_EQ(server.chosenCipher_, "TLS_CHACHA20_POLY1305_SHA256");
  EXPECT_TRUE(client.handshakeVerify_);
  EXPECT_TRUE(client.handshakeSuccess_);
  EXPECT_TRUE(!client.handshakeError_);
  EXPECT_TRUE(server.handshakeVerify_);
  EXPECT_TRUE(server.handshakeSuccess_);
  EXPECT_TRUE(!server.handshakeError_);
}

TEST(AsyncSSLSocketTest, SSLGetEKM) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();

  NetworkSocket fds[2];
  getfds(fds);
  getctx(clientCtx, serverCtx);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  EXPECT_EQ(
      nullptr, clientSock->getExportedKeyingMaterial("test", nullptr, 12));

  serverSock->sslAccept(nullptr, std::chrono::milliseconds::zero());
  clientSock->sslConn(nullptr, std::chrono::milliseconds::zero());
  eventBase.loop();
  auto ekm = clientSock->getExportedKeyingMaterial("test", nullptr, 32);

  EXPECT_NE(nullptr, ekm);
  EXPECT_EQ(ekm->computeChainDataLength(), 32);
  // non null context
  EXPECT_NE(
      nullptr,
      clientSock->getExportedKeyingMaterial(
          "test", IOBuf::copyBuffer("haha"), 32));
  // empty label
  EXPECT_NE(
      nullptr,
      clientSock->getExportedKeyingMaterial("", IOBuf::copyBuffer("haha"), 32));

  auto serverEkm = serverSock->getExportedKeyingMaterial("test", nullptr, 32);
  EXPECT_TRUE(folly::IOBufEqualTo{}(ekm, serverEkm));
}

TEST(AsyncSSLSocketTest, SSLGetEKMFailsOnTLS10) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  SSL_CTX_set_max_proto_version(serverCtx->getSSLCtx(), TLS1_VERSION);

  NetworkSocket fds[2];
  getfds(fds);
  getctx(clientCtx, serverCtx);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  serverSock->sslAccept(nullptr, std::chrono::milliseconds::zero());
  clientSock->sslConn(nullptr, std::chrono::milliseconds::zero());
  eventBase.loop();
  auto ekm = clientSock->getExportedKeyingMaterial("test", nullptr, 32);

  EXPECT_EQ(nullptr, ekm);
}

/**
 * Verify that server is able to get client cert by getPeerCert() API.
 */
TEST(AsyncSSLSocketTest, GetClientCertificate) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
  serverCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());
  serverCtx->loadTrustedCertificates(find_resource(kClientTestCA).c_str());
  serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(
      find_resource(kClientTestCA).c_str());

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  clientCtx->ciphers("AES256-SHA:AES128-SHA");
  clientCtx->loadPrivateKey(find_resource(kClientTestKey).c_str());
  clientCtx->loadCertificate(find_resource(kClientTestCert).c_str());
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  std::array<NetworkSocket, 2> fds;
  getfds(fds.data());

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  SSLHandshakeClient client(std::move(clientSock), true, true);
  SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);

  eventBase.loop();

  // Handshake should succeed.
  EXPECT_TRUE(client.handshakeSuccess_);
  EXPECT_TRUE(server.handshakeSuccess_);

  // Reclaim the sockets from SSLHandshakeBase.
  auto cliSocket = std::move(client).moveSocket();
  auto srvSocket = std::move(server).moveSocket();

  // Client cert retrieved from server side.
  auto serverPeerCert = srvSocket->getPeerCertificate();
  CHECK(serverPeerCert);

  // Client cert retrieved from client side.
  auto clientSelfCert = cliSocket->getSelfCertificate();
  CHECK(clientSelfCert);

  auto serverX509 =
      folly::OpenSSLTransportCertificate::tryExtractX509(serverPeerCert);
  CHECK(serverX509);

  auto clientX509 =
      folly::OpenSSLTransportCertificate::tryExtractX509(clientSelfCert);
  CHECK(clientX509);

  // The two certs should be the same.
  EXPECT_EQ(0, X509_cmp(clientX509.get(), serverX509.get()));
}

TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
  EventBase eventBase;
  auto ctx = std::make_shared<SSLContext>();

  NetworkSocket fds[2];
  getfds(fds);

  int bufLen = 42;
  uint8_t majorVersion = 18;
  uint8_t minorVersion = 25;

  // Create callback buf
  auto buf = IOBuf::create(bufLen);
  buf->append(bufLen);
  folly::io::RWPrivateCursor cursor(buf.get());
  cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
  cursor.write<uint16_t>(0);
  cursor.write<uint8_t>(38);
  cursor.write<uint8_t>(majorVersion);
  cursor.write<uint8_t>(minorVersion);
  cursor.skip(32);
  cursor.write<uint32_t>(0);

  SSL* ssl = ctx->createSSL();
  SCOPE_EXIT {
    SSL_free(ssl);
  };
  AsyncSSLSocket::UniquePtr sock(
      new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
  sock->enableClientHelloParsing();

  // Test client hello parsing in one packet
  AsyncSSLSocket::clientHelloParsingCallback(
      0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
  buf.reset();

  auto parsedClientHello = sock->getClientHelloInfo();
  EXPECT_TRUE(parsedClientHello != nullptr);
  EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
  EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
}

TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
  EventBase eventBase;
  auto ctx = std::make_shared<SSLContext>();

  NetworkSocket fds[2];
  getfds(fds);

  int bufLen = 42;
  uint8_t majorVersion = 18;
  uint8_t minorVersion = 25;

  // Create callback buf
  auto buf = IOBuf::create(bufLen);
  buf->append(bufLen);
  folly::io::RWPrivateCursor cursor(buf.get());
  cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
  cursor.write<uint16_t>(0);
  cursor.write<uint8_t>(38);
  cursor.write<uint8_t>(majorVersion);
  cursor.write<uint8_t>(minorVersion);
  cursor.skip(32);
  cursor.write<uint32_t>(0);

  SSL* ssl = ctx->createSSL();
  SCOPE_EXIT {
    SSL_free(ssl);
  };
  AsyncSSLSocket::UniquePtr sock(
      new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
  sock->enableClientHelloParsing();

  // Test parsing with two packets with first packet size < 3
  auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
  AsyncSSLSocket::clientHelloParsingCallback(
      0,
      0,
      SSL3_RT_HANDSHAKE,
      bufCopy->data(),
      bufCopy->length(),
      ssl,
      sock.get());
  bufCopy.reset();
  bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
  AsyncSSLSocket::clientHelloParsingCallback(
      0,
      0,
      SSL3_RT_HANDSHAKE,
      bufCopy->data(),
      bufCopy->length(),
      ssl,
      sock.get());
  bufCopy.reset();

  auto parsedClientHello = sock->getClientHelloInfo();
  EXPECT_TRUE(parsedClientHello != nullptr);
  EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
  EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
}

TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
  EventBase eventBase;
  auto ctx = std::make_shared<SSLContext>();
  ctx->setSupportedGroups(std::vector<std::string>({"P-256"}));

  NetworkSocket fds[2];
  getfds(fds);

  int bufLen = 42;
  uint8_t majorVersion = 18;
  uint8_t minorVersion = 25;

  // Create callback buf
  auto buf = IOBuf::create(bufLen);
  buf->append(bufLen);
  folly::io::RWPrivateCursor cursor(buf.get());
  cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
  cursor.write<uint16_t>(0);
  cursor.write<uint8_t>(38);
  cursor.write<uint8_t>(majorVersion);
  cursor.write<uint8_t>(minorVersion);
  cursor.skip(32);
  cursor.write<uint32_t>(0);

  SSL* ssl = ctx->createSSL();
  SCOPE_EXIT {
    SSL_free(ssl);
  };
  AsyncSSLSocket::UniquePtr sock(
      new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
  sock->enableClientHelloParsing();

  // Test parsing with multiple small packets
  for (std::size_t i = 0; i < buf->length(); i += 3) {
    auto bufCopy = folly::IOBuf::copyBuffer(
        buf->data() + i, std::min((std::size_t)3, buf->length() - i));
    AsyncSSLSocket::clientHelloParsingCallback(
        0,
        0,
        SSL3_RT_HANDSHAKE,
        bufCopy->data(),
        bufCopy->length(),
        ssl,
        sock.get());
    bufCopy.reset();
  }

  auto parsedClientHello = sock->getClientHelloInfo();
  EXPECT_TRUE(parsedClientHello != nullptr);
  EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
  EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
}

/**
 * Verify sucessful behavior of SSL certificate validation.
 */
TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto dfServerCtx = std::make_shared<SSLContext>();

  NetworkSocket fds[2];
  getfds(fds);
  getctx(clientCtx, dfServerCtx);

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));

  SSLHandshakeClient client(std::move(clientSock), true, true);
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  SSLHandshakeServer server(std::move(serverSock), true, true);

  eventBase.loop();

  EXPECT_TRUE(client.handshakeVerify_);
  EXPECT_TRUE(client.handshakeSuccess_);
  EXPECT_TRUE(!client.handshakeError_);
  EXPECT_LE(0, client.handshakeTime.count());
  EXPECT_TRUE(!server.handshakeVerify_);
  EXPECT_TRUE(server.handshakeSuccess_);
  EXPECT_TRUE(!server.handshakeError_);
  EXPECT_LE(0, server.handshakeTime.count());
}

/**
 * Verify that the client's verification callback is able to fail SSL
 * connection establishment.
 */
TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto dfServerCtx = std::make_shared<SSLContext>();

  NetworkSocket fds[2];
  getfds(fds);
  getctx(clientCtx, dfServerCtx);

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));

  SSLHandshakeClient client(std::move(clientSock), true, false);
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  SSLHandshakeServer server(std::move(serverSock), true, true);

  eventBase.loop();

  EXPECT_TRUE(client.handshakeVerify_);
  EXPECT_TRUE(!client.handshakeSuccess_);
  EXPECT_TRUE(client.handshakeError_);
  EXPECT_LE(0, client.handshakeTime.count());
  EXPECT_TRUE(!server.handshakeVerify_);
  EXPECT_TRUE(!server.handshakeSuccess_);
  EXPECT_TRUE(server.handshakeError_);
  EXPECT_LE(0, server.handshakeTime.count());
}

/**
 * Verify that the client successfully handshakes when
 * CertificateIdentityVerifier is set and returns with no exception.
 */
TEST(AsyncSSLSocketTest, SSLCertificateIdentityVerifierReturns) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<folly::SSLContext>();
  auto serverCtx = std::make_shared<folly::SSLContext>();
  getctx(clientCtx, serverCtx);
  // the client socket will default to USE_CTX, so set VERIFY here
  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  // load root certificate
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  // prepare a basic server (callbacks have a few EXPECTS to fullfil)
  ReadCallback readCallback(nullptr);
  // expects successful handshake
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback, serverCtx);

  std::shared_ptr<MockCertificateIdentityVerifier> verifier =
      std::make_shared<MockCertificateIdentityVerifier>();

  // expecting to only verify once, with the leaf certificate
  // (kTestCert)
  EXPECT_CALL(
      *verifier,
      verifyLeaf(Property(
          &AsyncTransportCertificate::getIdentity, StrEq("Asox Company"))))
      .WillOnce(
          Return(ByMove(std::make_unique<folly::ssl::BasicTransportCertificate>(
              "Asox Company", readCertFromFile(kTestCert)))));

  AsyncSSLSocket::Options opts;
  opts.verifier = std::move(verifier);

  // connect to server and handshake
  AsyncSSLSocket::UniquePtr socket(
      new AsyncSSLSocket(clientCtx, &eventBase, std::move(opts)));
  socket->connect(nullptr, server.getAddress(), 0);

  // write to satisfy server ReadCallback EXPECTs
  std::array<uint8_t, 128> buf;
  memset(buf.data(), 'a', buf.size());
  socket->write(nullptr, buf.data(), buf.size());

  eventBase.loop();

  socket->close();
}

/**
 * Verify that the client fails to connect during handshake because
 * CertificateIdentityVerifier returns a failure while verifying the server's
 * leaf certificate.
 */
TEST(AsyncSSLSocketTest, SSLCertificateIdentityVerifierFailsToConnect) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<folly::SSLContext>();
  auto serverCtx = std::make_shared<folly::SSLContext>();
  getctx(clientCtx, serverCtx);
  // the client socket will default to USE_CTX, so set VERIFY here
  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  // load root certificate
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  // prepare a basic server (callbacks have a few EXPECTS to fullfil)
  ReadCallback readCallback(nullptr);
  // expects a failed handshake
  HandshakeCallback handshakeCallback(
      &readCallback, HandshakeCallback::ExpectType::EXPECT_ERROR);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback, serverCtx);

  std::shared_ptr<MockCertificateIdentityVerifier> verifier =
      std::make_shared<MockCertificateIdentityVerifier>();

  // Throw an exception on verification failure
  CertificateIdentityVerifierException failed{"a failed test reason"};

  // expecting to only verify once, with the leaf certificate (kTestCert)
  EXPECT_CALL(
      *verifier,
      verifyLeaf(Property(
          &AsyncTransportCertificate::getIdentity, StrEq("Asox Company"))))
      .WillOnce(Throw(failed));

  AsyncSSLSocket::Options opts;
  opts.verifier = std::move(verifier);

  // connect to server and handshake
  AsyncSSLSocket::UniquePtr socket(
      new AsyncSSLSocket(clientCtx, &eventBase, std::move(opts)));
  socket->connect(nullptr, server.getAddress(), 0);

  eventBase.loop();

  socket->close();
}

/**
 * Verify that the client's CertificateIdentityVerifier is not invoked if
 * OpenSSL's verification fails. (With no HandshakeCB.)
 */
TEST(AsyncSSLSocketTest, SSLCertificateIdentityVerifierNotInvokedX509Failure) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<folly::SSLContext>();
  auto serverCtx = std::make_shared<folly::SSLContext>();
  getctx(clientCtx, serverCtx);
  // the client socket will default to USE_CTX, so set VERIFY here
  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  // DO NOT load root certificate, so that server certificate is rejected

  // prepare a basic server (callbacks have a few EXPECTS to fullfil)
  ReadCallback readCallback(nullptr);
  // expects successful handshake
  HandshakeCallback handshakeCallback(
      &readCallback, HandshakeCallback::ExpectType::EXPECT_ERROR);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback, serverCtx);

  // should not get called
  std::shared_ptr<StrictMock<MockCertificateIdentityVerifier>> verifier =
      std::make_shared<StrictMock<MockCertificateIdentityVerifier>>();

  AsyncSSLSocket::Options opts;
  opts.verifier = std::move(verifier);

  // connect to server and handshake
  AsyncSSLSocket::UniquePtr socket(
      new AsyncSSLSocket(clientCtx, &eventBase, std::move(opts)));
  socket->connect(nullptr, server.getAddress(), 0);

  eventBase.loop();

  socket->close();
}

/**
 * Verify that the client CertificateIdentityVerifier is not invoked if
 * HandshakeCB::handshakeVer verification fails.
 */
TEST(
    AsyncSSLSocketTest,
    SSLCertificateIdentityVerifierNotInvokedHandshakeCBFailure) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<folly::SSLContext>();
  auto serverCtx = std::make_shared<folly::SSLContext>();
  getctx(clientCtx, serverCtx);
  // the client socket will default to USE_CTX, so set VERIFY here
  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  // load root certificate
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  NetworkSocket fds[2];
  getfds(fds);

  AsyncSocket::UniquePtr rawClient(new AsyncSocket(&eventBase, fds[0]));
  AsyncSocket::UniquePtr rawServer(new AsyncSocket(&eventBase, fds[1]));

  // should not be invoked
  std::shared_ptr<StrictMock<MockCertificateIdentityVerifier>> verifier =
      std::make_shared<StrictMock<MockCertificateIdentityVerifier>>();

  AsyncSSLSocket::Options clientOpts;
  clientOpts.verifier = verifier;

  AsyncSSLSocket::Options serverOpts;
  serverOpts.isServer = true;

  AsyncSSLSocket::UniquePtr clientSock(new AsyncSSLSocket(
      clientCtx, std::move(rawClient), std::move(clientOpts)));
  AsyncSSLSocket::UniquePtr serverSock(new AsyncSSLSocket(
      serverCtx, std::move(rawServer), std::move(serverOpts)));

  serverSock->sslAccept(nullptr, std::chrono::milliseconds::zero());

  StrictMock<MockHandshakeCB> clientHandshakeCB;

  // Force the end entity certificate, which normally is successfully verified,
  // to be considered as unsuccessful
  EXPECT_CALL(clientHandshakeCB, handshakeVerImpl(clientSock.get(), true, _))
      .Times(AtLeast(1))
      .WillRepeatedly(Invoke([&](auto&&, bool preverifyOk, auto&& ctx) {
        auto currentDepth = X509_STORE_CTX_get_error_depth(ctx);
        if (currentDepth == 0) {
          EXPECT_TRUE(preverifyOk);
          return false;
        }
        return preverifyOk;
      }));

  // failure callback to verify handshake failed
  EXPECT_CALL(clientHandshakeCB, handshakeErrImpl(clientSock.get(), _));

  clientSock->sslConn(&clientHandshakeCB);

  eventBase.loop();

  clientSock->close();
  serverSock->close();
}

/**
 * Verify that the client CertificateIdentityVerifier is invoked on a server
 * socket when peer verification is requested.
 */
TEST(AsyncSSLSocketTest, SSLCertificateIdentityVerifierSucceedsOnServer) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<folly::SSLContext>();
  auto serverCtx = std::make_shared<folly::SSLContext>();
  getctx(clientCtx, serverCtx);
  // the client socket will default to USE_CTX, so set VERIFY here
  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  // load root certificate
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());
  // load identity and key on client, it's the same identity as server just for
  // convenience
  clientCtx->loadCertificate(find_resource(kTestCert).c_str());
  clientCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  // instruct server to verify client
  serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  serverCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  NetworkSocket fds[2];
  getfds(fds);

  AsyncSocket::UniquePtr rawClient(new AsyncSocket(&eventBase, fds[0]));
  AsyncSocket::UniquePtr rawServer(new AsyncSocket(&eventBase, fds[1]));

  // client and server verifiers should verify only once each
  std::shared_ptr<MockCertificateIdentityVerifier> clientVerifier =
      std::make_shared<MockCertificateIdentityVerifier>();
  EXPECT_CALL(
      *clientVerifier,
      verifyLeaf(Property(
          &AsyncTransportCertificate::getIdentity, StrEq("Asox Company"))))
      .WillOnce(
          Return(ByMove(std::make_unique<folly::ssl::BasicTransportCertificate>(
              "Asox Company", readCertFromFile(kTestCert)))));
  std::shared_ptr<StrictMock<MockCertificateIdentityVerifier>> serverVerifier =
      std::make_shared<StrictMock<MockCertificateIdentityVerifier>>();
  EXPECT_CALL(
      *serverVerifier,
      verifyLeaf(Property(
          &AsyncTransportCertificate::getIdentity, StrEq("Asox Company"))))
      .WillOnce(
          Return(ByMove(std::make_unique<folly::ssl::BasicTransportCertificate>(
              "Asox Company", readCertFromFile(kTestCert)))));

  AsyncSSLSocket::Options clientOpts;
  clientOpts.verifier = clientVerifier;

  AsyncSSLSocket::Options serverOpts;
  serverOpts.isServer = true;
  serverOpts.verifier = serverVerifier;

  AsyncSSLSocket::UniquePtr clientSock(new AsyncSSLSocket(
      clientCtx, std::move(rawClient), std::move(clientOpts)));
  AsyncSSLSocket::UniquePtr serverSock(new AsyncSSLSocket(
      serverCtx, std::move(rawServer), std::move(serverOpts)));

  // no HandshakeCBs anywhere
  serverSock->sslAccept(nullptr, std::chrono::milliseconds::zero());
  clientSock->sslConn(nullptr);

  eventBase.loop();

  clientSock->close();
  serverSock->close();
}

/**
 * Verify that the options in SSLContext can be overridden in
 * sslConnect/Accept.i.e specifying that no validation should be performed
 * allows an otherwise-invalid certificate to be accepted and doesn't fire
 * the validation callback.
 */
TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto dfServerCtx = std::make_shared<SSLContext>();

  NetworkSocket fds[2];
  getfds(fds);
  getctx(clientCtx, dfServerCtx);

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));

  SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);

  eventBase.loop();

  EXPECT_TRUE(!client.handshakeVerify_);
  EXPECT_TRUE(client.handshakeSuccess_);
  EXPECT_TRUE(!client.handshakeError_);
  EXPECT_LE(0, client.handshakeTime.count());
  EXPECT_TRUE(!server.handshakeVerify_);
  EXPECT_TRUE(server.handshakeSuccess_);
  EXPECT_TRUE(!server.handshakeError_);
  EXPECT_LE(0, server.handshakeTime.count());
}

/**
 * Verify that the options in SSLContext can be overridden in
 * sslConnect/Accept. Enable verification even if context says otherwise.
 * Test requireClientCert with client cert
 */
TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  serverCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());
  serverCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());
  serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(
      find_resource(kTestCA).c_str());

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  clientCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  clientCtx->loadCertificate(find_resource(kTestCert).c_str());
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  NetworkSocket fds[2];
  getfds(fds);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
  SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);

  eventBase.loop();

  EXPECT_TRUE(client.handshakeVerify_);
  EXPECT_TRUE(client.handshakeSuccess_);
  EXPECT_FALSE(client.handshakeError_);
  EXPECT_LE(0, client.handshakeTime.count());
  EXPECT_TRUE(server.handshakeVerify_);
  EXPECT_TRUE(server.handshakeSuccess_);
  EXPECT_FALSE(server.handshakeError_);
  EXPECT_LE(0, server.handshakeTime.count());
}

/**
 * Verify that the client's verification callback is able to override
 * the preverification failure and allow a successful connection.
 */
TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto dfServerCtx = std::make_shared<SSLContext>();

  NetworkSocket fds[2];
  getfds(fds);
  getctx(clientCtx, dfServerCtx);

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));

  SSLHandshakeClient client(std::move(clientSock), false, true);
  SSLHandshakeServer server(std::move(serverSock), true, true);

  eventBase.loop();

  EXPECT_TRUE(client.handshakeVerify_);
  EXPECT_TRUE(client.handshakeSuccess_);
  EXPECT_TRUE(!client.handshakeError_);
  EXPECT_LE(0, client.handshakeTime.count());
  EXPECT_TRUE(!server.handshakeVerify_);
  EXPECT_TRUE(server.handshakeSuccess_);
  EXPECT_TRUE(!server.handshakeError_);
  EXPECT_LE(0, server.handshakeTime.count());
}

/**
 * Verify that specifying that no validation should be performed allows an
 * otherwise-invalid certificate to be accepted and doesn't fire the validation
 * callback.
 */
TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto dfServerCtx = std::make_shared<SSLContext>();

  NetworkSocket fds[2];
  getfds(fds);
  getctx(clientCtx, dfServerCtx);

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
  dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));

  SSLHandshakeClient client(std::move(clientSock), false, false);
  SSLHandshakeServer server(std::move(serverSock), false, false);

  eventBase.loop();

  EXPECT_TRUE(!client.handshakeVerify_);
  EXPECT_TRUE(client.handshakeSuccess_);
  EXPECT_TRUE(!client.handshakeError_);
  EXPECT_LE(0, client.handshakeTime.count());
  EXPECT_TRUE(!server.handshakeVerify_);
  EXPECT_TRUE(server.handshakeSuccess_);
  EXPECT_TRUE(!server.handshakeError_);
  EXPECT_LE(0, server.handshakeTime.count());
}

/**
 * Test requireClientCert with client cert
 */
TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->setVerificationOption(
      SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  serverCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());
  serverCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());
  serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(
      find_resource(kTestCA).c_str());

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  clientCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  clientCtx->loadCertificate(find_resource(kTestCert).c_str());
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  NetworkSocket fds[2];
  getfds(fds);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  SSLHandshakeClient client(std::move(clientSock), true, true);
  SSLHandshakeServer server(std::move(serverSock), true, true);

  eventBase.loop();

  EXPECT_TRUE(client.handshakeVerify_);
  EXPECT_TRUE(client.handshakeSuccess_);
  EXPECT_FALSE(client.handshakeError_);
  EXPECT_LE(0, client.handshakeTime.count());
  EXPECT_TRUE(server.handshakeVerify_);
  EXPECT_TRUE(server.handshakeSuccess_);
  EXPECT_FALSE(server.handshakeError_);
  EXPECT_LE(0, server.handshakeTime.count());

  // check certificates
  auto clientSsl = std::move(client).moveSocket();
  auto serverSsl = std::move(server).moveSocket();

  auto clientPeer = clientSsl->getPeerCertificate();
  auto clientSelf = clientSsl->getSelfCertificate();
  auto serverPeer = serverSsl->getPeerCertificate();
  auto serverSelf = serverSsl->getSelfCertificate();

  EXPECT_NE(clientPeer, nullptr);
  EXPECT_NE(clientSelf, nullptr);
  EXPECT_NE(serverPeer, nullptr);
  EXPECT_NE(serverSelf, nullptr);

  EXPECT_EQ(clientPeer->getIdentity(), serverSelf->getIdentity());
  EXPECT_EQ(clientSelf->getIdentity(), serverPeer->getIdentity());
}

/**
 * Test requireClientCert with no client cert
 */
TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->setVerificationOption(
      SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  serverCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());
  serverCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());
  serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(
      find_resource(kTestCA).c_str());
  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

  NetworkSocket fds[2];
  getfds(fds);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  SSLHandshakeClient client(std::move(clientSock), false, false);
  SSLHandshakeServer server(std::move(serverSock), false, false);

  eventBase.loop();

  EXPECT_FALSE(server.handshakeVerify_);
  EXPECT_FALSE(server.handshakeSuccess_);
  EXPECT_TRUE(server.handshakeError_);
  EXPECT_LE(0, client.handshakeTime.count());
  EXPECT_LE(0, server.handshakeTime.count());
}

static void makeNonBlockingPipe(int pipefds[2]) {
  if (pipe(pipefds) != 0) {
    throw std::runtime_error("Cannot create pipe");
  }
  if (::fcntl(pipefds[0], F_SETFL, O_NONBLOCK) != 0) {
    throw std::runtime_error("Cannot set pipe to nonblocking");
  }
  if (::fcntl(pipefds[1], F_SETFL, O_NONBLOCK) != 0) {
    throw std::runtime_error("Cannot set pipe to nonblocking");
  }
}

// Custom RSA private key encryption method
static int kRSAExIndex = -1;
static int kRSAEvbExIndex = -1;
static int kRSASocketExIndex = -1;
static constexpr StringPiece kEngineId = "AsyncSSLSocketTest";

static int customRsaPrivEnc(
    int flen,
    const unsigned char* from,
    unsigned char* to,
    RSA* rsa,
    int padding) {
  LOG(INFO) << "rsa_priv_enc";
  EventBase* asyncJobEvb =
      reinterpret_cast<EventBase*>(RSA_get_ex_data(rsa, kRSAEvbExIndex));
  CHECK(asyncJobEvb);

  RSA* actualRSA = reinterpret_cast<RSA*>(RSA_get_ex_data(rsa, kRSAExIndex));
  CHECK(actualRSA);

  AsyncSSLSocket* socket = reinterpret_cast<AsyncSSLSocket*>(
      RSA_get_ex_data(rsa, kRSASocketExIndex));

  ASYNC_JOB* job = ASYNC_get_current_job();
  if (job == nullptr) {
    throw std::runtime_error("Expected call in job context");
  }
  ASYNC_WAIT_CTX* waitctx = ASYNC_get_wait_ctx(job);
  OSSL_ASYNC_FD pipefds[2] = {0, 0};
  makeNonBlockingPipe(pipefds);
  if (!ASYNC_WAIT_CTX_set_wait_fd(
          waitctx, kEngineId.data(), pipefds[0], nullptr, nullptr)) {
    throw std::runtime_error("Cannot set wait fd");
  }
  int ret = 0;
  int* retptr = &ret;

  auto hand = folly::NetworkSocket::native_handle_type(pipefds[1]);
  auto asyncPipeWriter = folly::AsyncPipeWriter::newWriter(
      asyncJobEvb, folly::NetworkSocket(hand));

  if (socket) {
    LOG(INFO) << "Got a socket passed in, closing it...";
    socket->closeNow();
  }
  asyncJobEvb->runInEventBaseThread([retptr = retptr,
                                     flen = flen,
                                     from = from,
                                     to = to,
                                     padding = padding,
                                     actualRSA = actualRSA,
                                     writer = std::move(asyncPipeWriter)]() {
    LOG(INFO) << "Running job";
    *retptr = RSA_meth_get_priv_enc(RSA_PKCS1_OpenSSL())(
        flen, from, to, actualRSA, padding);
    LOG(INFO) << "Finished job, writing to pipe";
    uint8_t byte = *retptr > 0 ? 1 : 0;
    writer->write(nullptr, &byte, 1);
  });

  LOG(INFO) << "About to pause job";

  ASYNC_pause_job();
  LOG(INFO) << "Resumed job with ret: " << ret;
  return ret;
}

void rsaFree(void*, void* ptr, CRYPTO_EX_DATA*, int, long, void*) {
  LOG(INFO) << "RSA_free is called with ptr " << std::hex << ptr;
  if (ptr == nullptr) {
    LOG(INFO) << "Returning early from rsaFree because ptr is null";
    return;
  }
  RSA* rsa = (RSA*)ptr;
  auto meth = RSA_get_method(rsa);
  if (meth != RSA_get_default_method()) {
    auto nonconst = const_cast<RSA_METHOD*>(meth);
    RSA_meth_free(nonconst);
    RSA_set_method(rsa, RSA_get_default_method());
  }
  RSA_free(rsa);
}

struct RSAPointers {
  RSA* actualrsa{nullptr};
  RSA* dummyrsa{nullptr};
  RSA_METHOD* meth{nullptr};
};

inline void RSAPointersFree(RSAPointers* p) {
  if (p->meth && p->dummyrsa && RSA_get_method(p->dummyrsa) == p->meth) {
    RSA_set_method(p->dummyrsa, RSA_get_default_method());
  }

  if (p->meth) {
    LOG(INFO) << "Freeing meth";
    RSA_meth_free(p->meth);
  }

  if (p->actualrsa) {
    LOG(INFO) << "Freeing actualrsa";
    RSA_free(p->actualrsa);
  }

  if (p->dummyrsa) {
    LOG(INFO) << "Freeing dummyrsa";
    RSA_free(p->dummyrsa);
  }

  delete p;
}

using RSAPointersDeleter =
    folly::static_function_deleter<RSAPointers, RSAPointersFree>;

std::unique_ptr<RSAPointers, RSAPointersDeleter> setupCustomRSA(
    const char* certPath, const char* keyPath, EventBase* jobEvb) {
  auto certPEM = getFileAsBuf(certPath);
  auto keyPEM = getFileAsBuf(keyPath);

  ssl::BioUniquePtr certBio(
      BIO_new_mem_buf((void*)certPEM.data(), certPEM.size()));
  ssl::BioUniquePtr keyBio(
      BIO_new_mem_buf((void*)keyPEM.data(), keyPEM.size()));

  ssl::X509UniquePtr cert(
      PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
  ssl::EvpPkeyUniquePtr evpPkey(
      PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
  ssl::EvpPkeyUniquePtr publicEvpPkey(X509_get_pubkey(cert.get()));

  std::unique_ptr<RSAPointers, RSAPointersDeleter> ret(new RSAPointers());

  RSA* actualrsa = EVP_PKEY_get1_RSA(evpPkey.get());
  LOG(INFO) << "actualrsa ptr " << std::hex << (void*)actualrsa;
  RSA* dummyrsa = EVP_PKEY_get1_RSA(publicEvpPkey.get());
  if (dummyrsa == nullptr) {
    throw std::runtime_error("Couldn't get RSA cert public factors");
  }
  RSA_METHOD* meth = RSA_meth_dup(RSA_get_default_method());
  if (meth == nullptr || RSA_meth_set1_name(meth, "Async RSA method") == 0 ||
      RSA_meth_set_priv_enc(meth, customRsaPrivEnc) == 0 ||
      RSA_meth_set_flags(meth, RSA_METHOD_FLAG_NO_CHECK) == 0) {
    throw std::runtime_error("Cannot create async RSA_METHOD");
  }
  RSA_set_method(dummyrsa, meth);
  RSA_set_flags(dummyrsa, RSA_FLAG_EXT_PKEY);

  kRSAExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
  kRSAEvbExIndex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
  kRSASocketExIndex =
      RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
  CHECK_NE(kRSAExIndex, -1);
  CHECK_NE(kRSAEvbExIndex, -1);
  CHECK_NE(kRSASocketExIndex, -1);
  RSA_set_ex_data(dummyrsa, kRSAExIndex, actualrsa);
  RSA_set_ex_data(dummyrsa, kRSAEvbExIndex, jobEvb);

  ret->actualrsa = actualrsa;
  ret->dummyrsa = dummyrsa;
  ret->meth = meth;

  return ret;
}

// TODO: disabled with ASAN doesn't play nice with ASYNC for some reason
#ifndef FOLLY_SANITIZE_ADDRESS
TEST(AsyncSSLSocketTest, OpenSSL110AsyncTest) {
  ASYNC_init_thread(1, 1);
  EventBase eventBase;
  ScopedEventBaseThread jobEvbThread;
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());
  serverCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());
  serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(
      find_resource(kTestCA).c_str());

  auto rsaPointers =
      setupCustomRSA(kTestCert, kTestKey, jobEvbThread.getEventBase());
  CHECK(rsaPointers->dummyrsa);
  // up-refs dummyrsa
  SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa);
  SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC);

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

  NetworkSocket fds[2];
  getfds(fds);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  SSLHandshakeClient client(std::move(clientSock), false, false);
  SSLHandshakeServer server(std::move(serverSock), false, false);

  eventBase.loop();

  EXPECT_TRUE(server.handshakeSuccess_);
  EXPECT_TRUE(client.handshakeSuccess_);
  ASYNC_cleanup_thread();
}

TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestFailure) {
  ASYNC_init_thread(1, 1);
  EventBase eventBase;
  ScopedEventBaseThread jobEvbThread;
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());
  serverCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());
  serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(
      find_resource(kTestCA).c_str());
  // Set the wrong key for the cert
  auto rsaPointers =
      setupCustomRSA(kTestCert, kClientTestKey, jobEvbThread.getEventBase());
  CHECK(rsaPointers->dummyrsa);
  SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa);
  SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC);

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

  NetworkSocket fds[2];
  getfds(fds);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  SSLHandshakeClient client(std::move(clientSock), false, false);
  SSLHandshakeServer server(std::move(serverSock), false, false);

  eventBase.loop();

  EXPECT_TRUE(server.handshakeError_);
  EXPECT_TRUE(client.handshakeError_);
  ASYNC_cleanup_thread();
}

TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestClosedWithCallbackPending) {
  ASYNC_init_thread(1, 1);
  EventBase eventBase;
  std::optional<EventBaseThread> jobEvbThread;
  jobEvbThread.emplace();
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());
  serverCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());
  serverCtx->setSupportedClientCertificateAuthorityNamesFromFile(
      find_resource(kTestCA).c_str());

  auto rsaPointers =
      setupCustomRSA(kTestCert, kTestKey, jobEvbThread->getEventBase());
  CHECK(rsaPointers->dummyrsa);
  // up-refs dummyrsa
  SSL_CTX_use_RSAPrivateKey(serverCtx->getSSLCtx(), rsaPointers->dummyrsa);
  SSL_CTX_set_mode(serverCtx->getSSLCtx(), SSL_MODE_ASYNC);

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

  NetworkSocket fds[2];
  getfds(fds);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  RSA_set_ex_data(rsaPointers->dummyrsa, kRSASocketExIndex, serverSock.get());

  SSLHandshakeClient client(std::move(clientSock), false, false);
  SSLHandshakeServer server(std::move(serverSock), false, false);

  eventBase.loop();

  EXPECT_TRUE(server.handshakeError_);
  EXPECT_TRUE(client.handshakeError_);
  ASYNC_cleanup_thread();
  jobEvbThread.reset();
}
#endif // FOLLY_SANITIZE_ADDRESS

TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
  using folly::ssl::OpenSSLUtils;
  auto cert = getFileAsBuf(kTestCert);
  auto key = getFileAsBuf(kTestKey);

  ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
  BIO_write(certBio.get(), cert.data(), cert.size());
  ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
  BIO_write(keyBio.get(), key.data(), key.size());

  // Create SSL structs from buffers to get properties
  ssl::X509UniquePtr certStruct(
      PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
  ssl::EvpPkeyUniquePtr keyStruct(
      PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
  certBio = nullptr;
  keyBio = nullptr;

  auto origCommonName = OpenSSLUtils::getCommonName(certStruct.get());
  auto origKeySize = EVP_PKEY_bits(keyStruct.get());
  certStruct = nullptr;
  keyStruct = nullptr;

  auto ctx = std::make_shared<SSLContext>();
  ctx->loadPrivateKeyFromBufferPEM(key);
  ctx->loadCertificateFromBufferPEM(cert);
  ctx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  ssl::SSLUniquePtr ssl(ctx->createSSL());

  auto newCert = SSL_get_certificate(ssl.get());
  auto newKey = SSL_get_privatekey(ssl.get());

  // Get properties from SSL struct
  auto newCommonName = OpenSSLUtils::getCommonName(newCert);
  auto newKeySize = EVP_PKEY_bits(newKey);

  // Check that the key and cert have the expected properties
  EXPECT_EQ(origCommonName, newCommonName);
  EXPECT_EQ(origKeySize, newKeySize);
}

TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
  EventBase eb;

  // Set up SSL context.
  auto sslContext = std::make_shared<SSLContext>();
  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

  // create SSL socket
  AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));

  EXPECT_EQ(1500, socket->getMinWriteSize());

  socket->setMinWriteSize(0);
  EXPECT_EQ(0, socket->getMinWriteSize());
  socket->setMinWriteSize(50000);
  EXPECT_EQ(50000, socket->getMinWriteSize());
}

class ReadCallbackTerminator : public ReadCallback {
 public:
  ReadCallbackTerminator(EventBase* base, WriteCallbackBase* wcb)
      : ReadCallback(wcb), base_(base) {}

  // Do not write data back, terminate the loop.
  void readDataAvailable(size_t len) noexcept override {
    std::cerr << "readDataAvailable, len " << len << std::endl;

    currentBuffer.length = len;

    buffers.push_back(currentBuffer);
    currentBuffer.reset();
    state = STATE_SUCCEEDED;

    socket_->setReadCB(nullptr);
    base_->terminateLoopSoon();
  }

 private:
  EventBase* base_;
};

/**
 * Test a full unencrypted codepath
 */
TEST(AsyncSSLSocketTest, UnencryptedTest) {
  EventBase base;

  auto clientCtx = std::make_shared<folly::SSLContext>();
  auto serverCtx = std::make_shared<folly::SSLContext>();
  NetworkSocket fds[2];
  getfds(fds);
  getctx(clientCtx, serverCtx);
  auto client =
      AsyncSSLSocket::newSocket(clientCtx, &base, fds[0], false, true);
  std::shared_ptr<AsyncSSLSocket> server =
      AsyncSSLSocket::newSocket(serverCtx, &base, fds[1], true, true);

  ReadCallbackTerminator readCallback(&base, nullptr);
  server->setReadCB(&readCallback);
  readCallback.setSocket(server);

  uint8_t buf[128];
  memset(buf, 'a', sizeof(buf));
  client->write(nullptr, buf, sizeof(buf));

  // Check that bytes are unencrypted
  char c;
  EXPECT_EQ(1, netops::recv(fds[1], &c, 1, MSG_PEEK));
  EXPECT_EQ('a', c);

  EventBaseAborter eba(&base, 3000);
  base.loop();

  EXPECT_EQ(1, readCallback.buffers.size());
  EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());

  server->setReadCB(&readCallback);

  // Unencrypted
  server->sslAccept(nullptr);
  client->sslConn(nullptr);

  // Do NOT wait for handshake, writing should be queued and happen after

  client->write(nullptr, buf, sizeof(buf));

  // Check that bytes are *not* unencrypted
  char c2;
  EXPECT_EQ(1, netops::recv(fds[1], &c2, 1, MSG_PEEK));
  EXPECT_NE('a', c2);

  base.loop();

  EXPECT_EQ(2, readCallback.buffers.size());
  EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
}

TEST(AsyncSSLSocketTest, ConnectUnencryptedTest) {
  auto clientCtx = std::make_shared<folly::SSLContext>();
  auto serverCtx = std::make_shared<folly::SSLContext>();
  getctx(clientCtx, serverCtx);

  WriteCallbackBase writeCallback;
  ReadCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  EventBase evb;
  std::shared_ptr<AsyncSSLSocket> socket =
      AsyncSSLSocket::newSocket(clientCtx, &evb, true);
  socket->connect(nullptr, server.getAddress(), 0);

  evb.loop();

  EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, socket->getSSLState());
  socket->sslConn(nullptr);
  evb.loop();
  EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, socket->getSSLState());

  // write()
  std::array<uint8_t, 128> buf;
  memset(buf.data(), 'a', buf.size());
  socket->write(nullptr, buf.data(), buf.size());

  // wait until the server has reflected all data written by the client
  SemiFuture<StateEnum> future = writeCallback.getSemiFuture();
  StateEnum state = std::move(future).get(std::chrono::seconds(3));
  EXPECT_EQ(state, STATE_SUCCEEDED);
  socket->close();
}

/**
 * Test acceptrunner in various situations
 */
TEST(AsyncSSLSocketTest, SSLAcceptRunnerBasic) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  serverCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  NetworkSocket fds[2];
  getfds(fds);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  serverCtx->sslAcceptRunner(std::make_unique<SSLAcceptEvbRunner>(&eventBase));

  SSLHandshakeClient client(std::move(clientSock), true, true);
  SSLHandshakeServer server(std::move(serverSock), true, true);

  eventBase.loop();

  EXPECT_TRUE(client.handshakeSuccess_);
  EXPECT_FALSE(client.handshakeError_);
  EXPECT_LE(0, client.handshakeTime.count());
  EXPECT_TRUE(server.handshakeSuccess_);
  EXPECT_FALSE(server.handshakeError_);
  EXPECT_LE(0, server.handshakeTime.count());
}

TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptError) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  serverCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  NetworkSocket fds[2];
  getfds(fds);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  serverCtx->sslAcceptRunner(
      std::make_unique<SSLAcceptErrorRunner>(&eventBase));

  SSLHandshakeClient client(std::move(clientSock), true, true);
  SSLHandshakeServer server(std::move(serverSock), true, true);

  eventBase.loop();

  EXPECT_FALSE(client.handshakeSuccess_);
  EXPECT_TRUE(client.handshakeError_);
  EXPECT_FALSE(server.handshakeSuccess_);
  EXPECT_TRUE(server.handshakeError_);
}

TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptClose) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  serverCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  NetworkSocket fds[2];
  getfds(fds);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  serverCtx->sslAcceptRunner(
      std::make_unique<SSLAcceptCloseRunner>(&eventBase, serverSock.get()));

  SSLHandshakeClient client(std::move(clientSock), true, true);
  SSLHandshakeServer server(std::move(serverSock), true, true);

  eventBase.loop();

  EXPECT_FALSE(client.handshakeSuccess_);
  EXPECT_TRUE(client.handshakeError_);
  EXPECT_FALSE(server.handshakeSuccess_);
  EXPECT_TRUE(server.handshakeError_);
}

TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptDestroy) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  serverCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  NetworkSocket fds[2];
  getfds(fds);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  SSLHandshakeClient client(std::move(clientSock), true, true);
  SSLHandshakeServer server(std::move(serverSock), true, true);

  serverCtx->sslAcceptRunner(
      std::make_unique<SSLAcceptDestroyRunner>(&eventBase, &server));

  eventBase.loop();

  EXPECT_FALSE(client.handshakeSuccess_);
  EXPECT_TRUE(client.handshakeError_);
  EXPECT_FALSE(server.handshakeSuccess_);
  EXPECT_TRUE(server.handshakeError_);
}

TEST(AsyncSSLSocketTest, SSLAcceptRunnerFiber) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  serverCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());

  NetworkSocket fds[2];
  getfds(fds);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  SSLHandshakeClient client(std::move(clientSock), true, true);
  SSLHandshakeServer server(std::move(serverSock), true, true);

  serverCtx->sslAcceptRunner(
      std::make_unique<SSLAcceptFiberRunner>(&eventBase));

  eventBase.loop();

  EXPECT_TRUE(client.handshakeSuccess_);
  EXPECT_FALSE(client.handshakeError_);
  EXPECT_TRUE(server.handshakeSuccess_);
  EXPECT_FALSE(server.handshakeError_);
}

static int newCloseCb(SSL* ssl, SSL_SESSION*) {
  AsyncSSLSocket::getFromSSL(ssl)->closeNow();

  // We do not want ownership over the supplied SSL_SESSION*, we return 0 to
  // tell OpenSSL to free it on our behalf.
  return 0;
}

static SSL_SESSION* getCloseCb(SSL* ssl, const unsigned char*, int, int*) {
  AsyncSSLSocket::getFromSSL(ssl)->closeNow();
  return nullptr;
} // namespace

TEST(AsyncSSLSocketTest, SSLAcceptRunnerFiberCloseSessionCb) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  serverCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());
  SSL_CTX_set_session_cache_mode(
      serverCtx->getSSLCtx(),
      SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
  SSL_CTX_sess_set_new_cb(serverCtx->getSSLCtx(), &newCloseCb);
  SSL_CTX_sess_set_get_cb(serverCtx->getSSLCtx(), &getCloseCb);
  serverCtx->sslAcceptRunner(
      std::make_unique<SSLAcceptFiberRunner>(&eventBase));

  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
  clientCtx->ciphers("AES128-SHA256");
  clientCtx->loadTrustedCertificates(find_resource(kTestCA).c_str());
  clientCtx->setOptions(SSL_OP_NO_TICKET);

  NetworkSocket fds[2];
  getfds(fds);

  AsyncSSLSocket::UniquePtr clientSock(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSock(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

  SSLHandshakeClient client(std::move(clientSock), true, true);
  SSLHandshakeServer server(std::move(serverSock), true, true);

  eventBase.loop();

  // As close() is called during session callbacks, client sees it as a
  // successful connection
  EXPECT_TRUE(client.handshakeSuccess_);
  EXPECT_FALSE(client.handshakeError_);
  EXPECT_FALSE(server.handshakeSuccess_);
  EXPECT_TRUE(server.handshakeError_);
}

TEST(AsyncSSLSocketTest, ConnResetErrorString) {
  // Start listening on a local port
  WriteCallbackBase writeCallback;
  WriteErrorCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(
      &readCallback, HandshakeCallback::EXPECT_ERROR);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
  socket->open();
  uint8_t buf[3] = {0x16, 0x03, 0x01};
  socket->write(buf, sizeof(buf));
  socket->closeWithReset();

  handshakeCallback.waitForHandshake();
  EXPECT_NE(
      handshakeCallback.errorString_.find("Network error"), std::string::npos);
  EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
}

TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
  // Start listening on a local port
  WriteCallbackBase writeCallback;
  WriteErrorCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(
      &readCallback, HandshakeCallback::EXPECT_ERROR);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
  socket->open();
  uint8_t buf[3] = {0x16, 0x03, 0x01};
  socket->write(buf, sizeof(buf));
  socket->close();

  handshakeCallback.waitForHandshake();
  EXPECT_NE(
      handshakeCallback.errorString_.find("Network error"), std::string::npos);
}

TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
  // Start listening on a local port
  WriteCallbackBase writeCallback;
  WriteErrorCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(
      &readCallback, HandshakeCallback::EXPECT_ERROR);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
  socket->open();
  uint8_t buf[256] = {0x16, 0x03};
  memset(buf + 2, 'a', sizeof(buf) - 2);
  socket->write(buf, sizeof(buf));
  socket->close();

  handshakeCallback.waitForHandshake();
  EXPECT_NE(
      handshakeCallback.errorString_.find("SSL routines"), std::string::npos);
#if defined(OPENSSL_IS_BORINGSSL)
  EXPECT_NE(
      handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
      std::string::npos);
#else
  EXPECT_NE(
      handshakeCallback.errorString_.find("packet length too long"),
      std::string::npos);
#endif
}

TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
  using folly::ssl::OpenSSLUtils;
  EXPECT_EQ(
      OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
  // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
  EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
  // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
  EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
}

#if defined __linux__
/**
 * Ensure TransparentTLS flag is disabled with AsyncSSLSocket
 */
TEST(AsyncSSLSocketTest, TTLSDisabled) {
  // clear all setsockopt tracking history
  globalStatic.reset();

  // Start listening on a local port
  WriteCallbackBase writeCallback;
  ReadCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  // Set up SSL context.
  auto sslContext = std::make_shared<SSLContext>();

  // connect
  auto socket =
      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
  socket->open();

  EXPECT_EQ(1, globalStatic.ttlsDisabledSet.count(socket->getNetworkSocket()));

  // write()
  std::array<uint8_t, 128> buf;
  memset(buf.data(), 'a', buf.size());
  socket->write(buf.data(), buf.size());

  // wait until the server has reflected all data written by the client
  SemiFuture<StateEnum> future = writeCallback.getSemiFuture();
  StateEnum state = std::move(future).get(std::chrono::seconds(3));
  EXPECT_EQ(state, STATE_SUCCEEDED);

  // close()
  socket->close();
}
#endif

#if FOLLY_ALLOW_TFO

class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
 public:
  using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;

  explicit MockAsyncTFOSSLSocket(
      std::shared_ptr<folly::SSLContext> sslCtx, EventBase* evb)
      : AsyncSSLSocket(sslCtx, evb) {}

  MOCK_METHOD(
      ssize_t,
      tfoSendMsg,
      (NetworkSocket fd, struct msghdr* msg, int msg_flags));
};

#if defined __linux__
/**
 * Ensure TransparentTLS flag is disabled with AsyncSSLSocket + TFO
 */
TEST(AsyncSSLSocketTest, TTLSDisabledWithTFO) {
  // clear all setsockopt tracking history
  globalStatic.reset();

  // Start listening on a local port
  WriteCallbackBase writeCallback;
  ReadCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback, true);

  // Set up SSL context.
  auto sslContext = std::make_shared<SSLContext>();

  // connect
  auto socket =
      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
  socket->enableTFO();
  socket->open();

  EXPECT_EQ(1, globalStatic.ttlsDisabledSet.count(socket->getNetworkSocket()));

  // write()
  std::array<uint8_t, 128> buf;
  memset(buf.data(), 'a', buf.size());
  socket->write(buf.data(), buf.size());

  // wait until the server has reflected all data written by the client
  SemiFuture<StateEnum> future = writeCallback.getSemiFuture();
  StateEnum state = std::move(future).get(std::chrono::seconds(3));
  EXPECT_EQ(state, STATE_SUCCEEDED);

  // close()
  socket->close();
}
#endif

/**
 * Test connecting to, writing to, reading from, and closing the
 * connection to the SSL server with TFO.
 */
TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
  // Start listening on a local port
  WriteCallbackBase writeCallback;
  ReadCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback, true);

  // Set up SSL context.
  auto sslContext = std::make_shared<SSLContext>();

  // connect
  auto socket =
      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
  socket->enableTFO();
  socket->open();

  // write()
  std::array<uint8_t, 128> buf;
  memset(buf.data(), 'a', buf.size());
  socket->write(buf.data(), buf.size());

  // read()
  std::array<uint8_t, 128> readbuf;
  uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
  EXPECT_EQ(bytesRead, 128);
  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);

  // close()
  socket->close();
}

/**
 * Test connecting to, writing to, reading from, and closing the
 * connection to the SSL server with TFO.
 */
TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
  // Start listening on a local port
  WriteCallbackBase writeCallback;
  ReadCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  // Set up SSL context.
  auto sslContext = std::make_shared<SSLContext>();

  // connect
  auto socket =
      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
  socket->enableTFO();
  socket->open();

  // write()
  std::array<uint8_t, 128> buf;
  memset(buf.data(), 'a', buf.size());
  socket->write(buf.data(), buf.size());

  // read()
  std::array<uint8_t, 128> readbuf;
  uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
  EXPECT_EQ(bytesRead, 128);
  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);

  // close()
  socket->close();
}

class ConnCallback : public AsyncSocket::ConnectCallback {
 public:
  void connectSuccess() noexcept override { state = State::SUCCESS; }

  void connectErr(const AsyncSocketException& ex) noexcept override {
    state = State::ERROR;
    error = ex.what();
  }

  enum class State { WAITING, SUCCESS, ERROR };

  State state{State::WAITING};
  std::string error;
};

template <class Cardinality>
MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
    EventBase* evb, const SocketAddress& address, Cardinality cardinality) {
  // Set up SSL context.
  auto sslContext = std::make_shared<SSLContext>();

  // connect
  auto socket = MockAsyncTFOSSLSocket::UniquePtr(
      new MockAsyncTFOSSLSocket(sslContext, evb));
  socket->enableTFO();

  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
      .Times(cardinality)
      .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
        sockaddr_storage addr;
        auto len = address.getAddress(&addr);
        return netops::connect(fd, (const struct sockaddr*)&addr, len);
      }));
  return socket;
}

TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
  if (!folly::test::isTFOAvailable()) {
    GTEST_SKIP() << "TFO not supported.";
  }

  // Start listening on a local port
  WriteCallbackBase writeCallback;
  ReadCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback, true);

  EventBase evb;

  auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  evb.loop();
  EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);

  evb.runInEventBaseThread([&] { socket->detachEventBase(); });
  evb.loop();

  BlockingSocket sock(std::move(socket));
  // write()
  std::array<uint8_t, 128> buf;
  memset(buf.data(), 'a', buf.size());
  sock.write(buf.data(), buf.size());

  // read()
  std::array<uint8_t, 128> readbuf;
  uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
  EXPECT_EQ(bytesRead, 128);
  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);

  // close()
  sock.close();
}

#if !defined(OPENSSL_IS_BORINGSSL)
TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
  // Start listening on a local port
  ConnectTimeoutCallback acceptCallback;
  TestSSLServer server(&acceptCallback, true);

  // Set up SSL context.
  auto sslContext = std::make_shared<SSLContext>();

  // connect
  auto socket =
      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
  socket->enableTFO();
  EXPECT_THROW(
      socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
}
#endif

#if !defined(OPENSSL_IS_BORINGSSL)
TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
  // Start listening on a local port
  ConnectTimeoutCallback acceptCallback;
  TestSSLServer server(&acceptCallback, true);

  EventBase evb;

  auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
  ConnCallback ccb;
  // Set a short timeout
  socket->connect(&ccb, server.getAddress(), 1);

  evb.loop();
  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
}
#endif

TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
  // Start listening on a local port
  EmptyReadCallback readCallback;
  HandshakeCallback handshakeCallback(
      &readCallback, HandshakeCallback::EXPECT_ERROR);
  HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback, true);

  EventBase evb;

  auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 100);

  evb.loop();
  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
  EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
}

TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
  // Start listening on a local port
  EventBase evb;

  // Hopefully nothing is listening on this address
  SocketAddress addr("127.0.0.1", 65535);
  auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
  ConnCallback ccb;
  socket->connect(&ccb, addr, 100);

  evb.loop();
  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
  EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
}

TEST(AsyncSSLSocketTest, TestPreReceivedData) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto dfServerCtx = std::make_shared<SSLContext>();
  std::array<NetworkSocket, 2> fds;
  getfds(fds.data());
  getctx(clientCtx, dfServerCtx);

  AsyncSSLSocket::UniquePtr clientSockPtr(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSockPtr(
      new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
  auto clientSock = clientSockPtr.get();
  auto serverSock = serverSockPtr.get();
  SSLHandshakeClient client(std::move(clientSockPtr), true, true);

  // Steal some data from the server.
  std::array<uint8_t, 10> buf;
  auto bytesReceived = netops::recv(fds[1], buf.data(), buf.size(), 0);
  checkUnixError(bytesReceived, "recv failed");

  serverSock->setPreReceivedData(
      IOBuf::wrapBuffer(ByteRange(buf.data(), bytesReceived)));
  SSLHandshakeServer server(std::move(serverSockPtr), true, true);
  // The order in which the client and server handshake completion callbacks
  // fire differs between TLS 1.2 and TLS 1.3. Loop until both are complete to
  // ensure this works regardless of TLS version.
  while ((!client.handshakeSuccess_ && !client.handshakeError_) ||
         (!server.handshakeSuccess_ && !server.handshakeError_)) {
    eventBase.loopOnce();
  }

  EXPECT_TRUE(client.handshakeSuccess_);
  EXPECT_TRUE(server.handshakeSuccess_);
  EXPECT_EQ(
      serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
}

TEST(AsyncSSLSocketTest, TestMoveFromAsyncSocket) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto dfServerCtx = std::make_shared<SSLContext>();
  std::array<NetworkSocket, 2> fds;
  getfds(fds.data());
  getctx(clientCtx, dfServerCtx);

  AsyncSSLSocket::UniquePtr clientSockPtr(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSocket::UniquePtr serverSockPtr(new AsyncSocket(&eventBase, fds[1]));
  auto clientSock = clientSockPtr.get();
  auto serverSock = serverSockPtr.get();
  SSLHandshakeClient client(std::move(clientSockPtr), true, true);

  // Steal some data from the server.
  std::array<uint8_t, 10> buf;
  auto bytesReceived = netops::recv(fds[1], buf.data(), buf.size(), 0);
  checkUnixError(bytesReceived, "recv failed");

  serverSock->setPreReceivedData(
      IOBuf::wrapBuffer(ByteRange(buf.data(), bytesReceived)));
  AsyncSSLSocket::UniquePtr serverSSLSockPtr(
      new AsyncSSLSocket(dfServerCtx, std::move(serverSockPtr), true));
  auto serverSSLSock = serverSSLSockPtr.get();
  SSLHandshakeServer server(std::move(serverSSLSockPtr), true, true);
  // The order in which the client and server handshake completion callbacks
  // fire differs between TLS 1.2 and TLS 1.3. Loop until both are complete to
  // ensure this works regardless of TLS version.
  while ((!client.handshakeSuccess_ && !client.handshakeError_) ||
         (!server.handshakeSuccess_ && !server.handshakeError_)) {
    eventBase.loopOnce();
  }

  EXPECT_TRUE(client.handshakeSuccess_);
  EXPECT_TRUE(server.handshakeSuccess_);
  EXPECT_EQ(
      serverSSLSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
}

TEST(AsyncSSLSocketTest, TestNullConnectCallbackError) {
  EventBase eventBase;
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  std::array<NetworkSocket, 2> fds;
  getfds(fds.data());
  getctx(clientCtx, serverCtx);
  // Mismatched ciphers
  clientCtx->setCiphersuitesOrThrow("TLS_AES_256_GCM_SHA384");
  serverCtx->setCiphersuitesOrThrow("TLS_AES_128_GCM_SHA256");

  AsyncSSLSocket::UniquePtr clientSockPtr(
      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
  AsyncSSLSocket::UniquePtr serverSockPtr(
      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
  auto clientSock = clientSockPtr.get();

  clientSock->sslConn(nullptr);

  ExpectSSLWriteErrorCallback wcb;
  wcb.setSocket(std::move(clientSockPtr));
  clientSock->write(&wcb, "abc", 3);

  SSLHandshakeServer server(std::move(serverSockPtr), true, true);
  eventBase.loop();

  EXPECT_FALSE(server.handshakeSuccess_);
}

TEST(AsyncSSLSocketTest, TestSSLSetClientOptionsP256) {
  EventBase evb;
  std::array<NetworkSocket, 2> fds;
  getfds(fds.data());
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->setSupportedGroups(std::vector<std::string>({"P-256"}));
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());
  serverCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  ssl::SSLCommonOptions::setClientOptions(*clientCtx);

  auto clientSocket =
      AsyncSSLSocket::newSocket(clientCtx, &evb, fds[0], false, true);
  std::shared_ptr<AsyncSSLSocket> serverSocket =
      AsyncSSLSocket::newSocket(serverCtx, &evb, fds[1], true, true);

  ReadCallbackTerminator readCallback(&evb, nullptr);
  serverSocket->setReadCB(&readCallback);
  readCallback.setSocket(serverSocket);
  serverSocket->sslAccept(nullptr);
  clientSocket->sslConn(nullptr);

  uint8_t buf[128];
  memset(buf, 'a', sizeof(buf));
  clientSocket->write(nullptr, buf, sizeof(buf));

  EventBaseAborter eba(&evb, 3000);

  evb.loop();

  auto sharedGroupName = serverSocket->getNegotiatedGroup();
  EXPECT_THAT(sharedGroupName, testing::HasSubstr("prime256v1"));
}

TEST(AsyncSSLSocketTest, TestSSLSetClientOptionsX25519) {
  EventBase evb;
  std::array<NetworkSocket, 2> fds;
  getfds(fds.data());
  auto clientCtx = std::make_shared<SSLContext>();
  auto serverCtx = std::make_shared<SSLContext>();
  serverCtx->setSupportedGroups(std::vector<std::string>({"X25519", "P-256"}));
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());
  serverCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  ssl::SSLCommonOptions::setClientOptions(*clientCtx);

  auto clientSocket =
      AsyncSSLSocket::newSocket(clientCtx, &evb, fds[0], false, true);
  std::shared_ptr<AsyncSSLSocket> serverSocket =
      AsyncSSLSocket::newSocket(serverCtx, &evb, fds[1], true, true);

  ReadCallbackTerminator readCallback(&evb, nullptr);
  serverSocket->setReadCB(&readCallback);
  readCallback.setSocket(serverSocket);
  serverSocket->sslAccept(nullptr);
  clientSocket->sslConn(nullptr);

  uint8_t buf[128];
  memset(buf, 'a', sizeof(buf));
  clientSocket->write(nullptr, buf, sizeof(buf));

  EventBaseAborter eba(&evb, 3000);

  evb.loop();

  auto sharedGroupName = serverSocket->getNegotiatedGroup();
  EXPECT_THAT(sharedGroupName, testing::HasSubstr("X25519"));
}

/**
 * Test overriding the flags passed to "sendmsg()" system call,
 * and verifying that write requests fail properly.
 */
TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
  // Start listening on a local port
  SendMsgFlagsCallback msgCallback;
  ExpectWriteErrorCallback writeCallback(&msgCallback);
  ReadCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  // Set up SSL context.
  auto sslContext = std::make_shared<SSLContext>();
  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

  // connect
  auto socket =
      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
  socket->open();

  // Setting flags to "-1" to trigger "Invalid argument" error
  // on attempt to use this flags in sendmsg() system call.
  msgCallback.resetFlags(-1);

  // write()
  std::vector<uint8_t> buf(128, 'a');
  ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());

  // close()
  socket->close();

  cerr << "SendMsgParamsCallback test completed" << endl;
}

#if FOLLY_HAVE_SO_TIMESTAMPING

class AsyncSSLSocketByteEventTest : public ::testing::Test {
 protected:
  using MockDispatcher = ::testing::NiceMock<netops::test::MockDispatcher>;
  using TestObserver =
      test::MockAsyncSocketLegacyLifecycleObserverForByteEvents;
  using ByteEventType = AsyncSocket::ByteEvent::Type;

  /**
   * Components of a client connection to TestServer.
   *
   * Includes EventBase, client's AsyncSocket.
   */
  class ClientConn {
   public:
    explicit ClientConn(
        std::shared_ptr<TestSSLServer> server,
        std::shared_ptr<AsyncSSLSocket> socket = nullptr)
        : server_(std::move(server)), socket_(std::move(socket)) {
      if (!socket_) {
        socket_ = AsyncSSLSocket::newSocket(getSslContext(), &getEventBase());
      }
      socket_->setOverrideNetOpsDispatcher(netOpsDispatcher_);
      netOpsDispatcher_->forwardToDefaultImpl();
    }

    ~ClientConn() {
      if (socket_) {
        socket_->close();
      }
    }

    void connect() {
      CHECK_NOTNULL(socket_.get());
      CHECK_NOTNULL(socket_->getEventBase());
      socket_->connect(&connCb_, server_->getAddress(), 30);
      socket_->getEventBase()->loop();
      ASSERT_EQ(connCb_.state, ConnCallback::State::SUCCESS);
      setReadCb();
    }

    void waitForHandshake() {
      while (connCb_.state == ConnCallback::State::WAITING) {
        socket_->getEventBase()->loopOnce();
      }
    }

    void setReadCb() {
      // Due to how libevent works, we currently need to be subscribed to
      // EV_READ events in order to get error messages.
      //
      // TODO(bschlinker): Resolve this with libevent modification.
      // See https://github.com/libevent/libevent/issues/1038 for details.
      socket_->setReadCB(&readCb_);
      readCb_.setSocket(socket_);
    }

    std::shared_ptr<NiceMock<TestObserver>> attachObserver(
        bool enableByteEvents) {
      auto observer = AsyncSSLSocketByteEventTest::attachObserver(
          socket_.get(), enableByteEvents);
      observers.push_back(observer);
      return observer;
    }

    /**
     * Write to client socket, echo at server, and wait for echo at client.
     *
     * Waiting for echo at client ensures that we have given opportunity for
     * timestamps to be generated by the kernel.
     */
    void writeAndReflect(
        const std::vector<uint8_t>& wbuf, const WriteFlags writeFlags) {
      CHECK_NOTNULL(socket_.get());
      CHECK_NOTNULL(socket_->getEventBase());

      // write to the client socket
      WriteCallbackBase wcb;
      socket_->write(&wcb, wbuf.data(), wbuf.size(), writeFlags);
      while (wcb.state == STATE_WAITING) {
        socket_->getEventBase()->loopOnce();
      }
      ASSERT_EQ(wcb.state, STATE_SUCCEEDED);

      // TestSSLServer reads and reflects for us

      // read reflection at client
      while (wbuf.size() != readCb_.dataRead()) {
        socket_->getEventBase()->loopOnce();
      }
      readCb_.verifyData(wbuf.data(), wbuf.size());
      readCb_.clearData();
    }

    std::shared_ptr<AsyncSSLSocket> getRawSocket() { return socket_; }

    std::shared_ptr<SSLContext> getSslContext() {
      static std::shared_ptr<SSLContext> sslContext = initSslContext();
      return sslContext;
    }

    EventBase& getEventBase() {
      static EventBase evb; // use same EventBase for all client sockets
      return evb;
    }

    void netOpsExpectTimestampingSetSockOpt() {
      // must whitelist other calls
      EXPECT_CALL(*netOpsDispatcher_, setsockopt(_, _, _, _, _))
          .Times(AnyNumber());
      EXPECT_CALL(
          *netOpsDispatcher_, setsockopt(_, SOL_SOCKET, SO_TIMESTAMPING, _, _))
          .Times(1);
    }

    void netOpsExpectNoTimestampingSetSockOpt() {
      // must whitelist other calls
      EXPECT_CALL(*netOpsDispatcher_, setsockopt(_, _, _, _, _))
          .Times(AnyNumber());
      EXPECT_CALL(
          *netOpsDispatcher_, setsockopt(_, SOL_SOCKET, SO_TIMESTAMPING, _, _))
          .Times(0);
    }

    void netOpsExpectWriteWithFlags(WriteFlags writeFlags) {
      EXPECT_CALL(*netOpsDispatcher_, sendmsg(_, _, _))
          .WillOnce(Invoke(
              [this, writeFlags](
                  NetworkSocket socket, const msghdr* message, int flags) {
                EXPECT_EQ(writeFlags, getMsgWriteFlags(*message));
                return netOpsDispatcher_->netops::Dispatcher::sendmsg(
                    socket, message, flags);
              }));
    }

    void netOpsVerifyAndClearExpectations() {
      Mock::VerifyAndClearExpectations(netOpsDispatcher_.get());
    }

    /**
     * Static utilities.
     */
    static std::shared_ptr<SSLContext> initSslContext() {
      auto sslContext = std::make_shared<SSLContext>();
      sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
      return sslContext;
    }

   private:
    // server
    std::shared_ptr<TestSSLServer> server_;

    // managed observers
    std::vector<std::shared_ptr<TestObserver>> observers;

    // socket components
    ConnCallback connCb_;
    ReadCallback readCb_;
    std::shared_ptr<MockDispatcher> netOpsDispatcher_{
        std::make_shared<MockDispatcher>()};
    std::shared_ptr<AsyncSSLSocket> socket_;
  };

  ClientConn getClientConn() { return ClientConn(server_); }

  void SetUp() override {
    serverWriteCb_ = std::make_unique<WriteCallbackBase>();
    serverReadCb_ = std::make_unique<ReadCallback>(serverWriteCb_.get());
    serverHandshakeCb_ =
        std::make_unique<HandshakeCallback>(serverReadCb_.get());
    serverAcceptCb_ =
        std::make_unique<SSLServerAcceptCallback>(serverHandshakeCb_.get());
    server_ = std::make_shared<TestSSLServer>(serverAcceptCb_.get());
  }

  /**
   * Static utility functions.
   */

  static std::shared_ptr<NiceMock<TestObserver>> attachObserver(
      AsyncSocket* socket, bool enableByteEvents) {
    AsyncSocket::LegacyLifecycleObserver::Config config = {};
    config.byteEvents = enableByteEvents;
    return std::make_shared<NiceMock<TestObserver>>(socket, config);
  }

  static WriteFlags getMsgWriteFlags(const struct msghdr& msg) {
    const struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
    if (!cmsg || cmsg->cmsg_level != SOL_SOCKET ||
        cmsg->cmsg_type != SO_TIMESTAMPING ||
        cmsg->cmsg_len != CMSG_LEN(sizeof(uint32_t))) {
      return WriteFlags::NONE;
    }

    const uint32_t* sofFlags =
        (reinterpret_cast<const uint32_t*>(CMSG_DATA(cmsg)));
    WriteFlags flags = WriteFlags::NONE;
    if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_SCHED) {
      flags = flags | WriteFlags::TIMESTAMP_SCHED;
    }
    if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_SOFTWARE) {
      flags = flags | WriteFlags::TIMESTAMP_TX;
    }
    if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_ACK) {
      flags = flags | WriteFlags::TIMESTAMP_ACK;
    }

    return flags;
  }

  static WriteFlags dropWriteFromFlags(WriteFlags writeFlags) {
    return writeFlags & ~WriteFlags::TIMESTAMP_WRITE;
  }

  // server components
  std::unique_ptr<WriteCallbackBase> serverWriteCb_;
  std::unique_ptr<ReadCallback> serverReadCb_;
  std::unique_ptr<HandshakeCallback> serverHandshakeCb_;
  std::unique_ptr<SSLServerAcceptCallback> serverAcceptCb_;
  std::shared_ptr<TestSSLServer> server_;
};

TEST_F(AsyncSSLSocketByteEventTest, ObserverAttachedBeforeConnect) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(1, 'a');

  auto clientConn = getClientConn();
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  clientConn.netOpsExpectTimestampingSetSockOpt();
  clientConn.connect();
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();

  {
    clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
    clientConn.writeAndReflect(wbuf, flags);
    clientConn.netOpsVerifyAndClearExpectations();

    // may have more than four new ByteEvents if write split further by kernel
    EXPECT_THAT(observer->byteEvents, SizeIs(Ge(4)));

    // due to SSL overhead, offset will not be 0
    auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }

  // write again to check offsets
  {
    const auto startNumByteEvents = observer->byteEvents.size();
    clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
    clientConn.writeAndReflect(wbuf, flags);
    clientConn.netOpsVerifyAndClearExpectations();

    // may have more than four new ByteEvents if write split further by kernel
    EXPECT_THAT(observer->byteEvents, SizeIs(Ge(startNumByteEvents + 4)));

    // due to SSL overhead, offset will not be 1
    auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }
}

TEST_F(AsyncSSLSocketByteEventTest, ObserverAttachedAfterConnect) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(1, 'a');

  auto clientConn = getClientConn();
  clientConn.netOpsExpectNoTimestampingSetSockOpt();
  clientConn.connect();
  clientConn.netOpsVerifyAndClearExpectations();

  // We make sure the server writes at least one byte to the client before
  // enabling byte events. Otherwise, the test is flaky. We believe this happens
  // when the following sequence occurs:
  //
  // (1) The client socket enables byte events successfully;
  //
  // (2) The client socket writes data to the server with timestamping set;
  //
  // (3) The client socket receives byte events for WRITE, SCHED, and TX and
  //     processes them using `handleErrMessages` in `ioReady`. Once the error
  //     queue is empty, `ioReady` calls `handleRead`. During this time, the
  //     client also begins to read data sent by the server as part of the SSL
  //     handshake completion;
  //
  // (4) The server socket receives the data from the client and reflects it
  //     back;
  //
  // (5) The client socket receives the data reflected by the server. When a
  //     data packet contains an ACK for previous data, the kernel emits the ACK
  //     timestamp before handing off the received data to userspace. However,
  //     in this case, since the client socket is still inside the `handleRead`
  //     function (step 3), it will process the read, even though the ACK
  //     timestamp is already enqueued in the error queue. When it finishes
  //     reading the data, it does not go back to handle messages from the error
  //     queue until the next call to `ioReady`. This causes several `EXPECT`
  //     statements below to fail.

  serverHandshakeCb_->waitForHandshake();
  clientConn.waitForHandshake();
  clientConn.writeAndReflect(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();

  clientConn.netOpsExpectTimestampingSetSockOpt();
  auto observer = clientConn.attachObserver(true);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();

  {
    clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
    clientConn.writeAndReflect(wbuf, flags);
    clientConn.netOpsVerifyAndClearExpectations();

    // may have more than four new ByteEvents if write split further by kernel
    EXPECT_THAT(observer->byteEvents, SizeIs(Ge(4)));

    // due to SSL overhead, offset will not be 0
    auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }

  // write again to check offsets
  {
    const auto startNumByteEvents = observer->byteEvents.size();
    clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
    clientConn.writeAndReflect(wbuf, flags);
    clientConn.netOpsVerifyAndClearExpectations();

    // may have more than four new ByteEvents if write split further by kernel
    EXPECT_THAT(observer->byteEvents, SizeIs(Ge(startNumByteEvents + 4)));

    // due to SSL overhead, offset will not be 1
    auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }
}

TEST_F(AsyncSSLSocketByteEventTest, MultiByteWrites) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;

  auto clientConn = getClientConn();
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  clientConn.netOpsExpectTimestampingSetSockOpt();
  clientConn.connect();
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();

  // write 20 bytes
  {
    std::vector<uint8_t> wbuf(20, 'a'); // 20 bytes

    clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
    clientConn.writeAndReflect(wbuf, flags);
    clientConn.netOpsVerifyAndClearExpectations();

    // may have more than four new ByteEvents if write split further by kernel
    EXPECT_THAT(observer->byteEvents, SizeIs(Ge(4)));

    // due to SSL overhead, offset will not be 0
    auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }

  // write 40 bytes
  {
    std::vector<uint8_t> wbuf(20, 'a'); // 20 bytes

    const auto startNumByteEvents = observer->byteEvents.size();
    clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
    clientConn.writeAndReflect(wbuf, flags);
    clientConn.netOpsVerifyAndClearExpectations();

    // may have more than four new ByteEvents if write split further by kernel
    EXPECT_THAT(observer->byteEvents, SizeIs(Ge(startNumByteEvents + 4)));

    // due to SSL overhead, offset will not be 1
    auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }
}

TEST_F(AsyncSSLSocketByteEventTest, MultiByteWritesEnableSecondWrite) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;

  auto clientConn = getClientConn();
  clientConn.netOpsExpectNoTimestampingSetSockOpt();
  clientConn.connect();
  clientConn.netOpsVerifyAndClearExpectations();

  // write 20 bytes with no ByteEvents / observer
  {
    std::vector<uint8_t> wbuf(20, 'a'); // 20 bytes
    clientConn.netOpsExpectWriteWithFlags(WriteFlags::NONE);
    clientConn.writeAndReflect(wbuf, flags);
    clientConn.netOpsVerifyAndClearExpectations();
  }

  // enable observer
  clientConn.netOpsExpectTimestampingSetSockOpt();
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();

  // write 40 bytes
  {
    std::vector<uint8_t> wbuf(20, 'a'); // 20 bytes

    const auto startNumByteEvents = observer->byteEvents.size();
    clientConn.netOpsExpectWriteWithFlags(dropWriteFromFlags(flags));
    clientConn.writeAndReflect(wbuf, flags);
    clientConn.netOpsVerifyAndClearExpectations();

    // may have more than four new ByteEvents if write split further by kernel
    EXPECT_THAT(observer->byteEvents, SizeIs(Ge(startNumByteEvents + 4)));

    // due to SSL overhead, offset will not be 1
    auto offsetExpected = clientConn.getRawSocket()->getRawBytesWritten() - 1;
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        offsetExpected,
        observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }
}

#endif // FOLLY_HAVE_SO_TIMESTAMPING

#endif // __linux__

// TLS 1.2 and TLS 1.3 deliver session tickets in different ways, but we can use
// SSLContext::SessionLifecycleCallbacks to receive them in a similar manner so
// tests can work regardless of version.
class SimpleSessionLifecycleCallback
    : public SSLContext::SessionLifecycleCallbacks {
 public:
  void onNewSession(SSL*, ssl::SSLSessionUniquePtr session) override {
    // This can be called multiple times. OpenSSL sends two session tickets by
    // default). Grab the last one.
    session_ = std::move(session);
    ASSERT_TRUE(socket_ != nullptr);
    // At this point we have what we need to resume a session. Detach the
    // ReadCallback, allowing the socket's EventBase to stop looping.
    socket_->setReadCB(nullptr);
  }

  // set when session is available
  folly::ssl::SSLSessionUniquePtr session_;
  // set after object construction
  folly::AsyncSSLSocket* socket_;
};

TEST(AsyncSSLSocketTest, TestSNIClientHelloBehavior) {
  EventBase eventBase;
  auto serverCtx = std::make_shared<SSLContext>();
  auto clientCtx = std::make_shared<SSLContext>();
  serverCtx->loadPrivateKey(find_resource(kTestKey).c_str());
  serverCtx->loadCertificate(find_resource(kTestCert).c_str());

  auto sessionCb = std::make_unique<SimpleSessionLifecycleCallback>();
  auto sessionCbPtr = sessionCb.get();
  clientCtx->setSessionLifecycleCallbacks(std::move(sessionCb));
  clientCtx->setSessionCacheContext("test context");
  clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
  folly::ssl::SSLSessionUniquePtr resumptionSession = nullptr;

  {
    std::array<NetworkSocket, 2> fds;
    getfds(fds.data());

    AsyncSSLSocket::UniquePtr clientSock(
        new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
    AsyncSSLSocket::UniquePtr serverSock(
        new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

    // Client sends SNI that doesn't match anything the server cert advertises
    clientSock->setServerName("Foobar");
    auto clientSockPtr = clientSock.get();
    SSLHandshakeServerParseClientHello server(
        std::move(serverSock), true, true);
    // capture client socket in session callback, so we can detach event base
    sessionCbPtr->socket_ = clientSockPtr;
    SSLHandshakeClient client(std::move(clientSock), true, true);

    // should stop when the session ticket is received
    ReadCallback readCallback;
    clientSockPtr->setReadCB(&readCallback);
    readCallback.state = STATE_SUCCEEDED;
    // loop until we receive session (and unregister read callback)
    eventBase.loop();
    resumptionSession = std::move(sessionCbPtr->session_);

    serverSock = std::move(server).moveSocket();
    auto chi = serverSock->getClientHelloInfo();
    ASSERT_NE(chi, nullptr);
    EXPECT_EQ(
        std::string("Foobar"), std::string(serverSock->getSSLServerName()));

    // create another client, resuming with the prior session, but under a
    // different common name.
    clientSock = std::move(client).moveSocket();
    ASSERT_TRUE(resumptionSession != nullptr);
  }

  {
    std::array<NetworkSocket, 2> fds;
    getfds(fds.data());

    AsyncSSLSocket::UniquePtr clientSock(
        new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
    AsyncSSLSocket::UniquePtr serverSock(
        new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));

    clientSock->setRawSSLSession(std::move(resumptionSession));
    clientSock->setServerName("Baz");
    SSLHandshakeServerParseClientHello server(
        std::move(serverSock), true, true);
    SSLHandshakeClient client(std::move(clientSock), true, true);
    eventBase.loop();

    serverSock = std::move(server).moveSocket();
    clientSock = std::move(client).moveSocket();
    EXPECT_TRUE(clientSock->getSSLSessionReused());

    // OpenSSL 1.1.1 changes the semantics of SSL_get_servername
    // in
    // https://github.com/openssl/openssl/commit/1c4aa31d79821dee9be98e915159d52cc30d8403
    //
    // Previously, the SNI would be taken from the ClientHello.
    // Now, the SNI will be taken from the established session.
    //
    // But the session that was established with the client (prior handshake)
    // would not have set the server name field because the SNI that the client
    // requested ("Foobar") did not match any of the SANs that the server was
    // presenting ("127.0.0.1")
    //
    // To preserve this 1.1.0 behavior, getSSLServerName() should return the
    // parsed ClientHello servername. This test asserts this behavior.
    auto sni = serverSock->getSSLServerName();
    ASSERT_NE(sni, nullptr);

    std::string sniStr(sni);
    EXPECT_EQ(sniStr, std::string("Baz"));
  }
}

TEST(AsyncSSLSocketTest, BytesWrittenWithMove) {
  WriteCallbackBase writeCallback;
  ReadCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  auto sslContext = std::make_shared<SSLContext>();
  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  auto socket1 =
      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
  socket1->open(std::chrono::milliseconds(10000));

  // write
  std::vector<uint8_t> wbuf(128, 'a');
  socket1->write(wbuf.data(), wbuf.size());
  const auto socket1AppBytes = socket1->getSocket()->getAppBytesWritten();
  const auto socket1RawBytes = socket1->getSocket()->getRawBytesWritten();
  EXPECT_EQ(128, socket1AppBytes);
  EXPECT_LT(128, socket1RawBytes);

  // read reflection
  std::vector<uint8_t> readbuf(wbuf.size());
  uint32_t bytesRead = socket1->readAll(readbuf.data(), readbuf.size());
  EXPECT_EQ(bytesRead, wbuf.size());

  // additional sanity checks on virtuals
  EXPECT_EQ(
      socket1->getSSLSocket()->getRawBytesWritten(),
      socket1->getSocket()->getRawBytesWritten());
  EXPECT_EQ(128, socket1->getSocket()->getAppBytesWritten());
  EXPECT_EQ(128, socket1->getSSLSocket()->getAppBytesWritten());

  // move to another AsyncSSLSocket
  AsyncSSLSocket::UniquePtr socket2(
      new AsyncSSLSocket(sslContext, socket1->getSocket()));
  EXPECT_EQ(socket1AppBytes, socket2->getAppBytesWritten());
  EXPECT_EQ(socket1RawBytes, socket2->getRawBytesWritten());

  // move to an AsyncSocket
  AsyncSocket::UniquePtr socket3(new AsyncSocket(std::move(socket2)));
  EXPECT_EQ(socket1AppBytes, socket3->getAppBytesWritten());
  EXPECT_EQ(socket1RawBytes, socket3->getRawBytesWritten());
}

#ifdef SIGPIPE
///////////////////////////////////////////////////////////////////////////
// init_unit_test_suite
///////////////////////////////////////////////////////////////////////////
namespace {
struct Initializer {
  Initializer() { signal(SIGPIPE, SIG_IGN); }
};
Initializer initializer;
} // namespace
#endif