folly/folly/io/async/test/AsyncIoUringSocketTest.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <array>
#include <chrono>
#include <map>
#include <random>
#include <vector>

#include <folly/FileUtil.h>
#include <folly/Subprocess.h>
#include <folly/executors/GlobalExecutor.h>
#include <folly/experimental/io/AsyncIoUringSocket.h>
#include <folly/experimental/io/IoUringBackend.h>
#include <folly/experimental/io/IoUringEvent.h>
#include <folly/futures/Future.h>
#include <folly/futures/Promise.h>
#include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/EventBase.h>
#include <folly/portability/GTest.h>
#include <folly/system/Shell.h>
#include <folly/test/SocketAddressTestHelper.h>

namespace folly {

namespace {

static constexpr std::chrono::milliseconds kTimeout{30000};
static constexpr size_t kBufferSize{1024};

std::string toString(std::unique_ptr<IOBuf> const& buf) {
  if (!buf) {
    return std::string();
  }
  auto coalesced = buf->coalesce();
  return std::string((char const*)coalesced.data(), coalesced.size());
}

class NullWriteCallback : public AsyncWriter::WriteCallback {
 public:
  void writeSuccess() noexcept override {}

  void writeErr(
      size_t bytesWritten, const AsyncSocketException& ex) noexcept override {
    LOG(FATAL) << "writeErr wrote=" << (int)bytesWritten << " " << ex;
  }
};

static NullWriteCallback nullWriteCallback;

class FutureWriteCallback : public AsyncWriter::WriteCallback {
 public:
  void writeSuccess() noexcept override {
    auto& [promise, future] = promiseContract;
    promise.setValue(Unit{});
  }

  void writeErr(
      size_t bytesWritten, const AsyncSocketException& ex) noexcept override {
    auto& [promise, future] = promiseContract;
    promise.setValue(makeUnexpected(std::make_pair(bytesWritten, ex)));
  }

  using TResult = Expected<Unit, std::pair<size_t, AsyncSocketException>>;
  SemiPromiseContract<TResult> promiseContract = makePromiseContract<TResult>();
};

} // namespace

class EchoTransport : public AsyncReader::ReadCallback,
                      public AsyncWriter::WriteCallback {
 public:
  explicit EchoTransport(AsyncSocketTransport::UniquePtr s, bool bm)
      : transport(std::move(s)), bufferMovable(bm) {}

  void start() { transport->setReadCB(this); }

  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
    *bufReturn = buff.data();
    *lenReturn = buff.size();
  }

  void writeSuccess() noexcept override {}

  void writeErr(
      size_t bytesWritten, const AsyncSocketException& ex) noexcept override {
    LOG_EVERY_N(ERROR, 10000) << "writeErr " << bytesWritten << " " << ex;
    transport->close();
  }

  void readEOF() noexcept override {
    VLOG(1) << "Closing transport!";
    transport->close();
  }

  void readErr(const AsyncSocketException& ex) noexcept override {
    LOG(ERROR) << "readErr " << ex;
    transport->close();
  }

  void readDataAvailable(size_t len) noexcept override {
    VLOG(1) << "readDataAvailable " << len;
    // have to copy as buff will be reused after
    transport->writeChain(this, IOBuf::copyBuffer(buff.data(), len));
  }

  bool isBufferMovable() noexcept override { return bufferMovable; }

  void readBufferAvailable(std::unique_ptr<IOBuf> readBuf) noexcept override {
    VLOG(1) << "readBuffer available " << readBuf->computeChainDataLength();
    transport->writeChain(this, std::move(readBuf));
  }

  AsyncSocketTransport::UniquePtr transport;
  bool bufferMovable;
  std::array<char, kBufferSize> buff;
};

