folly/folly/io/async/test/AsyncSSLSocketTest2.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <folly/io/async/AsyncSSLSocket.h>

#include <folly/futures/Promise.h>
#include <folly/init/Init.h>
#include <folly/io/async/EventBase.h>
#include <folly/io/async/SSLContext.h>
#include <folly/io/async/ScopedEventBaseThread.h>
#include <folly/io/async/test/AsyncSSLSocketTest.h>
#include <folly/portability/GTest.h>
#include <folly/portability/PThread.h>

using std::cerr;
using std::endl;

using namespace folly;
using namespace folly::test;

struct EvbAndContext {
  EvbAndContext() {
    ctx_.reset(new SSLContext());
    ctx_->setOptions(SSL_OP_NO_TICKET);
    ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
  }

  std::shared_ptr<AsyncSSLSocket> createSocket() {
    return AsyncSSLSocket::newSocket(ctx_, getEventBase());
  }

  EventBase* getEventBase() { return evb_.getEventBase(); }

  void attach(AsyncSSLSocket& socket) {
    socket.attachEventBase(getEventBase());
    socket.attachSSLContext(ctx_);
  }

  folly::ScopedEventBaseThread evb_;
  std::shared_ptr<SSLContext> ctx_;
};

class AttachDetachClient : public AsyncSocket::ConnectCallback,
                           public AsyncTransport::WriteCallback,
                           public AsyncTransport::ReadCallback {
 private:
  // two threads here - we'll create the socket in one, connect
  // in the other, and then read/write in the initial one
  EvbAndContext t1_;
  EvbAndContext t2_;
  std::shared_ptr<AsyncSSLSocket> sslSocket_;
  folly::SocketAddress address_;
  char buf_[128];
  char readbuf_[128];
  uint32_t bytesRead_;
  // promise to fulfill when done
  folly::Promise<bool> promise_;

  void detach() {
    sslSocket_->detachEventBase();
    sslSocket_->detachSSLContext();
  }

 public:
  explicit AttachDetachClient(const folly::SocketAddress& address)
      : address_(address), bytesRead_(0) {}

  Future<bool> getFuture() { return promise_.getFuture(); }

  void connect() {
    // create in one and then move to another
    auto t1Evb = t1_.getEventBase();
    t1Evb->runInEventBaseThread([this] {
      sslSocket_ = t1_.createSocket();
      // ensure we can detach and reattach the context before connecting
      for (int i = 0; i < 1000; ++i) {
        sslSocket_->detachSSLContext();
        sslSocket_->attachSSLContext(t1_.ctx_);
      }
      // detach from t1 and connect in t2
      detach();
      auto t2Evb = t2_.getEventBase();
      t2Evb->runInEventBaseThread([this] {
        t2_.attach(*sslSocket_);
        sslSocket_->connect(this, address_);
      });
    });
  }

  void connectSuccess() noexcept override {
    auto t2Evb = t2_.getEventBase();
    EXPECT_TRUE(t2Evb->isInEventBaseThread());
    cerr << "client SSL socket connected" << endl;
    for (int i = 0; i < 1000; ++i) {
      sslSocket_->detachSSLContext();
      sslSocket_->attachSSLContext(t2_.ctx_);
    }

    // detach from t2 and then read/write in t1
    t2Evb->runInEventBaseThread([this] {
      detach();
      auto t1Evb = t1_.getEventBase();
      t1Evb->runInEventBaseThread([this] {
        t1_.attach(*sslSocket_);
        sslSocket_->write(this, buf_, sizeof(buf_));
        sslSocket_->setReadCB(this);
        memset(readbuf_, 'b', sizeof(readbuf_));
        bytesRead_ = 0;
      });
    });
  }

  void connectErr(const AsyncSocketException& ex) noexcept override {
    cerr << "AttachDetachClient::connectError: " << ex.what() << endl;
    sslSocket_.reset();
  }

  void writeSuccess() noexcept override {
    cerr << "client write success" << endl;
  }

  void writeErr(
      size_t /* bytesWritten */,
      const AsyncSocketException& ex) noexcept override {
    cerr << "client writeError: " << ex.what() << endl;
  }

  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
    *bufReturn = readbuf_ + bytesRead_;
    *lenReturn = sizeof(readbuf_) - bytesRead_;
  }
  void readEOF() noexcept override { cerr << "client readEOF" << endl; }

  void readErr(const AsyncSocketException& ex) noexcept override {
    cerr << "client readError: " << ex.what() << endl;
    promise_.setException(ex);
  }

  void readDataAvailable(size_t len) noexcept override {
    EXPECT_TRUE(t1_.getEventBase()->isInEventBaseThread());
    EXPECT_EQ(sslSocket_->getEventBase(), t1_.getEventBase());
    cerr << "client read data: " << len << endl;
    bytesRead_ += len;
    if (len == sizeof(buf_)) {
      EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
      sslSocket_->closeNow();
      sslSocket_.reset();
      promise_.setValue(true);
    }
  }
};

