folly/folly/io/async/test/BlockingSocket.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <folly/Optional.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/SSLContext.h>
#include <folly/net/NetworkSocket.h>

namespace folly::test {

class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
                       public folly::AsyncTransport::ReadCallback,
                       public folly::AsyncTransport::WriteCallback {
 public:
  explicit BlockingSocket(folly::NetworkSocket fd)
      : sock_(new folly::AsyncSocket(&eventBase_, fd)) {}

  BlockingSocket(
      folly::SocketAddress address,
      std::shared_ptr<folly::SSLContext> sslContext)
      : sock_(
            sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_)
                       : new folly::AsyncSocket(&eventBase_)),
        address_(address) {}

  explicit BlockingSocket(folly::AsyncSocket::UniquePtr socket)
      : sock_(std::move(socket)) {
    sock_->attachEventBase(&eventBase_);
  }

  void enableTFO() { sock_->enableTFO(); }

  void setEorTracking(bool track) { sock_->setEorTracking(track); }

  void setAddress(folly::SocketAddress address) { address_ = address; }

  void open(
      std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) {
    DCHECK_LE(timeout.count(), std::numeric_limits<int>::max());
    sock_->connect(this, address_, folly::to_narrow(timeout.count()));
    eventBase_.loop();
    if (err_.has_value()) {
      throw err_.value();
    }
  }

  void close() { sock_->close(); }
  void closeWithReset() { sock_->closeWithReset(); }

  int32_t write(
      uint8_t const* buf,
      size_t len,
      folly::WriteFlags flags = folly::WriteFlags::NONE) {
    sock_->write(this, buf, len, flags);
    eventBase_.loop();
    if (err_.has_value()) {
      throw err_.value();
    }
    return folly::to_narrow(folly::to_signed(len));
  }

  void writev(
      const iovec* vec,
      size_t count,
      folly::WriteFlags flags = folly::WriteFlags::NONE) {
    sock_->writev(this, vec, count, flags);
    eventBase_.loop();
    if (err_.has_value()) {
      throw err_.value();
    }
  }

  void flush() {}

  int32_t readAll(uint8_t* buf, size_t len) {
    return readHelper(buf, len, true);
  }

  int32_t read(uint8_t* buf, size_t len) { return readHelper(buf, len, false); }

  int32_t readNoBlock(uint8_t* buf, size_t len) {
    return readHelper(buf, len, false, EVLOOP_NONBLOCK);
  }

  folly::NetworkSocket getNetworkSocket() const {
    return sock_->getNetworkSocket();
  }

  folly::AsyncSocket* getSocket() { return sock_.get(); }

  folly::AsyncSSLSocket* getSSLSocket() {
    return dynamic_cast<folly::AsyncSSLSocket*>(sock_.get());
  }

 private:
  folly::EventBase eventBase_;
  folly::AsyncSocket::UniquePtr sock_;
  folly::Optional<folly::AsyncSocketException> err_;
  uint8_t* readBuf_{nullptr};
  size_t readLen_{0};
  folly::SocketAddress address_;

  void connectSuccess() noexcept override {}
  void connectErr(const folly::AsyncSocketException& ex) noexcept override {
    err_ = ex;
  }
  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
    *bufReturn = readBuf_;
    *lenReturn = readLen_;
  }
  void readDataAvailable(size_t len) noexcept override {
    readBuf_ += len;
    readLen_ -= len;

    if (readLen_ == 0) {
      sock_->setReadCB(nullptr);
    }
  }
  void getReadBuffers(folly::IOBufIovecBuilder::IoVecVec& iovs) override {
    // we reuse the same readBuf_
    iovs.clear();
    for (size_t i = 0; i < readLen_; i++) {
      struct iovec iov;
      iov.iov_base = &readBuf_[i];
      iov.iov_len = 1;
      iovs.push_back(iov);
    }
  }
  void readEOF() noexcept override {}
  void readErr(const folly::AsyncSocketException& ex) noexcept override {
    err_ = ex;
  }
  void writeSuccess() noexcept override {}
  void writeErr(
      size_t /* bytesWritten */,
      const folly::AsyncSocketException& ex) noexcept override {
    err_ = ex;
  }

  int32_t readHelper(uint8_t* buf, size_t len, bool all, int flags = 0) {
    if (!sock_->good()) {
      return 0;
    }
    readBuf_ = buf;
    readLen_ = len;
    sock_->setReadCB(this);
    while (!err_ && sock_->good() && readLen_ > 0) {
      eventBase_.loopOnce(flags);
      if (!all) {
        break;
      }
    }
    sock_->setReadCB(nullptr);
    if (err_.has_value()) {
      throw err_.value();
    }
    if (all && readLen_ > 0) {
      throw folly::AsyncSocketException(
          folly::AsyncSocketException::UNKNOWN, "eof");
    }
    return folly::to_narrow(folly::to_signed(len - readLen_));
  }
};

} // namespace folly::test