class CollectCallback : public AsyncReader::ReadCallback {
 public:
  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
    *bufReturn = buff.data();
    *lenReturn = buff.size();
  }

  void readEOF() noexcept override { clear(); }

  void readErr(const AsyncSocketException& ex) noexcept override {
    if (promise) {
      promise->second.setException(ex);
    }
    clear();
  }

  void readDataAvailable(size_t len) noexcept override {
    VLOG(1) << "CollectCallback::readDataAvailable " << len;
    if (hadBuffers.has_value() && *hadBuffers) {
      LOG(FATAL) << "must either send buffers or use getReadBuffer";
    }
    hadBuffers = false;
    data += std::string(buff.data(), len);
    dataAvailable();
  }

  SemiFuture<std::string> waitFor(size_t n) {
    auto [p, f] = makePromiseContract<std::string>();
    promise = std::make_pair(n, std::move(p));
    return std::move(f).within(kTimeout);
  }

  void readBufferAvailable(std::unique_ptr<IOBuf> readBuf) noexcept override {
    if (hadBuffers.has_value() && !*hadBuffers) {
      LOG(FATAL) << "must either send buffers or use getReadBuffer";
    }
    hadBuffers = true;

    if (holdData) {
      auto cloned = readBuf->clone();
      if (bufData) {
        bufData->appendToChain(std::move(readBuf));
      } else {
        bufData = std::move(readBuf);
      }
      readBuf = std::move(cloned);
    }

    data += toString(readBuf);
    dataAvailable();
  }

  void clear() {
    data.clear();
    promise.reset();
    bufData.reset();
  }

  void dataAvailable() {
    VLOG(1) << "CollectCallback::dataAvailable have=" << data.size()
            << " want=" << (promise ? static_cast<int>(promise->first) : -1);
    if (!promise || promise->first > data.size()) {
      return;
    }
    promise->second.setValue(data.substr(0, promise->first));
    data.erase(0, promise->first);
    promise.reset();
  }

  bool isBufferMovable() noexcept override { return true; }

  void setHoldData(bool b) { holdData = b; }

  std::optional<std::pair<size_t, Promise<std::string>>> promise;
  std::shared_ptr<AsyncSocket> sock;
  std::string data;
  std::unique_ptr<IOBuf> bufData;
  std::optional<bool> hadBuffers;
  std::array<char, kBufferSize> buff;
  bool holdData = true;
};

struct TestParams {
  bool ioUringServer = false;
  bool ioUringClient = false;
  bool manySmallBuffers = false;
  bool supportBufferMovable = true;
  bool sendzc = false;
  bool registerFd = true;
  std::string testName() const {
    return folly::to<std::string>(
        ioUringServer ? "ioUringServer" : "oldServer",
        "_",
        ioUringClient ? "ioUringClient" : "oldClient",
        "_",
        manySmallBuffers ? "manySmallBuffers" : "oneBigBuffer",
        supportBufferMovable ? "" : "_noSupportBufferMovable",
        sendzc ? "_zerocopy" : "",
        "_",
        registerFd ? "" : "_noRegisterFd",
        "iouringBackend");
  }
};

struct ConnectedOptions {
  std::string fastOpenInitial;
  bool serverShouldRead = true;

  ConnectedOptions withNoServerShouldRead() {
    auto ret = *this;
    ret.serverShouldRead = false;
    return ret;
  }

  ConnectedOptions withFastOpen(std::string i) {
    auto ret = *this;
    ret.fastOpenInitial = std::move(i);
    return ret;
  }
};

