folly/folly/io/async/test/TestSSLServer.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <fcntl.h>
#include <sys/types.h>

#include <list>

#include <folly/SocketAddress.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/AsyncTransport.h>
#include <folly/io/async/EventBase.h>
#include <folly/io/async/ssl/SSLErrors.h>
#include <folly/io/async/test/CallbackStateEnum.h>
#include <folly/portability/GTest.h>
#include <folly/portability/Sockets.h>
#include <folly/portability/Unistd.h>
#include <folly/testing/TestUtil.h>

namespace folly::test {

extern const char* kTestCert;
extern const char* kTestKey;
extern const char* kTestCA;
extern const char* kTestCertCN;

extern const char* kClientTestCert;
extern const char* kClientTestKey;
extern const char* kClientTestCA;
extern const char* kClientTestChain;

class HandshakeCallback;

class SSLServerAcceptCallbackBase : public AsyncServerSocket::AcceptCallback {
 public:
  explicit SSLServerAcceptCallbackBase(HandshakeCallback* hcb)
      : state(STATE_WAITING), hcb_(hcb) {}

  ~SSLServerAcceptCallbackBase() override { EXPECT_EQ(STATE_SUCCEEDED, state); }

  void acceptError(folly::exception_wrapper ex) noexcept override {
    LOG(WARNING) << "SSLServerAcceptCallbackBase::acceptError " << ex;
    state = STATE_FAILED;
  }

  void connectionAccepted(
      folly::NetworkSocket fd,
      const SocketAddress& clientAddr,
      AcceptInfo /* info */) noexcept override {
    if (socket_) {
      socket_->detachEventBase();
    }
    LOG(INFO) << "Connection accepted";
    try {
      // Create a AsyncSSLSocket object with the fd. The socket should be
      // added to the event base and in the state of accepting SSL connection.
      socket_ = AsyncSSLSocket::newSocket(
          ctx_,
          base_,
          fd,
          /*server=*/true,
          /*deferSecurityNegotiation=*/false,
          &clientAddr);
    } catch (const std::exception& e) {
      LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
                    "object with socket "
                 << e.what() << fd;
      folly::netops::close(fd);
      acceptError(e);
      return;
    }

    connAccepted(socket_);
  }

  virtual void connAccepted(const std::shared_ptr<AsyncSSLSocket>& s) = 0;

  void detach() {
    if (socket_) {
      socket_->detachEventBase();
    }
  }

  StateEnum state;
  HandshakeCallback* hcb_;
  std::shared_ptr<SSLContext> ctx_;
  std::shared_ptr<AsyncSSLSocket> socket_;
  EventBase* base_;
};

class TestSSLServer {
 public:
  static std::unique_ptr<SSLContext> getDefaultSSLContext();

  // Create a TestSSLServer.
  // This immediately starts listening on the given port.
  explicit TestSSLServer(
      SSLServerAcceptCallbackBase* acb, bool enableTFO = false);
  explicit TestSSLServer(
      SSLServerAcceptCallbackBase* acb,
      std::shared_ptr<SSLContext> ctx,
      bool enableTFO = false);

  // Kills the thread.
  virtual ~TestSSLServer();

  EventBase& getEventBase() { return evb_; }

  void loadTestCerts();

  const SocketAddress& getAddress() const { return address_; }

 protected:
  EventBase evb_;
  std::shared_ptr<SSLContext> ctx_;
  SSLServerAcceptCallbackBase* acb_;
  std::shared_ptr<AsyncServerSocket> socket_;
  SocketAddress address_;
  std::thread thread_;

 private:
  void init(bool);
};
} // namespace folly::test