folly/folly/io/async/test/AsyncSSLSocketTest.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <fcntl.h>
#include <signal.h>
#include <sys/types.h>

#include <condition_variable>
#include <iostream>
#include <list>
#include <memory>

#include <folly/ExceptionWrapper.h>
#include <folly/SocketAddress.h>
#include <folly/fibers/FiberManagerMap.h>
#include <folly/io/SocketOptionMap.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/AsyncTransport.h>
#include <folly/io/async/EventBase.h>
#include <folly/io/async/ssl/SSLErrors.h>
#include <folly/io/async/test/TestSSLServer.h>
#include <folly/portability/GMock.h>
#include <folly/portability/GTest.h>
#include <folly/portability/PThread.h>
#include <folly/portability/Sockets.h>
#include <folly/portability/String.h>
#include <folly/portability/Unistd.h>
#include <folly/testing/TestUtil.h>

namespace folly::test {

// The destructors of all callback classes assert that the state is
// STATE_SUCCEEDED, for both possitive and negative tests. The tests
// are responsible for setting the succeeded state properly before the
// destructors are called.

class SendMsgParamsCallbackBase
    : public folly::AsyncSocket::SendMsgParamsCallback {
 public:
  SendMsgParamsCallbackBase() {}

  void setSocket(const std::shared_ptr<AsyncSSLSocket>& socket) {
    socket_ = socket;
    oldCallback_ = socket_->getSendMsgParamsCB();
    socket_->setSendMsgParamCB(this);
    socket_->setEorTracking(trackEor_);
  }

  void setEorTracking(bool track) {
    CHECK(!socket_); // should only be called during setup
    trackEor_ = track;
  }

  int getFlagsImpl(
      folly::WriteFlags flags, int /*defaultFlags*/) noexcept override {
    return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
  }

  void getAncillaryData(
      folly::WriteFlags flags,
      void* data,
      const AsyncSocket::WriteRequestTag& writeTag,
      const bool byteEventsEnabled) noexcept override {
    oldCallback_->getAncillaryData(flags, data, writeTag, byteEventsEnabled);
  }

  uint32_t getAncillaryDataSize(
      folly::WriteFlags flags,
      const AsyncSocket::WriteRequestTag& writeTag,
      const bool byteEventsEnabled) noexcept override {
    return oldCallback_->getAncillaryDataSize(
        flags, writeTag, byteEventsEnabled);
  }

  std::shared_ptr<AsyncSSLSocket> socket_;
  bool trackEor_{false};
  folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
};

class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
 public:
  SendMsgFlagsCallback() {}

  void resetFlags(int flags) { flags_ = flags; }

  int getFlagsImpl(
      folly::WriteFlags flags, int /*defaultFlags*/) noexcept override {
    if (flags_) {
      return flags_;
    } else {
      return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
    }
  }

  int flags_{0};
};

class SendMsgAncillaryDataCallback : public SendMsgParamsCallbackBase {
 public:
  SendMsgAncillaryDataCallback() {}

  /**
   * This data will be returned on calls to getAncillaryData.
   */
  void resetData(std::vector<char>&& data) { ancillaryData_.swap(data); }

  /**
   * These flags were observed on the last call to getAncillaryData.
   */
  folly::WriteFlags getObservedWriteFlags() { return observedWriteFlags_; }

  void getAncillaryData(
      folly::WriteFlags flags,
      void* data,
      const AsyncSocket::WriteRequestTag& writeTag,
      const bool byteEventsEnabled) noexcept override {
    // getAncillaryData is called through a long chain of functions after send
    // record the observed write flags so we can compare later
    observedWriteFlags_ = flags;

    if (ancillaryData_.size()) {
      std::cerr << "getAncillaryData: copying data" << std::endl;
      memcpy(data, ancillaryData_.data(), ancillaryData_.size());
    } else {
      oldCallback_->getAncillaryData(flags, data, writeTag, byteEventsEnabled);
    }
  }

  uint32_t getAncillaryDataSize(
      folly::WriteFlags flags,
      const AsyncSocket::WriteRequestTag& writeTag,
      const bool byteEventsEnabled) noexcept override {
    if (ancillaryData_.size()) {
      std::cerr << "getAncillaryDataSize: returning size" << std::endl;
      return ancillaryData_.size();
    } else {
      return oldCallback_->getAncillaryDataSize(
          flags, writeTag, byteEventsEnabled);
    }
  }

  folly::WriteFlags observedWriteFlags_{};
  std::vector<char> ancillaryData_;
};

class WriteCallbackBase : public AsyncTransport::WriteCallback {
 public:
  explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
      : state(STATE_WAITING),
        bytesWritten(0),
        exception(AsyncSocketException::UNKNOWN, "none"),
        mcb_(mcb) {}

  ~WriteCallbackBase() override { EXPECT_EQ(STATE_SUCCEEDED, state); }

  SemiFuture<StateEnum> getSemiFuture() { return promise_.getSemiFuture(); }

  virtual void setSocket(const std::shared_ptr<AsyncSSLSocket>& socket) {
    socket_ = socket;
    if (mcb_) {
      mcb_->setSocket(socket);
    }
  }

  void writeSuccess() noexcept override {
    std::cerr << "writeSuccess" << std::endl;
    state = STATE_SUCCEEDED;
    if (!promise_.isFulfilled()) {
      promise_.setValue(state);
    }
  }

  void writeErr(
      size_t nBytesWritten, const AsyncSocketException& ex) noexcept override {
    std::cerr << "writeError: bytesWritten " << nBytesWritten << ", exception "
              << ex.what() << std::endl;

    state = STATE_FAILED;
    if (!promise_.isFulfilled()) {
      promise_.setValue(state);
    }
    this->bytesWritten = nBytesWritten;
    exception = ex;
    socket_->close();
  }