class AsyncIoUringSocketTest : public ::testing::TestWithParam<TestParams>,
                               public AsyncServerSocket::AcceptCallback,
                               public AsyncSocket::ConnectCallback {
 public:
  static IoUringBackend::Options ioOptions(TestParams const& p) {
    auto options =
        IoUringBackend::Options{}.setUseRegisteredFds(p.registerFd ? 64 : 0);
    if (p.manySmallBuffers) {
      options.setInitialProvidedBuffers(1024, 2000);
    } else {
      options.setInitialProvidedBuffers(2000000, 1);
    }
    options.setDeferTaskRun(true);
    return options;
  }

  EventBase::Options ioUringEbOptions() {
    return EventBase::Options{}.setBackendFactory(
        [p = GetParam()]() -> std::unique_ptr<EventBaseBackendBase> {
          return std::make_unique<IoUringBackend>(ioOptions(p));
        });
  }

  EventBase::Options ebOptions() { return ioUringEbOptions(); }

  void maybeSkip() {
    if (unableToRun) {
      GTEST_SKIP();
    }
  }

  AsyncIoUringSocketTest() {
    try {
      base = std::make_unique<EventBase>(ebOptions());
    } catch (IoUringBackend::NotAvailable const&) {
      unableToRun = true;
      return;
    }

    backend = dynamic_cast<IoUringBackend*>(base->getBackend());

    backend->loopPoll(); // init delayed bits as this is the only thread

    serverSocket = AsyncServerSocket::newSocket(base.get());
    serverSocket->setTFOEnabled(true, 1);
    serverSocket->bind(0);
    serverSocket->listen(1024);
    serverSocket->addAcceptCallback(this, nullptr);
    serverSocket->startAccepting();
    serverSocket->getAddress(&serverAddress);
  }

  void connectionAccepted(
      NetworkSocket ns, const SocketAddress&, AcceptInfo) noexcept override {
    fdPromise.setValue(ns);
  }

  void connectSuccess() noexcept override {}
  void connectErr(const AsyncSocketException& ex) noexcept override {
    LOG(FATAL) << ex;
  }

  template <typename ServerReadCallback>
  struct Connected {
    std::unique_ptr<EchoTransport> client;
    AsyncTransport::UniquePtr server;
    std::unique_ptr<ServerReadCallback> callback;
    ~Connected() {
      if (server) {
        server->setReadCB(nullptr);
      }
    }
  };

  AsyncIoUringSocket::Options ioUringSocketOptions() const {
    AsyncIoUringSocket::Options ret;
    if (GetParam().sendzc) {
      ret.zeroCopyEnable = [](auto&&) { return true; };
    }
    return ret;
  }

  template <typename ServerReadCallback = CollectCallback>
  Connected<ServerReadCallback> makeConnected(
      ConnectedOptions options = ConnectedOptions{}) {
    AsyncSocketTransport::UniquePtr client;
    if (GetParam().ioUringClient) {
      client = AsyncSocketTransport::UniquePtr(
          new AsyncIoUringSocket(base.get(), ioUringSocketOptions()));
    } else {
      client =
          AsyncSocketTransport::UniquePtr(AsyncSocket::newSocket(base.get()));
    }

    if (options.fastOpenInitial.size()) {
      client->enableTFO();
    }
    client->connect(this, serverAddress);
    if (options.fastOpenInitial.size()) {
      client->writeChain(
          &nullWriteCallback, IOBuf::copyBuffer(options.fastOpenInitial));
    }

    auto fd = fdPromise.getFuture()
                  .within(kTimeout)
                  .via(base.get())
                  .getVia(base.get());
    fdPromise = {};
    auto c = std::make_unique<EchoTransport>(
        std::move(client), GetParam().supportBufferMovable);
    c->start();
    auto serverReadCallback = std::make_unique<ServerReadCallback>();
    AsyncTransport::UniquePtr server = GetParam().ioUringServer
        ? AsyncTransport::UniquePtr(new AsyncIoUringSocket(
              AsyncSocket::newSocket(base.get(), fd), ioUringSocketOptions()))
        : AsyncTransport::UniquePtr(AsyncSocket::newSocket(base.get(), fd));
    if (options.serverShouldRead) {
      server->setReadCB(serverReadCallback.get());
    }
    return Connected<ServerReadCallback>{
        std::move(c), std::move(server), std::move(serverReadCallback)};
  }

  bool unableToRun = false;
  std::unique_ptr<EventBase> base;
  std::unique_ptr<IoUringEvent> ioUringEvent;
  std::shared_ptr<AsyncServerSocket> serverSocket;
  IoUringBackend* backend = nullptr;
  SocketAddress serverAddress;

  Promise<NetworkSocket> fdPromise;
};

