folly/folly/io/coro/test/TransportTest.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <folly/Portability.h>

#include <folly/experimental/coro/BlockingWait.h>
#include <folly/experimental/coro/Collect.h>
#include <folly/io/async/test/AsyncSocketTest.h>
#include <folly/io/async/test/MockAsyncTransport.h>
#include <folly/io/async/test/ScopedBoundPort.h>
#include <folly/io/coro/ServerSocket.h>
#include <folly/io/coro/Transport.h>
#include <folly/portability/GTest.h>

#if FOLLY_HAS_COROUTINES

using namespace std::chrono_literals;
using namespace folly;
using namespace folly::coro;
using namespace folly::test;

class TransportTest : public testing::Test {
 public:
  template <typename F>
  void run(F f) {
    blockingWait(co_invoke(std::move(f)), &evb);
  }

  folly::coro::Task<> requestCancellation() {
    cancelSource.requestCancellation();
    co_return;
  }

  EventBase evb;
  CancellationSource cancelSource;
};

class ServerTransportTest : public TransportTest {
 public:
  folly::coro::Task<Transport> connect() {
    co_return co_await Transport::newConnectedSocket(
        &evb, srv.getAddress(), 0ms);
  }
  TestServer srv;
};

TEST_F(TransportTest, ConnectFailure) {
  run([&]() -> Task<> {
    ScopedBoundPort ph;

    auto serverAddr = ph.getAddress();
    EXPECT_THROW(
        co_await Transport::newConnectedSocket(&evb, serverAddr, 0ms),
        AsyncSocketException);
  });
}

TEST_F(ServerTransportTest, ConnectSuccess) {
  run([&]() -> Task<> {
    auto cs = co_await connect();
    EXPECT_EQ(srv.getAddress(), cs.getPeerAddress());
  });
}

TEST_F(ServerTransportTest, ConnectCancelled) {
  run([&]() -> Task<> {
    co_await folly::coro::collectAll(
        // token would be cancelled while waiting on connect
        [&]() -> Task<> {
          EXPECT_THROW(
              co_await co_withCancellation(cancelSource.getToken(), connect()),
              OperationCancelled);
        }(),
        requestCancellation());
    // token was cancelled before read was called
    EXPECT_THROW(
        co_await co_withCancellation(
            cancelSource.getToken(),
            Transport::newConnectedSocket(&evb, srv.getAddress(), 0ms)),
        OperationCancelled);
  });
}

TEST_F(TransportTest, ConnectWithOptsAndBindAddr) {
  run([&]() -> Task<> {
    TestServer srv1;
    TestServer srv2;
    SocketOptionMap opts{
        {{SOL_SOCKET, SO_REUSEADDR}, 1},
    };
    auto c1 = co_await Transport::newConnectedSocket(
        &evb, srv1.getAddress(), 0ms, opts);
    // Connect with same local port, this would fail without SO_REUSEADDR.
    auto c2 = co_await Transport::newConnectedSocket(
        &evb, srv2.getAddress(), 0ms, opts, c1.getLocalAddress());
    EXPECT_EQ(c1.getLocalAddress(), c2.getLocalAddress());
  });
}

TEST_F(ServerTransportTest, SimpleRead) {
  run([&]() -> Task<> {
    constexpr auto kBufSize = 65536;
    auto cs = co_await connect();
    // produces blocking socket
    auto ss = srv.accept(-1);

    std::array<uint8_t, kBufSize> sndBuf;
    std::memset(sndBuf.data(), 'a', sndBuf.size());
    ss->write(sndBuf.data(), sndBuf.size());

    // read using coroutines
    std::array<uint8_t, kBufSize> rcvBuf;
    auto reader = [&rcvBuf, &cs]() -> Task<Unit> {
      int totalBytes{0};
      while (totalBytes < kBufSize) {
        auto bytesRead = co_await cs.read(
            MutableByteRange(
                rcvBuf.data() + totalBytes,
                (rcvBuf.data() + rcvBuf.size() - totalBytes)),
            0ms);
        totalBytes += bytesRead;
      }
      co_return unit;
    };

    co_await reader();
    EXPECT_EQ(0, memcmp(sndBuf.data(), rcvBuf.data(), rcvBuf.size()));
  });
}