  std::shared_ptr<AsyncSSLSocket> socket_;
  StateEnum state;
  size_t bytesWritten;
  AsyncSocketException exception;
  SendMsgParamsCallbackBase* mcb_;
  Promise<StateEnum> promise_;
};

class ExpectWriteErrorCallback : public WriteCallbackBase {
 public:
  explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
      : WriteCallbackBase(mcb) {}

  ~ExpectWriteErrorCallback() override {
    EXPECT_EQ(STATE_FAILED, state);
    EXPECT_EQ(
        exception.getType(),
        AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
    EXPECT_EQ(exception.getErrno(), 22);
    // Suppress the assert in  ~WriteCallbackBase()
    state = STATE_SUCCEEDED;
  }
};

class ExpectSSLWriteErrorCallback : public WriteCallbackBase {
 public:
  explicit ExpectSSLWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
      : WriteCallbackBase(mcb) {}

  ~ExpectSSLWriteErrorCallback() override {
    EXPECT_EQ(STATE_FAILED, state);
    EXPECT_EQ(
        exception.getType(),
        AsyncSocketException::AsyncSocketExceptionType::SSL_ERROR);
    // Suppress the assert in  ~WriteCallbackBase()
    state = STATE_SUCCEEDED;
  }
};

class ReadCallbackBase : public AsyncTransport::ReadCallback {
 public:
  explicit ReadCallbackBase(WriteCallbackBase* wcb)
      : wcb_(wcb), state(STATE_WAITING) {}

  ~ReadCallbackBase() override { EXPECT_EQ(STATE_SUCCEEDED, state); }

  void setSocket(const std::shared_ptr<AsyncSSLSocket>& socket) {
    socket_ = socket;
  }

  void setState(StateEnum s) {
    state = s;
    if (wcb_) {
      wcb_->state = s;
    }
  }

  void readErr(const AsyncSocketException& ex) noexcept override {
    std::cerr << "readError " << ex.what() << std::endl;
    state = STATE_FAILED;
    socket_->close();
  }

  void readEOF() noexcept override {
    std::cerr << "readEOF" << std::endl;
    socket_->close();
  }

  std::shared_ptr<AsyncSSLSocket> socket_;
  WriteCallbackBase* wcb_;
  StateEnum state;
};

/**
 * ReadCallback reads data from the socket and then writes it back.
 *
 * It includes any folly::WriteFlags set via setWriteFlags(...) in its write
 * back operation.
 */
class ReadCallback : public ReadCallbackBase {
 public:
  explicit ReadCallback(WriteCallbackBase* wcb, bool reflect = true)
      : ReadCallbackBase(wcb),
        buffers(),
        writeFlags(folly::WriteFlags::NONE),
        reflect(reflect) {}

  explicit ReadCallback() : ReadCallback(nullptr, false) {}

  ~ReadCallback() override {
    for (std::vector<Buffer>::iterator it = buffers.begin();
         it != buffers.end();
         ++it) {
      it->free();
    }
    currentBuffer.free();
  }

  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
    if (!currentBuffer.buffer) {
      currentBuffer.allocate(4096);
    }
    *bufReturn = currentBuffer.buffer;
    *lenReturn = currentBuffer.length;
  }

  void readDataAvailable(size_t len) noexcept override {
    std::cerr << "readDataAvailable, len " << len << std::endl;

    currentBuffer.length = len;

    if (wcb_) {
      wcb_->setSocket(socket_);
    }

    // Write back the same data.
    if (reflect) {
      socket_->write(wcb_, currentBuffer.buffer, len, writeFlags);
    }

    buffers.push_back(currentBuffer);
    currentBuffer.reset();
    state = STATE_SUCCEEDED;
  }

  void verifyData(const char* expected, size_t expectedLen) const {
    verifyData((const unsigned char*)expected, expectedLen);
  }

  void verifyData(const unsigned char* expected, size_t expectedLen) const {
    size_t offset = 0;
    for (size_t idx = 0; idx < buffers.size(); ++idx) {
      const auto& buf = buffers[idx];
      size_t cmpLen = std::min(buf.length, expectedLen - offset);
      CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
      CHECK_EQ(cmpLen, buf.length);
      offset += cmpLen;
    }
    CHECK_EQ(offset, expectedLen);
  }

  void clearData() {
    for (auto& buffer : buffers) {
      buffer.free();
    }
    buffers.clear();
  }

  size_t dataRead() const {
    size_t ret = 0;
    for (const auto& buf : buffers) {
      ret += buf.length;
    }
    return ret;
  }

  /**
   * These flags will be used when writing the read data back to the socket.
   */
  void setWriteFlags(folly::WriteFlags flags) { writeFlags = flags; }

  class Buffer {
   public:
    Buffer() : buffer(nullptr), length(0) {}
    Buffer(char* buf, size_t len) : buffer(buf), length(len) {}

    void reset() {
      buffer = nullptr;
      length = 0;
    }
    void allocate(size_t len) {
      assert(buffer == nullptr);
      this->buffer = static_cast<char*>(malloc(len));
      this->length = len;
    }
    void free() {
      ::free(buffer);
      reset();
    }

    char* buffer;
    size_t length;
  };

  std::vector<Buffer> buffers;
  Buffer currentBuffer;
  folly::WriteFlags writeFlags;
  bool reflect; // whether read bytes will be written back to the transport
};