#define MAYBE_SKIP()                   \
  if (unableToRun) {                   \
    LOG(INFO) << "Unsupported kernel"; \
    return;                            \
  }

TEST_P(AsyncIoUringSocketTest, ConnectTimeout) {
  MAYBE_SKIP();
  struct CB : AsyncSocket::ConnectCallback {
    void connectSuccess() noexcept override {
      prom.setValue(makeExpected<AsyncSocketException>(Unit{}));
    }
    void connectErr(const AsyncSocketException& ex) noexcept override {
      prom.setValue(makeUnexpected(ex));
    }
    Promise<Expected<Unit, AsyncSocketException>> prom;
  } cb;

  // Try connecting to server that won't respond.
  //
  // This depends somewhat on the network where this test is run.
  // Hopefully this IP will be routable but unresponsive.
  // (Alternatively, we could try listening on a local raw socket, but that
  // normally requires root privileges.)
  auto host = SocketAddressTestHelper::isIPv6Enabled()
      ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
      : SocketAddressTestHelper::isIPv4Enabled()
      ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
      : nullptr;

  AsyncIoUringSocket::UniquePtr socket(new AsyncIoUringSocket(base.get()));
  socket->connect(
      &cb, SocketAddress{host, 65535}, std::chrono::milliseconds(1));

  auto res = cb.prom.getSemiFuture()
                 .within(kTimeout)
                 .via(base.get())
                 .getVia(base.get());
  ASSERT_FALSE(res);
  if (res.error().getType() == AsyncSocketException::NOT_OPEN) {
    // This can happen if we could not route to the IP address picked above.
    // In this case the connect will fail immediately rather than timing out.
    // Just skip the test in this case.
    GTEST_SKIP() << "do not have a routable but unreachable IP address";
    return;
  }
  EXPECT_EQ(res.error().getType(), AsyncSocketException::TIMED_OUT)
      << res.error().what();
}

TEST_P(AsyncIoUringSocketTest, EoF) {
  MAYBE_SKIP();
  struct CB : AsyncReader::ReadCallback {
    void readDataAvailable(size_t) noexcept override {
      // will terminate...
      terminate_with<std::runtime_error>("unexpected data");
    }
    void readEOF() noexcept override {
      VLOG(1) << "CB Setting EOF";
      prom.setValue(makeExpected<AsyncSocketException>(Unit{}));
    }

    void readErr(const AsyncSocketException& ex) noexcept override {
      VLOG(1) << "CB Setting Err " << ex;
      prom.setValue(makeUnexpected(ex));
    }

    void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
      *bufReturn = &buff;
      *lenReturn = 1;
    }

    Promise<Expected<Unit, AsyncSocketException>> prom;
    char buff;
  };

  auto c = makeConnected(ConnectedOptions{}.withNoServerShouldRead());
  {
    CB cb_eof;
    c.server->setReadCB(&cb_eof);
    c.client->transport->closeNow();
    c.client.reset();
    EXPECT_TRUE(cb_eof.prom.getSemiFuture()
                    .within(kTimeout)
                    .via(base.get())
                    .getVia(base.get())
                    .hasValue());
    c.server->setReadCB(nullptr);
  }
  EXPECT_FALSE(c.server->good());
  {
    CB cb_invalid;
    c.server->setReadCB(&cb_invalid);
    auto ex = cb_invalid.prom.getSemiFuture()
                  .within(kTimeout)
                  .via(base.get())
                  .getVia(base.get());
    ASSERT_TRUE(ex.hasError());
    auto er = ex.error();
    EXPECT_EQ(AsyncSocketException::NOT_OPEN, er.getType());
    c.server->setReadCB(nullptr);
  }
}