TEST_F(ServerTransportTest, SimpleIOBufRead) {
  run([&]() -> Task<> {
    // Exactly fills a buffer mid-loop and triggers deferredReadEOF handling
    constexpr auto kBufSize = 55 * 1184;

    auto cs = co_await connect();
    // produces blocking socket
    auto ss = srv.accept(-1);

    std::array<uint8_t, kBufSize> sndBuf;
    std::memset(sndBuf.data(), 'a', sndBuf.size());
    ss->write(sndBuf.data(), sndBuf.size());
    ss->close();

    // read using coroutines
    IOBufQueue rcvBuf(IOBufQueue::cacheChainLength());
    int totalBytes{0};
    while (totalBytes < kBufSize) {
      auto bytesRead = co_await cs.read(rcvBuf, 1000, 1000, 0ms);
      totalBytes += bytesRead;
    }
    auto bytesRead = co_await cs.read(rcvBuf, 1000, 1000, 50ms);
    EXPECT_EQ(bytesRead, 0); // closed

    auto data = rcvBuf.move();
    data->coalesce();
    EXPECT_EQ(0, memcmp(sndBuf.data(), data->data(), data->length()));
  });
}

TEST_F(ServerTransportTest, ReadCancelled) {
  run([&]() -> Task<> {
    auto cs = co_await connect();
    auto reader = [&cs]() -> Task<Unit> {
      std::array<uint8_t, 1024> rcvBuf;
      EXPECT_THROW(
          co_await cs.read(
              MutableByteRange(rcvBuf.data(), (rcvBuf.data() + rcvBuf.size())),
              0ms),
          OperationCancelled);
      co_return unit;
    };

    co_await co_withCancellation(
        cancelSource.getToken(),
        folly::coro::collectAll(requestCancellation(), reader()));
    // token was cancelled before read was called
    co_await co_withCancellation(cancelSource.getToken(), reader());
  });
}

TEST_F(ServerTransportTest, ReadTimeout) {
  run([&]() -> Task<> {
    auto cs = co_await connect();
    std::array<uint8_t, 1024> rcvBuf;
    EXPECT_THROW(
        co_await cs.read(
            MutableByteRange(rcvBuf.data(), (rcvBuf.data() + rcvBuf.size())),
            50ms),
        AsyncSocketException);
  });
}

TEST_F(ServerTransportTest, ReadError) {
  run([&]() -> Task<> {
    auto cs = co_await connect();
    // produces blocking socket
    auto ss = srv.accept(-1);
    ss->closeWithReset();

    std::array<uint8_t, 1024> rcvBuf;
    EXPECT_THROW(
        co_await cs.read(
            MutableByteRange(rcvBuf.data(), (rcvBuf.data() + rcvBuf.size())),
            50ms),
        AsyncSocketException);
  });
}

TEST_F(ServerTransportTest, SimpleWrite) {
  run([&]() -> Task<> {
    auto cs = co_await connect();
    // produces blocking socket
    auto ss = srv.accept(-1);

    constexpr auto kBufSize = 65536;
    std::array<uint8_t, kBufSize> sndBuf;
    std::memset(sndBuf.data(), 'a', sndBuf.size());

    // write use co-routine
    co_await cs.write(ByteRange(sndBuf.data(), sndBuf.data() + sndBuf.size()));
    // read on server side
    std::array<uint8_t, kBufSize> rcvBuf;
    ss->readAll(rcvBuf.data(), rcvBuf.size());

    EXPECT_EQ(0, memcmp(sndBuf.data(), rcvBuf.data(), rcvBuf.size()));
  });
}

TEST_F(ServerTransportTest, SimpleWritev) {
  run([&]() -> Task<> {
    auto cs = co_await connect();
    // produces blocking socket
    auto ss = srv.accept(-1);

    IOBufQueue sndBuf;
    constexpr auto kBufSize = 65536;
    std::array<uint8_t, kBufSize> bufA;
    std::memset(bufA.data(), 'a', bufA.size());
    std::array<uint8_t, kBufSize> bufB;
    std::memset(bufB.data(), 'b', bufB.size());
    sndBuf.append(bufA.data(), bufA.size());
    sndBuf.append(bufB.data(), bufB.size());

    // write use co-routine
    co_await cs.write(sndBuf);

    // read on server side
    std::array<uint8_t, kBufSize> rcvBufA;
    ss->readAll(rcvBufA.data(), rcvBufA.size());
    EXPECT_EQ(0, memcmp(bufA.data(), rcvBufA.data(), rcvBufA.size()));
    std::array<uint8_t, kBufSize> rcvBufB;
    ss->readAll(rcvBufB.data(), rcvBufB.size());
    EXPECT_EQ(0, memcmp(bufB.data(), rcvBufB.data(), rcvBufB.size()));
  });
}