class ReadErrorCallback : public ReadCallbackBase {
 public:
  explicit ReadErrorCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}

  // Return nullptr buffer to trigger readError()
  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
    *bufReturn = nullptr;
    *lenReturn = 0;
  }

  void readDataAvailable(size_t /* len */) noexcept override {
    // This should never to called.
    FAIL();
  }

  void readErr(const AsyncSocketException& ex) noexcept override {
    ReadCallbackBase::readErr(ex);
    std::cerr << "ReadErrorCallback::readError" << std::endl;
    setState(STATE_SUCCEEDED);
  }
};

class ReadEOFCallback : public ReadCallbackBase {
 public:
  explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}

  // Return nullptr buffer to trigger readError()
  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
    *bufReturn = nullptr;
    *lenReturn = 0;
  }

  void readDataAvailable(size_t /* len */) noexcept override {
    // This should never to called.
    FAIL();
  }

  void readEOF() noexcept override {
    ReadCallbackBase::readEOF();
    setState(STATE_SUCCEEDED);
  }
};

class WriteErrorCallback : public ReadCallback {
 public:
  explicit WriteErrorCallback(WriteCallbackBase* wcb) : ReadCallback(wcb) {}

  void readDataAvailable(size_t len) noexcept override {
    std::cerr << "readDataAvailable, len " << len << std::endl;

    currentBuffer.length = len;

    // close the socket before writing to trigger writeError().
    netops::close(socket_->getNetworkSocket());

    wcb_->setSocket(socket_);

    // Write back the same data.
    folly::test::msvcSuppressAbortOnInvalidParams(
        [&] { socket_->write(wcb_, currentBuffer.buffer, len); });

    if (wcb_->state == STATE_FAILED) {
      setState(STATE_SUCCEEDED);
    } else {
      state = STATE_FAILED;
    }

    buffers.push_back(currentBuffer);
    currentBuffer.reset();
  }

  void readErr(const AsyncSocketException& ex) noexcept override {
    std::cerr << "readError " << ex.what() << std::endl;
    // do nothing since this is expected
  }
};

class EmptyReadCallback : public ReadCallback {
 public:
  explicit EmptyReadCallback() : ReadCallback(nullptr) {}

  void readErr(const AsyncSocketException& ex) noexcept override {
    std::cerr << "readError " << ex.what() << std::endl;
    state = STATE_FAILED;
    if (tcpSocket_) {
      tcpSocket_->close();
    }
  }

  void readEOF() noexcept override {
    std::cerr << "readEOF" << std::endl;
    if (tcpSocket_) {
      tcpSocket_->close();
    }
    state = STATE_SUCCEEDED;
  }

  std::shared_ptr<AsyncSocket> tcpSocket_;
};

class MockCertificateIdentityVerifier : public CertificateIdentityVerifier {
 public:
  MOCK_METHOD(
      std::unique_ptr<AsyncTransportCertificate>,
      verifyLeaf,
      (const AsyncTransportCertificate&),
      (const));
};

class MockHandshakeCB : public AsyncSSLSocket::HandshakeCB {
 public:
  MOCK_METHOD(bool, handshakeVerImpl, (AsyncSSLSocket*, bool, X509_STORE_CTX*));
  virtual bool handshakeVer(
      AsyncSSLSocket* sock,
      bool preverifyOk,
      X509_STORE_CTX* ctx) noexcept override {
    return handshakeVerImpl(sock, preverifyOk, ctx);
  }

  MOCK_METHOD(void, handshakeSucImpl, (AsyncSSLSocket*));
  virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
    handshakeSucImpl(sock);
  }

  MOCK_METHOD(
      void, handshakeErrImpl, (AsyncSSLSocket*, const AsyncSocketException&));
  virtual void handshakeErr(
      AsyncSSLSocket* sock, const AsyncSocketException& ex) noexcept override {
    handshakeErrImpl(sock, ex);
  }
};

class HandshakeCallback : public AsyncSSLSocket::HandshakeCB {
 public:
  enum ExpectType { EXPECT_SUCCESS, EXPECT_ERROR };

  explicit HandshakeCallback(
      ReadCallbackBase* rcb, ExpectType expect = EXPECT_SUCCESS)
      : state(STATE_WAITING), rcb_(rcb), expect_(expect) {}

  void setSocket(const std::shared_ptr<AsyncSSLSocket>& socket) {
    socket_ = socket;
  }

  void setState(StateEnum s) {
    state = s;
    rcb_->setState(s);
  }

  // Functions inherited from AsyncSSLSocketHandshakeCallback
  void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
    isResumed_ = sock->getSSLSessionReused();
    std::lock_guard<std::mutex> g(mutex_);
    cv_.notify_all();
    EXPECT_EQ(sock, socket_.get());
    std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
    rcb_->setSocket(socket_);
    sock->setReadCB(rcb_);
    state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
  }
  void handshakeErr(
      AsyncSSLSocket* /* sock */,
      const AsyncSocketException& ex) noexcept override {
    isResumed_ = false;
    std::lock_guard<std::mutex> g(mutex_);
    cv_.notify_all();
    std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
    state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
    if (expect_ == EXPECT_ERROR) {
      // rcb will never be invoked
      rcb_->setState(STATE_SUCCEEDED);
    }
    errorString_ = ex.what();
  }

  void waitForHandshake() {
    std::unique_lock<std::mutex> lock(mutex_);
    cv_.wait(lock, [this] { return state != STATE_WAITING; });
  }

  ~HandshakeCallback() override { EXPECT_EQ(STATE_SUCCEEDED, state); }

  void closeSocket() {
    socket_->close();
    state = STATE_SUCCEEDED;
  }

  std::shared_ptr<AsyncSSLSocket> getSocket() { return socket_; }

  bool isResumed() const { return isResumed_; }

  StateEnum state;
  std::shared_ptr<AsyncSSLSocket> socket_;
  ReadCallbackBase* rcb_;
  ExpectType expect_;
  std::mutex mutex_;
  std::condition_variable cv_;
  std::string errorString_;
  bool isResumed_{false};
};