struct DetachCB : folly::AsyncDetachFdCallback {
  void fdDetached(
      NetworkSocket ns, std::unique_ptr<IOBuf> unread) noexcept override {
    promise.setValue(std::make_pair(ns, std::move(unread)));
  }

  void fdDetachFail(const AsyncSocketException& ex) noexcept override {
    promise.setException(ex);
  }

  folly::Promise<std::pair<NetworkSocket, std::unique_ptr<IOBuf>>> promise;
};

TEST_P(AsyncIoUringSocketTest, Detach) {
  MAYBE_SKIP();
  auto c = makeConnected();
  auto* was = c.server->getUnderlyingTransport<folly::AsyncIoUringSocket>();
  ASSERT_NE(was, nullptr);

  // write something
  c.client->transport->write(&nullWriteCallback, "hello", 5);
  // make sure it gets run
  backend->submitOutstanding();

  // sleep a bit to get the read all the way into the completion queue
  /* sleep override */ std::this_thread::sleep_for(
      std::chrono::milliseconds(5));

  // now detach before the write is completed
  DetachCB cb;
  was->asyncDetachFd(&cb);
  ASSERT_FALSE(cb.promise.isFulfilled()) << "must wait for read to finish";

  auto res = cb.promise.getSemiFuture()
                 .within(kTimeout)
                 .via(base.get())
                 .getVia(base.get());
  EXPECT_GE(res.first.toFd(), 0);
  if (res.second) {
    // did not cancel in time
    EXPECT_EQ("hello", toString(res.second));
  } else {
    // did cancel in time
    char buff[128];
    memset(buff, 0, sizeof(buff));
    int ret;
    do {
      ret = read(res.first.toFd(), &buff, sizeof(buff));
    } while (ret == -1 && errno == EINTR);
    ASSERT_EQ(5, ret);
    EXPECT_EQ("hello", std::string(buff));
  }
}

TEST_P(AsyncIoUringSocketTest, DetachEventBase) {
  MAYBE_SKIP();
  auto c = makeConnected();

  // write something
  FutureWriteCallback fwc;
  auto* transport = c.client->transport.get();
  transport->write(&fwc, "hello", 5);

  // make sure it gets run
  backend->submitOutstanding();
  ASSERT_TRUE(transport->isDetachable());
  transport->detachEventBase();

  EventBase newBase{ioUringEbOptions()};
  transport->attachEventBase(&newBase);

  auto resFut = c.callback->waitFor(5).via(folly::getGlobalCPUExecutor().get());
  auto start = std::chrono::steady_clock::now();
  do {
    newBase.loopOnce(EVLOOP_NONBLOCK);
    base->loopOnce(EVLOOP_NONBLOCK);
    /* sleep override */ std::this_thread::sleep_for(
        std::chrono::milliseconds(20));
    auto res = resFut.poll();
    if (res) {
      EXPECT_EQ("hello", res->value());
      break;
    }
    if (std::chrono::steady_clock::now() > start + std::chrono::seconds(1)) {
      FAIL();
      break;
    }
  } while (true);

  // make sure write arrived, it should be on the new event base
  auto& [promise, future] = fwc.promiseContract;
  ASSERT_TRUE(std::move(future).via(&newBase).getVia(&newBase).hasValue());

  ASSERT_TRUE(transport->isDetachable());
  transport->detachEventBase();
  transport->attachEventBase(base.get());
}

TEST_P(AsyncIoUringSocketTest, DetachEventBaseClear) {
  MAYBE_SKIP();
  auto c = makeConnected();

  // write something
  c.client->transport->write(&nullWriteCallback, "hello", 5);

  backend->submitOutstanding();
  ASSERT_TRUE(c.client->transport->isDetachable());
  c.client->transport->detachEventBase();

  // now free things in the middle
  base->loopOnce(EVLOOP_NONBLOCK);
}