/**
 * Test passing contexts between threads
 */
TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
  // Start listening on a local port
  WriteCallbackBase writeCallback;
  ReadCallback readCallback(&writeCallback);
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
  TestSSLServer server(&acceptCallback);

  std::shared_ptr<AttachDetachClient> client(
      new AttachDetachClient(server.getAddress()));

  auto f = client->getFuture();
  client->connect();
  EXPECT_TRUE(std::move(f).within(std::chrono::seconds(3)).get());
}

class ConnectClient : public AsyncSocket::ConnectCallback {
 public:
  ConnectClient() = default;

  Future<bool> getFuture() { return promise_.getFuture(); }

  void connect(const folly::SocketAddress& addr) {
    t1_.getEventBase()->runInEventBaseThread([&] {
      socket_ = t1_.createSocket();
      socket_->connect(this, addr);
    });
  }

  void connectSuccess() noexcept override {
    socket_.reset();
    promise_.setValue(true);
  }

  void connectErr(const AsyncSocketException& /* ex */) noexcept override {
    socket_.reset();
    promise_.setValue(false);
  }

  void setCtx(std::shared_ptr<SSLContext> ctx) { t1_.ctx_ = ctx; }

 private:
  EvbAndContext t1_;
  // promise to fulfill when done with a value of true if connect succeeded
  folly::Promise<bool> promise_;
  std::shared_ptr<AsyncSSLSocket> socket_;
};

class NoopReadCallback : public ReadCallbackBase {
 public:
  NoopReadCallback() : ReadCallbackBase(nullptr) { state = STATE_SUCCEEDED; }

  void getReadBuffer(void** buf, size_t* lenReturn) override {
    *buf = &buffer_;
    *lenReturn = 1;
  }
  void readDataAvailable(size_t) noexcept override {}

  uint8_t buffer_{0};
};

TEST(AsyncSSLSocketTest2, TestTLS12DefaultClient) {
  // Start listening on a local port
  NoopReadCallback readCallback;
  HandshakeCallback handshakeCallback(&readCallback);
  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
  auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
  TestSSLServer server(&acceptCallback, ctx);
  server.loadTestCerts();

  // create a default client
  auto c1 = std::make_unique<ConnectClient>();
  auto f1 = c1->getFuture();
  c1->connect(server.getAddress());
  EXPECT_TRUE(std::move(f1).within(std::chrono::seconds(3)).get());
}

// Pre-TLS 1.2 client attempting to connect to a TLS 1.2+ server, should not be
// able to connect.
TEST(AsyncSSLSocketTest2, TestLegacyClientCannotConnectToTLS12Server) {
  // Start listening on a local port
  NoopReadCallback readCallback;
  HandshakeCallback handshakeCallback(
      &readCallback, HandshakeCallback::EXPECT_ERROR);
  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
  auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
  TestSSLServer server(&acceptCallback, ctx);
  server.loadTestCerts();

  // create a client that doesn't speak TLS 1.2+
  auto c2 = std::make_unique<ConnectClient>();
  auto clientCtx = std::make_shared<SSLContext>(SSLContext::TLSv1);
  clientCtx->setOptions(SSL_OP_NO_TLSv1_2);
  clientCtx->disableTLS13();
  c2->setCtx(clientCtx);
  auto f2 = c2->getFuture();
  c2->connect(server.getAddress());
  EXPECT_FALSE(std::move(f2).within(std::chrono::seconds(3)).get());
}

int main(int argc, char* argv[]) {
#ifdef SIGPIPE
  signal(SIGPIPE, SIG_IGN);
#endif
  testing::InitGoogleTest(&argc, argv);
  folly::Init init(&argc, &argv);
  return RUN_ALL_TESTS();
  OPENSSL_cleanup();
}