class SSLServerAcceptCallback : public SSLServerAcceptCallbackBase {
 public:
  uint32_t timeout_;

  explicit SSLServerAcceptCallback(HandshakeCallback* hcb, uint32_t timeout = 0)
      : SSLServerAcceptCallbackBase(hcb), timeout_(timeout) {}

  ~SSLServerAcceptCallback() override {
    if (timeout_ > 0) {
      // if we set a timeout, we expect failure
      EXPECT_EQ(hcb_->state, STATE_FAILED);
      hcb_->setState(STATE_SUCCEEDED);
    }
  }

  void connAccepted(
      const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
    auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
    std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;

    hcb_->setSocket(sock);
    sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
    EXPECT_EQ(sock->getSSLState(), AsyncSSLSocket::STATE_ACCEPTING);

    state = STATE_SUCCEEDED;
  }
};

class SSLServerAcceptCallbackDelay : public SSLServerAcceptCallback {
 public:
  explicit SSLServerAcceptCallbackDelay(HandshakeCallback* hcb)
      : SSLServerAcceptCallback(hcb) {}

  void connAccepted(
      const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
    auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);

    std::cerr << "SSLServerAcceptCallbackDelay::connAccepted" << std::endl;
    auto fd = sock->getNetworkSocket();

#ifndef TCP_NOPUSH
    {
      // The accepted connection should already have TCP_NODELAY set
      int value;
      socklen_t valueLength = sizeof(value);
      int rc = netops::getsockopt(
          fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
      EXPECT_EQ(rc, 0);
      EXPECT_EQ(value, 1);
    }
#endif

    // Unset the TCP_NODELAY option.
    int value = 0;
    socklen_t valueLength = sizeof(value);
    int rc =
        netops::setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
    EXPECT_EQ(rc, 0);

    rc = netops::getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
    EXPECT_EQ(rc, 0);
    EXPECT_EQ(value, 0);

    SSLServerAcceptCallback::connAccepted(sock);
  }
};

class HandshakeErrorCallback : public SSLServerAcceptCallbackBase {
 public:
  explicit HandshakeErrorCallback(HandshakeCallback* hcb)
      : SSLServerAcceptCallbackBase(hcb) {}

  void connAccepted(
      const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
    auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);

    std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;

    // The first call to sslAccept() should succeed.
    hcb_->setSocket(sock);
    sock->sslAccept(hcb_);
    EXPECT_EQ(sock->getSSLState(), AsyncSSLSocket::STATE_ACCEPTING);

    // The second call to sslAccept() should fail.
    HandshakeCallback callback2(hcb_->rcb_);
    callback2.setSocket(sock);
    sock->sslAccept(&callback2);
    EXPECT_EQ(sock->getSSLState(), AsyncSSLSocket::STATE_ERROR);

    // Both callbacks should be in the error state.
    EXPECT_EQ(hcb_->state, STATE_FAILED);
    EXPECT_EQ(callback2.state, STATE_FAILED);

    state = STATE_SUCCEEDED;
    hcb_->setState(STATE_SUCCEEDED);
    callback2.setState(STATE_SUCCEEDED);
  }
};

class HandshakeTimeoutCallback : public SSLServerAcceptCallbackBase {
 public:
  explicit HandshakeTimeoutCallback(HandshakeCallback* hcb)
      : SSLServerAcceptCallbackBase(hcb) {}

  void connAccepted(
      const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
    std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;

    auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);

    hcb_->setSocket(sock);
    sock->getEventBase()->tryRunAfterDelay(
        [=] {
          std::cerr << "Delayed SSL accept, client will have close by now"
                    << std::endl;
          // SSL accept will fail
          EXPECT_EQ(sock->getSSLState(), AsyncSSLSocket::STATE_UNINIT);
          hcb_->socket_->sslAccept(hcb_);
          // This registers for an event
          EXPECT_EQ(sock->getSSLState(), AsyncSSLSocket::STATE_ACCEPTING);

          state = STATE_SUCCEEDED;
        },
        100);
  }
};

class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
 public:
  ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
    // We don't care if we get invoked or not.
    // The client may time out and give up before connAccepted() is even
    // called.
    state = STATE_SUCCEEDED;
  }

  void connAccepted(
      const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
    std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;

    // Just wait a while before closing the socket, so the client
    // will time out waiting for the handshake to complete.
    s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
  }
};