TEST_P(AsyncIoUringSocketTest, FastOpen) {
  MAYBE_SKIP();
  bool had_fastopen = false;
  bool can_fastopen = false;
  std::string fo;
  if (readFile("/proc/sys/net/ipv4/tcp_fastopen", fo)) {
    auto fast_open = folly::to<int>(fo);
    if (fast_open == 3) {
      can_fastopen = true;
    }
  }
  if (!can_fastopen) {
    LOG(INFO) << "/proc/sys/net/ipv4/tcp_fastopen must be 3 to do fastopen, "
                 "but we will test the code flow anyway";
  }
  // technically we could run ip tcp_metrics flush here, but messing with the
  // system in a test is awful
  folly::Subprocess subProc("/sbin/ip tcp_metrics show ::1"_shellify());
  int has_cookies_already = subProc.wait().exitStatus();
  if (has_cookies_already == 0) {
    LOG(INFO)
        << "already had cookies, so cannot do fastopen test, but will test code flow anyway. "
        << " you could do a `/sbin/ip tcp_metrics flush` to test this";
    had_fastopen = true;
  }
  auto opts = ConnectedOptions{}.withFastOpen("hello");
  {
    auto conn = makeConnected(opts);
    EXPECT_EQ(
        "hello", conn.callback->waitFor(5).via(base.get()).getVia(base.get()));
    if (!had_fastopen) {
      EXPECT_FALSE(conn.client->transport->getTFOSucceded());
    }
  }
  {
    auto conn = makeConnected(opts);
    EXPECT_EQ(
        "hello", conn.callback->waitFor(5).via(base.get()).getVia(base.get()));
    if (can_fastopen) {
      EXPECT_TRUE(conn.client->transport->getTFOSucceded());
    }
  }
}

class AsyncIoUringSocketTestAll : public AsyncIoUringSocketTest {};

TEST_P(AsyncIoUringSocketTestAll, WriteChain2) {
  MAYBE_SKIP();
  auto [e, s, cb] = makeConnected();
  s->writeChain(&nullWriteCallback, IOBuf::copyBuffer("hello"));
  EXPECT_EQ("hello", cb->waitFor(5).via(base.get()).getVia(base.get()));
  s->writeChain(&nullWriteCallback, IOBuf::copyBuffer("there"));
  EXPECT_EQ("there", cb->waitFor(5).via(base.get()).getVia(base.get()));
}

TEST_P(AsyncIoUringSocketTestAll, WriteChainOrder) {
  MAYBE_SKIP();
  auto [e, s, cb] = makeConnected();
  auto chain = IOBuf::copyBuffer("h");
  chain->appendToChain(IOBuf::copyBuffer("e"));
  chain->appendToChain(IOBuf::copyBuffer("ll"));
  chain->appendToChain(IOBuf::copyBuffer("o"));
  s->writeChain(&nullWriteCallback, std::move(chain));
  EXPECT_EQ("hello", cb->waitFor(5).via(base.get()).getVia(base.get()));
}

TEST_P(AsyncIoUringSocketTestAll, WriteChainLong) {
  MAYBE_SKIP();
  auto [e, s, cb] = makeConnected();
  auto chain = IOBuf::copyBuffer("?");
  std::string res = "?";
  for (int i = 0; i < 4096; i++) {
    std::string x(1, 'a' + i % 26);
    chain->appendToChain(IOBuf::copyBuffer(x));
    res += x;
  }
  s->writeChain(&nullWriteCallback, std::move(chain));
  EXPECT_EQ(res, cb->waitFor(res.size()).via(base.get()).getVia(base.get()));
}