TEST_F(ServerTransportTest, WriteCancelled) {
  run([&]() -> Task<> {
    auto cs = co_await connect();
    // reduce the send buffer size so the write wouldn't complete immediately
    auto asyncSocket = dynamic_cast<folly::AsyncSocket*>(cs.getTransport());
    CHECK(asyncSocket);
    EXPECT_EQ(asyncSocket->setSendBufSize(4096), 0);
    // produces blocking socket
    auto ss = srv.accept(-1);
    constexpr auto kBufSize = 65536;
    std::array<uint8_t, kBufSize> sndBuf;
    std::memset(sndBuf.data(), 'a', sndBuf.size());

    // write use co-routine
    auto writter = [&]() -> Task<> {
      EXPECT_THROW(
          co_await co_withCancellation(
              cancelSource.getToken(),
              cs.write(
                  ByteRange(sndBuf.data(), sndBuf.data() + sndBuf.size()))),
          OperationCancelled);
    };

    co_await folly::coro::collectAll(requestCancellation(), writter());
    co_await co_withCancellation(cancelSource.getToken(), writter());
  });
}

TEST_F(TransportTest, SimpleAccept) {
  run([&]() -> Task<> {
    ServerSocket css(AsyncServerSocket::newSocket(&evb), std::nullopt, 16);
    auto serverAddr = css.getAsyncServerSocket()->getAddress();

    co_await folly::coro::collectAll(
        css.accept(), Transport::newConnectedSocket(&evb, serverAddr, 0ms));
  });
}

TEST_F(TransportTest, AcceptCancelled) {
  run([&]() -> Task<> {
    co_await folly::coro::collectAll(requestCancellation(), [&]() -> Task<> {
      ServerSocket css(AsyncServerSocket::newSocket(&evb), std::nullopt, 16);
      EXPECT_THROW(
          co_await co_withCancellation(cancelSource.getToken(), css.accept()),
          OperationCancelled);
    }());
  });
}

TEST_F(TransportTest, RequestCancellationInAnotherThread) {
  auto ass = AsyncServerSocket::newSocket(&evb);
  std::thread anotherThread([&] {
    std::this_thread::sleep_for(10ms);
    cancelSource.requestCancellation();
  });

  run([&]() -> Task<> {
    ServerSocket css(ass, std::nullopt, 16);
    EXPECT_THROW(
        co_await co_withCancellation(cancelSource.getToken(), css.accept()),
        OperationCancelled);
  });

  anotherThread.join();
}

TEST_F(TransportTest, AsyncClientAndServer) {
  run([&]() -> Task<> {
    constexpr int kSize = 128;
    ServerSocket css(AsyncServerSocket::newSocket(&evb), std::nullopt, 16);
    auto serverAddr = css.getAsyncServerSocket()->getAddress();
    auto cs = co_await Transport::newConnectedSocket(&evb, serverAddr, 0ms);

    co_await folly::coro::collectAll(
        [&css]() -> Task<> {
          auto sock = co_await css.accept();
          std::array<uint8_t, kSize> buf;
          memset(buf.data(), 'a', kSize);
          co_await sock->write(ByteRange(buf.begin(), buf.end()));
          css.close();
        }(),

        [&cs]() -> Task<> {
          std::array<uint8_t, kSize> buf;
          // For fun, shutdown the write half -- we don't need it
          cs.shutdownWrite();
          auto len =
              co_await cs.read(MutableByteRange(buf.begin(), buf.end()), 0ms);
          cs.close();
          EXPECT_TRUE(len == buf.size());
        }());
  });
}

class MockTransportTest : public TransportTest {
 public:
  folly::coro::Task<Transport> connect() {
    mockTransport = new testing::NiceMock<test::MockAsyncTransport>();
    folly::AsyncTransport::UniquePtr transport(mockTransport);
    co_return Transport(&evb, std::move(transport));
  }
  test::MockAsyncTransport* mockTransport;
};

TEST_F(MockTransportTest, readSuccessCanceled) {
  run([&]() -> Task<> {
    auto cs = co_await connect();
    constexpr auto kBufSize = 65536;
    std::array<uint8_t, kBufSize> rcvBuf;
    EXPECT_CALL(*mockTransport, setReadCB(testing::_))
        .WillOnce(testing::Invoke(
            [](AsyncReader::ReadCallback* rcb) { rcb->readEOF(); }));
    EXPECT_CALL(*mockTransport, setReadCB(nullptr)).Times(2);
    folly::CancellationSource cancellationSource;
    auto readFut =
        co_withCancellation(
            cancellationSource.getToken(),
            cs.read(
                MutableByteRange(rcvBuf.data(), rcvBuf.data() + rcvBuf.size()),
                100ms))
            .scheduleOn(&evb)
            .start();
    // Let the read coro start and get the EOF
    co_await co_reschedule_on_current_executor;
    // cancel
    cancellationSource.requestCancellation();
    // read succeeds with nRead == 0
    auto nRead = co_await std::move(readFut);
    EXPECT_EQ(nRead, 0);
  });
}
#endif // FOLLY_HAS_COROUTINES