class BlockingWriteClient : private AsyncSSLSocket::HandshakeCB,
                            private AsyncTransport::WriteCallback {
 public:
  explicit BlockingWriteClient(AsyncSSLSocket::UniquePtr socket)
      : socket_(std::move(socket)), bufLen_(2500), iovCount_(2000) {
    // Fill buf_
    buf_ = std::make_unique<uint8_t[]>(bufLen_);
    for (uint32_t n = 0; n < sizeof(buf_); ++n) {
      buf_[n] = n % 0xff;
    }

    // Initialize iov_
    iov_ = std::make_unique<struct iovec[]>(iovCount_);
    for (uint32_t n = 0; n < iovCount_; ++n) {
      iov_[n].iov_base = buf_.get() + n;
      if (n & 0x1) {
        iov_[n].iov_len = n % bufLen_;
      } else {
        iov_[n].iov_len = bufLen_ - (n % bufLen_);
      }
    }

    socket_->sslConn(this, std::chrono::milliseconds(100));
  }

  struct iovec* getIovec() const { return iov_.get(); }
  uint32_t getIovecCount() const { return iovCount_; }

 private:
  void handshakeSuc(AsyncSSLSocket*) noexcept override {
    socket_->writev(this, iov_.get(), iovCount_);
  }
  void handshakeErr(
      AsyncSSLSocket*, const AsyncSocketException& ex) noexcept override {
    ADD_FAILURE() << "client handshake error: " << ex.what();
  }
  void writeSuccess() noexcept override { socket_->close(); }
  void writeErr(
      size_t bytesWritten, const AsyncSocketException& ex) noexcept override {
    ADD_FAILURE() << "client write error after " << bytesWritten
                  << " bytes: " << ex.what();
  }

  AsyncSSLSocket::UniquePtr socket_;
  uint32_t bufLen_;
  uint32_t iovCount_;
  std::unique_ptr<uint8_t[]> buf_;
  std::unique_ptr<struct iovec[]> iov_;
};

class BlockingWriteServer : private AsyncSSLSocket::HandshakeCB,
                            private AsyncTransport::ReadCallback {
 public:
  explicit BlockingWriteServer(AsyncSSLSocket::UniquePtr socket)
      : socket_(std::move(socket)), bufSize_(2500 * 2000), bytesRead_(0) {
    buf_ = std::make_unique<uint8_t[]>(bufSize_);
    socket_->sslAccept(this, std::chrono::milliseconds(100));
  }

  void checkBuffer(struct iovec* iov, uint32_t count) const {
    uint32_t idx = 0;
    for (uint32_t n = 0; n < count; ++n) {
      size_t bytesLeft = bytesRead_ - idx;
      int rc = memcmp(
          buf_.get() + idx,
          iov[n].iov_base,
          std::min(iov[n].iov_len, bytesLeft));
      if (rc != 0) {
        FAIL() << "buffer mismatch at iovec " << n << "/" << count
               << ": rc=" << rc;
      }
      if (iov[n].iov_len > bytesLeft) {
        FAIL() << "server did not read enough data: " << "ended at byte "
               << bytesLeft << "/" << iov[n].iov_len << " in iovec " << n << "/"
               << count;
      }

      idx += iov[n].iov_len;
    }
    if (idx != bytesRead_) {
      ADD_FAILURE() << "server read extra data: " << bytesRead_
                    << " bytes read; expected " << idx;
    }
  }

 private:
  void handshakeSuc(AsyncSSLSocket*) noexcept override {
    // Wait 10ms before reading, so the client's writes will initially block.
    socket_->getEventBase()->tryRunAfterDelay(
        [this] { socket_->setReadCB(this); }, 10);
  }
  void handshakeErr(
      AsyncSSLSocket*, const AsyncSocketException& ex) noexcept override {
    ADD_FAILURE() << "server handshake error: " << ex.what();
  }
  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
    *bufReturn = buf_.get() + bytesRead_;
    *lenReturn = bufSize_ - bytesRead_;
  }
  void readDataAvailable(size_t len) noexcept override {
    bytesRead_ += len;
    socket_->setReadCB(nullptr);
    socket_->getEventBase()->tryRunAfterDelay(
        [this] { socket_->setReadCB(this); }, 2);
  }
  void readEOF() noexcept override { socket_->close(); }
  void readErr(const AsyncSocketException& ex) noexcept override {
    ADD_FAILURE() << "server read error: " << ex.what();
  }

  AsyncSSLSocket::UniquePtr socket_;
  uint32_t bufSize_;
  uint32_t bytesRead_;
  std::unique_ptr<uint8_t[]> buf_;
};

class AlpnClient : private AsyncSSLSocket::HandshakeCB,
                   private AsyncTransport::WriteCallback {
 public:
  explicit AlpnClient(AsyncSSLSocket::UniquePtr socket)
      : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
    socket_->sslConn(this);
  }

  const unsigned char* nextProto;
  unsigned nextProtoLength;
  folly::Optional<AsyncSocketException> except;

 private:
  void handshakeSuc(AsyncSSLSocket*) noexcept override {
    socket_->getSelectedNextProtocol(&nextProto, &nextProtoLength);
  }
  void handshakeErr(
      AsyncSSLSocket*, const AsyncSocketException& ex) noexcept override {
    except = ex;
  }
  void writeSuccess() noexcept override { socket_->close(); }
  void writeErr(
      size_t bytesWritten, const AsyncSocketException& ex) noexcept override {
    ADD_FAILURE() << "client write error after " << bytesWritten
                  << " bytes: " << ex.what();
  }

  AsyncSSLSocket::UniquePtr socket_;
};

class AlpnServer : private AsyncSSLSocket::HandshakeCB,
                   private AsyncTransport::ReadCallback {
 public:
  explicit AlpnServer(AsyncSSLSocket::UniquePtr socket)
      : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
    socket_->sslAccept(this);
    socket_->enableClientHelloParsing();
  }

  const unsigned char* nextProto;
  unsigned nextProtoLength;
  folly::Optional<AsyncSocketException> except;
  const std::vector<std::string>& getClientAlpns() const {
    return socket_->getClientAlpns();
  }

 private:
  void handshakeSuc(AsyncSSLSocket*) noexcept override {
    socket_->getSelectedNextProtocol(&nextProto, &nextProtoLength);
  }
  void handshakeErr(
      AsyncSSLSocket*, const AsyncSocketException& ex) noexcept override {
    except = ex;
  }
  void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
    *lenReturn = 0;
  }
  void readDataAvailable(size_t /* len */) noexcept override {}
  void readEOF() noexcept override { socket_->close(); }
  void readErr(const AsyncSocketException& ex) noexcept override {
    ADD_FAILURE() << "server read error: " << ex.what();
  }

  AsyncSSLSocket::UniquePtr socket_;
};