TEST_P(AsyncIoUringSocketTestAll, Write) {
  MAYBE_SKIP();
  auto [e, s, cb] = makeConnected();
  s->write(&nullWriteCallback, "hello", 5);
  EXPECT_EQ("hello", cb->waitFor(5).via(base.get()).getVia(base.get()));
}

TEST_P(AsyncIoUringSocketTestAll, WriteAfterWait) {
  MAYBE_SKIP();
  auto conn = makeConnected();
  auto& s = conn.server;
  auto& cb = conn.callback;

  EXPECT_EQ(
      "hello",
      folly::futures::sleep(std::chrono::milliseconds(500))
          .via(base.get())
          .thenValue([&](auto&&) { s->write(&nullWriteCallback, "hello", 5); })
          .thenValue([&](auto&&) { return cb->waitFor(5); })
          .getVia(base.get()));
}

namespace {
std::string randomString(size_t n) {
  std::random_device r;
  std::default_random_engine e1(r());

  std::uniform_int_distribution<char> uniform_dist('A', 'Z');

  std::string ret;
  ret.reserve(n);
  for (size_t i = 0; i < n; i++) {
    ret.push_back(uniform_dist(e1));
  }
  return ret;
}
} // namespace

TEST_P(AsyncIoUringSocketTestAll, WriteBig) {
  MAYBE_SKIP();
  auto [e, s, cb] = makeConnected();
  cb->setHoldData(true);
  std::string big = randomString(4000000);
  s->write(&nullWriteCallback, big.c_str(), big.size());
  auto res = cb->waitFor(big.size()).via(base.get()).getVia(base.get());
  EXPECT_TRUE(big == res) << big.size() << " vs " << res.size();
}

TEST_P(AsyncIoUringSocketTestAll, WriteBigChunked) {
  MAYBE_SKIP();
  auto [e, s, cb] = makeConnected();
  cb->setHoldData(true);
  std::string big = randomString(4000000);
  size_t at = 0;
  int const kChunkSize = 256;
  while (at < big.size()) {
    auto len = std::min<size_t>(big.size() - at, kChunkSize);
    s->write(&nullWriteCallback, big.c_str() + at, len);
    at += len;
  }
  auto res = cb->waitFor(big.size()).via(base.get()).getVia(base.get());
  EXPECT_TRUE(big == res) << big.size() << " vs " << res.size();
}

TEST_P(AsyncIoUringSocketTestAll, WriteBigDrop) {
  MAYBE_SKIP();
  auto [e, s, cb] = makeConnected();
  cb->setHoldData(false); // should trigger overflow in provided buffers
  std::string big(4000000, 'X');
  s->write(&nullWriteCallback, big.c_str(), big.size());
  EXPECT_EQ(big, cb->waitFor(big.size()).via(base.get()).getVia(base.get()));
}

TEST_P(AsyncIoUringSocketTestAll, Writev) {
  MAYBE_SKIP();
  auto [e, s, cb] = makeConnected();
  std::array<iovec, 2> iov = {{{(void*)"hel", 3}, {(void*)"lo", 2}}};
  s->writev(&nullWriteCallback, iov.data(), iov.size());
  EXPECT_EQ("hello", cb->waitFor(5).via(base.get()).getVia(base.get()));
}

TEST_P(AsyncIoUringSocketTestAll, SendTimeout) {
  MAYBE_SKIP();
  if (!GetParam().ioUringServer) {
    // folly::AsyncSocket is not totally reliable with timeouts
    return;
  }
  auto conn = makeConnected(ConnectedOptions{}.withNoServerShouldRead());
  FutureWriteCallback ecb;
  std::string big(40000000, 'X');
  std::vector<iovec> iov;
  iov.resize(100);
  for (size_t i = 0; i < iov.size(); i++) {
    iov[i].iov_base = big.data();
    iov[i].iov_len = big.size();
  }
  base->runInEventBaseThread([&]() {
    conn.server->setSendTimeout(1);
    conn.server->writev(&ecb, iov.data(), iov.size());
  });
  auto& [promise, future] = ecb.promiseContract;
  auto ex = std::move(future).via(base.get()).getVia(base.get());
  ASSERT_TRUE(ex.hasError());
  EXPECT_EQ(AsyncSocketException::TIMED_OUT, ex.error().second.getType());
}