class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
                            public AsyncTransport::ReadCallback {
 public:
  explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
      : socket_(std::move(socket)) {
    socket_->sslAccept(this);
  }

  ~RenegotiatingServer() override { socket_->setReadCB(nullptr); }

  void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
    LOG(INFO) << "Renegotiating server handshake success";
    socket_->setReadCB(this);
  }
  void handshakeErr(
      AsyncSSLSocket*, const AsyncSocketException& ex) noexcept override {
    ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
  }
  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
    *lenReturn = sizeof(buf);
    *bufReturn = buf;
  }
  void readDataAvailable(size_t /* len */) noexcept override {}
  void readEOF() noexcept override {}
  void readErr(const AsyncSocketException& ex) noexcept override {
    LOG(INFO) << "server got read error " << ex.what();
    auto exPtr = dynamic_cast<const SSLException*>(&ex);
    ASSERT_NE(nullptr, exPtr);
    std::string exStr(ex.what());
    SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
    ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
    renegotiationError_ = true;
  }

  AsyncSSLSocket::UniquePtr socket_;
  unsigned char buf[128];
  bool renegotiationError_{false};
};

class SNIClient : private AsyncSSLSocket::HandshakeCB,
                  private AsyncTransport::WriteCallback {
 public:
  explicit SNIClient(AsyncSSLSocket::UniquePtr socket)
      : serverNameMatch(false), socket_(std::move(socket)) {
    socket_->sslConn(this);
  }

  std::string getApplicationProtocol() {
    return socket_->getApplicationProtocol();
  }

  bool serverNameMatch;

 private:
  void handshakeSuc(AsyncSSLSocket*) noexcept override {
    serverNameMatch = socket_->isServerNameMatch();
  }
  void handshakeErr(
      AsyncSSLSocket*, const AsyncSocketException& ex) noexcept override {
    ADD_FAILURE() << "client handshake error: " << ex.what();
  }
  void writeSuccess() noexcept override { socket_->close(); }
  void writeErr(
      size_t bytesWritten, const AsyncSocketException& ex) noexcept override {
    ADD_FAILURE() << "client write error after " << bytesWritten
                  << " bytes: " << ex.what();
  }

  AsyncSSLSocket::UniquePtr socket_;
};

class SNIServer : private AsyncSSLSocket::HandshakeCB,
                  private AsyncTransport::ReadCallback {
 public:
  explicit SNIServer(
      AsyncSSLSocket::UniquePtr socket,
      const std::shared_ptr<folly::SSLContext>& ctx,
      const std::shared_ptr<folly::SSLContext>& sniCtx,
      const std::string& expectedServerName)
      : serverNameMatch(false),
        socket_(std::move(socket)),
        sniCtx_(sniCtx),
        expectedServerName_(expectedServerName) {
    ctx->setServerNameCallback(
        std::bind(&SNIServer::serverNameCallback, this, std::placeholders::_1));
    socket_->sslAccept(this);
  }

  std::string getApplicationProtocol() {
    return socket_->getApplicationProtocol();
  }

  bool serverNameMatch;

 private:
  void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
  void handshakeErr(
      AsyncSSLSocket*, const AsyncSocketException& ex) noexcept override {
    ADD_FAILURE() << "server handshake error: " << ex.what();
  }
  void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
    *lenReturn = 0;
  }
  void readDataAvailable(size_t /* len */) noexcept override {}
  void readEOF() noexcept override { socket_->close(); }
  void readErr(const AsyncSocketException& ex) noexcept override {
    ADD_FAILURE() << "server read error: " << ex.what();
  }

  folly::SSLContext::ServerNameCallbackResult serverNameCallback(SSL* ssl) {
    const char* sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
    if (sniCtx_ && sn && !strcasecmp(expectedServerName_.c_str(), sn)) {
      AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl);
      sslSocket->switchServerSSLContext(sniCtx_);
      serverNameMatch = true;
      return folly::SSLContext::SERVER_NAME_FOUND;
    } else {
      serverNameMatch = false;
      return folly::SSLContext::SERVER_NAME_NOT_FOUND;
    }
  }

  AsyncSSLSocket::UniquePtr socket_;
  std::shared_ptr<folly::SSLContext> sniCtx_;
  std::string expectedServerName_;
};

class SSLClient : public AsyncSocket::ConnectCallback,
                  public AsyncTransport::WriteCallback,
                  public AsyncTransport::ReadCallback {
 private:
  EventBase* eventBase_;
  std::shared_ptr<AsyncSSLSocket> sslSocket_;
  std::shared_ptr<folly::ssl::SSLSession> session_;
  std::shared_ptr<folly::SSLContext> ctx_;
  uint32_t requests_;
  folly::SocketAddress address_;
  uint32_t timeout_;
  char buf_[128];
  char readbuf_[128];
  uint32_t bytesRead_;
  uint32_t hit_;
  uint32_t miss_;
  uint32_t errors_;
  uint32_t writeAfterConnectErrors_;

  // These settings test that we eventually drain the
  // socket, even if the maxReadsPerEvent_ is hit during
  // a event loop iteration.
  static constexpr size_t kMaxReadsPerEvent = 2;
  // 2 event loop iterations
  static constexpr size_t kMaxReadBufferSz =
      sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;

 public:
  SSLClient(
      EventBase* eventBase,
      const folly::SocketAddress& address,
      uint32_t requests,
      uint32_t timeout = 0)
      : eventBase_(eventBase),
        session_(nullptr),
        requests_(requests),
        address_(address),
        timeout_(timeout),
        bytesRead_(0),
        hit_(0),
        miss_(0),
        errors_(0),
        writeAfterConnectErrors_(0) {
    ctx_.reset(new folly::SSLContext());
    ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
    memset(buf_, 'a', sizeof(buf_));
  }

  ~SSLClient() override {
    if (errors_ == 0) {
      EXPECT_EQ(bytesRead_, sizeof(buf_));
    }
  }

  uint32_t getHit() const { return hit_; }

  uint32_t getMiss() const { return miss_; }

  uint32_t getErrors() const { return errors_; }

  uint32_t getWriteAfterConnectErrors() const {
    return writeAfterConnectErrors_;
  }

  void setSSLOptions(long options) { ctx_->setOptions(options); }

  void connect(bool writeNow = false) {
    sslSocket_ = AsyncSSLSocket::newSocket(ctx_, eventBase_);
    if (session_ != nullptr) {
      sslSocket_->setSSLSession(session_);
    }
    requests_--;
    sslSocket_->connect(this, address_, timeout_);
    if (sslSocket_ && writeNow) {
      // write some junk, used in an error test
      sslSocket_->write(this, buf_, sizeof(buf_));
    }
  }

  void connectSuccess() noexcept override {
    std::cerr << "client SSL socket connected" << std::endl;
    if (sslSocket_->getSSLSessionReused()) {
      hit_++;
    } else {
      miss_++;
      session_ = sslSocket_->getSSLSession();
    }

    // write()
    sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
    sslSocket_->write(this, buf_, sizeof(buf_));
    sslSocket_->setReadCB(this);
    memset(readbuf_, 'b', sizeof(readbuf_));
    bytesRead_ = 0;
  }

  void connectErr(const AsyncSocketException& ex) noexcept override {
    std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
    errors_++;
    sslSocket_.reset();
  }

  void writeSuccess() noexcept override {
    std::cerr << "client write success" << std::endl;
  }

  void writeErr(
      size_t /* bytesWritten */,
      const AsyncSocketException& ex) noexcept override {
    std::cerr << "client writeError: " << ex.what() << std::endl;
    if (!sslSocket_) {
      writeAfterConnectErrors_++;
    }
  }

  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
    *bufReturn = readbuf_ + bytesRead_;
    *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
  }

  void readEOF() noexcept override {
    std::cerr << "client readEOF" << std::endl;
  }

  void readErr(const AsyncSocketException& ex) noexcept override {
    std::cerr << "client readError: " << ex.what() << std::endl;
  }

  void readDataAvailable(size_t len) noexcept override {
    std::cerr << "client read data: " << len << std::endl;
    bytesRead_ += len;
    if (bytesRead_ == sizeof(buf_)) {
      EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
      sslSocket_->closeNow();
      sslSocket_.reset();
      if (requests_ != 0) {
        connect();
      }
    }
  }
};

class SSLHandshakeBase : public AsyncSSLSocket::HandshakeCB,
                         private AsyncTransport::WriteCallback {
 public:
  explicit SSLHandshakeBase(
      AsyncSSLSocket::UniquePtr socket, bool preverifyResult, bool verifyResult)
      : handshakeVerify_(false),
        handshakeSuccess_(false),
        handshakeError_(false),
        socket_(std::move(socket)),
        preverifyResult_(preverifyResult),
        verifyResult_(verifyResult) {}

  AsyncSSLSocket::UniquePtr moveSocket() && { return std::move(socket_); }

  bool handshakeVerify_;
  bool handshakeSuccess_;
  bool handshakeError_;
  int handshakeVerifyInvocations_{};
  std::chrono::nanoseconds handshakeTime;

 protected:
  AsyncSSLSocket::UniquePtr socket_;
  bool preverifyResult_;
  bool verifyResult_;

  // HandshakeCallback
  bool handshakeVer(
      AsyncSSLSocket* /* sock */,
      bool preverifyOk,
      X509_STORE_CTX* /* ctx */) noexcept override {
    auto invocation = handshakeVerifyInvocations_++;

    if (invocation == 0) {
      handshakeVerify_ = true;
      EXPECT_EQ(preverifyResult_, preverifyOk);
    }
    return verifyResult_;
  }

  void handshakeSuc(AsyncSSLSocket*) noexcept override {
    LOG(INFO) << "Handshake success";
    handshakeSuccess_ = true;
    if (socket_) {
      handshakeTime = socket_->getHandshakeTime();
    }
  }

  void handshakeErr(
      AsyncSSLSocket*, const AsyncSocketException& ex) noexcept override {
    LOG(INFO) << "Handshake error " << ex.what();
    handshakeError_ = true;
    if (socket_) {
      handshakeTime = socket_->getHandshakeTime();
    }
  }

  // WriteCallback
  void writeSuccess() noexcept override {
    if (socket_) {
      socket_->close();
    }
  }

  void writeErr(
      size_t bytesWritten, const AsyncSocketException& ex) noexcept override {
    ADD_FAILURE() << "client write error after " << bytesWritten
                  << " bytes: " << ex.what();
  }
};

class SSLHandshakeClient : public SSLHandshakeBase {
 public:
  SSLHandshakeClient(
      AsyncSSLSocket::UniquePtr socket, bool preverifyResult, bool verifyResult)
      : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
    socket_->sslConn(this, std::chrono::milliseconds::zero());
  }
};