auto mkAllTestParams() {
  std::vector<TestParams> t;

  auto addFeatureCases = [&](TestParams const& base) {
    TestParams all = base;

    // add test cases where each feature is not the default, as well as one
    // where all the features are not the default
    auto add_flip_case = [&](auto ptr) {
      auto tc = base;
      tc.*ptr = all.*ptr = !(tc.*ptr);
      t.push_back(tc);
    };

    add_flip_case(&TestParams::registerFd);
    if (IoUringBackend::kernelSupportsSendZC()) {
      add_flip_case(&TestParams::sendzc);
    }
    add_flip_case(&TestParams::supportBufferMovable);
    t.push_back(all);
  };

  for (bool server : {false, true}) {
    for (bool client : {false, true}) {
      for (bool manySmallBuffers : {false, true}) {
        TestParams base;
        base.ioUringServer = server;
        base.ioUringClient = client;
        base.manySmallBuffers = manySmallBuffers;
        t.push_back(base);

        // only expand feature flags in some cases to reduce the massive
        // explosion of tests
        if (server && client) {
          addFeatureCases(base);
        }
      }
    }
  }
  return t;
}

INSTANTIATE_TEST_SUITE_P(
    AsyncIoUringSocketTest,
    AsyncIoUringSocketTestAll,
    ::testing::ValuesIn(mkAllTestParams()),
    [](const ::testing::TestParamInfo<TestParams>& info) {
      return info.param.testName();
    });

TestParams mkBasicTestParams() {
  TestParams t;
  t.ioUringClient = t.ioUringServer = true;
  return t;
}

INSTANTIATE_TEST_SUITE_P(
    AsyncIoUringSocketTest,
    AsyncIoUringSocketTest,
    ::testing::Values(mkBasicTestParams()),
    [](const ::testing::TestParamInfo<TestParams>& info) {
      return info.param.testName();
    });

class AsyncIoUringSocketTakeoverTest : public AsyncIoUringSocketTest {};

class AsyncSocketWithPreRead : public AsyncSocket {
 public:
  AsyncSocketWithPreRead(AsyncSocket::UniquePtr a, std::string const& pre_read)
      : AsyncSocket(std::move(a)) {
    preReceivedData_ = IOBuf::copyBuffer(pre_read);
  }
};

TEST_P(AsyncIoUringSocketTakeoverTest, PreRead) {
  MAYBE_SKIP();
  auto conn = makeConnected(ConnectedOptions{}.withNoServerShouldRead());
  AsyncSocket::UniquePtr sock(
      dynamic_cast<AsyncSocket*>(conn.server.release()));
  ASSERT_NE(sock, nullptr);
  AsyncIoUringSocket::UniquePtr io_uring(
      new AsyncIoUringSocket(AsyncSocket::UniquePtr(
          new AsyncSocketWithPreRead(std::move(sock), "hello"))));
  io_uring->setReadCB(conn.callback.get());
  io_uring->write(&nullWriteCallback, "there", 5);
  EXPECT_EQ(
      "hellothere",
      conn.callback->waitFor(10).via(base.get()).getVia(base.get()));
}

TestParams mkTakeoverParams() {
  TestParams t;
  t.ioUringClient = t.ioUringServer = false;
  return t;
}

INSTANTIATE_TEST_SUITE_P(
    AsyncIoUringSocketTakeoverTest,
    AsyncIoUringSocketTakeoverTest,
    ::testing::Values(mkTakeoverParams()),
    [](const ::testing::TestParamInfo<TestParams>& info) {
      return info.param.testName();
    });

} // namespace folly