class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
 public:
  SSLHandshakeClientNoVerify(
      AsyncSSLSocket::UniquePtr socket, bool preverifyResult, bool verifyResult)
      : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
    socket_->sslConn(
        this,
        std::chrono::milliseconds::zero(),
        folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
  }
};

class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
 public:
  SSLHandshakeClientDoVerify(
      AsyncSSLSocket::UniquePtr socket, bool preverifyResult, bool verifyResult)
      : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
    socket_->sslConn(
        this,
        std::chrono::milliseconds::zero(),
        folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
  }
};

class SSLHandshakeServer : public SSLHandshakeBase {
 public:
  SSLHandshakeServer(
      AsyncSSLSocket::UniquePtr socket, bool preverifyResult, bool verifyResult)
      : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
    socket_->sslAccept(this, std::chrono::milliseconds::zero());
  }
};

class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
 public:
  SSLHandshakeServerParseClientHello(
      AsyncSSLSocket::UniquePtr socket, bool preverifyResult, bool verifyResult)
      : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
    socket_->enableClientHelloParsing();
    socket_->sslAccept(this, std::chrono::milliseconds::zero());
  }

  std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;

 protected:
  void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
    handshakeSuccess_ = true;
    sock->getSSLSharedCiphers(sharedCiphers_);
    sock->getSSLServerCiphers(serverCiphers_);
    sock->getSSLClientCiphers(clientCiphers_);
    chosenCipher_ = sock->getNegotiatedCipherName();
  }
};

class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
 public:
  SSLHandshakeServerNoVerify(
      AsyncSSLSocket::UniquePtr socket, bool preverifyResult, bool verifyResult)
      : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
    socket_->sslAccept(
        this,
        std::chrono::milliseconds::zero(),
        folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
  }
};

class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
 public:
  SSLHandshakeServerDoVerify(
      AsyncSSLSocket::UniquePtr socket, bool preverifyResult, bool verifyResult)
      : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
    socket_->sslAccept(
        this,
        std::chrono::milliseconds::zero(),
        folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
  }
};

class EventBaseAborter : public AsyncTimeout {
 public:
  EventBaseAborter(EventBase* eventBase, uint32_t timeoutMS)
      : AsyncTimeout(eventBase, AsyncTimeout::InternalEnum::INTERNAL),
        eventBase_(eventBase) {
    scheduleTimeout(timeoutMS);
  }

  void timeoutExpired() noexcept override {
    FAIL() << "test timed out";
    eventBase_->terminateLoopSoon();
  }

 private:
  EventBase* eventBase_;
};

class SSLAcceptEvbRunner : public SSLAcceptRunner {
 public:
  explicit SSLAcceptEvbRunner(EventBase* evb) : evb_(evb) {}
  ~SSLAcceptEvbRunner() override = default;

  void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
      const override {
    evb_->runInLoop([acceptFunc = std::move(acceptFunc),
                     finallyFunc = std::move(finallyFunc)]() mutable {
      finallyFunc(acceptFunc());
    });
  }

 protected:
  EventBase* evb_;
};

class SSLAcceptErrorRunner : public SSLAcceptEvbRunner {
 public:
  explicit SSLAcceptErrorRunner(EventBase* evb) : SSLAcceptEvbRunner(evb) {}
  ~SSLAcceptErrorRunner() override = default;

  void run(Function<int()> /*acceptFunc*/, Function<void(int)> finallyFunc)
      const override {
    evb_->runInLoop(
        [finallyFunc = std::move(finallyFunc)]() mutable { finallyFunc(-1); });
  }
};

class SSLAcceptCloseRunner : public SSLAcceptEvbRunner {
 public:
  explicit SSLAcceptCloseRunner(EventBase* evb, folly::AsyncSSLSocket* sock)
      : SSLAcceptEvbRunner(evb), socket_(sock) {}
  ~SSLAcceptCloseRunner() override = default;

  void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
      const override {
    evb_->runInLoop([acceptFunc = std::move(acceptFunc),
                     finallyFunc = std::move(finallyFunc),
                     sock = socket_]() mutable {
      auto ret = acceptFunc();
      sock->closeNow();
      finallyFunc(ret);
    });
  }

 private:
  folly::AsyncSSLSocket* socket_;
};

class SSLAcceptDestroyRunner : public SSLAcceptEvbRunner {
 public:
  explicit SSLAcceptDestroyRunner(EventBase* evb, SSLHandshakeBase* base)
      : SSLAcceptEvbRunner(evb), sslBase_(base) {}
  ~SSLAcceptDestroyRunner() override = default;

  void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
      const override {
    evb_->runInLoop([acceptFunc = std::move(acceptFunc),
                     finallyFunc = std::move(finallyFunc),
                     sslBase = sslBase_]() mutable {
      auto ret = acceptFunc();
      std::move(*sslBase).moveSocket();
      finallyFunc(ret);
    });
  }

 private:
  SSLHandshakeBase* sslBase_;
};

class SSLAcceptFiberRunner : public SSLAcceptEvbRunner {
 public:
  explicit SSLAcceptFiberRunner(EventBase* evb) : SSLAcceptEvbRunner(evb) {}
  ~SSLAcceptFiberRunner() override = default;

  void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
      const override {
    auto& fiberManager = folly::fibers::getFiberManager(*evb_);
    fiberManager.addTaskFinally(
        std::move(acceptFunc),
        [finally = std::move(finallyFunc)](folly::Try<int>&& res) mutable {
          finally(res.value());
        });
  }
};

} // namespace folly::test