folly/folly/io/async/test/AsyncSocketTest2.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <folly/io/async/test/AsyncSocketTest2.h>

#include <fcntl.h>
#include <sys/types.h>

#include <time.h>
#include <iostream>
#include <memory>
#include <thread>

#include <folly/ExceptionWrapper.h>
#include <folly/Random.h>
#include <folly/SocketAddress.h>
#include <folly/io/IOBuf.h>
#include <folly/io/SocketOptionMap.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/EventBase.h>
#include <folly/io/async/ScopedEventBaseThread.h>
#include <folly/io/async/test/AsyncSocketTest.h>
#include <folly/io/async/test/MockAsyncSocketLegacyObserver.h>
#include <folly/io/async/test/MockAsyncSocketObserver.h>
#include <folly/io/async/test/TFOUtil.h>
#include <folly/io/async/test/Util.h>
#include <folly/net/test/MockNetOpsDispatcher.h>
#include <folly/net/test/MockTcpInfoDispatcher.h>
#include <folly/portability/GMock.h>
#include <folly/portability/GTest.h>
#include <folly/portability/Sockets.h>
#include <folly/portability/Unistd.h>
#include <folly/synchronization/Baton.h>
#include <folly/test/SocketAddressTestHelper.h>
#include <folly/testing/TestUtil.h>

using std::min;
using std::string;
using std::unique_ptr;
using std::vector;
using std::chrono::milliseconds;
using testing::MatchesRegex;

using namespace folly;
using namespace folly::test;
using namespace testing;

namespace {
// string and corresponding vector with 100 characters
const std::string kOneHundredCharacterString(
    "ThisIsAVeryLongStringThatHas100Characters"
    "AndIsUniqueEnoughToBeInterestingForTestUsageNowEndOfMessage");
const std::vector<uint8_t> kOneHundredCharacterVec(
    kOneHundredCharacterString.begin(), kOneHundredCharacterString.end());

WriteFlags msgFlagsToWriteFlags(const int msg_flags) {
  WriteFlags flags = WriteFlags::NONE;
#ifdef MSG_MORE
  if (msg_flags & MSG_MORE) {
    flags = flags | WriteFlags::CORK;
  }
#endif // MSG_MORE

#ifdef MSG_EOR
  if (msg_flags & MSG_EOR) {
    flags = flags | WriteFlags::EOR;
  }
#endif

#ifdef MSG_ZEROCOPY
  if (msg_flags & MSG_ZEROCOPY) {
    flags = flags | WriteFlags::WRITE_MSG_ZEROCOPY;
  }
#endif
  return flags;
}

WriteFlags getMsgAncillaryTsFlags(const struct msghdr& msg) {
  const struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
  if (!cmsg || cmsg->cmsg_level != SOL_SOCKET ||
      cmsg->cmsg_type != SO_TIMESTAMPING ||
      cmsg->cmsg_len != CMSG_LEN(sizeof(uint32_t))) {
    return WriteFlags::NONE;
  }

  const uint32_t* sofFlags =
      (reinterpret_cast<const uint32_t*>(CMSG_DATA(cmsg)));
  WriteFlags flags = WriteFlags::NONE;
  if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_SCHED) {
    flags = flags | WriteFlags::TIMESTAMP_SCHED;
  }
  if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_SOFTWARE) {
    flags = flags | WriteFlags::TIMESTAMP_TX;
  }
  if (*sofFlags & folly::netops::SOF_TIMESTAMPING_TX_ACK) {
    flags = flags | WriteFlags::TIMESTAMP_ACK;
  }

  return flags;
}

WriteFlags getMsgAncillaryTsFlags(const struct msghdr* msg) {
  return getMsgAncillaryTsFlags(*msg);
}

MATCHER_P(SendmsgMsghdrHasTotalIovLen, len, "") {
  size_t iovLen = 0;
  for (size_t i = 0; i < arg.msg_iovlen; i++) {
    iovLen += arg.msg_iov[i].iov_len;
  }
  return len == iovLen;
}

MATCHER_P(SendmsgInvocHasTotalIovLen, len, "") {
  size_t iovLen = 0;
  for (const auto& iov : arg.iovs) {
    iovLen += iov.iov_len;
  }
  return len == iovLen;
}

MATCHER_P(SendmsgInvocHasIovFirstByte, firstBytePtr, "") {
  if (arg.iovs.empty()) {
    return false;
  }

  const auto& firstIov = arg.iovs.front();
  auto iovFirstBytePtr = const_cast<void*>(
      static_cast<const void*>(reinterpret_cast<uint8_t*>(firstIov.iov_base)));
  return firstBytePtr == iovFirstBytePtr;
}

MATCHER_P(SendmsgInvocHasIovLastByte, lastBytePtr, "") {
  if (arg.iovs.empty()) {
    return false;
  }

  const auto& lastIov = arg.iovs.back();
  auto iovLastBytePtr = const_cast<void*>(static_cast<const void*>(
      reinterpret_cast<uint8_t*>(lastIov.iov_base) + lastIov.iov_len - 1));
  return lastBytePtr == iovLastBytePtr;
}

MATCHER_P(SendmsgInvocMsgFlagsEq, writeFlags, "") {
  return writeFlags == arg.writeFlagsInMsgFlags;
}

MATCHER_P(SendmsgInvocAncillaryFlagsEq, writeFlags, "") {
  return writeFlags == arg.writeFlagsInAncillary;
}

MATCHER_P2(ByteEventMatching, type, offset, "") {
  if (type != arg.type || (size_t)offset != arg.offset) {
    return false;
  }
  return true;
}
} // namespace

class DelayedWrite : public AsyncTimeout {
 public:
  DelayedWrite(
      const std::shared_ptr<AsyncSocket>& socket,
      unique_ptr<IOBuf>&& bufs,
      AsyncTransportWrapper::WriteCallback* wcb,
      bool cork,
      bool lastWrite = false)
      : AsyncTimeout(socket->getEventBase()),
        socket_(socket),
        bufs_(std::move(bufs)),
        wcb_(wcb),
        cork_(cork),
        lastWrite_(lastWrite) {}

 private:
  void timeoutExpired() noexcept override {
    WriteFlags flags = cork_ ? WriteFlags::CORK : WriteFlags::NONE;
    socket_->writeChain(wcb_, std::move(bufs_), flags);
    if (lastWrite_) {
      socket_->shutdownWrite();
    }
  }

  std::shared_ptr<AsyncSocket> socket_;
  unique_ptr<IOBuf> bufs_;
  AsyncTransportWrapper::WriteCallback* wcb_;
  bool cork_;
  bool lastWrite_;
};

///////////////////////////////////////////////////////////////////////////
// constructor related tests
///////////////////////////////////////////////////////////////////////////

/**
 * Test constructing with an existing fd.
 */
TEST(AsyncSocketTest, ConstructWithFd) {
  // construct a pair of unix sockets
  NetworkSocket fds[2];
  {
    auto ret = netops::socketpair(AF_UNIX, SOCK_STREAM, 0, fds);
    EXPECT_EQ(0, ret);
  }

  // "client" socket
  auto cfd = fds[0];
  ASSERT_NE(cfd, NetworkSocket());

  // instantiate AsyncSocket w/o any connectionEstablishTimestamp
  EventBase evb;
  auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb, cfd));

  // should be no connect timestamps
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket->getConnectStartTime());
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket->getConnectEndTime());

  // should be no establish time, since not passed on construction
  EXPECT_FALSE(socket->getConnectionEstablishTime().has_value());
}

/**
 * Test constructing with an existing fd, passing a connection establish ts.
 */
TEST(AsyncSocketTest, ConstructWithFdAndTimestamp) {
  // construct a pair of unix sockets
  NetworkSocket fds[2];
  {
    auto ret = netops::socketpair(AF_UNIX, SOCK_STREAM, 0, fds);
    EXPECT_EQ(0, ret);
  }

  // "client" socket
  auto cfd = fds[0];
  ASSERT_NE(cfd, NetworkSocket());

  // instantiate AsyncSocket w/ a connectionEstablishTimestamp
  const auto connectionEstablishTime = std::chrono::steady_clock::now();
  EventBase evb;
  auto socket = AsyncSocket::UniquePtr(
      new AsyncSocket(&evb, cfd, 0, nullptr, connectionEstablishTime));

  // should be no connect timestamps
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket->getConnectStartTime());
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket->getConnectEndTime());

  // should have connection establish time, as passed on construction
  ASSERT_TRUE(socket->getConnectionEstablishTime().has_value());
  EXPECT_EQ(
      connectionEstablishTime, socket->getConnectionEstablishTime().value());
}

/**
 * Test constructing with an existing fd, then moving.
 */
TEST(AsyncSocketTest, ConstructWithFdThenMove) {
  // construct a pair of unix sockets
  NetworkSocket fds[2];
  {
    auto ret = netops::socketpair(AF_UNIX, SOCK_STREAM, 0, fds);
    EXPECT_EQ(0, ret);
  }

  // "client" socket
  auto cfd = fds[0];
  ASSERT_NE(cfd, NetworkSocket());

  // instantiate AsyncSocket
  EventBase evb;
  auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb, cfd));

  // should be no connect timestamps
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket->getConnectStartTime());
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket->getConnectEndTime());

  // should be no establish time, since not passed on construction
  EXPECT_FALSE(socket->getConnectionEstablishTime().has_value());

  // move the socket
  auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(std::move(socket)));

  // should still be no connect timestamps
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket2->getConnectStartTime());
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket2->getConnectEndTime());

  // should still be no establish time, since not passed on orig construction
  EXPECT_FALSE(socket2->getConnectionEstablishTime().has_value());
}

/**
 * Test constructing with an existing fd, then moving.
 */
TEST(AsyncSocketTest, ConstructWithFdAndTimestampThenMove) {
  // construct a pair of unix sockets
  NetworkSocket fds[2];
  {
    auto ret = netops::socketpair(AF_UNIX, SOCK_STREAM, 0, fds);
    EXPECT_EQ(0, ret);
  }

  // "client" socket
  auto cfd = fds[0];
  ASSERT_NE(cfd, NetworkSocket());

  // instantiate AsyncSocket w/ a connectionEstablishTimestamp
  const auto connectionEstablishTime = std::chrono::steady_clock::now();
  EventBase evb;
  auto socket = AsyncSocket::UniquePtr(
      new AsyncSocket(&evb, cfd, 0, nullptr, connectionEstablishTime));

  // should be no connect timestamps
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket->getConnectStartTime());
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket->getConnectEndTime());

  // should have connection establish time, as passed on construction
  ASSERT_TRUE(socket->getConnectionEstablishTime().has_value());
  EXPECT_EQ(
      connectionEstablishTime, socket->getConnectionEstablishTime().value());

  // move the socket
  auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(std::move(socket)));

  // should still be no connect timestamps
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket2->getConnectStartTime());
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket2->getConnectEndTime());

  // should have connection  establish time, as passed on orig construction
  ASSERT_TRUE(socket2->getConnectionEstablishTime().has_value());
  EXPECT_EQ(
      connectionEstablishTime, socket2->getConnectionEstablishTime().value());
}

///////////////////////////////////////////////////////////////////////////
// connect() tests
///////////////////////////////////////////////////////////////////////////

/**
 * Test connecting to a server
 */
TEST(AsyncSocketTest, Connect) {
  // Start listening on a local port
  TestServer server;

  // Connect using a AsyncSocket
  EventBase evb;
  auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket->getConnectStartTime());
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket->getConnectEndTime());
  EXPECT_FALSE(socket->getConnectionEstablishTime().has_value());
  ConnCallback cb;
  const auto startedAt = std::chrono::steady_clock::now();
  socket->connect(&cb, server.getAddress(), 30);

  evb.loop();
  const auto finishedAt = std::chrono::steady_clock::now();

  ASSERT_EQ(cb.state, STATE_SUCCEEDED);
  EXPECT_LE(0, socket->getConnectTime().count());
  EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());

  EXPECT_GE(socket->getConnectStartTime(), startedAt);
  EXPECT_LE(socket->getConnectStartTime(), socket->getConnectEndTime());
  EXPECT_LE(socket->getConnectEndTime(), finishedAt);

  // since connect() successful, the establish time == connect() end time
  ASSERT_TRUE(socket->getConnectionEstablishTime().has_value());
  EXPECT_EQ(
      socket->getConnectEndTime(),
      socket->getConnectionEstablishTime().value());
}

/**
 * Test connecting to a server, then move the socket.ΒΈ
 */
TEST(AsyncSocketTest, ConnectThenMove) {
  // Start listening on a local port
  TestServer server;

  // Connect using a AsyncSocket
  EventBase evb;
  auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket->getConnectStartTime());
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket->getConnectEndTime());
  EXPECT_FALSE(socket->getConnectionEstablishTime().has_value());
  ConnCallback cb;
  const auto startedAt = std::chrono::steady_clock::now();
  socket->connect(&cb, server.getAddress(), 30);

  evb.loop();
  const auto finishedAt = std::chrono::steady_clock::now();

  ASSERT_EQ(cb.state, STATE_SUCCEEDED);
  EXPECT_LE(0, socket->getConnectTime().count());
  EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());

  EXPECT_GE(socket->getConnectStartTime(), startedAt);
  EXPECT_LE(socket->getConnectStartTime(), socket->getConnectEndTime());
  EXPECT_LE(socket->getConnectEndTime(), finishedAt);

  // since connect() successful, the establish time == connect() end time
  ASSERT_TRUE(socket->getConnectionEstablishTime().has_value());
  EXPECT_EQ(
      socket->getConnectEndTime(),
      socket->getConnectionEstablishTime().value());

  // store timings, then move the socket
  const auto connectStartTime = socket->getConnectStartTime();
  const auto connectEndTime = socket->getConnectEndTime();
  auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(std::move(socket)));

  // timings should have been moved with the socket
  EXPECT_EQ(connectStartTime, socket2->getConnectStartTime());
  EXPECT_EQ(connectEndTime, socket2->getConnectEndTime());
  ASSERT_TRUE(socket2->getConnectionEstablishTime().has_value());
  EXPECT_EQ(connectEndTime, socket2->getConnectionEstablishTime().value());
}

/**
 * Test connecting to a server that isn't listening.
 */
TEST(AsyncSocketTest, ConnectRefused) {
  EventBase evb;
  auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket->getConnectStartTime());
  EXPECT_EQ(
      std::chrono::steady_clock::time_point(), socket->getConnectEndTime());
  EXPECT_FALSE(socket->getConnectionEstablishTime().has_value());

  // Hopefully nothing is actually listening on this address
  folly::SocketAddress addr("127.0.0.1", 65535);
  ConnCallback cb;
  const auto startedAt = std::chrono::steady_clock::now();
  socket->connect(&cb, addr, 30);

  evb.loop();
  const auto finishedAt = std::chrono::steady_clock::now();

  EXPECT_EQ(STATE_FAILED, cb.state);
  EXPECT_EQ(AsyncSocketException::NOT_OPEN, cb.exception.getType());
  EXPECT_LE(0, socket->getConnectTime().count());
  EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());

  EXPECT_GE(socket->getConnectStartTime(), startedAt);
  EXPECT_LE(socket->getConnectStartTime(), socket->getConnectEndTime());
  EXPECT_LE(socket->getConnectEndTime(), finishedAt);

  // since connect() failed, the establish time is empty.
  EXPECT_FALSE(socket->getConnectionEstablishTime().has_value());
}

/**
 * Test connection timeout
 */
TEST(AsyncSocketTest, ConnectTimeout) {
  EventBase evb;

  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);

  // Try connecting to server that won't respond.
  //
  // This depends somewhat on the network where this test is run.
  // Hopefully this IP will be routable but unresponsive.
  // (Alternatively, we could try listening on a local raw socket, but that
  // normally requires root privileges.)
  auto host = SocketAddressTestHelper::isIPv6Enabled()
      ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
      : SocketAddressTestHelper::isIPv4Enabled()
      ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
      : nullptr;
  SocketAddress addr(host, 65535);
  ConnCallback cb;
  const auto startedAt = std::chrono::steady_clock::now();
  socket->connect(&cb, addr, 1); // also set a ridiculously small timeout

  evb.loop();
  const auto finishedAt = std::chrono::steady_clock::now();

  ASSERT_EQ(cb.state, STATE_FAILED);
  if (cb.exception.getType() == AsyncSocketException::NOT_OPEN) {
    // This can happen if we could not route to the IP address picked above.
    // In this case the connect will fail immediately rather than timing out.
    // Just skip the test in this case.
    SKIP() << "do not have a routable but unreachable IP address";
  }
  ASSERT_EQ(cb.exception.getType(), AsyncSocketException::TIMED_OUT);

  EXPECT_LE(0, socket->getConnectTime().count());
  EXPECT_EQ(std::chrono::milliseconds(1), socket->getConnectTimeout());

  EXPECT_GE(socket->getConnectStartTime(), startedAt);
  EXPECT_LE(socket->getConnectStartTime(), socket->getConnectEndTime());
  EXPECT_LE(socket->getConnectEndTime(), finishedAt);

  // since connect() failed, the establish time is empty.
  EXPECT_FALSE(socket->getConnectionEstablishTime().has_value());

  // Verify that we can still get the peer address after a timeout.
  // Use case is if the client was created from a client pool, and we want
  // to log which peer failed.
  folly::SocketAddress peer;
  socket->getPeerAddress(&peer);
  ASSERT_EQ(peer, addr);
  EXPECT_LE(0, socket->getConnectTime().count());
  EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(1));
}

enum class TFOState {
  DISABLED,
  ENABLED,
};

class AsyncSocketConnectTest : public ::testing::TestWithParam<TFOState> {};

std::vector<TFOState> getTestingValues() {
  std::vector<TFOState> vals;
  vals.emplace_back(TFOState::DISABLED);

#if FOLLY_ALLOW_TFO
  vals.emplace_back(TFOState::ENABLED);
#endif
  return vals;
}

INSTANTIATE_TEST_SUITE_P(
    ConnectTests,
    AsyncSocketConnectTest,
    ::testing::ValuesIn(getTestingValues()));

/**
 * Test writing immediately after connecting, without waiting for connect
 * to finish.
 */
TEST_P(AsyncSocketConnectTest, ConnectAndWrite) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);

  if (GetParam() == TFOState::ENABLED) {
    socket->enableTFO();
  }

  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // write()
  char buf[128];
  memset(buf, 'a', sizeof(buf));
  WriteCallback wcb(true /*enableReleaseIOBufCallback*/);
  // use writeChain so we can pass an IOBuf
  socket->writeChain(&wcb, IOBuf::copyBuffer(buf, sizeof(buf)));

  // Loop.  We don't bother accepting on the server socket yet.
  // The kernel should be able to buffer the write request so it can succeed.
  evb.loop();

  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
  ASSERT_EQ(wcb.numIoBufCount, 1);
  ASSERT_EQ(wcb.numIoBufBytes, sizeof(buf));

  // Make sure the server got a connection and received the data
  socket->close();
  server.verifyConnection(buf, sizeof(buf));

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
  EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
}

/**
 * Test connecting using a nullptr connect callback.
 */
TEST_P(AsyncSocketConnectTest, ConnectNullCallback) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  if (GetParam() == TFOState::ENABLED) {
    socket->enableTFO();
  }

  socket->connect(nullptr, server.getAddress(), 30);

  // write some data, just so we have some way of verifing
  // that the socket works correctly after connecting
  char buf[128];
  memset(buf, 'a', sizeof(buf));
  WriteCallback wcb;
  socket->write(&wcb, buf, sizeof(buf));

  evb.loop();

  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);

  // Make sure the server got a connection and received the data
  socket->close();
  server.verifyConnection(buf, sizeof(buf));

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

/**
 * Test calling both write() and close() immediately after connecting, without
 * waiting for connect to finish.
 *
 * This exercises the STATE_CONNECTING_CLOSING code.
 */
TEST_P(AsyncSocketConnectTest, ConnectWriteAndClose) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  if (GetParam() == TFOState::ENABLED) {
    socket->enableTFO();
  }
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // write()
  char buf[128];
  memset(buf, 'a', sizeof(buf));
  WriteCallback wcb;
  socket->write(&wcb, buf, sizeof(buf));

  // close()
  socket->close();

  // Loop.  We don't bother accepting on the server socket yet.
  // The kernel should be able to buffer the write request so it can succeed.
  evb.loop();

  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);

  // Make sure the server got a connection and received the data
  server.verifyConnection(buf, sizeof(buf));

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

/**
 * Test calling close() immediately after connect()
 */
TEST(AsyncSocketTest, ConnectAndClose) {
  TestServer server;

  // Connect using a AsyncSocket
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // Hopefully the connect didn't succeed immediately.
  // If it did, we can't exercise the close-while-connecting code path.
  if (ccb.state == STATE_SUCCEEDED) {
    LOG(INFO) << "connect() succeeded immediately; aborting test "
                 "of close-during-connect behavior";
    return;
  }

  socket->close();

  // Loop, although there shouldn't be anything to do.
  evb.loop();

  // Make sure the connection was aborted
  ASSERT_EQ(ccb.state, STATE_FAILED);

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

/**
 * Test calling closeNow() immediately after connect()
 *
 * This should be identical to the normal close behavior.
 */
TEST(AsyncSocketTest, ConnectAndCloseNow) {
  TestServer server;

  // Connect using a AsyncSocket
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // Hopefully the connect didn't succeed immediately.
  // If it did, we can't exercise the close-while-connecting code path.
  if (ccb.state == STATE_SUCCEEDED) {
    LOG(INFO) << "connect() succeeded immediately; aborting test "
                 "of closeNow()-during-connect behavior";
    return;
  }

  socket->closeNow();

  // Loop, although there shouldn't be anything to do.
  evb.loop();

  // Make sure the connection was aborted
  ASSERT_EQ(ccb.state, STATE_FAILED);

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

/**
 * Test calling both write() and closeNow() immediately after connecting,
 * without waiting for connect to finish.
 *
 * This should abort the pending write.
 */
TEST(AsyncSocketTest, ConnectWriteAndCloseNow) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // Hopefully the connect didn't succeed immediately.
  // If it did, we can't exercise the close-while-connecting code path.
  if (ccb.state == STATE_SUCCEEDED) {
    LOG(INFO) << "connect() succeeded immediately; aborting test "
                 "of write-during-connect behavior";
    return;
  }

  // write()
  char buf[128];
  memset(buf, 'a', sizeof(buf));
  WriteCallback wcb;
  socket->write(&wcb, buf, sizeof(buf));

  // close()
  socket->closeNow();

  // Loop, although there shouldn't be anything to do.
  evb.loop();

  ASSERT_EQ(ccb.state, STATE_FAILED);
  ASSERT_EQ(wcb.state, STATE_FAILED);

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

/**
 * Test installing a read callback immediately, before connect() finishes.
 */
TEST_P(AsyncSocketConnectTest, ConnectAndRead) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  if (GetParam() == TFOState::ENABLED) {
    socket->enableTFO();
  }

  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  ReadCallback rcb;
  socket->setReadCB(&rcb);

  if (GetParam() == TFOState::ENABLED) {
    // Trigger a connection
    socket->writeChain(nullptr, IOBuf::copyBuffer("hey"));
  }

  // Even though we haven't looped yet, we should be able to accept
  // the connection and send data to it.
  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
  uint8_t buf[128];
  memset(buf, 'a', sizeof(buf));
  acceptedSocket->write(buf, sizeof(buf));
  acceptedSocket->flush();
  acceptedSocket->close();

  // Loop, although there shouldn't be anything to do.
  evb.loop();

  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
  ASSERT_EQ(rcb.buffers.size(), 1);
  ASSERT_EQ(rcb.buffers[0].length, sizeof(buf));
  ASSERT_EQ(memcmp(rcb.buffers[0].buffer, buf, sizeof(buf)), 0);

  ASSERT_FALSE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

TEST_P(AsyncSocketConnectTest, ConnectAndReadv) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  if (GetParam() == TFOState::ENABLED) {
    socket->enableTFO();
  }

  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  static constexpr size_t kBuffSize = 10;
  static constexpr size_t kLen = 40;
  static constexpr size_t kDataSize = 128;

  ReadvCallback rcb(kBuffSize, kLen);
  socket->setReadCB(&rcb);

  if (GetParam() == TFOState::ENABLED) {
    // Trigger a connection
    socket->writeChain(nullptr, IOBuf::copyBuffer("hey"));
  }

  // Even though we haven't looped yet, we should be able to accept
  // the connection and send data to it.
  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
  std::string data(kDataSize, 'A');
  acceptedSocket->write(
      reinterpret_cast<unsigned char*>(data.data()), data.size());
  acceptedSocket->flush();
  acceptedSocket->close();

  // Loop, although there shouldn't be anything to do.
  evb.loop();

  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
  rcb.verifyData(data);

  ASSERT_FALSE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

TEST_P(AsyncSocketConnectTest, ConnectAndZeroCopyRead) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  if (GetParam() == TFOState::ENABLED) {
    socket->enableTFO();
  }

  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  static constexpr size_t kBuffSize = 4096;
  static constexpr size_t kDataSize = 32 * 1024;

  static constexpr size_t kNumEntries = 1024;
  static constexpr size_t kEntrySize = 128 * 1024;

  auto memStore =
      AsyncSocket::createDefaultZeroCopyMemStore(kNumEntries, kEntrySize);
  ZeroCopyReadCallback rcb(memStore.get(), kBuffSize);
  socket->setReadCB(&rcb);

  if (GetParam() == TFOState::ENABLED) {
    // Trigger a connection
    socket->writeChain(nullptr, IOBuf::copyBuffer("hey"));
  }

  // Even though we haven't looped yet, we should be able to accept
  // the connection and send data to it.
  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
  std::string data(kDataSize, ' ');
  // generate random data
  std::mt19937 rng(folly::randomNumberSeed());
  for (size_t i = 0; i < data.size(); ++i) {
    data[i] = static_cast<char>(rng());
  }
  auto ret = acceptedSocket->write(
      reinterpret_cast<unsigned char*>(data.data()), data.size());
  ASSERT_EQ(ret, data.size());
  acceptedSocket->flush();
  acceptedSocket->close();

  // Loop
  evb.loop();

  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
  rcb.verifyData(data);

  ASSERT_FALSE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

/**
 * Test installing a read callback and then closing immediately before the
 * connect attempt finishes.
 */
TEST(AsyncSocketTest, ConnectReadAndClose) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // Hopefully the connect didn't succeed immediately.
  // If it did, we can't exercise the close-while-connecting code path.
  if (ccb.state == STATE_SUCCEEDED) {
    LOG(INFO) << "connect() succeeded immediately; aborting test "
                 "of read-during-connect behavior";
    return;
  }

  ReadCallback rcb;
  socket->setReadCB(&rcb);

  // close()
  socket->close();

  // Loop, although there shouldn't be anything to do.
  evb.loop();

  ASSERT_EQ(ccb.state, STATE_FAILED); // we aborted the close attempt
  ASSERT_EQ(rcb.buffers.size(), 0);
  ASSERT_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

/**
 * Test both writing and installing a read callback immediately,
 * before connect() finishes.
 */
TEST_P(AsyncSocketConnectTest, ConnectWriteAndRead) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  if (GetParam() == TFOState::ENABLED) {
    socket->enableTFO();
  }
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // write()
  char buf1[128];
  memset(buf1, 'a', sizeof(buf1));
  WriteCallback wcb;
  socket->write(&wcb, buf1, sizeof(buf1));

  // set a read callback
  ReadCallback rcb;
  socket->setReadCB(&rcb);

  // Even though we haven't looped yet, we should be able to accept
  // the connection and send data to it.
  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
  uint8_t buf2[128];
  memset(buf2, 'b', sizeof(buf2));
  acceptedSocket->write(buf2, sizeof(buf2));
  acceptedSocket->flush();

  // shut down the write half of acceptedSocket, so that the AsyncSocket
  // will stop reading and we can break out of the event loop.
  netops::shutdown(acceptedSocket->getNetworkSocket(), SHUT_WR);

  // Loop
  evb.loop();

  // Make sure the connect succeeded
  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);

  // Make sure the AsyncSocket read the data written by the accepted socket
  ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
  ASSERT_EQ(rcb.buffers.size(), 1);
  ASSERT_EQ(rcb.buffers[0].length, sizeof(buf2));
  ASSERT_EQ(memcmp(rcb.buffers[0].buffer, buf2, sizeof(buf2)), 0);

  // Close the AsyncSocket so we'll see EOF on acceptedSocket
  socket->close();

  // Make sure the accepted socket saw the data written by the AsyncSocket
  uint8_t readbuf[sizeof(buf1)];
  acceptedSocket->readAll(readbuf, sizeof(readbuf));
  ASSERT_EQ(memcmp(buf1, readbuf, sizeof(buf1)), 0);
  uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
  ASSERT_EQ(bytesRead, 0);

  ASSERT_FALSE(socket->isClosedBySelf());
  ASSERT_TRUE(socket->isClosedByPeer());
}

/**
 * Test writing to the socket then shutting down writes before the connect
 * attempt finishes.
 */
TEST(AsyncSocketTest, ConnectWriteAndShutdownWrite) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // Hopefully the connect didn't succeed immediately.
  // If it did, we can't exercise the write-while-connecting code path.
  if (ccb.state == STATE_SUCCEEDED) {
    LOG(INFO) << "connect() succeeded immediately; skipping test";
    return;
  }

  // Ask to write some data
  char wbuf[128];
  memset(wbuf, 'a', sizeof(wbuf));
  WriteCallback wcb;
  socket->write(&wcb, wbuf, sizeof(wbuf));
  socket->shutdownWrite();

  // Shutdown writes
  socket->shutdownWrite();

  // Even though we haven't looped yet, we should be able to accept
  // the connection.
  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();

  // Since the connection is still in progress, there should be no data to
  // read yet.  Verify that the accepted socket is not readable.
  netops::PollDescriptor fds[1];
  fds[0].fd = acceptedSocket->getNetworkSocket();
  fds[0].events = POLLIN;
  fds[0].revents = 0;
  int rc = netops::poll(fds, 1, 0);
  ASSERT_EQ(rc, 0);

  // Write data to the accepted socket
  uint8_t acceptedWbuf[192];
  memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
  acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
  acceptedSocket->flush();

  // Loop
  evb.loop();

  // The loop should have completed the connection, written the queued data,
  // and shutdown writes on the socket.
  //
  // Check that the connection was completed successfully and that the write
  // callback succeeded.
  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);

  // Check that we can read the data that was written to the socket, and that
  // we see an EOF, since its socket was half-shutdown.
  uint8_t readbuf[sizeof(wbuf)];
  acceptedSocket->readAll(readbuf, sizeof(readbuf));
  ASSERT_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
  uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
  ASSERT_EQ(bytesRead, 0);

  // Close the accepted socket.  This will cause it to see EOF
  // and uninstall the read callback when we loop next.
  acceptedSocket->close();

  // Install a read callback, then loop again.
  ReadCallback rcb;
  socket->setReadCB(&rcb);
  evb.loop();

  // This loop should have read the data and seen the EOF
  ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
  ASSERT_EQ(rcb.buffers.size(), 1);
  ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
  ASSERT_EQ(
      memcmp(rcb.buffers[0].buffer, acceptedWbuf, sizeof(acceptedWbuf)), 0);

  ASSERT_FALSE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

/**
 * Test reading, writing, and shutting down writes before the connect attempt
 * finishes.
 */
TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWrite) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // Hopefully the connect didn't succeed immediately.
  // If it did, we can't exercise the write-while-connecting code path.
  if (ccb.state == STATE_SUCCEEDED) {
    LOG(INFO) << "connect() succeeded immediately; skipping test";
    return;
  }

  // Install a read callback
  ReadCallback rcb;
  socket->setReadCB(&rcb);

  // Ask to write some data
  char wbuf[128];
  memset(wbuf, 'a', sizeof(wbuf));
  WriteCallback wcb;
  socket->write(&wcb, wbuf, sizeof(wbuf));

  // Shutdown writes
  socket->shutdownWrite();

  // Even though we haven't looped yet, we should be able to accept
  // the connection.
  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();

  // Since the connection is still in progress, there should be no data to
  // read yet.  Verify that the accepted socket is not readable.
  netops::PollDescriptor fds[1];
  fds[0].fd = acceptedSocket->getNetworkSocket();
  fds[0].events = POLLIN;
  fds[0].revents = 0;
  int rc = netops::poll(fds, 1, 0);
  ASSERT_EQ(rc, 0);

  // Write data to the accepted socket
  uint8_t acceptedWbuf[192];
  memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
  acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
  acceptedSocket->flush();
  // Shutdown writes to the accepted socket.  This will cause it to see EOF
  // and uninstall the read callback.
  netops::shutdown(acceptedSocket->getNetworkSocket(), SHUT_WR);

  // Loop
  evb.loop();

  // The loop should have completed the connection, written the queued data,
  // shutdown writes on the socket, read the data we wrote to it, and see the
  // EOF.
  //
  // Check that the connection was completed successfully and that the read
  // and write callbacks were invoked as expected.
  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
  ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
  ASSERT_EQ(rcb.buffers.size(), 1);
  ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
  ASSERT_EQ(
      memcmp(rcb.buffers[0].buffer, acceptedWbuf, sizeof(acceptedWbuf)), 0);
  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);

  // Check that we can read the data that was written to the socket, and that
  // we see an EOF, since its socket was half-shutdown.
  uint8_t readbuf[sizeof(wbuf)];
  acceptedSocket->readAll(readbuf, sizeof(readbuf));
  ASSERT_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
  uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
  ASSERT_EQ(bytesRead, 0);

  // Fully close both sockets
  acceptedSocket->close();
  socket->close();

  ASSERT_FALSE(socket->isClosedBySelf());
  ASSERT_TRUE(socket->isClosedByPeer());
}

/**
 * Test reading, writing, and calling shutdownWriteNow() before the
 * connect attempt finishes.
 */
TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWriteNow) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // Hopefully the connect didn't succeed immediately.
  // If it did, we can't exercise the write-while-connecting code path.
  if (ccb.state == STATE_SUCCEEDED) {
    LOG(INFO) << "connect() succeeded immediately; skipping test";
    return;
  }

  // Install a read callback
  ReadCallback rcb;
  socket->setReadCB(&rcb);

  // Ask to write some data
  char wbuf[128];
  memset(wbuf, 'a', sizeof(wbuf));
  WriteCallback wcb;
  socket->write(&wcb, wbuf, sizeof(wbuf));

  // Shutdown writes immediately.
  // This should immediately discard the data that we just tried to write.
  socket->shutdownWriteNow();

  // Verify that writeError() was invoked on the write callback.
  ASSERT_EQ(wcb.state, STATE_FAILED);
  ASSERT_EQ(wcb.bytesWritten, 0);

  // Even though we haven't looped yet, we should be able to accept
  // the connection.
  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();

  // Since the connection is still in progress, there should be no data to
  // read yet.  Verify that the accepted socket is not readable.
  netops::PollDescriptor fds[1];
  fds[0].fd = acceptedSocket->getNetworkSocket();
  fds[0].events = POLLIN;
  fds[0].revents = 0;
  int rc = netops::poll(fds, 1, 0);
  ASSERT_EQ(rc, 0);

  // Write data to the accepted socket
  uint8_t acceptedWbuf[192];
  memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
  acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
  acceptedSocket->flush();
  // Shutdown writes to the accepted socket.  This will cause it to see EOF
  // and uninstall the read callback.
  netops::shutdown(acceptedSocket->getNetworkSocket(), SHUT_WR);

  // Loop
  evb.loop();

  // The loop should have completed the connection, written the queued data,
  // shutdown writes on the socket, read the data we wrote to it, and see the
  // EOF.
  //
  // Check that the connection was completed successfully and that the read
  // callback was invoked as expected.
  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
  ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
  ASSERT_EQ(rcb.buffers.size(), 1);
  ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
  ASSERT_EQ(
      memcmp(rcb.buffers[0].buffer, acceptedWbuf, sizeof(acceptedWbuf)), 0);

  // Since we used shutdownWriteNow(), it should have discarded all pending
  // write data.  Verify we see an immediate EOF when reading from the accepted
  // socket.
  uint8_t readbuf[sizeof(wbuf)];
  uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
  ASSERT_EQ(bytesRead, 0);

  // Fully close both sockets
  acceptedSocket->close();
  socket->close();

  ASSERT_FALSE(socket->isClosedBySelf());
  ASSERT_TRUE(socket->isClosedByPeer());
}

// Helper function for use in testConnectOptWrite()
// Temporarily disable the read callback
void tmpDisableReads(AsyncSocket* socket, ReadCallback* rcb) {
  // Uninstall the read callback
  socket->setReadCB(nullptr);
  // Schedule the read callback to be reinstalled after 1ms
  socket->getEventBase()->runInLoop(
      std::bind(&AsyncSocket::setReadCB, socket, rcb));
}

/**
 * Test connect+write, then have the connect callback perform another write.
 *
 * This tests interaction of the optimistic writing after connect with
 * additional write attempts that occur in the connect callback.
 */
void testConnectOptWrite(size_t size1, size_t size2, bool close = false) {
  TestServer server;
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);

  // connect()
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // Hopefully the connect didn't succeed immediately.
  // If it did, we can't exercise the optimistic write code path.
  if (ccb.state == STATE_SUCCEEDED) {
    LOG(INFO) << "connect() succeeded immediately; aborting test "
                 "of optimistic write behavior";
    return;
  }

  // Tell the connect callback to perform a write when the connect succeeds
  WriteCallback wcb2;
  std::unique_ptr<char[]> buf2(new char[size2]);
  memset(buf2.get(), 'b', size2);
  if (size2 > 0) {
    ccb.successCallback = [&] { socket->write(&wcb2, buf2.get(), size2); };
    // Tell the second write callback to close the connection when it is done
    wcb2.successCallback = [&] { socket->closeNow(); };
  }

  // Schedule one write() immediately, before the connect finishes
  std::unique_ptr<char[]> buf1(new char[size1]);
  memset(buf1.get(), 'a', size1);
  WriteCallback wcb1;
  if (size1 > 0) {
    socket->write(&wcb1, buf1.get(), size1);
  }

  if (close) {
    // immediately perform a close, before connect() completes
    socket->close();
  }

  // Start reading from the other endpoint after 10ms.
  // If we're using large buffers, we have to read so that the writes don't
  // block forever.
  std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
  ReadCallback rcb;
  rcb.dataAvailableCallback =
      std::bind(tmpDisableReads, acceptedSocket.get(), &rcb);
  socket->getEventBase()->tryRunAfterDelay(
      std::bind(&AsyncSocket::setReadCB, acceptedSocket.get(), &rcb), 10);

  // Loop.  We don't bother accepting on the server socket yet.
  // The kernel should be able to buffer the write request so it can succeed.
  evb.loop();

  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
  if (size1 > 0) {
    ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
  }
  if (size2 > 0) {
    ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
  }

  socket->close();

  // Make sure the read callback received all of the data
  size_t bytesRead = 0;
  for (const auto& buffer : rcb.buffers) {
    size_t start = bytesRead;
    bytesRead += buffer.length;
    size_t end = bytesRead;
    if (start < size1) {
      size_t cmpLen = min(size1, end) - start;
      ASSERT_EQ(memcmp(buffer.buffer, buf1.get() + start, cmpLen), 0);
    }
    if (end > size1 && end <= size1 + size2) {
      size_t itOffset;
      size_t buf2Offset;
      size_t cmpLen;
      if (start >= size1) {
        itOffset = 0;
        buf2Offset = start - size1;
        cmpLen = end - start;
      } else {
        itOffset = size1 - start;
        buf2Offset = 0;
        cmpLen = end - size1;
      }
      ASSERT_EQ(
          memcmp(buffer.buffer + itOffset, buf2.get() + buf2Offset, cmpLen), 0);
    }
  }
  ASSERT_EQ(bytesRead, size1 + size2);
}

TEST(AsyncSocketTest, ConnectCallbackWrite) {
  // Test using small writes that should both succeed immediately
  testConnectOptWrite(100, 200);

  // Test using a large buffer in the connect callback, that should block
  const size_t largeSize = 32 * 1024 * 1024;
  testConnectOptWrite(100, largeSize);

  // Test using a large initial write
  testConnectOptWrite(largeSize, 100);

  // Test using two large buffers
  testConnectOptWrite(largeSize, largeSize);

  // Test a small write in the connect callback,
  // but no immediate write before connect completes
  testConnectOptWrite(0, 64);

  // Test a large write in the connect callback,
  // but no immediate write before connect completes
  testConnectOptWrite(0, largeSize);

  // Test connect, a small write, then immediately call close() before connect
  // completes
  testConnectOptWrite(211, 0, true);

  // Test connect, a large immediate write (that will block), then immediately
  // call close() before connect completes
  testConnectOptWrite(largeSize, 0, true);
}

///////////////////////////////////////////////////////////////////////////
// write() related tests
///////////////////////////////////////////////////////////////////////////

/**
 * Test writing using a nullptr callback
 */
TEST(AsyncSocketTest, WriteNullCallback) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket =
      AsyncSocket::newSocket(&evb, server.getAddress(), 30);
  evb.loop(); // loop until the socket is connected

  // write() with a nullptr callback
  char buf[128];
  memset(buf, 'a', sizeof(buf));
  socket->write(nullptr, buf, sizeof(buf));

  evb.loop(); // loop until the data is sent

  // Make sure the server got a connection and received the data
  socket->close();
  server.verifyConnection(buf, sizeof(buf));

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

/**
 * Test writing with a send timeout
 */
TEST(AsyncSocketTest, WriteTimeout) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket =
      AsyncSocket::newSocket(&evb, server.getAddress(), 30);
  evb.loop(); // loop until the socket is connected

  // write() a large chunk of data, with no-one on the other end reading.
  // Tricky: the kernel caches the connection metrics for recently-used
  // routes (see tcp_no_metrics_save) so a freshly opened connection can
  // have a send buffer size bigger than wmem_default.  This makes the test
  // flaky on contbuild if writeLength is < wmem_max (20M on our systems).
  size_t writeLength = 32 * 1024 * 1024;
  uint32_t timeout = 200;
  socket->setSendTimeout(timeout);
  std::unique_ptr<char[]> buf(new char[writeLength]);
  memset(buf.get(), 'a', writeLength);
  WriteCallback wcb;
  socket->write(&wcb, buf.get(), writeLength);

  TimePoint start;
  evb.loop();
  TimePoint end;

  // Make sure the write attempt timed out as requested
  ASSERT_EQ(wcb.state, STATE_FAILED);
  ASSERT_EQ(wcb.exception.getType(), AsyncSocketException::TIMED_OUT);

  // Check that the write timed out within a reasonable period of time.
  // We don't check for exactly the specified timeout, since AsyncSocket only
  // times out when it hasn't made progress for that period of time.
  //
  // On linux, the first write sends a few hundred kb of data, then blocks for
  // writability, and then unblocks again after 40ms and is able to write
  // another smaller of data before blocking permanently.  Therefore it doesn't
  // time out until 40ms + timeout.
  //
  // I haven't fully verified the cause of this, but I believe it probably
  // occurs because the receiving end delays sending an ack for up to 40ms.
  // (This is the default value for TCP_DELACK_MIN.)  Once the sender receives
  // the ack, it can send some more data.  However, after that point the
  // receiver's kernel buffer is full.  This 40ms delay happens even with
  // TCP_NODELAY and TCP_QUICKACK enabled on both endpoints.  However, the
  // kernel may be automatically disabling TCP_QUICKACK after receiving some
  // data.
  //
  // For now, we simply check that the timeout occurred within 160ms of
  // the requested value.
  T_CHECK_TIMEOUT(start, end, milliseconds(timeout), milliseconds(160));
}

/**
 * Test getting local and peer addresses with no fd.
 *
 * Value returned should be empty; no failure should occur.
 */
TEST(AsyncSocketTest, GetAddressesNoFd) {
  EventBase evb;
  auto socket = AsyncSocket::newSocket(&evb);

  {
    folly::SocketAddress address;
    socket->getLocalAddress(&address);
    EXPECT_TRUE(address.empty());
  }

  {
    folly::SocketAddress address;
    socket->getPeerAddress(&address);
    EXPECT_TRUE(address.empty());
  }
}

/**
 * Test getting local and peer addresses after connecting.
 */
TEST(AsyncSocketTest, GetAddressesAfterConnectGetwhileopenandonclose) {
  EventBase evb;
  auto socket = AsyncSocket::newSocket(&evb);

  // Start listening on a local port
  TestServer server;

  // Connect
  {
    ConnCallback cb;
    socket->connect(&cb, server.getAddress(), 30);
    evb.loop();
    ASSERT_EQ(cb.state, STATE_SUCCEEDED);
  }

  // Get local, make sure it's not empty and not equal to server
  const folly::SocketAddress localAddress = [&socket]() {
    folly::SocketAddress address;
    socket->getLocalAddress(&address);
    return address;
  }();
  EXPECT_FALSE(localAddress.empty());
  EXPECT_NE(server.getAddress(), localAddress);

  const folly::SocketAddress peerAddress = [&socket]() {
    folly::SocketAddress address;
    socket->getPeerAddress(&address);
    return address;
  }();
  EXPECT_FALSE(peerAddress.empty());
  EXPECT_EQ(server.getAddress(), peerAddress);

  // Close
  socket->closeNow();

  // Addresses should still be available as they're cached
  const folly::SocketAddress localAddress2 = [&socket]() {
    folly::SocketAddress address;
    socket->getLocalAddress(&address);
    return address;
  }();
  EXPECT_EQ(localAddress2, localAddress);

  const folly::SocketAddress peerAddress2 = [&socket]() {
    folly::SocketAddress address;
    socket->getPeerAddress(&address);
    return address;
  }();
  EXPECT_EQ(peerAddress2, peerAddress);
}

/**
 * Test getting local and peer addresses after closing.
 *
 * Only peer address is available under these conditions.
 */
TEST(AsyncSocketTest, GetAddressesAfterConnectGetonlyafterclose) {
  EventBase evb;
  auto socket = AsyncSocket::newSocket(&evb);

  // Start listening on a local port
  TestServer server;

  // Connect
  {
    ConnCallback cb;
    socket->connect(&cb, server.getAddress(), 30);
    evb.loop();
    ASSERT_EQ(cb.state, STATE_SUCCEEDED);
  }

  // Close
  socket->closeNow();

  // Local address unavailable since never fetched
  {
    folly::SocketAddress address;
    socket->getLocalAddress(&address);
    EXPECT_TRUE(address.empty());
  }

  // Peer address available since it was passed to connect()
  {
    folly::SocketAddress address;
    socket->getPeerAddress(&address);
    EXPECT_FALSE(address.empty());
    EXPECT_EQ(server.getAddress(), address);
  }
}

/**
 * Test getting local and peer addresses after connecting.
 */
TEST(AsyncSocketTest, GetAddressesAfterInitFromFdGetoninitandonclose) {
  EventBase evb;

  // Start listening on a local port
  TestServer server;

  // Create a socket, connect, then create another AsyncSocket from just fd
  auto socket = [&server, &evb]() {
    auto socket1 = AsyncSocket::newSocket(&evb);
    ConnCallback cb;
    socket1->connect(&cb, server.getAddress(), 30);
    evb.loop();
    return AsyncSocket::newSocket(&evb, socket1->detachNetworkSocket());
  }();

  // Get local, make sure it's not empty and not equal to server
  const folly::SocketAddress localAddress = [&socket]() {
    folly::SocketAddress address;
    socket->getLocalAddress(&address);
    return address;
  }();
  EXPECT_FALSE(localAddress.empty());
  EXPECT_NE(server.getAddress(), localAddress);

  const folly::SocketAddress peerAddress = [&socket]() {
    folly::SocketAddress address;
    socket->getPeerAddress(&address);
    return address;
  }();
  EXPECT_FALSE(peerAddress.empty());
  EXPECT_EQ(server.getAddress(), peerAddress);

  // Close
  socket->closeNow();

  // Addresses should still be available as they're cached
  const folly::SocketAddress localAddress2 = [&socket]() {
    folly::SocketAddress address;
    socket->getLocalAddress(&address);
    return address;
  }();
  EXPECT_EQ(localAddress2, localAddress);

  const folly::SocketAddress peerAddress2 = [&socket]() {
    folly::SocketAddress address;
    socket->getPeerAddress(&address);
    return address;
  }();
  EXPECT_EQ(peerAddress2, peerAddress);
}

/**
 * Test writing to a socket that the remote endpoint has closed
 */
TEST(AsyncSocketTest, WritePipeError) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket =
      AsyncSocket::newSocket(&evb, server.getAddress(), 30);
  socket->setSendTimeout(1000);
  evb.loop(); // loop until the socket is connected

  // accept and immediately close the socket
  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
  acceptedSocket->close();

  // write() a large chunk of data
  size_t writeLength = 32 * 1024 * 1024;
  std::unique_ptr<char[]> buf(new char[writeLength]);
  memset(buf.get(), 'a', writeLength);
  WriteCallback wcb;
  socket->write(&wcb, buf.get(), writeLength);

  evb.loop();

  // Make sure the write failed.
  // It would be nice if AsyncSocketException could convey the errno value,
  // so that we could check for EPIPE
  ASSERT_EQ(wcb.state, STATE_FAILED);
  ASSERT_EQ(wcb.exception.getType(), AsyncSocketException::INTERNAL_ERROR);
  ASSERT_THAT(
      wcb.exception.what(),
      MatchesRegex(
          kIsMobile
              ? "AsyncSocketException: writev\\(\\) failed \\(peer=.+\\), type = Internal error, errno = .+ \\(Broken pipe\\)"
              : "AsyncSocketException: writev\\(\\) failed \\(peer=.+, local=.+\\), type = Internal error, errno = .+ \\(Broken pipe\\)"));
  ASSERT_FALSE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

/**
 * Test writing to a socket that has its read side closed
 */
TEST(AsyncSocketTest, WriteAfterReadEOF) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket =
      AsyncSocket::newSocket(&evb, server.getAddress(), 30);
  evb.loop(); // loop until the socket is connected

  // Accept the connection
  std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
  ReadCallback rcb;
  acceptedSocket->setReadCB(&rcb);

  // Shutdown the write side of client socket (read side of server socket)
  socket->shutdownWrite();
  evb.loop();

  // Check that accepted socket is still writable
  ASSERT_FALSE(acceptedSocket->good());
  ASSERT_TRUE(acceptedSocket->writable());

  // Write data to accepted socket
  constexpr size_t simpleBufLength = 5;
  char simpleBuf[simpleBufLength];
  memset(simpleBuf, 'a', simpleBufLength);
  WriteCallback wcb;
  acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
  evb.loop();

  // Make sure we were able to write even after getting a read EOF
  ASSERT_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
}

/**
 * Test that bytes written is correctly computed in case of write failure
 */
TEST(AsyncSocketTest, WriteErrorCallbackBytesWritten) {
  // Send and receive buffer sizes for the sockets.
  // Note that Linux will double this value to allow space for bookkeeping
  // overhead.
  constexpr size_t kSockBufSize = 8 * 1024;
  constexpr size_t kEffectiveSockBufSize = 2 * kSockBufSize;

  TestServer server(false, kSockBufSize);

  SocketOptionMap options{
      {{SOL_SOCKET, SO_SNDBUF}, int(kSockBufSize)},
      {{SOL_SOCKET, SO_RCVBUF}, int(kSockBufSize)},
      {{IPPROTO_TCP, TCP_NODELAY}, 1},
  };

  // The current thread will be used by the receiver - use a separate thread
  // for the sender.
  EventBase senderEvb;
  std::thread senderThread([&]() { senderEvb.loopForever(); });

  ConnCallback ccb;
  WriteCallback wcb;
  std::shared_ptr<AsyncSocket> socket;

  senderEvb.runInEventBaseThreadAndWait([&]() {
    socket = AsyncSocket::newSocket(&senderEvb);
    socket->connect(&ccb, server.getAddress(), 30, options);
  });

  // accept the socket on the server side
  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();

  // Send a big (100KB) write so that it is partially written.
  constexpr size_t kSendSize = 100 * 1024;
  auto const sendBuf = std::vector<char>(kSendSize, 'a');

  senderEvb.runInEventBaseThreadAndWait(
      [&]() { socket->write(&wcb, sendBuf.data(), kSendSize); });

  // Read 20KB of data from the socket to allow the sender to send a bit more
  // data after it initially blocks.
  constexpr size_t kRecvSize = 20 * 1024;
  uint8_t recvBuf[kRecvSize];
  auto bytesRead = acceptedSocket->readAll(recvBuf, sizeof(recvBuf));
  ASSERT_EQ(kRecvSize, bytesRead);
  EXPECT_EQ(0, memcmp(recvBuf, sendBuf.data(), bytesRead));

  // We should be able to send at least the amount of data received plus the
  // send buffer size.  In practice we should probably be able to send
  constexpr size_t kMinExpectedBytesWritten = kRecvSize + kSockBufSize;

  // We shouldn't be able to send more than the amount of data received plus
  // the send buffer size of the sending socket (kEffectiveSockBufSize) plus
  // the receive buffer size on the receiving socket (kEffectiveSockBufSize)
  constexpr size_t kMaxExpectedBytesWritten =
      kRecvSize + kEffectiveSockBufSize + kEffectiveSockBufSize;
  static_assert(
      kMaxExpectedBytesWritten < kSendSize, "kSendSize set too small");

  // Need to delay after receiving 20KB and before closing the receive side so
  // that the send side has a chance to fill the send buffer past.
  using clock = std::chrono::steady_clock;
  auto const deadline = clock::now() + std::chrono::seconds(2);
  while (wcb.bytesWritten < kMinExpectedBytesWritten &&
         clock::now() < deadline) {
    std::this_thread::yield();
  }
  acceptedSocket->closeWithReset();

  senderEvb.terminateLoopSoon();
  senderThread.join();
  socket.reset();

  ASSERT_EQ(STATE_FAILED, wcb.state);
  ASSERT_LE(kMinExpectedBytesWritten, wcb.bytesWritten);
  ASSERT_GE(kMaxExpectedBytesWritten, wcb.bytesWritten);
}

/**
 * Test writing a mix of simple buffers and IOBufs
 */
TEST(AsyncSocketTest, WriteIOBuf) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // Accept the connection
  std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
  ReadCallback rcb;
  acceptedSocket->setReadCB(&rcb);

  // Check if EOR tracking flag can be set and reset.
  EXPECT_FALSE(socket->isEorTrackingEnabled());
  socket->setEorTracking(true);
  EXPECT_TRUE(socket->isEorTrackingEnabled());
  socket->setEorTracking(false);
  EXPECT_FALSE(socket->isEorTrackingEnabled());

  // Write a simple buffer to the socket
  constexpr size_t simpleBufLength = 5;
  char simpleBuf[simpleBufLength];
  memset(simpleBuf, 'a', simpleBufLength);
  WriteCallback wcb;
  socket->write(&wcb, simpleBuf, simpleBufLength);

  // Write a single-element IOBuf chain
  size_t buf1Length = 7;
  unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
  memset(buf1->writableData(), 'b', buf1Length);
  buf1->append(buf1Length);
  unique_ptr<IOBuf> buf1Copy(buf1->clone());
  WriteCallback wcb2;
  socket->writeChain(&wcb2, std::move(buf1));

  // Write a multiple-element IOBuf chain
  size_t buf2Length = 11;
  unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
  memset(buf2->writableData(), 'c', buf2Length);
  buf2->append(buf2Length);
  size_t buf3Length = 13;
  unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
  memset(buf3->writableData(), 'd', buf3Length);
  buf3->append(buf3Length);
  buf2->appendToChain(std::move(buf3));
  unique_ptr<IOBuf> buf2Copy(buf2->clone());
  buf2Copy->coalesce();
  WriteCallback wcb3;
  socket->writeChain(&wcb3, std::move(buf2));
  socket->shutdownWrite();

  // Let the reads and writes run to completion
  evb.loop();

  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
  ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
  ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);

  // Make sure the reader got the right data in the right order
  ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
  ASSERT_EQ(rcb.buffers.size(), 1);
  ASSERT_EQ(
      rcb.buffers[0].length,
      simpleBufLength + buf1Length + buf2Length + buf3Length);
  ASSERT_EQ(memcmp(rcb.buffers[0].buffer, simpleBuf, simpleBufLength), 0);
  ASSERT_EQ(
      memcmp(
          rcb.buffers[0].buffer + simpleBufLength,
          buf1Copy->data(),
          buf1Copy->length()),
      0);
  ASSERT_EQ(
      memcmp(
          rcb.buffers[0].buffer + simpleBufLength + buf1Length,
          buf2Copy->data(),
          buf2Copy->length()),
      0);

  acceptedSocket->close();
  socket->close();

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

TEST(AsyncSocketTest, WriteIOBufCorked) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // Accept the connection
  std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
  ReadCallback rcb;
  acceptedSocket->setReadCB(&rcb);

  // Do three writes, 100ms apart, with the "cork" flag set
  // on the second write.  The reader should see the first write
  // arrive by itself, followed by the second and third writes
  // arriving together.
  size_t buf1Length = 5;
  unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
  memset(buf1->writableData(), 'a', buf1Length);
  buf1->append(buf1Length);
  size_t buf2Length = 7;
  unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
  memset(buf2->writableData(), 'b', buf2Length);
  buf2->append(buf2Length);
  size_t buf3Length = 11;
  unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
  memset(buf3->writableData(), 'c', buf3Length);
  buf3->append(buf3Length);
  WriteCallback wcb1;
  socket->writeChain(&wcb1, std::move(buf1));
  WriteCallback wcb2;
  DelayedWrite write2(socket, std::move(buf2), &wcb2, true);
  write2.scheduleTimeout(100);
  WriteCallback wcb3;
  DelayedWrite write3(socket, std::move(buf3), &wcb3, false, true);
  write3.scheduleTimeout(140);

  evb.loop();
  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
  ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
  ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
  if (wcb3.state != STATE_SUCCEEDED) {
    throw(wcb3.exception);
  }
  ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);

  // Make sure the reader got the data with the right grouping
  ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
  ASSERT_EQ(rcb.buffers.size(), 2);
  ASSERT_EQ(rcb.buffers[0].length, buf1Length);
  ASSERT_EQ(rcb.buffers[1].length, buf2Length + buf3Length);

  acceptedSocket->close();
  socket->close();

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

/**
 * Test performing a zero-length write
 */
TEST(AsyncSocketTest, ZeroLengthWrite) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket =
      AsyncSocket::newSocket(&evb, server.getAddress(), 30);
  evb.loop(); // loop until the socket is connected

  auto acceptedSocket = server.acceptAsync(&evb);
  ReadCallback rcb;
  acceptedSocket->setReadCB(&rcb);

  size_t len1 = 1024 * 1024;
  size_t len2 = 1024 * 1024;
  std::unique_ptr<char[]> buf(new char[len1 + len2]);
  memset(buf.get(), 'a', len1);
  memset(buf.get() + len1, 'b', len2);

  WriteCallback wcb1;
  WriteCallback wcb2;
  WriteCallback wcb3;
  WriteCallback wcb4;
  socket->write(&wcb1, buf.get(), 0);
  socket->write(&wcb2, buf.get(), len1);
  socket->write(&wcb3, buf.get() + len1, 0);
  socket->write(&wcb4, buf.get() + len1, len2);
  socket->close();

  evb.loop(); // loop until the data is sent

  ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
  ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
  ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
  ASSERT_EQ(wcb4.state, STATE_SUCCEEDED);
  rcb.verifyData(buf.get(), len1 + len2);

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

TEST(AsyncSocketTest, ZeroLengthWritev) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket =
      AsyncSocket::newSocket(&evb, server.getAddress(), 30);
  evb.loop(); // loop until the socket is connected

  auto acceptedSocket = server.acceptAsync(&evb);
  ReadCallback rcb;
  acceptedSocket->setReadCB(&rcb);

  size_t len1 = 1024 * 1024;
  size_t len2 = 1024 * 1024;
  std::unique_ptr<char[]> buf(new char[len1 + len2]);
  memset(buf.get(), 'a', len1);
  memset(buf.get(), 'b', len2);

  WriteCallback wcb;
  constexpr size_t iovCount = 4;
  struct iovec iov[iovCount];
  iov[0].iov_base = buf.get();
  iov[0].iov_len = len1;
  iov[1].iov_base = buf.get() + len1;
  iov[1].iov_len = 0;
  iov[2].iov_base = buf.get() + len1;
  iov[2].iov_len = len2;
  iov[3].iov_base = buf.get() + len1 + len2;
  iov[3].iov_len = 0;

  socket->writev(&wcb, iov, iovCount);
  socket->close();
  evb.loop(); // loop until the data is sent

  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
  rcb.verifyData(buf.get(), len1 + len2);

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

///////////////////////////////////////////////////////////////////////////
// close() related tests
///////////////////////////////////////////////////////////////////////////

/**
 * Test calling close() with pending writes when the socket is already closing.
 */
TEST(AsyncSocketTest, ClosePendingWritesWhileClosing) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // accept the socket on the server side
  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();

  // Loop to ensure the connect has completed
  evb.loop();

  // Make sure we are connected
  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);

  // Schedule pending writes, until several write attempts have blocked
  char buf[128];
  memset(buf, 'a', sizeof(buf));
  typedef vector<std::shared_ptr<WriteCallback>> WriteCallbackVector;
  WriteCallbackVector writeCallbacks;

  writeCallbacks.reserve(5);
  while (writeCallbacks.size() < 5) {
    std::shared_ptr<WriteCallback> wcb(new WriteCallback);

    socket->write(wcb.get(), buf, sizeof(buf));
    if (wcb->state == STATE_SUCCEEDED) {
      // Succeeded immediately.  Keep performing more writes
      continue;
    }

    // This write is blocked.
    // Have the write callback call close() when writeError() is invoked
    wcb->errorCallback = std::bind(&AsyncSocket::close, socket.get());
    writeCallbacks.push_back(wcb);
  }

  // Call closeNow() to immediately fail the pending writes
  socket->closeNow();

  // Make sure writeError() was invoked on all of the pending write callbacks
  for (const auto& writeCallback : writeCallbacks) {
    ASSERT_EQ((writeCallback)->state, STATE_FAILED);
  }

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

///////////////////////////////////////////////////////////////////////////
// ImmediateRead related tests
///////////////////////////////////////////////////////////////////////////

/* AsyncSocket use to verify immediate read works */
class AsyncSocketImmediateRead : public folly::AsyncSocket {
 public:
  bool immediateReadCalled = false;
  explicit AsyncSocketImmediateRead(folly::EventBase* evb) : AsyncSocket(evb) {}

 protected:
  void checkForImmediateRead() noexcept override {
    immediateReadCalled = true;
    AsyncSocket::handleRead();
  }
};

TEST(AsyncSocket, ConnectReadImmediateRead) {
  TestServer server;

  const size_t maxBufferSz = 100;
  const size_t maxReadsPerEvent = 1;
  const size_t expectedDataSz = maxBufferSz * 3;
  char expectedData[expectedDataSz];
  memset(expectedData, 'j', expectedDataSz);

  EventBase evb;
  ReadCallback rcb(maxBufferSz);
  AsyncSocketImmediateRead socket(&evb);
  socket.connect(nullptr, server.getAddress(), 30);

  evb.loop(); // loop until the socket is connected

  socket.setReadCB(&rcb);
  socket.setMaxReadsPerEvent(maxReadsPerEvent);
  socket.immediateReadCalled = false;

  auto acceptedSocket = server.acceptAsync(&evb);

  ReadCallback rcbServer;
  WriteCallback wcbServer;
  rcbServer.dataAvailableCallback = [&]() {
    if (rcbServer.dataRead() == expectedDataSz) {
      // write back all data read
      rcbServer.verifyData(expectedData, expectedDataSz);
      acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
      acceptedSocket->close();
    }
  };
  acceptedSocket->setReadCB(&rcbServer);

  // write data
  WriteCallback wcb1;
  socket.write(&wcb1, expectedData, expectedDataSz);
  evb.loop();
  ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
  rcb.verifyData(expectedData, expectedDataSz);
  ASSERT_EQ(socket.immediateReadCalled, true);

  ASSERT_FALSE(socket.isClosedBySelf());
  ASSERT_FALSE(socket.isClosedByPeer());
}

TEST(AsyncSocket, ConnectReadUninstallRead) {
  TestServer server;

  const size_t maxBufferSz = 100;
  const size_t maxReadsPerEvent = 1;
  const size_t expectedDataSz = maxBufferSz * 3;
  char expectedData[expectedDataSz];
  memset(expectedData, 'k', expectedDataSz);

  EventBase evb;
  ReadCallback rcb(maxBufferSz);
  AsyncSocketImmediateRead socket(&evb);
  socket.connect(nullptr, server.getAddress(), 30);

  evb.loop(); // loop until the socket is connected

  socket.setReadCB(&rcb);
  socket.setMaxReadsPerEvent(maxReadsPerEvent);
  socket.immediateReadCalled = false;

  auto acceptedSocket = server.acceptAsync(&evb);

  ReadCallback rcbServer;
  WriteCallback wcbServer;
  rcbServer.dataAvailableCallback = [&]() {
    if (rcbServer.dataRead() == expectedDataSz) {
      // write back all data read
      rcbServer.verifyData(expectedData, expectedDataSz);
      acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
      acceptedSocket->close();
    }
  };
  acceptedSocket->setReadCB(&rcbServer);

  rcb.dataAvailableCallback = [&]() {
    // we read data and reset readCB
    socket.setReadCB(nullptr);
  };

  // write data
  WriteCallback wcb;
  socket.write(&wcb, expectedData, expectedDataSz);
  evb.loop();
  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);

  /* we shoud've only read maxBufferSz data since readCallback_
   * was reset in dataAvailableCallback */
  ASSERT_EQ(rcb.dataRead(), maxBufferSz);
  ASSERT_EQ(socket.immediateReadCalled, false);

  ASSERT_FALSE(socket.isClosedBySelf());
  ASSERT_FALSE(socket.isClosedByPeer());
}

// TODO:
// - Test connect() and have the connect callback set the read callback
// - Test connect() and have the connect callback unset the read callback
// - Test reading/writing/closing/destroying the socket in the connect callback
// - Test reading/writing/closing/destroying the socket in the read callback
// - Test reading/writing/closing/destroying the socket in the write callback
// - Test one-way shutdown behavior
// - Test changing the EventBase
//
// - TODO: test multiple threads sharing a AsyncSocket, and detaching from it
//   in connectSuccess(), readDataAvailable(), writeSuccess()

///////////////////////////////////////////////////////////////////////////
// AsyncServerSocket tests
///////////////////////////////////////////////////////////////////////////

/**
 * Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set
 */
TEST(AsyncSocketTest, ServerAcceptOptions) {
  EventBase eventBase;

  // Create a server socket
  std::shared_ptr<AsyncServerSocket> serverSocket(
      AsyncServerSocket::newSocket(&eventBase));
  serverSocket->bind(0);
  serverSocket->listen(16);
  folly::SocketAddress serverAddress;
  serverSocket->getAddress(&serverAddress);

  // Add a callback to accept one connection then stop the loop
  TestAcceptCallback acceptCallback;
  acceptCallback.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
      });
  acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
    serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
  });
  serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
  serverSocket->startAccepting();

  // Connect to the server socket
  std::shared_ptr<AsyncSocket> socket(
      AsyncSocket::newSocket(&eventBase, serverAddress));

  eventBase.loop();

  // Verify that the server accepted a connection
  ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
  ASSERT_EQ(
      acceptCallback.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
  ASSERT_EQ(
      acceptCallback.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
  ASSERT_EQ(
      acceptCallback.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
  auto fd = acceptCallback.getEvents()->at(1).fd;

#ifndef _WIN32
  // It is not possible to check if a socket is already in non-blocking mode on
  // Windows. Yes really. The accepted connection should already be in
  // non-blocking mode
  int flags = fcntl(fd.toFd(), F_GETFL, 0);
  ASSERT_EQ(flags & O_NONBLOCK, O_NONBLOCK);
#endif

#ifndef TCP_NOPUSH
  // The accepted connection should already have TCP_NODELAY set
  int value;
  socklen_t valueLength = sizeof(value);
  int rc =
      netops::getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
  ASSERT_EQ(rc, 0);
  ASSERT_EQ(value, 1);
#endif
}

/**
 * Test AsyncServerSocket::removeAcceptCallback()
 */
TEST(AsyncSocketTest, RemoveAcceptCallback) {
  // Create a new AsyncServerSocket
  EventBase eventBase;
  std::shared_ptr<AsyncServerSocket> serverSocket(
      AsyncServerSocket::newSocket(&eventBase));
  serverSocket->bind(0);
  serverSocket->listen(16);
  folly::SocketAddress serverAddress;
  serverSocket->getAddress(&serverAddress);

  // Add several accept callbacks
  TestAcceptCallback cb1;
  TestAcceptCallback cb2;
  TestAcceptCallback cb3;
  TestAcceptCallback cb4;
  TestAcceptCallback cb5;
  TestAcceptCallback cb6;
  TestAcceptCallback cb7;

  // Test having callbacks remove other callbacks before them on the list,
  // after them on the list, or removing themselves.
  //
  // Have callback 2 remove callback 3 and callback 5 the first time it is
  // called.
  int cb2Count = 0;
  cb1.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        std::shared_ptr<AsyncSocket> sock2(AsyncSocket::newSocket(
            &eventBase, serverAddress)); // cb2: -cb3 -cb5
      });
  cb3.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {});
  cb4.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        std::shared_ptr<AsyncSocket> sock3(
            AsyncSocket::newSocket(&eventBase, serverAddress)); // cb4
      });
  cb5.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        std::shared_ptr<AsyncSocket> sock5(
            AsyncSocket::newSocket(&eventBase, serverAddress)); // cb7: -cb7
      });
  cb2.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        if (cb2Count == 0) {
          serverSocket->removeAcceptCallback(&cb3, nullptr);
          serverSocket->removeAcceptCallback(&cb5, nullptr);
        }
        ++cb2Count;
      });
  // Have callback 6 remove callback 4 the first time it is called,
  // and destroy the server socket the second time it is called
  int cb6Count = 0;
  cb6.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        if (cb6Count == 0) {
          serverSocket->removeAcceptCallback(&cb4, nullptr);
          std::shared_ptr<AsyncSocket> sock6(
              AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
          std::shared_ptr<AsyncSocket> sock7(
              AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2
          std::shared_ptr<AsyncSocket> sock8(
              AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: stop

        } else {
          serverSocket.reset();
        }
        ++cb6Count;
      });
  // Have callback 7 remove itself
  cb7.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        serverSocket->removeAcceptCallback(&cb7, nullptr);
      });

  serverSocket->addAcceptCallback(&cb1, &eventBase);
  serverSocket->addAcceptCallback(&cb2, &eventBase);
  serverSocket->addAcceptCallback(&cb3, &eventBase);
  serverSocket->addAcceptCallback(&cb4, &eventBase);
  serverSocket->addAcceptCallback(&cb5, &eventBase);
  serverSocket->addAcceptCallback(&cb6, &eventBase);
  serverSocket->addAcceptCallback(&cb7, &eventBase);
  serverSocket->startAccepting();

  // Make several connections to the socket
  std::shared_ptr<AsyncSocket> sock1(
      AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
  std::shared_ptr<AsyncSocket> sock4(
      AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: -cb4

  // Loop until we are stopped
  eventBase.loop();

  // Check to make sure that the expected callbacks were invoked.
  //
  // NOTE: This code depends on the AsyncServerSocket operating calling all of
  // the AcceptCallbacks in round-robin fashion, in the order that they were
  // added.  The code is implemented this way right now, but the API doesn't
  // explicitly require it be done this way.  If we change the code not to be
  // exactly round robin in the future, we can simplify the test checks here.
  // (We'll also need to update the termination code, since we expect cb6 to
  // get called twice to terminate the loop.)
  ASSERT_EQ(cb1.getEvents()->size(), 4);
  ASSERT_EQ(cb1.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
  ASSERT_EQ(cb1.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
  ASSERT_EQ(cb1.getEvents()->at(2).type, TestAcceptCallback::TYPE_ACCEPT);
  ASSERT_EQ(cb1.getEvents()->at(3).type, TestAcceptCallback::TYPE_STOP);

  ASSERT_EQ(cb2.getEvents()->size(), 4);
  ASSERT_EQ(cb2.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
  ASSERT_EQ(cb2.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
  ASSERT_EQ(cb2.getEvents()->at(2).type, TestAcceptCallback::TYPE_ACCEPT);
  ASSERT_EQ(cb2.getEvents()->at(3).type, TestAcceptCallback::TYPE_STOP);

  ASSERT_EQ(cb3.getEvents()->size(), 2);
  ASSERT_EQ(cb3.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
  ASSERT_EQ(cb3.getEvents()->at(1).type, TestAcceptCallback::TYPE_STOP);

  ASSERT_EQ(cb4.getEvents()->size(), 3);
  ASSERT_EQ(cb4.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
  ASSERT_EQ(cb4.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
  ASSERT_EQ(cb4.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);

  ASSERT_EQ(cb5.getEvents()->size(), 2);
  ASSERT_EQ(cb5.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
  ASSERT_EQ(cb5.getEvents()->at(1).type, TestAcceptCallback::TYPE_STOP);

  ASSERT_EQ(cb6.getEvents()->size(), 4);
  ASSERT_EQ(cb6.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
  ASSERT_EQ(cb6.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
  ASSERT_EQ(cb6.getEvents()->at(2).type, TestAcceptCallback::TYPE_ACCEPT);
  ASSERT_EQ(cb6.getEvents()->at(3).type, TestAcceptCallback::TYPE_STOP);

  ASSERT_EQ(cb7.getEvents()->size(), 3);
  ASSERT_EQ(cb7.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
  ASSERT_EQ(cb7.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
  ASSERT_EQ(cb7.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
}

/**
 * Test AsyncServerSocket::removeAcceptCallback()
 */
TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
  // Create a new AsyncServerSocket
  EventBase eventBase;
  std::shared_ptr<AsyncServerSocket> serverSocket(
      AsyncServerSocket::newSocket(&eventBase));
  serverSocket->bind(0);
  serverSocket->listen(16);
  folly::SocketAddress serverAddress;
  serverSocket->getAddress(&serverAddress);

  // Add several accept callbacks
  TestAcceptCallback cb1;
  auto thread_id = std::this_thread::get_id();
  cb1.setAcceptStartedFn([&]() {
    CHECK_NE(thread_id, std::this_thread::get_id());
    thread_id = std::this_thread::get_id();
  });
  cb1.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        ASSERT_EQ(thread_id, std::this_thread::get_id());
        serverSocket->removeAcceptCallback(&cb1, &eventBase);
      });
  cb1.setAcceptStoppedFn(
      [&]() { ASSERT_EQ(thread_id, std::this_thread::get_id()); });

  // Test having callbacks remove other callbacks before them on the list,
  serverSocket->addAcceptCallback(&cb1, &eventBase);
  serverSocket->startAccepting();

  // Make several connections to the socket
  std::shared_ptr<AsyncSocket> sock1(
      AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1

  // Loop in another thread
  auto other = std::thread([&]() { eventBase.loop(); });
  other.join();

  // Check to make sure that the expected callbacks were invoked.
  //
  // NOTE: This code depends on the AsyncServerSocket operating calling all of
  // the AcceptCallbacks in round-robin fashion, in the order that they were
  // added.  The code is implemented this way right now, but the API doesn't
  // explicitly require it be done this way.  If we change the code not to be
  // exactly round robin in the future, we can simplify the test checks here.
  // (We'll also need to update the termination code, since we expect cb6 to
  // get called twice to terminate the loop.)
  ASSERT_EQ(cb1.getEvents()->size(), 3);
  ASSERT_EQ(cb1.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
  ASSERT_EQ(cb1.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
  ASSERT_EQ(cb1.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
}

void serverSocketSanityTest(AsyncServerSocket* serverSocket) {
  EventBase* eventBase = serverSocket->getEventBase();
  CHECK(eventBase);

  // Add a callback to accept one connection then stop accepting
  TestAcceptCallback acceptCallback;
  acceptCallback.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
      });
  acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
    serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
  });
  serverSocket->addAcceptCallback(&acceptCallback, eventBase);
  serverSocket->startAccepting();

  // Connect to the server socket
  folly::SocketAddress serverAddress;
  serverSocket->getAddress(&serverAddress);
  AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress));

  // Loop to process all events
  eventBase->loop();

  // Verify that the server accepted a connection
  ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
  ASSERT_EQ(
      acceptCallback.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
  ASSERT_EQ(
      acceptCallback.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
  ASSERT_EQ(
      acceptCallback.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
}

/* Verify that we don't leak sockets if we are destroyed()
 * and there are still writes pending
 *
 * If destroy() only calls close() instead of closeNow(),
 * it would shutdown(writes) on the socket, but it would
 * never be close()'d, and the socket would leak
 */
TEST(AsyncSocketTest, DestroyCloseTest) {
  TestServer server;

  // connect()
  EventBase clientEB;
  EventBase serverEB;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&clientEB);
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // Accept the connection
  std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&serverEB);
  ReadCallback rcb;
  acceptedSocket->setReadCB(&rcb);

  // Write a large buffer to the socket that is larger than kernel buffer
  size_t simpleBufLength = 5000000;
  char* simpleBuf = new char[simpleBufLength];
  memset(simpleBuf, 'a', simpleBufLength);
  WriteCallback wcb;

  // Let the reads and writes run to completion
  int fd = acceptedSocket->getNetworkSocket().toFd();

  acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
  socket.reset();
  acceptedSocket.reset();

  // Test that server socket was closed
  folly::test::msvcSuppressAbortOnInvalidParams([&] {
    ssize_t sz = read(fd, simpleBuf, simpleBufLength);
    ASSERT_EQ(sz, -1);
    ASSERT_EQ(errno, EBADF);
  });
  delete[] simpleBuf;
}

/**
 * Test AsyncServerSocket::useExistingSocket()
 */
TEST(AsyncSocketTest, ServerExistingSocket) {
  EventBase eventBase;

  // Test creating a socket, and letting AsyncServerSocket bind and listen
  {
    // Manually create a socket
    auto fd = netops::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
    ASSERT_NE(fd, NetworkSocket());

    // Create a server socket
    AsyncServerSocket::UniquePtr serverSocket(
        new AsyncServerSocket(&eventBase));
    serverSocket->useExistingSocket(fd);
    folly::SocketAddress address;
    serverSocket->getAddress(&address);
    address.setPort(0);
    serverSocket->bind(address);
    serverSocket->listen(16);

    // Make sure the socket works
    serverSocketSanityTest(serverSocket.get());
  }

  // Test creating a socket and binding manually,
  // then letting AsyncServerSocket listen
  {
    // Manually create a socket
    auto fd = netops::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
    ASSERT_NE(fd, NetworkSocket());
    // bind
    struct sockaddr_in addr;
    addr.sin_family = AF_INET;
    addr.sin_port = 0;
    addr.sin_addr.s_addr = INADDR_ANY;
    ASSERT_EQ(
        netops::bind(
            fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
        0);
    // Look up the address that we bound to
    folly::SocketAddress boundAddress;
    boundAddress.setFromLocalAddress(fd);

    // Create a server socket
    AsyncServerSocket::UniquePtr serverSocket(
        new AsyncServerSocket(&eventBase));
    serverSocket->useExistingSocket(fd);
    serverSocket->listen(16);

    // Make sure AsyncServerSocket reports the same address that we bound to
    folly::SocketAddress serverSocketAddress;
    serverSocket->getAddress(&serverSocketAddress);
    ASSERT_EQ(boundAddress, serverSocketAddress);

    // Make sure the socket works
    serverSocketSanityTest(serverSocket.get());
  }

  // Test creating a socket, binding and listening manually,
  // then giving it to AsyncServerSocket
  {
    // Manually create a socket
    auto fd = netops::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
    ASSERT_NE(fd, NetworkSocket());
    // bind
    struct sockaddr_in addr;
    addr.sin_family = AF_INET;
    addr.sin_port = 0;
    addr.sin_addr.s_addr = INADDR_ANY;
    ASSERT_EQ(
        netops::bind(
            fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
        0);
    // Look up the address that we bound to
    folly::SocketAddress boundAddress;
    boundAddress.setFromLocalAddress(fd);
    // listen
    ASSERT_EQ(netops::listen(fd, 16), 0);

    // Create a server socket
    AsyncServerSocket::UniquePtr serverSocket(
        new AsyncServerSocket(&eventBase));
    serverSocket->useExistingSocket(fd);

    // Make sure AsyncServerSocket reports the same address that we bound to
    folly::SocketAddress serverSocketAddress;
    serverSocket->getAddress(&serverSocketAddress);
    ASSERT_EQ(boundAddress, serverSocketAddress);

    // Make sure the socket works
    serverSocketSanityTest(serverSocket.get());
  }
}

TEST(AsyncSocketTest, UnixDomainSocketTest) {
  EventBase eventBase;

  // Create a server socket
  std::shared_ptr<AsyncServerSocket> serverSocket(
      AsyncServerSocket::newSocket(&eventBase));
  string path(1, 0);
  path.append(folly::to<string>("/anonymous", folly::Random::rand64()));
  folly::SocketAddress serverAddress;
  serverAddress.setFromPath(path);
  serverSocket->bind(serverAddress);
  serverSocket->listen(16);

  // Add a callback to accept one connection then stop the loop
  TestAcceptCallback acceptCallback;
  acceptCallback.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
      });
  acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
    serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
  });
  serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
  serverSocket->startAccepting();

  // Connect to the server socket
  std::shared_ptr<AsyncSocket> socket(
      AsyncSocket::newSocket(&eventBase, serverAddress));

  eventBase.loop();

  // Verify that the server accepted a connection
  ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
  ASSERT_EQ(
      acceptCallback.getEvents()->at(0).type, TestAcceptCallback::TYPE_START);
  ASSERT_EQ(
      acceptCallback.getEvents()->at(1).type, TestAcceptCallback::TYPE_ACCEPT);
  ASSERT_EQ(
      acceptCallback.getEvents()->at(2).type, TestAcceptCallback::TYPE_STOP);
  auto fd = acceptCallback.getEvents()->at(1).fd;

#ifndef _WIN32
  // It is not possible to check if a socket is already in non-blocking mode on
  // Windows. Yes really. The accepted connection should already be in
  // non-blocking mode
  int flags = fcntl(fd.toFd(), F_GETFL, 0);
  ASSERT_EQ(flags & O_NONBLOCK, O_NONBLOCK);
#endif
}

TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
  EventBase eventBase;
  TestConnectionEventCallback connectionEventCallback;

  // Create a server socket
  std::shared_ptr<AsyncServerSocket> serverSocket(
      AsyncServerSocket::newSocket(&eventBase));
  serverSocket->setConnectionEventCallback(&connectionEventCallback);
  serverSocket->bind(0);
  serverSocket->listen(16);
  folly::SocketAddress serverAddress;
  serverSocket->getAddress(&serverAddress);

  // Add a callback to accept one connection then stop the loop
  TestAcceptCallback acceptCallback;
  acceptCallback.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
      });
  acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
    serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
  });
  serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
  serverSocket->startAccepting();

  // Connect to the server socket
  std::shared_ptr<AsyncSocket> socket(
      AsyncSocket::newSocket(&eventBase, serverAddress));

  eventBase.loop();

  // Validate the connection event counters
  ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
  ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
  ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
  ASSERT_EQ(
      connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0);
  ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0);
  ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
  ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
  ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
}

TEST(AsyncSocketTest, CallbackInPrimaryEventBase) {
  EventBase eventBase;
  TestConnectionEventCallback connectionEventCallback;

  // Create a server socket
  std::shared_ptr<AsyncServerSocket> serverSocket(
      AsyncServerSocket::newSocket(&eventBase));
  serverSocket->setConnectionEventCallback(&connectionEventCallback);
  serverSocket->bind(0);
  serverSocket->listen(16);
  folly::SocketAddress serverAddress;
  serverSocket->getAddress(&serverAddress);

  // Add a callback to accept one connection then stop the loop
  TestAcceptCallback acceptCallback;
  acceptCallback.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
      });
  acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
    serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
  });
  bool acceptStartedFlag{false};
  acceptCallback.setAcceptStartedFn(
      [&acceptStartedFlag]() { acceptStartedFlag = true; });
  bool acceptStoppedFlag{false};
  acceptCallback.setAcceptStoppedFn(
      [&acceptStoppedFlag]() { acceptStoppedFlag = true; });
  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
  serverSocket->startAccepting();

  // Connect to the server socket
  std::shared_ptr<AsyncSocket> socket(
      AsyncSocket::newSocket(&eventBase, serverAddress));

  eventBase.loop();

  ASSERT_TRUE(acceptStartedFlag);
  ASSERT_TRUE(acceptStoppedFlag);
  // Validate the connection event counters
  ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
  ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
  ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
  ASSERT_EQ(
      connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0);
  ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0);
  ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
  ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
  ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
}

TEST(AsyncSocketTest, CallbackInSecondaryEventBase) {
  EventBase eventBase;
  TestConnectionEventCallback connectionEventCallback;

  // Create a server socket
  std::shared_ptr<AsyncServerSocket> serverSocket(
      AsyncServerSocket::newSocket(&eventBase));
  serverSocket->setConnectionEventCallback(&connectionEventCallback);
  serverSocket->bind(0);
  serverSocket->listen(16);
  SocketAddress serverAddress;
  serverSocket->getAddress(&serverAddress);

  // Add a callback to accept one connection then stop the loop
  TestAcceptCallback acceptCallback;
  ScopedEventBaseThread cobThread("ioworker_test");
  acceptCallback.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const SocketAddress& /* addr */) {
        eventBase.runInEventBaseThread([&] {
          serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
        });
      });
  acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
    eventBase.runInEventBaseThread(
        [&] { serverSocket->removeAcceptCallback(&acceptCallback, nullptr); });
  });
  std::atomic<bool> acceptStartedFlag{false};
  acceptCallback.setAcceptStartedFn([&]() { acceptStartedFlag = true; });
  Baton<> acceptStoppedFlag;
  acceptCallback.setAcceptStoppedFn([&]() { acceptStoppedFlag.post(); });
  serverSocket->addAcceptCallback(&acceptCallback, cobThread.getEventBase());
  serverSocket->startAccepting();

  // Connect to the server socket
  std::shared_ptr<AsyncSocket> socket(
      AsyncSocket::newSocket(&eventBase, serverAddress));

  eventBase.loop();

  ASSERT_TRUE(acceptStoppedFlag.try_wait_for(std::chrono::seconds(1)));
  ASSERT_TRUE(acceptStartedFlag);
  // Validate the connection event counters
  ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
  ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
  ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
  ASSERT_EQ(
      connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 1);
  ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 1);
  ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
  ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
  ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
}

/**
 * Test AsyncServerSocket::getNumPendingMessagesInQueue()
 */
TEST(AsyncSocketTest, NumPendingMessagesInQueue) {
  EventBase eventBase;

  // Counter of how many connections have been accepted
  int count = 0;

  // Create a server socket
  auto serverSocket(AsyncServerSocket::newSocket(&eventBase));
  serverSocket->bind(0);
  serverSocket->listen(16);
  folly::SocketAddress serverAddress;
  serverSocket->getAddress(&serverAddress);

  // Add a callback to accept connections
  TestAcceptCallback acceptCallback;
  folly::ScopedEventBaseThread cobThread("ioworker_test");
  acceptCallback.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        count++;
        eventBase.runInEventBaseThreadAndWait([&] {
          ASSERT_EQ(4 - count, serverSocket->getNumPendingMessagesInQueue());
        });
        if (count == 4) {
          eventBase.runInEventBaseThread([&] {
            serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
          });
        }
      });
  acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
    eventBase.runInEventBaseThread(
        [&] { serverSocket->removeAcceptCallback(&acceptCallback, nullptr); });
  });
  serverSocket->addAcceptCallback(&acceptCallback, cobThread.getEventBase());
  serverSocket->startAccepting();

  // Connect to the server socket, 4 clients, there are 4 connections
  auto socket1(AsyncSocket::newSocket(&eventBase, serverAddress));
  auto socket2(AsyncSocket::newSocket(&eventBase, serverAddress));
  auto socket3(AsyncSocket::newSocket(&eventBase, serverAddress));
  auto socket4(AsyncSocket::newSocket(&eventBase, serverAddress));

  eventBase.loop();
  ASSERT_EQ(4, count);
}

/**
 * Test AsyncTransport::BufferCallback
 */
TEST(AsyncSocketTest, BufferTest) {
  TestServer server(false, 1024 * 1024);

  EventBase evb;
  SocketOptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30, option);

  char buf[100 * 1024];
  memset(buf, 'c', sizeof(buf));
  WriteCallback wcb;
  BufferCallback bcb(socket.get(), sizeof(buf));
  socket->setBufferCallback(&bcb);
  socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);

  std::thread t1([&]() { server.verifyConnection(buf, sizeof(buf)); });

  evb.loop();
  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);

  ASSERT_TRUE(bcb.hasBuffered());
  ASSERT_TRUE(bcb.hasBufferCleared());

  socket->close();

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());

  t1.join();
}

TEST(AsyncSocketTest, BufferTestChain) {
  TestServer server(false, 1024 * 1024);

  EventBase evb;
  SocketOptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30, option);

  char buf1[100 * 1024];
  memset(buf1, 'c', sizeof(buf1));
  char buf2[100 * 1024];
  memset(buf2, 'f', sizeof(buf2));

  auto buf = folly::IOBuf::copyBuffer(buf1, sizeof(buf1));
  buf->appendToChain(folly::IOBuf::copyBuffer(buf2, sizeof(buf2)));
  ASSERT_EQ(sizeof(buf1) + sizeof(buf2), buf->computeChainDataLength());

  BufferCallback bcb(socket.get(), buf->computeChainDataLength());
  socket->setBufferCallback(&bcb);

  WriteCallback wcb;
  socket->writeChain(&wcb, buf->clone(), WriteFlags::NONE);

  std::thread t1([&]() {
    buf->coalesce();
    server.verifyConnection(
        reinterpret_cast<const char*>(buf->data()), buf->length());
  });

  evb.loop();
  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);

  ASSERT_TRUE(bcb.hasBuffered());
  ASSERT_TRUE(bcb.hasBufferCleared());

  socket->close();

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
  t1.join();
}

TEST(AsyncSocketTest, BufferCallbackKill) {
  TestServer server(false, 1024 * 1024);

  EventBase evb;
  SocketOptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30, option);
  evb.loopOnce();

  char buf[100 * 1024];
  memset(buf, 'c', sizeof(buf));
  BufferCallback bcb(socket.get(), sizeof(buf));
  socket->setBufferCallback(&bcb);
  WriteCallback wcb;
  wcb.successCallback = [&] {
    ASSERT_TRUE(socket.unique());
    socket.reset();
  };

  // This will trigger AsyncSocket::handleWrite,
  // which calls WriteCallback::writeSuccess,
  // which calls wcb.successCallback above,
  // which tries to delete socket
  // Then, the socket will also try to use this BufferCallback
  // And that should crash us, if there is no DestructorGuard on the stack
  socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);

  std::thread t1([&]() { server.verifyConnection(buf, sizeof(buf)); });

  evb.loop();
  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);

  t1.join();
}

#if FOLLY_ALLOW_TFO
TEST(AsyncSocketTest, ConnectTFO) {
  if (!folly::test::isTFOAvailable()) {
    GTEST_SKIP() << "TFO not supported.";
  }

  // Start listening on a local port
  TestServer server(true);

  // Connect using a AsyncSocket
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  socket->enableTFO();
  ConnCallback cb;
  socket->connect(&cb, server.getAddress(), 30);

  std::array<uint8_t, 128> buf;
  memset(buf.data(), 'a', buf.size());

  std::array<uint8_t, 3> readBuf;
  auto sendBuf = IOBuf::copyBuffer("hey");

  std::thread t([&] {
    auto acceptedSocket = server.accept();
    acceptedSocket->write(buf.data(), buf.size());
    acceptedSocket->flush();
    acceptedSocket->readAll(readBuf.data(), readBuf.size());
    acceptedSocket->close();
  });

  evb.loop();

  ASSERT_EQ(cb.state, STATE_SUCCEEDED);
  EXPECT_LE(0, socket->getConnectTime().count());
  EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
  EXPECT_TRUE(socket->getTFOAttempted());

  // Should trigger the connect
  WriteCallback write;
  ReadCallback rcb;
  socket->writeChain(&write, sendBuf->clone());
  socket->setReadCB(&rcb);
  evb.loop();

  t.join();

  EXPECT_EQ(STATE_SUCCEEDED, write.state);
  EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
  EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
  ASSERT_EQ(1, rcb.buffers.size());
  ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
  EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
  EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
}

TEST(AsyncSocketTest, ConnectTFOSupplyEarlyReadCB) {
  if (!folly::test::isTFOAvailable()) {
    GTEST_SKIP() << "TFO not supported.";
  }

  // Start listening on a local port
  TestServer server(true);

  // Connect using a AsyncSocket
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  socket->enableTFO();
  ConnCallback cb;
  socket->connect(&cb, server.getAddress(), 30);
  ReadCallback rcb;
  socket->setReadCB(&rcb);

  std::array<uint8_t, 128> buf;
  memset(buf.data(), 'a', buf.size());

  std::array<uint8_t, 3> readBuf;
  auto sendBuf = IOBuf::copyBuffer("hey");

  std::thread t([&] {
    auto acceptedSocket = server.accept();
    acceptedSocket->write(buf.data(), buf.size());
    acceptedSocket->flush();
    acceptedSocket->readAll(readBuf.data(), readBuf.size());
    acceptedSocket->close();
  });

  evb.loop();

  ASSERT_EQ(cb.state, STATE_SUCCEEDED);
  EXPECT_LE(0, socket->getConnectTime().count());
  EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
  EXPECT_TRUE(socket->getTFOAttempted());

  // Should trigger the connect
  WriteCallback write;
  socket->writeChain(&write, sendBuf->clone());
  evb.loop();

  t.join();

  EXPECT_EQ(STATE_SUCCEEDED, write.state);
  EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
  EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
  ASSERT_EQ(1, rcb.buffers.size());
  ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
  EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
  EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
}

/**
 * Test connecting to a server that isn't listening
 */
TEST(AsyncSocketTest, ConnectRefusedImmediatelyTFO) {
  EventBase evb;

  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);

  socket->enableTFO();

  // Hopefully nothing is actually listening on this address
  folly::SocketAddress addr("::1", 65535);
  ConnCallback cb;
  socket->connect(&cb, addr, 30);

  evb.loop();

  WriteCallback write1;
  // Trigger the connect if TFO attempt is supported.
  socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
  WriteCallback write2;
  socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
  evb.loop();

  if (!socket->getTFOFinished()) {
    EXPECT_EQ(STATE_FAILED, write1.state);
  } else {
    EXPECT_EQ(STATE_SUCCEEDED, write1.state);
    EXPECT_FALSE(socket->getTFOSucceded());
  }

  EXPECT_EQ(STATE_FAILED, write2.state);

  EXPECT_EQ(STATE_SUCCEEDED, cb.state);
  EXPECT_LE(0, socket->getConnectTime().count());
  EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
  EXPECT_TRUE(socket->getTFOAttempted());
}

/**
 * Test calling closeNow() immediately after connecting.
 */
TEST(AsyncSocketTest, ConnectWriteAndCloseNowTFO) {
  TestServer server(true);

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  socket->enableTFO();

  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  // write()
  std::array<char, 128> buf;
  memset(buf.data(), 'a', buf.size());

  // close()
  socket->closeNow();

  // Loop, although there shouldn't be anything to do.
  evb.loop();

  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

/**
 * Test calling close() immediately after connect()
 */
TEST(AsyncSocketTest, ConnectAndCloseTFO) {
  TestServer server(true);

  // Connect using a AsyncSocket
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  socket->enableTFO();

  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  socket->close();

  // Loop, although there shouldn't be anything to do.
  evb.loop();

  // Make sure the connection was aborted
  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

class MockAsyncTFOSocket : public AsyncSocket {
 public:
  using UniquePtr = std::unique_ptr<MockAsyncTFOSocket, Destructor>;

  explicit MockAsyncTFOSocket(EventBase* evb) : AsyncSocket(evb) {}

  MOCK_METHOD(
      ssize_t,
      tfoSendMsg,
      (NetworkSocket fd, struct msghdr* msg, int msg_flags));
};

TEST(AsyncSocketTest, TestTFOUnsupported) {
  TestServer server(true);

  // Connect using a AsyncSocket
  EventBase evb;
  auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
  socket->enableTFO();

  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);
  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);

  ReadCallback rcb;
  socket->setReadCB(&rcb);

  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
      .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
  WriteCallback write;
  auto sendBuf = IOBuf::copyBuffer("hey");
  socket->writeChain(&write, sendBuf->clone());
  EXPECT_EQ(STATE_WAITING, write.state);

  std::array<uint8_t, 128> buf;
  memset(buf.data(), 'a', buf.size());

  std::array<uint8_t, 3> readBuf;

  std::thread t([&] {
    std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
    acceptedSocket->write(buf.data(), buf.size());
    acceptedSocket->flush();
    acceptedSocket->readAll(readBuf.data(), readBuf.size());
    acceptedSocket->close();
  });

  evb.loop();

  t.join();
  EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
  EXPECT_EQ(STATE_SUCCEEDED, write.state);

  EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
  EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
  ASSERT_EQ(1, rcb.buffers.size());
  ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
  EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
  EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
}

TEST(AsyncSocketTest, ConnectRefusedDelayedTFO) {
  EventBase evb;

  auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
  socket->enableTFO();

  // Hopefully this fails
  folly::SocketAddress fakeAddr("127.0.0.1", 65535);
  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
      .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
        sockaddr_storage addr;
        auto len = fakeAddr.getAddress(&addr);
        auto ret = netops::connect(fd, (const struct sockaddr*)&addr, len);
        LOG(INFO) << "connecting the socket " << fd << " : " << ret << " : "
                  << errno;
        return ret;
      }));

  // Hopefully nothing is actually listening on this address
  ConnCallback cb;
  socket->connect(&cb, fakeAddr, 30);

  WriteCallback write1;
  // Trigger the connect if TFO attempt is supported.
  socket->writeChain(&write1, IOBuf::copyBuffer("hey"));

  if (socket->getTFOFinished()) {
    // This test is useless now.
    return;
  }
  WriteCallback write2;
  // Trigger the connect if TFO attempt is supported.
  socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
  evb.loop();

  EXPECT_EQ(STATE_FAILED, write1.state);
  EXPECT_EQ(STATE_FAILED, write2.state);
  EXPECT_FALSE(socket->getTFOSucceded());

  EXPECT_EQ(STATE_SUCCEEDED, cb.state);
  EXPECT_LE(0, socket->getConnectTime().count());
  EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
  EXPECT_TRUE(socket->getTFOAttempted());
}

TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) {
  // Try connecting to server that won't respond.
  //
  // This depends somewhat on the network where this test is run.
  // Hopefully this IP will be routable but unresponsive.
  // (Alternatively, we could try listening on a local raw socket, but that
  // normally requires root privileges.)
  auto host = SocketAddressTestHelper::isIPv6Enabled()
      ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
      : SocketAddressTestHelper::isIPv4Enabled()
      ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
      : nullptr;
  SocketAddress addr(host, 65535);

  // Connect using a AsyncSocket
  EventBase evb;
  auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
  socket->enableTFO();

  ConnCallback ccb;
  // Set a very small timeout
  socket->connect(&ccb, addr, 1);
  EXPECT_EQ(STATE_SUCCEEDED, ccb.state);

  ReadCallback rcb;
  socket->setReadCB(&rcb);

  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
      .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
  WriteCallback write;
  socket->writeChain(&write, IOBuf::copyBuffer("hey"));

  evb.loop();

  EXPECT_EQ(STATE_FAILED, write.state);
}

TEST(AsyncSocketTest, TestTFOFallbackToConnect) {
  TestServer server(true);

  // Connect using a AsyncSocket
  EventBase evb;
  auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
  socket->enableTFO();

  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);
  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);

  ReadCallback rcb;
  socket->setReadCB(&rcb);

  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
      .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
        sockaddr_storage addr;
        auto len = server.getAddress().getAddress(&addr);
        return netops::connect(fd, (const struct sockaddr*)&addr, len);
      }));
  WriteCallback write;
  auto sendBuf = IOBuf::copyBuffer("hey");
  socket->writeChain(&write, sendBuf->clone());
  EXPECT_EQ(STATE_WAITING, write.state);

  std::array<uint8_t, 128> buf;
  memset(buf.data(), 'a', buf.size());

  std::array<uint8_t, 3> readBuf;

  std::thread t([&] {
    std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
    acceptedSocket->write(buf.data(), buf.size());
    acceptedSocket->flush();
    acceptedSocket->readAll(readBuf.data(), readBuf.size());
    acceptedSocket->close();
  });

  evb.loop();

  t.join();
  EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));

  EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
  EXPECT_EQ(STATE_SUCCEEDED, write.state);

  EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
  ASSERT_EQ(1, rcb.buffers.size());
  ASSERT_EQ(buf.size(), rcb.buffers[0].length);
  EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
}

TEST(AsyncSocketTest, TestTFOFallbackTimeout) {
  // Try connecting to server that won't respond.
  //
  // This depends somewhat on the network where this test is run.
  // Hopefully this IP will be routable but unresponsive.
  // (Alternatively, we could try listening on a local raw socket, but that
  // normally requires root privileges.)
  auto host = SocketAddressTestHelper::isIPv6Enabled()
      ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
      : SocketAddressTestHelper::isIPv4Enabled()
      ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
      : nullptr;
  SocketAddress addr(host, 65535);

  // Connect using a AsyncSocket
  EventBase evb;
  auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
  socket->enableTFO();

  ConnCallback ccb;
  // Set a very small timeout
  socket->connect(&ccb, addr, 1);
  EXPECT_EQ(STATE_SUCCEEDED, ccb.state);

  ReadCallback rcb;
  socket->setReadCB(&rcb);

  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
      .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
        sockaddr_storage addr2;
        auto len = addr.getAddress(&addr2);
        return netops::connect(fd, (const struct sockaddr*)&addr2, len);
      }));
  WriteCallback write;
  socket->writeChain(&write, IOBuf::copyBuffer("hey"));

  evb.loop();

  EXPECT_EQ(STATE_FAILED, write.state);
}

TEST(AsyncSocketTest, TestTFOEagain) {
  TestServer server(true);

  // Connect using a AsyncSocket
  EventBase evb;
  auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
  socket->enableTFO();

  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);

  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
      .WillOnce(SetErrnoAndReturn(EAGAIN, -1));
  WriteCallback write;
  socket->writeChain(&write, IOBuf::copyBuffer("hey"));

  evb.loop();

  EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
  EXPECT_EQ(STATE_FAILED, write.state);
}

// Sending a large amount of data in the first write which will
// definitely not fit into MSS.
TEST(AsyncSocketTest, ConnectTFOWithBigData) {
  if (!folly::test::isTFOAvailable()) {
    GTEST_SKIP() << "TFO not supported.";
  }

  // Start listening on a local port
  TestServer server(true);

  // Connect using a AsyncSocket
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  socket->enableTFO();
  ConnCallback cb;
  socket->connect(&cb, server.getAddress(), 30);

  std::array<uint8_t, 128> buf;
  memset(buf.data(), 'a', buf.size());

  constexpr size_t len = 10 * 1024;
  auto sendBuf = IOBuf::create(len);
  sendBuf->append(len);
  std::array<uint8_t, len> readBuf;

  std::thread t([&] {
    auto acceptedSocket = server.accept();
    acceptedSocket->write(buf.data(), buf.size());
    acceptedSocket->flush();
    acceptedSocket->readAll(readBuf.data(), readBuf.size());
    acceptedSocket->close();
  });

  evb.loop();

  ASSERT_EQ(cb.state, STATE_SUCCEEDED);
  EXPECT_LE(0, socket->getConnectTime().count());
  EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
  EXPECT_TRUE(socket->getTFOAttempted());

  // Should trigger the connect
  WriteCallback write;
  ReadCallback rcb;
  socket->writeChain(&write, sendBuf->clone());
  socket->setReadCB(&rcb);
  evb.loop();

  t.join();

  EXPECT_EQ(STATE_SUCCEEDED, write.state);
  EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
  EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
  ASSERT_EQ(1, rcb.buffers.size());
  ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
  EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
  EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
}

#endif // FOLLY_ALLOW_TFO

class MockEvbChangeCallback : public AsyncSocket::EvbChangeCallback {
 public:
  MOCK_METHOD(void, evbAttached, (AsyncSocket*));
  MOCK_METHOD(void, evbDetached, (AsyncSocket*));
};

TEST(AsyncSocketTest, EvbCallbacks) {
  auto cb = std::make_unique<MockEvbChangeCallback>();
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);

  InSequence seq;
  EXPECT_CALL(*cb, evbDetached(socket.get())).Times(1);
  EXPECT_CALL(*cb, evbAttached(socket.get())).Times(1);

  socket->setEvbChangedCallback(std::move(cb));
  socket->detachEventBase();
  socket->attachEventBase(&evb);
}

TEST(AsyncSocketTest, TestEvbDetachWtRegisteredIOHandlers) {
  // Start listening on a local port
  TestServer server;

  // Connect using a AsyncSocket
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  ConnCallback cb;
  socket->connect(&cb, server.getAddress(), 30);

  evb.loop();

  ASSERT_EQ(cb.state, STATE_SUCCEEDED);
  EXPECT_LE(0, socket->getConnectTime().count());
  EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));

  // After the ioHandlers are registered, still should be able to detach/attach
  ReadCallback rcb;
  socket->setReadCB(&rcb);

  auto cbEvbChg = std::make_unique<MockEvbChangeCallback>();
  InSequence seq;
  EXPECT_CALL(*cbEvbChg, evbDetached(socket.get())).Times(1);
  EXPECT_CALL(*cbEvbChg, evbAttached(socket.get())).Times(1);

  socket->setEvbChangedCallback(std::move(cbEvbChg));
  EXPECT_TRUE(socket->isDetachable());
  socket->detachEventBase();
  socket->attachEventBase(&evb);

  socket->close();
}

TEST(AsyncSocketTest, TestEvbDetachThenClose) {
  // Start listening on a local port
  TestServer server;

  // Connect an AsyncSocket to the server
  EventBase evb;
  auto socket = AsyncSocket::newSocket(&evb);
  ConnCallback cb;
  socket->connect(&cb, server.getAddress(), 30);

  evb.loop();

  ASSERT_EQ(cb.state, STATE_SUCCEEDED);
  EXPECT_LE(0, socket->getConnectTime().count());
  EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));

  // After the ioHandlers are registered, still should be able to detach/attach
  ReadCallback rcb;
  socket->setReadCB(&rcb);

  auto cbEvbChg = std::make_unique<MockEvbChangeCallback>();
  InSequence seq;
  EXPECT_CALL(*cbEvbChg, evbDetached(socket.get())).Times(1);

  socket->setEvbChangedCallback(std::move(cbEvbChg));

  // Should be possible to destroy/call closeNow() without an attached EventBase
  EXPECT_TRUE(socket->isDetachable());
  socket->detachEventBase();
  socket.reset();
}

TEST(AsyncSocket, BytesWrittenWithMove) {
  TestServer server;

  EventBase evb;
  auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
  ConnCallback ccb;
  socket1->connect(&ccb, server.getAddress(), 30);
  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();

  EXPECT_EQ(0, socket1->getRawBytesWritten());
  std::vector<uint8_t> wbuf(128, 'a');
  WriteCallback wcb;
  socket1->write(&wcb, wbuf.data(), wbuf.size());
  evb.loopOnce();
  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
  EXPECT_EQ(wbuf.size(), socket1->getRawBytesWritten());
  EXPECT_EQ(wbuf.size(), socket1->getAppBytesWritten());

  auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(std::move(socket1)));
  EXPECT_EQ(wbuf.size(), socket2->getRawBytesWritten());
  EXPECT_EQ(wbuf.size(), socket2->getAppBytesWritten());
}

#ifdef FOLLY_HAVE_MSG_ERRQUEUE
struct AsyncSocketErrMessageCallbackTestParams {
  folly::Optional<int> resetCallbackAfter;
  folly::Optional<int> closeSocketAfter;
  int gotTimestampExpected{0};
  int gotByteSeqExpected{0};
};

class AsyncSocketErrMessageCallbackTest
    : public ::testing::TestWithParam<AsyncSocketErrMessageCallbackTestParams> {
 public:
  static std::vector<AsyncSocketErrMessageCallbackTestParams>
  getTestingValues() {
    std::vector<AsyncSocketErrMessageCallbackTestParams> vals;
    // each socket err message triggers two socket callbacks:
    //   (1) timestamp callback
    //   (2) byteseq callback

    // reset callback cases
    // resetting the callback should prevent any further callbacks
    {
      AsyncSocketErrMessageCallbackTestParams params;
      params.resetCallbackAfter = 1;
      params.gotTimestampExpected = 1;
      params.gotByteSeqExpected = 0;
      vals.push_back(params);
    }
    {
      AsyncSocketErrMessageCallbackTestParams params;
      params.resetCallbackAfter = 2;
      params.gotTimestampExpected = 1;
      params.gotByteSeqExpected = 1;
      vals.push_back(params);
    }
    {
      AsyncSocketErrMessageCallbackTestParams params;
      params.resetCallbackAfter = 3;
      params.gotTimestampExpected = 2;
      params.gotByteSeqExpected = 1;
      vals.push_back(params);
    }
    {
      AsyncSocketErrMessageCallbackTestParams params;
      params.resetCallbackAfter = 4;
      params.gotTimestampExpected = 2;
      params.gotByteSeqExpected = 2;
      vals.push_back(params);
    }

    // close socket cases
    // closing the socket will prevent callbacks after the current err message
    // callbacks (both timestamp and byteseq) are completed
    {
      AsyncSocketErrMessageCallbackTestParams params;
      params.closeSocketAfter = 1;
      params.gotTimestampExpected = 1;
      params.gotByteSeqExpected = 1;
      vals.push_back(params);
    }
    {
      AsyncSocketErrMessageCallbackTestParams params;
      params.closeSocketAfter = 2;
      params.gotTimestampExpected = 1;
      params.gotByteSeqExpected = 1;
      vals.push_back(params);
    }
    {
      AsyncSocketErrMessageCallbackTestParams params;
      params.closeSocketAfter = 3;
      params.gotTimestampExpected = 2;
      params.gotByteSeqExpected = 2;
      vals.push_back(params);
    }
    {
      AsyncSocketErrMessageCallbackTestParams params;
      params.closeSocketAfter = 4;
      params.gotTimestampExpected = 2;
      params.gotByteSeqExpected = 2;
      vals.push_back(params);
    }
    return vals;
  }
};

INSTANTIATE_TEST_SUITE_P(
    ErrMessageTests,
    AsyncSocketErrMessageCallbackTest,
    ::testing::ValuesIn(AsyncSocketErrMessageCallbackTest::getTestingValues()));

class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
 public:
  TestErrMessageCallback()
      : exception_(folly::AsyncSocketException::UNKNOWN, "none") {}

  void errMessage(const cmsghdr& cmsg) noexcept override {
    if (cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_TIMESTAMPING) {
      gotTimestamp_++;
      checkResetCallback();
      checkCloseSocket();
    } else if (
        (cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
        (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
      gotByteSeq_++;
      checkResetCallback();
      checkCloseSocket();
    }
  }

  void errMessageError(
      const folly::AsyncSocketException& ex) noexcept override {
    exception_ = ex;
  }

  void checkResetCallback() noexcept {
    if (socket_ != nullptr && resetCallbackAfter_ != -1 &&
        gotTimestamp_ + gotByteSeq_ == resetCallbackAfter_) {
      socket_->setErrMessageCB(nullptr);
    }
  }

  void checkCloseSocket() noexcept {
    if (socket_ != nullptr && closeSocketAfter_ != -1 &&
        gotTimestamp_ + gotByteSeq_ == closeSocketAfter_) {
      socket_->close();
    }
  }

  folly::AsyncSocket* socket_{nullptr};
  folly::AsyncSocketException exception_;
  int gotTimestamp_{0};
  int gotByteSeq_{0};
  int resetCallbackAfter_{-1};
  int closeSocketAfter_{-1};
};

TEST_P(AsyncSocketErrMessageCallbackTest, ErrMessageCallback) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);

  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);
  LOG(INFO) << "Client socket fd=" << socket->getNetworkSocket();

  // Let the socket
  evb.loop();

  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);

  // Set read callback to keep the socket subscribed for event
  // notifications. Though we're no planning to read anything from
  // this side of the connection.
  ReadCallback rcb(1);
  socket->setReadCB(&rcb);

  // Set up timestamp callbacks
  TestErrMessageCallback errMsgCB;
  socket->setErrMessageCB(&errMsgCB);
  ASSERT_EQ(
      socket->getErrMessageCallback(),
      static_cast<folly::AsyncSocket::ErrMessageCallback*>(&errMsgCB));

  // set the number of error messages before socket is closed or callback reset
  const auto testParams = GetParam();
  errMsgCB.socket_ = socket.get();
  if (testParams.resetCallbackAfter.has_value()) {
    errMsgCB.resetCallbackAfter_ = testParams.resetCallbackAfter.value();
  }
  if (testParams.closeSocketAfter.has_value()) {
    errMsgCB.closeSocketAfter_ = testParams.closeSocketAfter.value();
  }

  // Enable timestamp notifications
  ASSERT_NE(socket->getNetworkSocket(), NetworkSocket());
  int flags = folly::netops::SOF_TIMESTAMPING_OPT_ID |
      folly::netops::SOF_TIMESTAMPING_OPT_TSONLY |
      folly::netops::SOF_TIMESTAMPING_SOFTWARE |
      folly::netops::SOF_TIMESTAMPING_OPT_CMSG |
      folly::netops::SOF_TIMESTAMPING_TX_SCHED;
  SocketOptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
  EXPECT_EQ(tstampingOpt.apply(socket->getNetworkSocket(), flags), 0);

  // write()
  std::vector<uint8_t> wbuf(128, 'a');
  WriteCallback wcb;
  // Send two packets to get two EOM notifications
  socket->write(&wcb, wbuf.data(), wbuf.size() / 2);
  socket->write(&wcb, wbuf.data() + wbuf.size() / 2, wbuf.size() / 2);

  // Accept the connection.
  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
  LOG(INFO) << "Server socket fd=" << acceptedSocket->getNetworkSocket();

  // Loop
  evb.loopOnce();
  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);

  // Check that we can read the data that was written to the socket
  std::vector<uint8_t> rbuf(wbuf.size(), 0);
  uint32_t bytesRead = acceptedSocket->readAll(rbuf.data(), rbuf.size());
  ASSERT_EQ(bytesRead, wbuf.size());
  ASSERT_TRUE(std::equal(wbuf.begin(), wbuf.end(), rbuf.begin()));

  // Close both sockets
  acceptedSocket->close();
  socket->close();

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());

  // Check for the timestamp notifications.
  ASSERT_EQ(
      errMsgCB.exception_.getType(), folly::AsyncSocketException::UNKNOWN);
  ASSERT_EQ(errMsgCB.gotByteSeq_, testParams.gotByteSeqExpected);
  ASSERT_EQ(errMsgCB.gotTimestamp_, testParams.gotTimestampExpected);
}

#endif // FOLLY_HAVE_MSG_ERRQUEUE

#if FOLLY_HAVE_SO_TIMESTAMPING

class AsyncSocketByteEventTest : public ::testing::Test {
 protected:
  using MockDispatcher = ::testing::NiceMock<netops::test::MockDispatcher>;
  using TestObserver = MockAsyncSocketLegacyLifecycleObserverForByteEvents;
  using ByteEventType = AsyncSocket::ByteEvent::Type;

  /**
   * Components of a client connection to TestServer.
   *
   * Includes EventBase, client's AsyncSocket, and corresponding server socket.
   */
  class ClientConn {
   public:
    /**
     * Call to sendmsg intercepted and recorded by netops::Dispatcher.
     */
    struct SendmsgInvocation {
      // the iovecs in the msghdr
      std::vector<iovec> iovs;

      // WriteFlags encoded in msg_flags
      WriteFlags writeFlagsInMsgFlags{WriteFlags::NONE};

      // WriteFlags encoded in the msghdr's ancillary data
      WriteFlags writeFlagsInAncillary{WriteFlags::NONE};
    };

    explicit ClientConn(
        std::shared_ptr<TestServer> server,
        std::shared_ptr<AsyncSocket> socket = nullptr,
        std::shared_ptr<BlockingSocket> acceptedSocket = nullptr)
        : server_(std::move(server)),
          socket_(std::move(socket)),
          acceptedSocket_(std::move(acceptedSocket)) {
      if (!socket_) {
        socket_ = AsyncSocket::newSocket(&getEventBase());
      } else {
        setReadCb();
      }
      socket_->setOverrideNetOpsDispatcher(netOpsDispatcher_);
      netOpsDispatcher_->forwardToDefaultImpl();
    }

    void connect() {
      CHECK_NOTNULL(socket_.get());
      CHECK_NOTNULL(socket_->getEventBase());
      socket_->connect(&connCb_, server_->getAddress(), 30);
      socket_->getEventBase()->loop();
      ASSERT_EQ(connCb_.state, STATE_SUCCEEDED);
      setReadCb();

      // accept the socket at the server
      acceptedSocket_ = server_->accept();
    }

    void setReadCb() {
      // Due to how libevent works, we currently need to be subscribed to
      // EV_READ events in order to get error messages.
      //
      // TODO(bschlinker): Resolve this with libevent modification.
      // See https://github.com/libevent/libevent/issues/1038 for details.
      socket_->setReadCB(&readCb_);
    }

    void setMockTcpInfoDispatcher(
        std::shared_ptr<MockTcpInfoDispatcher> mockTcpInfoDispatcher) {
      socket_->setOverrideTcpInfoDispatcher(mockTcpInfoDispatcher);
    }

    std::shared_ptr<NiceMock<TestObserver>> attachObserver(
        bool enableByteEvents, bool enablePrewrite = false) {
      auto observer = AsyncSocketByteEventTest::attachObserver(
          socket_.get(), enableByteEvents, enablePrewrite);
      observers_.push_back(observer);
      return observer;
    }

    /**
     * Write to client socket and read at server.
     */
    void writeAtClientReadAtServer(
        const iovec* iov, const size_t count, const WriteFlags writeFlags) {
      CHECK_NOTNULL(socket_.get());
      CHECK_NOTNULL(socket_->getEventBase());

      // read buffer for server
      std::vector<uint8_t> rbuf(iovsToNumBytes(iov, count), 0);
      uint64_t rbufReadBytes = 0;

      // write to the client socket, incrementally read at the server
      WriteCallback wcb;
      socket_->writev(&wcb, iov, count, writeFlags);
      while (wcb.state == STATE_WAITING) {
        socket_->getEventBase()->loopOnce();
        rbufReadBytes += acceptedSocket_->readNoBlock(
            rbuf.data() + rbufReadBytes, rbuf.size() - rbufReadBytes);
      }
      ASSERT_EQ(wcb.state, STATE_SUCCEEDED);

      // finish reading, then compare
      rbufReadBytes += acceptedSocket_->readAll(
          rbuf.data() + rbufReadBytes, rbuf.size() - rbufReadBytes);
      const auto cBuf = iovsToVector(iov, count);
      ASSERT_EQ(rbufReadBytes, cBuf.size());
      ASSERT_TRUE(std::equal(cBuf.begin(), cBuf.end(), rbuf.begin()));
    }

    /**
     * Write to client socket and read at server.
     */
    void writeAtClientReadAtServer(
        const std::vector<uint8_t>& wbuf, const WriteFlags writeFlags) {
      iovec op;
      op.iov_base = const_cast<void*>(static_cast<const void*>(wbuf.data()));
      op.iov_len = wbuf.size();
      writeAtClientReadAtServer(&op, 1, writeFlags);
    }

    /**
     * Write to client socket, echo at server, and wait for echo at client.
     *
     * Waiting for echo at client ensures that we have given opportunity for
     * timestamps to be generated by the kernel.
     */
    void writeAtClientReadAtServerReflectReadAtClient(
        const iovec* iov, const size_t count, const WriteFlags writeFlags) {
      writeAtClientReadAtServer(iov, count, writeFlags);

      // reflect
      const auto wbuf = iovsToVector(iov, count);
      acceptedSocket_->write(wbuf.data(), wbuf.size());
      while (wbuf.size() != readCb_.dataRead()) {
        socket_->getEventBase()->loopOnce();
      }
      readCb_.verifyData(wbuf.data(), wbuf.size());
      readCb_.clearData();
    }

    /**
     * Write to the client and wait for the client to read.
     */
    void writeAtServerReadAtClient(const iovec* iov, const size_t count) {
      const auto wbuf = iovsToVector(iov, count);
      acceptedSocket_->write(wbuf.data(), wbuf.size());
      while (wbuf.size() != readCb_.dataRead()) {
        socket_->getEventBase()->loopOnce();
      }
      readCb_.verifyData(wbuf.data(), wbuf.size());
      readCb_.clearData();
    }

    /**
     * Write to client socket, echo at server, and wait for echo at client.
     *
     * Waiting for echo at client ensures that we have given opportunity for
     * timestamps to be generated by the kernel.
     */
    void writeAtClientReadAtServerReflectReadAtClient(
        const std::vector<uint8_t>& wbuf, const WriteFlags writeFlags) {
      iovec op = {};
      op.iov_base = const_cast<void*>(static_cast<const void*>(wbuf.data()));
      op.iov_len = wbuf.size();
      writeAtClientReadAtServerReflectReadAtClient(&op, 1, writeFlags);
    }

    /**
     * Write directly to the NetworkSocket, bypassing AsyncSocket.
     */
    void writeAtClientDirectlyToNetworkSocket(
        const std::vector<uint8_t>& wbuf) {
      struct msghdr msg = {};
      struct iovec iovec = {};
      iovec.iov_base = (void*)wbuf.data();
      iovec.iov_len = wbuf.size();

      msg.msg_name = nullptr;
      msg.msg_namelen = 0;
      msg.msg_iov = &iovec;
      msg.msg_iovlen = 1;
      msg.msg_flags = 0;
      msg.msg_controllen = 0;
      msg.msg_control = nullptr;

      auto ret = netops::Dispatcher::getDefaultInstance()->sendmsg(
          socket_->getNetworkSocket(), &msg, 0);
      ASSERT_EQ(ret, wbuf.size());
    }

    std::shared_ptr<AsyncSocket> getRawSocket() { return socket_; }

    std::shared_ptr<BlockingSocket> getAcceptedSocket() {
      return acceptedSocket_;
    }

    EventBase& getEventBase() {
      static EventBase evb; // use same EventBase for all client sockets
      return evb;
    }

    std::shared_ptr<MockDispatcher> getNetOpsDispatcher() const {
      return netOpsDispatcher_;
    }

    /**
     * Get recorded SendmsgInvocations.
     */
    const std::vector<SendmsgInvocation>& getSendmsgInvocations() {
      return sendmsgInvocations_;
    }

    /**
     * Get successful error queue reads.
     */
    int getErrorQueueReads() { return errorQueueReads_; }

    /**
     * Expect a call to setsockopt with optname SO_TIMESTAMPING.
     */
    void netOpsExpectTimestampingSetSockOpt() {
      // must whitelist other calls
      EXPECT_CALL(*netOpsDispatcher_, setsockopt(_, _, _, _, _))
          .Times(AnyNumber());
      EXPECT_CALL(
          *netOpsDispatcher_, setsockopt(_, SOL_SOCKET, SO_TIMESTAMPING, _, _))
          .Times(1);
    }

    /**
     * Expect NO calls to setsockopt with optname SO_TIMESTAMPING.
     */
    void netOpsExpectNoTimestampingSetSockOpt() {
      // must whitelist other calls
      EXPECT_CALL(*netOpsDispatcher_, setsockopt(_, _, _, _, _))
          .Times(AnyNumber());
      EXPECT_CALL(*netOpsDispatcher_, setsockopt(_, _, SO_TIMESTAMPING, _, _))
          .Times(0);
    }

    /**
     * Expect sendmsg to be called with the passed WriteFlags in ancillary data.
     */
    void netOpsExpectSendmsgWithAncillaryTsFlags(WriteFlags writeFlags) {
      auto getMsgAncillaryTsFlags = std::bind(
          (WriteFlags(*)(const struct msghdr* msg)) & ::getMsgAncillaryTsFlags,
          std::placeholders::_1);
      EXPECT_CALL(
          *netOpsDispatcher_,
          sendmsg(_, ResultOf(getMsgAncillaryTsFlags, Eq(writeFlags)), _))
          .WillOnce(DoDefault());
    }

    /**
     * When sendmsg is called, record details and then forward to real sendmsg.
     *
     * This creates a default action.
     */
    void netOpsOnSendmsgRecordIovecsAndFlagsAndFwd() {
      ON_CALL(*netOpsDispatcher_, sendmsg(_, _, _))
          .WillByDefault(::testing::Invoke(
              [this](NetworkSocket s, const msghdr* message, int flags) {
                recordSendmsgInvocation(s, message, flags);
                return netops::Dispatcher::getDefaultInstance()->sendmsg(
                    s, message, flags);
              }));
    }

    /**
     * When recvmsg is called, forward to real recv message and record details
     * on return.
     *
     * This creates a default action.
     */
    void netOpsOnRecvmsg() {
      ON_CALL(*netOpsDispatcher_, recvmsg(_, _, _))
          .WillByDefault(::testing::Invoke(
              [this](NetworkSocket s, msghdr* message, int flags) {
                int ret = netops::Dispatcher::getDefaultInstance()->recvmsg(
                    s, message, flags);
                recordRecvmsgInvocation(s, message, flags, ret);
                return ret;
              }));
    }

    void netOpsVerifyAndClearExpectations() {
      Mock::VerifyAndClearExpectations(netOpsDispatcher_.get());
    }

   private:
    void recordSendmsgInvocation(
        NetworkSocket /* s */, const msghdr* message, int flags) {
      SendmsgInvocation invoc = {};
      invoc.iovs = getMsgIovecs(message);
      invoc.writeFlagsInMsgFlags = msgFlagsToWriteFlags(flags);
      invoc.writeFlagsInAncillary = getMsgAncillaryTsFlags(message);
      sendmsgInvocations_.emplace_back(std::move(invoc));
    }

    void recordRecvmsgInvocation(
        NetworkSocket /* s */,
        msghdr* /* message */,
        int flags,
        int returnValue) {
      if (flags == MSG_ERRQUEUE && returnValue >= 0) {
        errorQueueReads_ += 1;
      }
    }

    // server
    std::shared_ptr<TestServer> server_;

    // managed observers
    std::vector<std::shared_ptr<TestObserver>> observers_;

    // socket components
    ConnCallback connCb_;
    ReadCallback readCb_;
    std::shared_ptr<MockDispatcher> netOpsDispatcher_{
        std::make_shared<MockDispatcher>()};
    std::shared_ptr<AsyncSocket> socket_;

    // accepted socket at server
    std::shared_ptr<BlockingSocket> acceptedSocket_;

    // sendmsg invocations observed
    std::vector<SendmsgInvocation> sendmsgInvocations_;

    // successful error queue reads observer
    int errorQueueReads_{0};
  };

  ClientConn getClientConn() { return ClientConn(server_); }

  /**
   * Static utility functions.
   */

  static std::shared_ptr<NiceMock<TestObserver>> attachObserver(
      AsyncSocket* socket, bool enableByteEvents, bool enablePrewrite = false) {
    AsyncSocket::LegacyLifecycleObserver::Config config = {};
    config.byteEvents = enableByteEvents;
    config.prewrite = enablePrewrite;
    return std::make_shared<NiceMock<TestObserver>>(socket, config);
  }

  static std::vector<uint8_t> getHundredBytesOfData() {
    return std::vector<uint8_t>(
        kOneHundredCharacterString.begin(), kOneHundredCharacterString.end());
  }

  static std::vector<uint8_t> get10KBOfData() {
    std::vector<uint8_t> vec;
    vec.reserve(kOneHundredCharacterString.size() * 100);
    for (auto i = 0; i < 100; i++) {
      vec.insert(
          vec.end(),
          kOneHundredCharacterString.begin(),
          kOneHundredCharacterString.end());
    }
    CHECK_EQ(10000, vec.size());
    return vec;
  }

  static std::vector<uint8_t> get1000KBOfData() {
    std::vector<uint8_t> vec;
    vec.reserve(kOneHundredCharacterString.size() * 10000);
    for (auto i = 0; i < 10000; i++) {
      vec.insert(
          vec.end(),
          kOneHundredCharacterString.begin(),
          kOneHundredCharacterString.end());
    }
    CHECK_EQ(1000000, vec.size());
    return vec;
  }

  static WriteFlags dropWriteFromFlags(WriteFlags writeFlags) {
    return writeFlags & ~WriteFlags::TIMESTAMP_WRITE;
  }

  static std::vector<iovec> getMsgIovecs(const struct msghdr& msg) {
    std::vector<iovec> iovecs;
    for (size_t i = 0; i < msg.msg_iovlen; i++) {
      iovecs.emplace_back(msg.msg_iov[i]);
    }
    return iovecs;
  }

  static std::vector<iovec> getMsgIovecs(const struct msghdr* msg) {
    return getMsgIovecs(*msg);
  }

  static std::vector<uint8_t> iovsToVector(
      const iovec* iov, const size_t count) {
    std::vector<uint8_t> vec;
    for (size_t i = 0; i < count; i++) {
      if (iov[i].iov_len == 0) {
        continue;
      }
      const auto ptr = reinterpret_cast<uint8_t*>(iov[i].iov_base);
      vec.insert(vec.end(), ptr, ptr + iov[i].iov_len);
    }
    return vec;
  }

  static size_t iovsToNumBytes(const iovec* iov, const size_t count) {
    size_t bytes = 0;
    for (size_t i = 0; i < count; i++) {
      bytes += iov[i].iov_len;
    }
    return bytes;
  }

  std::vector<AsyncSocket::ByteEvent> filterToWriteEvents(
      const std::vector<AsyncSocket::ByteEvent>& input) {
    std::vector<AsyncSocket::ByteEvent> result;
    std::copy_if(
        input.begin(),
        input.end(),
        std::back_inserter(result),
        [](auto& event) {
          return event.type == AsyncSocket::ByteEvent::WRITE;
        });
    return result;
  }

  // server
  std::shared_ptr<TestServer> server_{std::make_shared<TestServer>()};
};

TEST_F(AsyncSocketByteEventTest, MsgFlagsToWriteFlags) {
#ifdef MSG_MORE
  EXPECT_EQ(WriteFlags::CORK, msgFlagsToWriteFlags(MSG_MORE));
#endif // MSG_MORE

#ifdef MSG_EOR
  EXPECT_EQ(WriteFlags::EOR, msgFlagsToWriteFlags(MSG_EOR));
#endif

#ifdef MSG_ZEROCOPY
  EXPECT_EQ(WriteFlags::WRITE_MSG_ZEROCOPY, msgFlagsToWriteFlags(MSG_ZEROCOPY));
#endif

#if defined(MSG_MORE) && defined(MSG_EOR)
  EXPECT_EQ(
      WriteFlags::CORK | WriteFlags::EOR,
      msgFlagsToWriteFlags(MSG_MORE | MSG_EOR));
#endif
}

TEST_F(AsyncSocketByteEventTest, GetMsgAncillaryTsFlags) {
  auto ancillaryDataSize = CMSG_LEN(sizeof(uint32_t));
  auto ancillaryData = reinterpret_cast<char*>(alloca(ancillaryDataSize));

  auto getMsg = [&ancillaryDataSize, &ancillaryData](uint32_t sofFlags) {
    struct msghdr msg = {};
    msg.msg_name = nullptr;
    msg.msg_namelen = 0;
    msg.msg_iov = nullptr;
    msg.msg_iovlen = 0;
    msg.msg_flags = 0;
    msg.msg_controllen = 0;
    msg.msg_control = nullptr;
    if (sofFlags) {
      msg.msg_controllen = ancillaryDataSize;
      msg.msg_control = ancillaryData;
      struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
      CHECK_NOTNULL(cmsg);
      cmsg->cmsg_level = SOL_SOCKET;
      cmsg->cmsg_type = SO_TIMESTAMPING;
      cmsg->cmsg_len = CMSG_LEN(sizeof(uint32_t));
      memcpy(CMSG_DATA(cmsg), &sofFlags, sizeof(sofFlags));
    }
    return msg;
  };

  // SCHED
  {
    auto msg = getMsg(folly::netops::SOF_TIMESTAMPING_TX_SCHED);
    EXPECT_EQ(WriteFlags::TIMESTAMP_SCHED, getMsgAncillaryTsFlags(msg));
  }

  // TX
  {
    auto msg = getMsg(folly::netops::SOF_TIMESTAMPING_TX_SOFTWARE);
    EXPECT_EQ(WriteFlags::TIMESTAMP_TX, getMsgAncillaryTsFlags(msg));
  }

  // ACK
  {
    auto msg = getMsg(folly::netops::SOF_TIMESTAMPING_TX_ACK);
    EXPECT_EQ(WriteFlags::TIMESTAMP_ACK, getMsgAncillaryTsFlags(msg));
  }

  // SCHED + TX + ACK
  {
    auto msg = getMsg(
        folly::netops::SOF_TIMESTAMPING_TX_SCHED |
        folly::netops::SOF_TIMESTAMPING_TX_SOFTWARE |
        folly::netops::SOF_TIMESTAMPING_TX_ACK);
    EXPECT_EQ(
        WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX |
            WriteFlags::TIMESTAMP_ACK,
        getMsgAncillaryTsFlags(msg));
  }
}

TEST_F(AsyncSocketByteEventTest, ObserverAttachedBeforeConnect) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(1, 'a');

  auto clientConn = getClientConn();
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  clientConn.netOpsExpectTimestampingSetSockOpt();
  clientConn.connect();
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();

  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer->byteEvents, SizeIs(4));
  EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
  EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
  EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
  EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));

  // write again to check offsets
  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer->byteEvents, SizeIs(8));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
}

TEST_F(AsyncSocketByteEventTest, ObserverAttachedAfterConnect) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(1, 'a');

  auto clientConn = getClientConn();
  clientConn.netOpsExpectNoTimestampingSetSockOpt();
  clientConn.connect();
  clientConn.netOpsVerifyAndClearExpectations();

  clientConn.netOpsExpectTimestampingSetSockOpt();
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();

  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer->byteEvents, SizeIs(4));
  EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
  EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
  EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
  EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));

  // write again to check offsets
  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer->byteEvents, SizeIs(8));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
}

TEST_F(
    AsyncSocketByteEventTest, ObserverAttachedBeforeConnectByteEventsDisabled) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(1, 'a');

  auto clientConn = getClientConn();
  auto observer = clientConn.attachObserver(false /* enableByteEvents */);
  clientConn.netOpsExpectNoTimestampingSetSockOpt();

  clientConn.connect(); // connect after observer attached
  EXPECT_EQ(0, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();

  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
      WriteFlags::NONE); // events disabled
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  EXPECT_THAT(observer->byteEvents, IsEmpty());
  clientConn.netOpsVerifyAndClearExpectations();

  // now enable ByteEvents with another observer, then write again
  clientConn.netOpsExpectTimestampingSetSockOpt();
  auto observer2 = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(0, observer->byteEventsEnabledCalled); // observer 1 doesn't want
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  EXPECT_EQ(1, observer2->byteEventsEnabledCalled); // should be set
  EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  EXPECT_NE(WriteFlags::NONE, flags);
  EXPECT_NE(WriteFlags::NONE, dropWriteFromFlags(flags));
  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();

  // expect no ByteEvents for first observer, four for the second
  EXPECT_THAT(observer->byteEvents, IsEmpty());
  EXPECT_THAT(observer2->byteEvents, SizeIs(4));
  EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
  EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
  EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
  EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
}

TEST_F(
    AsyncSocketByteEventTest, ObserverAttachedAfterConnectByteEventsDisabled) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(1, 'a');

  auto clientConn = getClientConn();
  clientConn.netOpsExpectNoTimestampingSetSockOpt();

  clientConn.connect(); // connect before observer attached

  auto observer = clientConn.attachObserver(false /* enableByteEvents */);
  EXPECT_EQ(0, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();

  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
      WriteFlags::NONE); // events disabled
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  EXPECT_THAT(observer->byteEvents, IsEmpty());
  clientConn.netOpsVerifyAndClearExpectations();

  // now enable ByteEvents with another observer, then write again
  clientConn.netOpsExpectTimestampingSetSockOpt();
  auto observer2 = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(0, observer->byteEventsEnabledCalled); // observer 1 doesn't want
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  EXPECT_EQ(1, observer2->byteEventsEnabledCalled); // should be set
  EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  EXPECT_NE(WriteFlags::NONE, flags);
  EXPECT_NE(WriteFlags::NONE, dropWriteFromFlags(flags));
  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();

  // expect no ByteEvents for first observer, four for the second
  EXPECT_THAT(observer->byteEvents, IsEmpty());
  EXPECT_THAT(observer2->byteEvents, SizeIs(4));
  EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
  EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
  EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
  EXPECT_EQ(1U, observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
}

TEST_F(AsyncSocketByteEventTest, ObserverAttachedAfterWrite) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(1, 'a');

  auto clientConn = getClientConn();
  clientConn.netOpsExpectNoTimestampingSetSockOpt();
  clientConn.connect(); // connect before observer attached
  clientConn.netOpsVerifyAndClearExpectations();

  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
      WriteFlags::NONE); // events disabled
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();

  clientConn.netOpsExpectTimestampingSetSockOpt();
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();

  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();

  EXPECT_THAT(observer->byteEvents, SizeIs(4));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
}

TEST_F(AsyncSocketByteEventTest, ObserverAttachedAfterClose) {
  auto clientConn = getClientConn();
  clientConn.connect();
  clientConn.getRawSocket()->close();
  EXPECT_TRUE(clientConn.getRawSocket()->isClosedBySelf());

  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(0, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
}

TEST_F(AsyncSocketByteEventTest, MultipleObserverAttached) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(50, 'a');

  // attach observer 1 before connect
  auto clientConn = getClientConn();
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  clientConn.netOpsExpectTimestampingSetSockOpt();
  clientConn.connect();
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();

  // attach observer 2 after connect
  auto observer2 = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer2->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  // write
  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();

  // check observer1
  EXPECT_THAT(observer->byteEvents, SizeIs(4));
  EXPECT_EQ(49U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
  EXPECT_EQ(49U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
  EXPECT_EQ(49U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
  EXPECT_EQ(49U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));

  // check observer2
  EXPECT_THAT(observer2->byteEvents, SizeIs(4));
  EXPECT_EQ(
      49U, observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
  EXPECT_EQ(
      49U, observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
  EXPECT_EQ(49U, observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
  EXPECT_EQ(49U, observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
}

/**
 * Test when kernel offset (uint32_t) wraps around.
 */
TEST_F(AsyncSocketByteEventTest, KernelOffsetWrap) {
  auto clientConn = getClientConn();
  clientConn.connect();
  clientConn.netOpsExpectTimestampingSetSockOpt();
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();

  const uint64_t wbufSize = 3000000;
  const std::vector<uint8_t> wbuf(wbufSize, 'a');

  // part 1: write close to the wrap point with no ByteEvents to speed things up
  const uint64_t bytesToWritePt1 =
      static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) -
      (wbufSize * 5);
  while (clientConn.getRawSocket()->getRawBytesWritten() < bytesToWritePt1) {
    clientConn.writeAtClientReadAtServer(
        wbuf, WriteFlags::NONE); // no reflect needed
  }

  // part 2: write over the wrap point with ByteEvents
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
  const uint64_t bytesToWritePt2 =
      static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) +
      (wbufSize * 5);
  while (clientConn.getRawSocket()->getRawBytesWritten() < bytesToWritePt2) {
    clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
        dropWriteFromFlags(flags));
    clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
    clientConn.netOpsVerifyAndClearExpectations();
    const uint64_t expectedOffset =
        clientConn.getRawSocket()->getRawBytesWritten() - 1;
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }

  // part 3: one more write outside of a loop with extra checks
  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();
  const auto expectedOffset =
      clientConn.getRawSocket()->getRawBytesWritten() - 1;
  EXPECT_LT(std::numeric_limits<uint32_t>::max(), expectedOffset);
  EXPECT_EQ(
      expectedOffset,
      observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
  EXPECT_EQ(
      expectedOffset,
      observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
  EXPECT_EQ(
      expectedOffset,
      observer->maxOffsetForByteEventReceived(ByteEventType::TX));
  EXPECT_EQ(
      expectedOffset,
      observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
}

/**
 * Ensure that ErrMessageCallback still works when ByteEvents enabled.
 */
TEST_F(AsyncSocketByteEventTest, ErrMessageCallbackStillTriggered) {
  auto clientConn = getClientConn();
  clientConn.connect();
  clientConn.netOpsExpectTimestampingSetSockOpt();
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();

  TestErrMessageCallback errMsgCB;
  clientConn.getRawSocket()->setErrMessageCB(&errMsgCB);

  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;

  std::vector<uint8_t> wbuf(1, 'a');
  EXPECT_NE(WriteFlags::NONE, flags);
  EXPECT_NE(WriteFlags::NONE, dropWriteFromFlags(flags));
  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();

  // observer should get events
  EXPECT_THAT(observer->byteEvents, SizeIs(4));
  EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
  EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
  EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
  EXPECT_EQ(0U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));

  // err message callbach should get events, too
  EXPECT_EQ(3, errMsgCB.gotByteSeq_);
  EXPECT_EQ(3, errMsgCB.gotTimestamp_);

  // write again, more events for both
  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer->byteEvents, SizeIs(8));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::TX));
  EXPECT_EQ(1U, observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
  EXPECT_EQ(6, errMsgCB.gotByteSeq_);
  EXPECT_EQ(6, errMsgCB.gotTimestamp_);
}

/**
 * Ensure that ByteEvents disabled for unix sockets (not supported).
 */
TEST_F(AsyncSocketByteEventTest, FailUnixSocket) {
  std::shared_ptr<NiceMock<TestObserver>> observer;
  auto netOpsDispatcher = std::make_shared<MockDispatcher>();

  NetworkSocket fd[2];
  EXPECT_EQ(netops::socketpair(AF_UNIX, SOCK_STREAM, 0, fd), 0);
  ASSERT_NE(fd[0], NetworkSocket());
  ASSERT_NE(fd[1], NetworkSocket());
  SCOPE_EXIT {
    netops::close(fd[1]);
  };

  EXPECT_EQ(netops::set_socket_non_blocking(fd[0]), 0);
  EXPECT_EQ(netops::set_socket_non_blocking(fd[1]), 0);

  auto clientSocketRaw = AsyncSocket::newSocket(nullptr, fd[0]);
  auto clientBlockingSocket = BlockingSocket(std::move(clientSocketRaw));
  clientBlockingSocket.getSocket()->setOverrideNetOpsDispatcher(
      netOpsDispatcher);

  // make sure no SO_TIMESTAMPING setsockopt on observer attach
  EXPECT_CALL(*netOpsDispatcher, setsockopt(_, _, _, _, _)).Times(AnyNumber());
  EXPECT_CALL(
      *netOpsDispatcher, setsockopt(_, SOL_SOCKET, SO_TIMESTAMPING, _, _))
      .Times(0); // no calls
  observer = attachObserver(
      clientBlockingSocket.getSocket(), true /* enableByteEvents */);
  EXPECT_EQ(0, observer->byteEventsEnabledCalled);
  EXPECT_EQ(1, observer->byteEventsUnavailableCalled);
  EXPECT_TRUE(observer->byteEventsUnavailableCalledEx.has_value());
  Mock::VerifyAndClearExpectations(netOpsDispatcher.get());

  // do a write, we should see it has no timestamp flags
  const std::vector<uint8_t> wbuf(1, 'a');
  EXPECT_CALL(*netOpsDispatcher, sendmsg(_, _, _))
      .WillOnce(WithArgs<1>(Invoke([](const msghdr* message) {
        EXPECT_EQ(WriteFlags::NONE, getMsgAncillaryTsFlags(*message));
        return 1;
      })));
  clientBlockingSocket.write(
      wbuf.data(),
      wbuf.size(),
      WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
          WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK);
  Mock::VerifyAndClearExpectations(netOpsDispatcher.get());
}

/**
 * If socket timestamps already enabled, do not enable ByteEvents.
 */
TEST_F(AsyncSocketByteEventTest, FailTimestampsAlreadyEnabled) {
  auto clientConn = getClientConn();
  clientConn.connect();

  // enable timestamps via setsockopt
  const uint32_t flags = folly::netops::SOF_TIMESTAMPING_OPT_ID |
      folly::netops::SOF_TIMESTAMPING_OPT_TSONLY |
      folly::netops::SOF_TIMESTAMPING_SOFTWARE |
      folly::netops::SOF_TIMESTAMPING_RAW_HARDWARE |
      folly::netops::SOF_TIMESTAMPING_OPT_TX_SWHW;
  const auto ret = clientConn.getRawSocket()->setSockOpt(
      SOL_SOCKET, SO_TIMESTAMPING, &flags);
  EXPECT_EQ(0, ret);

  clientConn.netOpsExpectNoTimestampingSetSockOpt();
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(0, observer->byteEventsEnabledCalled);
  EXPECT_EQ(1, observer->byteEventsUnavailableCalled); // fail
  EXPECT_TRUE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();

  std::vector<uint8_t> wbuf(1, 'a');
  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(WriteFlags::NONE);
  clientConn.writeAtClientReadAtServerReflectReadAtClient(
      wbuf,
      WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
          WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK);
  clientConn.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer->byteEvents, IsEmpty());
}

/**
 * Verify that ByteEvent information is properly copied during socket moves.
 */

TEST_F(AsyncSocketByteEventTest, MoveByteEventsEnabled) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(50, 'a');

  auto clientConn = getClientConn();
  clientConn.connect();

  // observer with ByteEvents enabled
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  // move the socket immediately and add an observer with ByteEvents enabled
  auto clientConn2 = ClientConn(
      server_,
      AsyncSocket::UniquePtr(new AsyncSocket(clientConn.getRawSocket().get())),
      clientConn.getAcceptedSocket());
  auto observer2 = clientConn2.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer2->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  // write following move, make sure the offsets are correct
  clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
      dropWriteFromFlags(flags));
  clientConn2.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn2.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(4)));
  {
    const auto expectedOffset = 49U;
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }

  // write again
  clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
      dropWriteFromFlags(flags));
  clientConn2.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn2.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(8)));
  {
    const auto expectedOffset = 99U;
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }
}

TEST_F(AsyncSocketByteEventTest, WriteThenDetachThenEnableByteEvents) {
  const auto flags = WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX |
      WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(20, 'a');
  auto clientConn = getClientConn();
  clientConn.connect();

  // observer with ByteEvents enabled
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  // write
  EXPECT_CALL(*clientConn.getNetOpsDispatcher(), recvmsg(_, _, _)).Times(0);
  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  iovec op = {};
  op.iov_base = const_cast<void*>(static_cast<const void*>(wbuf.data()));
  op.iov_len = wbuf.size();
  clientConn.writeAtClientReadAtServer(&op, 1, flags);

  // now detach the fd and create a new AsyncSocket with the same fd and add an
  // observer with ByteEvents enabled
  auto fd = clientConn.getRawSocket().get()->detachNetworkSocket();
  auto clientConn2 = ClientConn(
      server_,
      AsyncSocket::UniquePtr(new AsyncSocket(&clientConn.getEventBase(), fd)),
      clientConn.getAcceptedSocket());
  clientConn2.netOpsOnRecvmsg();
  // initialize socket family from underlying network socket
  clientConn2.getRawSocket()->cacheAddresses();

  // byte events should not be enabled because the fd already has timestamping
  // enabled (from when it was controlled by clientConn)
  auto observer2 = clientConn2.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(0, observer2->byteEventsEnabledCalled);
  EXPECT_EQ(1, observer2->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  // now have the server reflect the data previously written by the client and
  // check that the client is able to read the data
  clientConn2.writeAtServerReadAtClient(&op, 1);
  // now that we've read everything reflected by the server, loop once more to
  // allow the error queue to be read
  if (clientConn2.getErrorQueueReads() != 3) {
    clientConn2.getEventBase().loopOnce();
  }
  // we should read three timestamping (SCHED, TX, ACK) messages from the
  // error queue
  EXPECT_EQ(clientConn2.getErrorQueueReads(), 3);
}

TEST_F(AsyncSocketByteEventTest, WriteThenDetachThenDoNotEnableByteEvents) {
  const auto flags = WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX |
      WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(20, 'a');
  auto clientConn = getClientConn();
  clientConn.connect();

  // observer with ByteEvents enabled
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  // write
  EXPECT_CALL(*clientConn.getNetOpsDispatcher(), recvmsg(_, _, _)).Times(0);
  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  iovec op = {};
  op.iov_base = const_cast<void*>(static_cast<const void*>(wbuf.data()));
  op.iov_len = wbuf.size();
  clientConn.writeAtClientReadAtServer(&op, 1, flags);

  // now detach the fd and create a new AsyncSocket with the same fd
  // do not enable byte events on the new socket
  auto fd = clientConn.getRawSocket().get()->detachNetworkSocket();
  auto clientConn2 = ClientConn(
      server_,
      AsyncSocket::UniquePtr(new AsyncSocket(&clientConn.getEventBase(), fd)),
      clientConn.getAcceptedSocket());
  clientConn2.netOpsOnRecvmsg();
  // initialize socket family from underlying network socket
  clientConn2.getRawSocket()->cacheAddresses();

  // now have the server reflect the data previously written by the client and
  // check that the client is able to read the data
  clientConn2.writeAtServerReadAtClient(&op, 1);
  // now that we've read everything reflected by the server, loop once more to
  // allow the error queue to be read
  if (clientConn2.getErrorQueueReads() != 3) {
    clientConn2.getEventBase().loopOnce();
  }
  // we should read three timestamping (SCHED, TX, ACK) messages from the error
  // queue
  EXPECT_EQ(clientConn2.getErrorQueueReads(), 3);
}

TEST_F(AsyncSocketByteEventTest, WriteThenMoveByteEventsEnabled) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(50, 'a');

  auto clientConn = getClientConn();
  clientConn.connect();

  // observer with ByteEvents enabled
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  // write
  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer->byteEvents, SizeIs(Ge(4)));
  {
    const auto expectedOffset = 49U;
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }

  // now move the socket and add an observer with ByteEvents enabled
  auto clientConn2 = ClientConn(
      server_,
      AsyncSocket::UniquePtr(
          new AsyncSocket(std::move(clientConn.getRawSocket().get()))),
      clientConn.getAcceptedSocket());
  auto observer2 = clientConn2.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer2->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  // write following move, make sure the offsets are correct
  clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
      dropWriteFromFlags(flags));
  clientConn2.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn2.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(4)));
  {
    const auto expectedOffset = 99U;
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }

  // write again
  clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
      dropWriteFromFlags(flags));
  clientConn2.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn2.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(8)));
  {
    const auto expectedOffset = 149U;
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }
}

TEST_F(AsyncSocketByteEventTest, MoveThenEnableByteEvents) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(50, 'a');

  auto clientConn = getClientConn();
  clientConn.connect();

  // observer with ByteEvents disabled
  auto observer = clientConn.attachObserver(false /* enableByteEvents */);
  EXPECT_EQ(0, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  // move the socket immediately and add an observer with ByteEvents enabled
  auto clientConn2 = ClientConn(
      server_,
      AsyncSocket::UniquePtr(new AsyncSocket(clientConn.getRawSocket().get())),
      clientConn.getAcceptedSocket());
  auto observer2 = clientConn2.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer2->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  // write following move, make sure the offsets are correct
  clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
      dropWriteFromFlags(flags));
  clientConn2.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn2.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(4)));
  {
    const auto expectedOffset = 49U;
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }

  // write again
  clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
      dropWriteFromFlags(flags));
  clientConn2.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn2.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(8)));
  {
    const auto expectedOffset = 99U;
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }
}

TEST_F(AsyncSocketByteEventTest, WriteThenMoveThenEnableByteEvents) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(50, 'a');

  auto clientConn = getClientConn();
  clientConn.connect();

  // observer with ByteEvents disabled
  auto observer = clientConn.attachObserver(false /* enableByteEvents */);
  EXPECT_EQ(0, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  // write, ByteEvents disabled
  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
      WriteFlags::NONE); // events diabled
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();

  // now move the socket and add an observer with ByteEvents enabled
  auto clientConn2 = ClientConn(
      server_,
      AsyncSocket::UniquePtr(new AsyncSocket(clientConn.getRawSocket().get())),
      clientConn.getAcceptedSocket());
  auto observer2 = clientConn2.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer2->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer2->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  // write following move, make sure the offsets are correct
  clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
      dropWriteFromFlags(flags));
  clientConn2.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn2.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(4)));
  {
    const auto expectedOffset = 99U;
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }

  // write again
  clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
      dropWriteFromFlags(flags));
  clientConn2.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn2.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer2->byteEvents, SizeIs(Ge(8)));
  {
    const auto expectedOffset = 149U;
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        expectedOffset,
        observer2->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }
}

TEST_F(AsyncSocketByteEventTest, NoObserverMoveThenEnableByteEvents) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(50, 'a');

  auto clientConn = getClientConn();
  clientConn.connect();

  // no observer

  // move the socket immediately and add an observer with ByteEvents enabled
  auto clientConn2 = ClientConn(
      server_,
      AsyncSocket::UniquePtr(new AsyncSocket(clientConn.getRawSocket().get())),
      clientConn.getAcceptedSocket());
  auto observer = clientConn2.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  // write following move, make sure the offsets are correct
  clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
      dropWriteFromFlags(flags));
  clientConn2.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn2.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer->byteEvents, SizeIs(Ge(4)));
  {
    const auto expectedOffset = 49U;
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }

  // write again
  clientConn2.netOpsExpectSendmsgWithAncillaryTsFlags(
      dropWriteFromFlags(flags));
  clientConn2.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn2.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer->byteEvents, SizeIs(Ge(8)));
  {
    const auto expectedOffset = 99U;
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::TX));
    EXPECT_EQ(
        expectedOffset,
        observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
  }
}

/**
 * Inspect ByteEvent fields, including xTimestampRequested in WRITE events.
 *
 * See CheckByteEventDetailsRawBytesWrittenAndTriedToWrite and
 * AsyncSocketByteEventDetailsTest::CheckByteEventDetails as well.
 */
TEST_F(AsyncSocketByteEventTest, CheckByteEventDetails) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;
  const std::vector<uint8_t> wbuf(1, 'a');

  auto clientConn = getClientConn();
  clientConn.connect();
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  EXPECT_NE(WriteFlags::NONE, dropWriteFromFlags(flags));
  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
  clientConn.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer->byteEvents, SizeIs(Eq(4)));
  const auto expectedOffset = wbuf.size() - 1;

  // check WRITE
  {
    auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
        expectedOffset, ByteEventType::WRITE);
    ASSERT_TRUE(maybeByteEvent.has_value());
    auto& byteEvent = maybeByteEvent.value();

    EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
    EXPECT_EQ(expectedOffset, byteEvent.offset);
    EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
    EXPECT_LT(
        std::chrono::steady_clock::now() - std::chrono::seconds(60),
        byteEvent.ts);

    EXPECT_EQ(flags, byteEvent.maybeWriteFlags);
    EXPECT_TRUE(byteEvent.schedTimestampRequestedOnWrite());
    EXPECT_TRUE(byteEvent.txTimestampRequestedOnWrite());
    EXPECT_TRUE(byteEvent.ackTimestampRequestedOnWrite());

    EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
    EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());

    // maybeRawBytesWritten and maybeRawBytesTriedToWrite are tested in
    // CheckByteEventDetailsRawBytesWrittenAndTriedToWrite
  }

  // check SCHED, TX, ACK
  for (const auto& byteEventType :
       {ByteEventType::SCHED, ByteEventType::TX, ByteEventType::ACK}) {
    auto maybeByteEvent =
        observer->getByteEventReceivedWithOffset(expectedOffset, byteEventType);
    ASSERT_TRUE(maybeByteEvent.has_value());
    auto& byteEvent = maybeByteEvent.value();

    EXPECT_EQ(byteEventType, byteEvent.type);
    EXPECT_EQ(expectedOffset, byteEvent.offset);
    EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
    EXPECT_LT(
        std::chrono::steady_clock::now() - std::chrono::seconds(60),
        byteEvent.ts);

    EXPECT_FALSE(byteEvent.maybeWriteFlags.has_value());
    EXPECT_DEATH((void)byteEvent.schedTimestampRequestedOnWrite(), ".*");
    EXPECT_DEATH((void)byteEvent.txTimestampRequestedOnWrite(), ".*");
    EXPECT_DEATH((void)byteEvent.ackTimestampRequestedOnWrite(), ".*");

    EXPECT_TRUE(byteEvent.maybeSoftwareTs.has_value());
    EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
  }
}

/**
 * Inspect ByteEvent fields maybeRawBytesWritten and maybeRawBytesTriedToWrite.
 */
TEST_F(
    AsyncSocketByteEventTest,
    CheckByteEventDetailsRawBytesWrittenAndTriedToWrite) {
  auto clientConn = getClientConn();
  clientConn.connect();
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  struct ExpectedSendmsgInvocation {
    size_t expectedTotalIovLen{0};
    ssize_t returnVal{0}; // number of bytes written or error val
    folly::Optional<size_t> maybeWriteEventExpectedOffset{};
    folly::Optional<WriteFlags> maybeWriteEventExpectedFlags{};
  };

  const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
      WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;

  // first write
  //
  // no splits triggered by observer
  //
  // sendmsg will incrementally accept the bytes so we can test the values of
  // maybeRawBytesWritten and maybeRawBytesTriedToWrite
  {
    // bytes written per sendmsg call: 20, 10, 50, -1 (EAGAIN), 11, 99
    const std::vector<ExpectedSendmsgInvocation> expectedSendmsgInvocations{
        // {
        //    expectedTotalIovLen, returnVal,
        //    maybeWriteEventExpectedOffset, maybeWriteEventExpectedFlags
        // },
        {100, 20, 19, flags},
        {80, 10, 29, flags},
        {70, 50, 79, flags},
        {20, -1, folly::none, flags},
        {20, 11, 90, flags},
        {9, 9, 99, flags}};

    // sendmsg will be called, we return # of bytes written
    {
      InSequence s;
      for (const auto& expectedInvocation : expectedSendmsgInvocations) {
        EXPECT_CALL(
            *(clientConn.getNetOpsDispatcher()),
            sendmsg(
                _,
                Pointee(SendmsgMsghdrHasTotalIovLen(
                    expectedInvocation.expectedTotalIovLen)),
                _))
            .WillOnce(::testing::InvokeWithoutArgs([expectedInvocation]() {
              if (expectedInvocation.returnVal < 0) {
                errno = EAGAIN; // returning error, set EAGAIN
              }
              return expectedInvocation.returnVal;
            }));
      }
    }

    // write
    // writes will be intercepted, so we don't need to read at other end
    WriteCallback wcb;
    clientConn.getRawSocket()->write(
        &wcb,
        kOneHundredCharacterVec.data(),
        kOneHundredCharacterVec.size(),
        flags);
    while (STATE_WAITING == wcb.state) {
      clientConn.getRawSocket()->getEventBase()->loopOnce();
    }
    ASSERT_EQ(STATE_SUCCEEDED, wcb.state);

    // check write events
    for (const auto& expectedInvocation : expectedSendmsgInvocations) {
      if (expectedInvocation.returnVal < 0) {
        // should be no WriteEvent since the return value was an error
        continue;
      }

      ASSERT_TRUE(expectedInvocation.maybeWriteEventExpectedOffset.has_value());
      const auto& expectedOffset =
          *expectedInvocation.maybeWriteEventExpectedOffset;

      auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
          expectedOffset, ByteEventType::WRITE);
      ASSERT_TRUE(maybeByteEvent.has_value());
      auto& byteEvent = maybeByteEvent.value();

      EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
      EXPECT_EQ(expectedOffset, byteEvent.offset);
      EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
      EXPECT_LT(
          std::chrono::steady_clock::now() - std::chrono::seconds(60),
          byteEvent.ts);

      EXPECT_EQ(
          expectedInvocation.maybeWriteEventExpectedFlags,
          byteEvent.maybeWriteFlags);
      EXPECT_TRUE(byteEvent.schedTimestampRequestedOnWrite());
      EXPECT_TRUE(byteEvent.txTimestampRequestedOnWrite());
      EXPECT_TRUE(byteEvent.ackTimestampRequestedOnWrite());

      EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
      EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());

      // what we really want to test
      EXPECT_EQ(
          folly::to_unsigned(expectedInvocation.returnVal),
          byteEvent.maybeRawBytesWritten);
      EXPECT_EQ(
          expectedInvocation.expectedTotalIovLen,
          byteEvent.maybeRawBytesTriedToWrite);
    }
  }

  // everything should have occurred by now
  clientConn.netOpsVerifyAndClearExpectations();

  // second write
  //
  // sendmsg will incrementally accept the bytes so we can test the values of
  // maybeRawBytesWritten and maybeRawBytesTriedToWrite
  {
    // bytes written per sendmsg call: 20, 30, 50
    const std::vector<ExpectedSendmsgInvocation> expectedSendmsgInvocations{
        {100, 20, 119, flags}, {80, 30, 149, flags}, {50, 50, 199, flags}};

    // sendmsg will be called, we return # of bytes written
    {
      InSequence s;
      for (const auto& expectedInvocation : expectedSendmsgInvocations) {
        EXPECT_CALL(
            *(clientConn.getNetOpsDispatcher()),
            sendmsg(
                _,
                Pointee(SendmsgMsghdrHasTotalIovLen(
                    expectedInvocation.expectedTotalIovLen)),
                _))
            .WillOnce(::testing::InvokeWithoutArgs([expectedInvocation]() {
              return expectedInvocation.returnVal;
            }));
      }
    }

    // write
    // writes will be intercepted, so we don't need to read at other end
    WriteCallback wcb;
    clientConn.getRawSocket()->write(
        &wcb,
        kOneHundredCharacterVec.data(),
        kOneHundredCharacterVec.size(),
        flags);
    while (STATE_WAITING == wcb.state) {
      clientConn.getRawSocket()->getEventBase()->loopOnce();
    }
    ASSERT_EQ(STATE_SUCCEEDED, wcb.state);

    // check write events
    for (const auto& expectedInvocation : expectedSendmsgInvocations) {
      ASSERT_TRUE(expectedInvocation.maybeWriteEventExpectedOffset.has_value());
      const auto& expectedOffset =
          *expectedInvocation.maybeWriteEventExpectedOffset;

      auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
          expectedOffset, ByteEventType::WRITE);
      ASSERT_TRUE(maybeByteEvent.has_value());
      auto& byteEvent = maybeByteEvent.value();

      EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
      EXPECT_EQ(expectedOffset, byteEvent.offset);
      EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
      EXPECT_LT(
          std::chrono::steady_clock::now() - std::chrono::seconds(60),
          byteEvent.ts);

      EXPECT_EQ(
          expectedInvocation.maybeWriteEventExpectedFlags,
          byteEvent.maybeWriteFlags);
      EXPECT_TRUE(byteEvent.schedTimestampRequestedOnWrite());
      EXPECT_TRUE(byteEvent.txTimestampRequestedOnWrite());
      EXPECT_TRUE(byteEvent.ackTimestampRequestedOnWrite());

      EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
      EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());

      // what we really want to test
      EXPECT_EQ(
          folly::to_unsigned(expectedInvocation.returnVal),
          byteEvent.maybeRawBytesWritten);
      EXPECT_EQ(
          expectedInvocation.expectedTotalIovLen,
          byteEvent.maybeRawBytesTriedToWrite);
    }
  }
}

TEST_F(AsyncSocketByteEventTest, SplitIoVecArraySingleIoVec) {
  // get srciov from lambda to enable us to keep it const during test
  const char* buf = kOneHundredCharacterString.c_str();
  auto getSrcIov = [&buf]() {
    std::vector<struct iovec> srcIov(2);
    srcIov[0].iov_base = const_cast<void*>(static_cast<const void*>(buf));
    srcIov[0].iov_len = kOneHundredCharacterString.size();
    return srcIov;
  };

  std::vector<struct iovec> srcIov = getSrcIov();
  const auto data = srcIov.data();

  // split 0 -> 0 (first byte)
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        0, 0, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(1, dstIovCount);
    EXPECT_EQ(1, dstIov[0].iov_len);
    EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
    EXPECT_EQ(buf, dstIov[0].iov_base);
  }

  // split 0 -> 49 (50th byte)
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        0, 49, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(1, dstIovCount);
    EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
    EXPECT_EQ(50, dstIov[0].iov_len);
  }

  // split 0 -> 98 (penultimate byte)
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        0, 98, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(1, dstIovCount);
    EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
    EXPECT_EQ(99, dstIov[0].iov_len);
  }

  // split 0 -> 99 (pointless split)
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        0, 99, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(1, dstIovCount);
    EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
    EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
  }
}

TEST_F(AsyncSocketByteEventTest, SplitIoVecArrayMultiIoVecInvalid) {
  // get srciov from lambda to enable us to keep it const during test
  const char* buf = kOneHundredCharacterString.c_str();
  auto getSrcIov = [&buf]() {
    std::vector<struct iovec> srcIov(4);
    srcIov[0].iov_base = const_cast<void*>(static_cast<const void*>(buf));
    srcIov[0].iov_len = 50;
    srcIov[1].iov_base = const_cast<void*>(static_cast<const void*>(buf + 50));
    srcIov[1].iov_len = 50;
    return srcIov;
  };

  std::vector<struct iovec> srcIov = getSrcIov();
  const auto data = srcIov.data();

  // dstIov.size() < srcIov.size(); this is not allowed
  std::vector<struct iovec> dstIov(1);
  size_t dstIovCount = dstIov.size();
  EXPECT_LT(dstIovCount, srcIov.size());
  EXPECT_DEATH(
      AsyncSocket::splitIovecArray(
          0, 0, data, srcIov.size(), dstIov.data(), dstIovCount),
      ".*");
}

TEST_F(AsyncSocketByteEventTest, SplitIoVecArrayMultiIoVec) {
  // get srciov from lambda to enable us to keep it const during test
  const char* buf = kOneHundredCharacterString.c_str();
  auto getSrcIov = [&buf]() {
    std::vector<struct iovec> srcIov(4);
    srcIov[0].iov_base = const_cast<void*>(static_cast<const void*>(buf));
    srcIov[0].iov_len = 25;
    srcIov[1].iov_base = const_cast<void*>(static_cast<const void*>(buf + 25));
    srcIov[1].iov_len = 25;
    srcIov[2].iov_base = const_cast<void*>(static_cast<const void*>(buf + 50));
    srcIov[2].iov_len = 25;
    srcIov[3].iov_base = const_cast<void*>(static_cast<const void*>(buf + 75));
    srcIov[3].iov_len = 25;
    return srcIov;
  };

  std::vector<struct iovec> srcIov = getSrcIov();
  const auto data = srcIov.data();

  // split 0 -> 0 (first byte)
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        0, 0, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(1, dstIovCount);
    EXPECT_EQ(1, dstIov[0].iov_len);
    EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
    EXPECT_EQ(buf, dstIov[0].iov_base);
  }

  // split 0 -> 98 (penultimate byte)
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        0, 98, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(4, dstIovCount);
    EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
    EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
    EXPECT_EQ(srcIov[1].iov_base, dstIov[1].iov_base);
    EXPECT_EQ(srcIov[1].iov_len, dstIov[1].iov_len);
    EXPECT_EQ(srcIov[2].iov_base, dstIov[2].iov_base);
    EXPECT_EQ(srcIov[2].iov_len, dstIov[2].iov_len);

    // last iovec is different
    EXPECT_EQ(24, dstIov[3].iov_len);
    EXPECT_EQ(srcIov[3].iov_base, dstIov[3].iov_base);
  }

  // split 0 -> 99 (pointless split)
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        0, 99, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(4, dstIovCount);
    EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
    EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
    EXPECT_EQ(srcIov[1].iov_base, dstIov[1].iov_base);
    EXPECT_EQ(srcIov[1].iov_len, dstIov[1].iov_len);
    EXPECT_EQ(srcIov[2].iov_base, dstIov[2].iov_base);
    EXPECT_EQ(srcIov[2].iov_len, dstIov[2].iov_len);
    EXPECT_EQ(srcIov[3].iov_base, dstIov[3].iov_base);
    EXPECT_EQ(srcIov[3].iov_len, dstIov[3].iov_len);
  }

  //
  // test when endOffset is near a iovec boundary
  //

  // split 0 -> 49 (50th byte)
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        0, 49, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(2, dstIovCount);
    EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
    EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
    EXPECT_EQ(srcIov[1].iov_base, dstIov[1].iov_base);
    EXPECT_EQ(srcIov[1].iov_len, dstIov[1].iov_len);
  }

  // split 0 -> 50 (51st byte)
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        0, 50, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(3, dstIovCount);
    EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
    EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
    EXPECT_EQ(srcIov[1].iov_base, dstIov[1].iov_base);
    EXPECT_EQ(srcIov[1].iov_len, dstIov[1].iov_len);

    // last iovec is one byte
    EXPECT_EQ(1, dstIov[2].iov_len);
    EXPECT_EQ(srcIov[2].iov_base, dstIov[2].iov_base);
  }

  // split 0 -> 51 (52nd byte)
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        0, 51, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(3, dstIovCount);
    EXPECT_EQ(srcIov[0].iov_base, dstIov[0].iov_base);
    EXPECT_EQ(srcIov[0].iov_len, dstIov[0].iov_len);
    EXPECT_EQ(srcIov[1].iov_base, dstIov[1].iov_base);
    EXPECT_EQ(srcIov[1].iov_len, dstIov[1].iov_len);

    // last iovec is two bytes
    EXPECT_EQ(2, dstIov[2].iov_len);
    EXPECT_EQ(srcIov[2].iov_base, dstIov[2].iov_base);
  }

  //
  // test when startOffset is near a iovec boundary
  //

  // split 49 -> 99
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        49, 99, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(3, dstIovCount);

    // first dst iovec is one byte, starts 24 bytes in to the second src iovec
    EXPECT_EQ(1, dstIov[0].iov_len);
    EXPECT_EQ(
        dstIov[0].iov_base,
        const_cast<void*>(static_cast<const void*>(
            reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 24)));

    // second dst iovec is third src iovec
    // third dst iovec is fourth src iovec
    EXPECT_EQ(dstIov[1].iov_base, srcIov[2].iov_base);
    EXPECT_EQ(dstIov[1].iov_len, srcIov[2].iov_len);
    EXPECT_EQ(dstIov[2].iov_base, srcIov[3].iov_base);
    EXPECT_EQ(dstIov[2].iov_len, srcIov[3].iov_len);
  }

  // split 50 -> 99
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        50, 99, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(2, dstIovCount);

    // first dst iovec is third src iovec
    // second dst iovec is fourth src iovec
    EXPECT_EQ(dstIov[0].iov_base, srcIov[2].iov_base);
    EXPECT_EQ(dstIov[0].iov_len, srcIov[2].iov_len);
    EXPECT_EQ(dstIov[1].iov_base, srcIov[3].iov_base);
    EXPECT_EQ(dstIov[1].iov_len, srcIov[3].iov_len);
  }

  // split 51 -> 99
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        51, 99, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(2, dstIovCount);

    // first dst iovec is 24 bytes, starts 1 byte in to the third src iovec
    EXPECT_EQ(24, dstIov[0].iov_len);
    EXPECT_EQ(
        dstIov[0].iov_base,
        const_cast<void*>(static_cast<const void*>(
            reinterpret_cast<uint8_t*>(srcIov[2].iov_base) + 1)));

    // second dst iovec is fourth src iovec
    EXPECT_EQ(dstIov[1].iov_base, srcIov[3].iov_base);
    EXPECT_EQ(dstIov[1].iov_len, srcIov[3].iov_len);
  }

  //
  // test when startOffset and endOffset are near iovec boundaries
  //

  // split 49 -> 49
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        49, 49, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(1, dstIovCount);

    // first dst iovec is one byte, starts 24 bytes in to the second src iovec
    EXPECT_EQ(1, dstIov[0].iov_len);
    EXPECT_EQ(
        dstIov[0].iov_base,
        const_cast<void*>(static_cast<const void*>(
            reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 24)));
  }

  // split 49 -> 50
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        49, 50, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(2, dstIovCount);

    // first dst iovec is one byte, starts 24 bytes in to the second src iovec
    EXPECT_EQ(1, dstIov[0].iov_len);
    EXPECT_EQ(
        dstIov[0].iov_base,
        const_cast<void*>(static_cast<const void*>(
            reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 24)));

    // second iovec is one byte, starts at the third src iovec
    EXPECT_EQ(1, dstIov[1].iov_len);
    EXPECT_EQ(dstIov[1].iov_base, srcIov[2].iov_base);
  }

  // split 49 -> 51
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        49, 51, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(2, dstIovCount);

    // first dst iovec is one byte, starts 24 bytes in to the second src iovec
    EXPECT_EQ(1, dstIov[0].iov_len);
    EXPECT_EQ(
        dstIov[0].iov_base,
        const_cast<void*>(static_cast<const void*>(
            reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 24)));

    // second iovec is two bytes, starts at the third src iovec
    EXPECT_EQ(2, dstIov[1].iov_len);
    EXPECT_EQ(dstIov[1].iov_base, srcIov[2].iov_base);
  }

  // split 50 -> 50
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        50, 50, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(1, dstIovCount);

    // first dst iovec is one byte, starts at the third src iovec
    EXPECT_EQ(1, dstIov[0].iov_len);
    EXPECT_EQ(dstIov[0].iov_base, srcIov[2].iov_base);
  }

  // split 50 -> 51
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        50, 51, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(1, dstIovCount);

    // first dst iovec is two bytes, starts at the third src iovec
    EXPECT_EQ(2, dstIov[0].iov_len);
    EXPECT_EQ(dstIov[0].iov_base, srcIov[2].iov_base);
  }

  // split 51 -> 51
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        51, 51, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(1, dstIovCount);

    // first dst iovec is one byte, starts 1 byte into the third src iovec
    EXPECT_EQ(1, dstIov[0].iov_len);
    EXPECT_EQ(
        dstIov[0].iov_base,
        const_cast<void*>(static_cast<const void*>(
            reinterpret_cast<uint8_t*>(srcIov[2].iov_base) + 1)));
  }

  // split 48 -> 98
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        48, 98, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(3, dstIovCount);

    // first dst iovec is two bytes, starts 23 bytes in to the second src iovec
    EXPECT_EQ(2, dstIov[0].iov_len);
    EXPECT_EQ(
        dstIov[0].iov_base,
        const_cast<void*>(static_cast<const void*>(
            reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 23)));

    // second dst iovec is third src iovec
    EXPECT_EQ(dstIov[1].iov_base, srcIov[2].iov_base);
    EXPECT_EQ(dstIov[1].iov_len, srcIov[2].iov_len);

    // third dst iovec is 24 bytes, starts at the fourth src iovec
    EXPECT_EQ(24, dstIov[2].iov_len);
    EXPECT_EQ(dstIov[2].iov_base, srcIov[3].iov_base);
  }

  // split 49 -> 98
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        49, 98, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(3, dstIovCount);

    // first dst iovec is one byte, starts 24 bytes in to the second src iovec
    EXPECT_EQ(1, dstIov[0].iov_len);
    EXPECT_EQ(
        dstIov[0].iov_base,
        const_cast<void*>(static_cast<const void*>(
            reinterpret_cast<uint8_t*>(srcIov[1].iov_base) + 24)));

    // second dst iovec is third src iovec
    EXPECT_EQ(dstIov[1].iov_base, srcIov[2].iov_base);
    EXPECT_EQ(dstIov[1].iov_len, srcIov[2].iov_len);

    // third dst iovec is 24 bytes, starts at the fourth src iovec
    EXPECT_EQ(24, dstIov[2].iov_len);
    EXPECT_EQ(dstIov[2].iov_base, srcIov[3].iov_base);
  }

  // split 50 -> 98
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        50, 98, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(2, dstIovCount);

    // first dst iovec is third src iovec
    EXPECT_EQ(dstIov[0].iov_base, srcIov[2].iov_base);
    EXPECT_EQ(dstIov[0].iov_len, srcIov[2].iov_len);

    // second dst iovec is 24 bytes, starts at the fourth src iovec
    EXPECT_EQ(24, dstIov[1].iov_len);
    EXPECT_EQ(dstIov[1].iov_base, srcIov[3].iov_base);
  }

  // split 51 -> 98
  {
    std::vector<struct iovec> dstIov(4);
    size_t dstIovCount = dstIov.size();
    AsyncSocket::splitIovecArray(
        51, 98, data, srcIov.size(), dstIov.data(), dstIovCount);

    ASSERT_EQ(2, dstIovCount);

    // first dst iovec is 24 bytes, starts 1 byte in to the third src iovec
    EXPECT_EQ(24, dstIov[0].iov_len);
    EXPECT_EQ(
        dstIov[0].iov_base,
        const_cast<void*>(static_cast<const void*>(
            reinterpret_cast<uint8_t*>(srcIov[2].iov_base) + 1)));

    // second dst iovec is 24 bytes, starts at the fourth src iovec
    EXPECT_EQ(24, dstIov[1].iov_len);
    EXPECT_EQ(dstIov[1].iov_base, srcIov[3].iov_base);
  }
}

TEST_F(AsyncSocketByteEventTest, SendmsgMatchers) {
  // empty
  {
    const ClientConn::SendmsgInvocation sendmsgInvoc = {};
    // length
    EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(0)));

    // iov first byte
    EXPECT_THAT(
        sendmsgInvoc,
        Not(SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data())));
    EXPECT_THAT(
        sendmsgInvoc,
        Not(SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data() + 5)));

    // iov last byte
    EXPECT_THAT(
        sendmsgInvoc,
        Not(SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data())));
    EXPECT_THAT(
        sendmsgInvoc,
        Not(SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 5)));
  }

  // single iov, last byte = end of kOneHundredCharacterVec
  {
    struct iovec iov = {};
    iov.iov_base = const_cast<void*>(
        static_cast<const void*>((kOneHundredCharacterVec.data())));
    iov.iov_len = kOneHundredCharacterVec.size();
    const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov}};

    struct msghdr msg = {};
    msg.msg_name = nullptr;
    msg.msg_namelen = 0;
    msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
    msg.msg_iovlen = sendmsgInvoc.iovs.size();

    // length
    EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(100)));
    EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(100)));

    // iov first byte
    EXPECT_THAT(
        sendmsgInvoc,
        SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data()));
    EXPECT_THAT(
        sendmsgInvoc,
        Not(SendmsgInvocHasIovFirstByte(
            kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
            1)));

    // iov last byte
    EXPECT_THAT(
        sendmsgInvoc,
        Not(SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data())));
    EXPECT_THAT(
        sendmsgInvoc,
        SendmsgInvocHasIovLastByte(
            kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
            1));
  }

  // single iov, first and last byte = start of kOneHundredCharacterVec
  {
    struct iovec iov = {};
    iov.iov_base = const_cast<void*>(
        static_cast<const void*>((kOneHundredCharacterVec.data())));
    iov.iov_len = 1;
    const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov}};

    struct msghdr msg = {};
    msg.msg_name = nullptr;
    msg.msg_namelen = 0;
    msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
    msg.msg_iovlen = sendmsgInvoc.iovs.size();

    // length
    EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(1)));
    EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(1)));

    // iov first byte
    EXPECT_THAT(
        sendmsgInvoc,
        SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data()));
    EXPECT_THAT(
        sendmsgInvoc,
        Not(SendmsgInvocHasIovFirstByte(
            kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
            1)));

    // iov last byte
    EXPECT_THAT(
        sendmsgInvoc,
        SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data()));
    EXPECT_THAT(
        sendmsgInvoc,
        Not(SendmsgInvocHasIovLastByte(
            kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
            1)));
  }

  // single iov, first and last byte = end of kOneHundredCharacterVec
  {
    struct iovec iov = {};
    iov.iov_base = const_cast<void*>(static_cast<const void*>(
        (kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size())));
    iov.iov_len = 1;
    const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov}};

    struct msghdr msg = {};
    msg.msg_name = nullptr;
    msg.msg_namelen = 0;
    msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
    msg.msg_iovlen = sendmsgInvoc.iovs.size();

    // length
    EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(1)));
    EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(1)));

    // iov first byte
    EXPECT_THAT(
        sendmsgInvoc,
        SendmsgInvocHasIovFirstByte(
            kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size()));

    // iov last byte
    EXPECT_THAT(
        sendmsgInvoc,
        SendmsgInvocHasIovLastByte(
            kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size()));
  }

  // two iov, (0 -> 0, 1 - > 99), last byte = end of kOneHundredCharacterVec
  {
    struct iovec iov1 = {};
    iov1.iov_base = const_cast<void*>(
        static_cast<const void*>((kOneHundredCharacterVec.data())));
    iov1.iov_len = 1;

    struct iovec iov2 = {};
    iov2.iov_base = const_cast<void*>(
        static_cast<const void*>((kOneHundredCharacterVec.data() + 1)));
    iov2.iov_len = 99;

    const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov1, iov2}};

    struct msghdr msg = {};
    msg.msg_name = nullptr;
    msg.msg_namelen = 0;
    msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
    msg.msg_iovlen = sendmsgInvoc.iovs.size();

    // length
    EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(100)));
    EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(100)));

    // iov first byte
    EXPECT_THAT(
        sendmsgInvoc,
        SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data()));

    // iov last byte
    EXPECT_THAT(
        sendmsgInvoc,
        SendmsgInvocHasIovLastByte(
            kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
            1));
  }

  // two iov, (0 -> 49, 50 - > 99), last byte = end of kOneHundredCharacterVec
  {
    struct iovec iov1 = {};
    iov1.iov_base = const_cast<void*>(
        static_cast<const void*>((kOneHundredCharacterVec.data())));
    iov1.iov_len = 50;

    struct iovec iov2 = {};
    iov2.iov_base = const_cast<void*>(
        static_cast<const void*>((kOneHundredCharacterVec.data() + 50)));
    iov2.iov_len = 50;

    const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov1, iov2}};

    struct msghdr msg = {};
    msg.msg_name = nullptr;
    msg.msg_namelen = 0;
    msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
    msg.msg_iovlen = sendmsgInvoc.iovs.size();

    // length
    EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(100)));
    EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(100)));

    // iov first byte
    EXPECT_THAT(
        sendmsgInvoc,
        SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data()));

    // iov last byte
    EXPECT_THAT(
        sendmsgInvoc,
        SendmsgInvocHasIovLastByte(
            kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
            1));
  }

  // two iov, (0 -> 49, 50 - > 98), last byte = penultimate byte
  {
    struct iovec iov1 = {};
    iov1.iov_base = const_cast<void*>(
        static_cast<const void*>((kOneHundredCharacterVec.data())));
    iov1.iov_len = 50;

    struct iovec iov2 = {};
    iov2.iov_base = const_cast<void*>(
        static_cast<const void*>((kOneHundredCharacterVec.data() + 50)));
    iov2.iov_len = 49;

    const ClientConn::SendmsgInvocation sendmsgInvoc = {.iovs = {iov1, iov2}};

    struct msghdr msg = {};
    msg.msg_name = nullptr;
    msg.msg_namelen = 0;
    msg.msg_iov = const_cast<struct iovec*>(sendmsgInvoc.iovs.data());
    msg.msg_iovlen = sendmsgInvoc.iovs.size();

    // length
    EXPECT_THAT(sendmsgInvoc, SendmsgInvocHasTotalIovLen(size_t(99)));
    EXPECT_THAT(msg, SendmsgMsghdrHasTotalIovLen(size_t(99)));

    // iov first byte
    EXPECT_THAT(
        sendmsgInvoc,
        SendmsgInvocHasIovFirstByte(kOneHundredCharacterVec.data()));

    // iov last byte
    EXPECT_THAT(
        sendmsgInvoc,
        SendmsgInvocHasIovLastByte(
            kOneHundredCharacterVec.data() + kOneHundredCharacterVec.size() -
            2));
  }
}

TEST_F(AsyncSocketByteEventTest, SendmsgInvocMsgFlagsEq) {
  // empty
  {
    const ClientConn::SendmsgInvocation sendmsgInvoc;
    EXPECT_THAT(sendmsgInvoc, SendmsgInvocMsgFlagsEq(WriteFlags::NONE));
    EXPECT_THAT(sendmsgInvoc, Not(SendmsgInvocMsgFlagsEq(WriteFlags::CORK)));
  }

  // flag set
  {
    ClientConn::SendmsgInvocation sendmsgInvoc = {};
    sendmsgInvoc.writeFlagsInMsgFlags = WriteFlags::CORK;
    EXPECT_THAT(sendmsgInvoc, Not(SendmsgInvocMsgFlagsEq(WriteFlags::NONE)));
    EXPECT_THAT(
        sendmsgInvoc,
        Not(SendmsgInvocMsgFlagsEq(
            WriteFlags::EOR | WriteFlags::CORK))); // should be exact match
    EXPECT_THAT(sendmsgInvoc, SendmsgInvocMsgFlagsEq(WriteFlags::CORK));
  }
}

TEST_F(AsyncSocketByteEventTest, SendmsgInvocAncillaryFlagsEq) {
  // empty
  {
    const ClientConn::SendmsgInvocation sendmsgInvoc;
    EXPECT_THAT(sendmsgInvoc, SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE));
    EXPECT_THAT(
        sendmsgInvoc,
        Not(SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_TX)));
  }

  // flag set
  {
    ClientConn::SendmsgInvocation sendmsgInvoc = {};
    sendmsgInvoc.writeFlagsInAncillary = WriteFlags::TIMESTAMP_TX;
    EXPECT_THAT(
        sendmsgInvoc, Not(SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE)));
    EXPECT_THAT(
        sendmsgInvoc,
        Not(SendmsgInvocAncillaryFlagsEq(
            WriteFlags::TIMESTAMP_TX |
            WriteFlags::TIMESTAMP_ACK))); // should be exact match
    EXPECT_THAT(
        sendmsgInvoc, SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_TX));
  }
}

TEST_F(AsyncSocketByteEventTest, ByteEventMatching) {
  // offset = 0, type = WRITE
  {
    AsyncSocket::ByteEvent event = {};
    event.type = ByteEventType::WRITE;
    event.offset = 0;
    EXPECT_THAT(event, ByteEventMatching(ByteEventType::WRITE, 0));

    // not matching
    EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::WRITE, 10)));
    EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::TX, 0)));
    EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::ACK, 0)));
    EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::SCHED, 0)));
  }

  // offset = 10, type = TX
  {
    AsyncSocket::ByteEvent event = {};
    event.type = ByteEventType::TX;
    event.offset = 10;
    EXPECT_THAT(event, ByteEventMatching(ByteEventType::TX, 10));

    // not matching
    EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::TX, 0)));
    EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::WRITE, 10)));
    EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::ACK, 10)));
    EXPECT_THAT(event, Not(ByteEventMatching(ByteEventType::SCHED, 10)));
  }
}

TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserver) {
  auto clientConn = getClientConn();
  clientConn.connect();
  auto observer = clientConn.attachObserver(
      true /* enableByteEvents */, true /* enablePrewrite */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();

  const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
      WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
  ON_CALL(*observer, prewriteMock(_, _, _))
      .WillByDefault(testing::Invoke(
          [](AsyncTransport*,
             const AsyncSocketObserverInterface::PrewriteState& state,
             AsyncSocketObserverInterface::PrewriteRequestContainer&
                 container) {
            AsyncSocketObserverInterface::PrewriteRequest request;
            if (state.startOffset == 0) {
              request.maybeOffsetToSplitWrite = 0;
            } else if (state.startOffset <= 50) {
              request.maybeOffsetToSplitWrite = 50;
            } else if (state.startOffset <= 98) {
              request.maybeOffsetToSplitWrite = 98;
            }

            request.writeFlagsToAddAtOffset = flags;
            container.addRequest(request);
          }));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(
      kOneHundredCharacterVec, WriteFlags::NONE);

  EXPECT_THAT(
      clientConn.getSendmsgInvocations(),
      ElementsAre(
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data()),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 50),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 98),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
              SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
              SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE))));

  // verify WRITE events exist at the appropriate locations
  // we verify timestamp events are generated elsewhere
  //
  // should _not_ contain events for 99 as no prewrite for that
  EXPECT_THAT(
      filterToWriteEvents(observer->byteEvents),
      ElementsAre(
          ByteEventMatching(ByteEventType::WRITE, 0),
          ByteEventMatching(ByteEventType::WRITE, 50),
          ByteEventMatching(ByteEventType::WRITE, 98)));
}

/**
 * Test explicitly that CORK (MSG_MORE) is set if write is split in middle.
 */
TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverCorkIfSplitMiddle) {
  auto clientConn = getClientConn();
  clientConn.connect();
  auto observer = clientConn.attachObserver(
      true /* enableByteEvents */, true /* enablePrewrite */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();

  const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
      WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
  ON_CALL(*observer, prewriteMock(_, _, _))
      .WillByDefault(testing::Invoke(
          [](AsyncTransport*,
             const AsyncSocketObserverInterface::PrewriteState& state,
             AsyncSocketObserverInterface::PrewriteRequestContainer&
                 container) {
            AsyncSocketObserverInterface::PrewriteRequest request;
            if (state.startOffset <= 50) {
              request.maybeOffsetToSplitWrite = 50;
            }
            request.writeFlagsToAddAtOffset = flags;
            container.addRequest(request);
          }));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(
      kOneHundredCharacterVec, WriteFlags::NONE);

  EXPECT_THAT(
      clientConn.getSendmsgInvocations(),
      ElementsAre(
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 50),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
              SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
              SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE))));

  // verify WRITE events exist at the appropriate locations
  // we verify timestamp events are generated elsewhere
  EXPECT_THAT(
      filterToWriteEvents(observer->byteEvents),
      ElementsAre(ByteEventMatching(ByteEventType::WRITE, 50)));
}

/**
 * Test explicitly that CORK (MSG_MORE) is set if write is split in middle.
 */
TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverNoCorkIfSplitAtEnd) {
  auto clientConn = getClientConn();
  clientConn.connect();
  auto observer = clientConn.attachObserver(
      true /* enableByteEvents */, true /* enablePrewrite */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();

  const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
      WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
  ON_CALL(*observer, prewriteMock(_, _, _))
      .WillByDefault(testing::Invoke(
          [](AsyncTransport*,
             const AsyncSocketObserverInterface::PrewriteState& state,
             AsyncSocketObserverInterface::PrewriteRequestContainer&
                 container) {
            AsyncSocketObserverInterface::PrewriteRequest request;
            if (state.startOffset <= 99) {
              request.maybeOffsetToSplitWrite = 99;
            }
            request.writeFlagsToAddAtOffset = flags;
            container.addRequest(request);
          }));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(
      kOneHundredCharacterVec, WriteFlags::NONE);

  EXPECT_THAT(
      clientConn.getSendmsgInvocations(),
      ElementsAre(AllOf(
          SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
          SendmsgInvocMsgFlagsEq(WriteFlags::NONE), // no cork!
          SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))));

  // verify WRITE events exist at the appropriate locations
  // we verify timestamp events are generated elsewhere
  EXPECT_THAT(
      filterToWriteEvents(observer->byteEvents),
      ElementsAre(ByteEventMatching(ByteEventType::WRITE, 99)));
}

/**
 * Test explicitly that split flags are NOT added if no split.
 */
TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverNoSplitFlagsIfNoSplit) {
  auto clientConn = getClientConn();
  clientConn.connect();
  auto observer = clientConn.attachObserver(
      true /* enableByteEvents */, true /* enablePrewrite */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();

  const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
      WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
  ON_CALL(*observer, prewriteMock(_, _, _))
      .WillByDefault(testing::Invoke(
          [](AsyncTransport*,
             const AsyncSocketObserverInterface::PrewriteState& /* state */,
             AsyncSocketObserverInterface::PrewriteRequestContainer&
                 container) {
            AsyncSocketObserverInterface::PrewriteRequest request;
            request.writeFlagsToAddAtOffset = flags;
            container.addRequest(request);
          }));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(
      kOneHundredCharacterVec, WriteFlags::NONE);

  EXPECT_THAT(
      clientConn.getSendmsgInvocations(),
      ElementsAre(AllOf(
          SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
          SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
          SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE))));
}

/**
 * Test more combinations of prewrite flags, including writeFlagsToAdd.
 */
TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverFlagsOnAll) {
  auto clientConn = getClientConn();
  clientConn.connect();
  auto observer = clientConn.attachObserver(
      true /* enableByteEvents */, true /* enablePrewrite */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();
  ON_CALL(*observer, prewriteMock(_, _, _))
      .WillByDefault(testing::Invoke(
          [](AsyncTransport*,
             const AsyncSocketObserverInterface::PrewriteState& state,
             AsyncSocketObserverInterface::PrewriteRequestContainer&
                 container) {
            AsyncSocketObserverInterface::PrewriteRequest request;
            if (state.startOffset == 0) {
              request.maybeOffsetToSplitWrite = 0;
              request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_WRITE;
            } else if (state.startOffset <= 10) {
              request.maybeOffsetToSplitWrite = 10;
              request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_SCHED;
            } else if (state.startOffset <= 20) {
              request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_TX;
              request.maybeOffsetToSplitWrite = 20;
            } else if (state.startOffset <= 30) {
              request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_ACK;
              request.maybeOffsetToSplitWrite = 30;
            } else if (state.startOffset <= 40) {
              request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_TX;
              request.writeFlagsToAdd |= WriteFlags::TIMESTAMP_WRITE;
              request.maybeOffsetToSplitWrite = 40;
            } else {
              request.writeFlagsToAdd |= WriteFlags::TIMESTAMP_WRITE;
            }

            container.addRequest(request);
          }));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(
      kOneHundredCharacterVec, WriteFlags::NONE);

  EXPECT_THAT(
      clientConn.getSendmsgInvocations(),
      ElementsAre(
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data()),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE)),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 10),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_SCHED)),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 20),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_TX)),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 30),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_ACK)),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 40),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_TX)),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
              SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
              SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE))));

  // verify WRITE events exist at the appropriate locations
  // we verify timestamp events are generated elsewhere
  EXPECT_THAT(
      filterToWriteEvents(observer->byteEvents),
      ElementsAre(
          ByteEventMatching(ByteEventType::WRITE, 0),
          ByteEventMatching(ByteEventType::WRITE, 40),
          ByteEventMatching(ByteEventType::WRITE, 99)));
}

/**
 * Test merging of write flags with those passed to AsyncSocket::write().
 */
TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverFlagsOnWrite) {
  auto clientConn = getClientConn();
  clientConn.connect();
  auto observer = clientConn.attachObserver(
      true /* enableByteEvents */, true /* enablePrewrite */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();

  // first byte, observer adds TX and WRITE, onwards, it just adds WRITE
  ON_CALL(*observer, prewriteMock(_, _, _))
      .WillByDefault(testing::Invoke(
          [](AsyncTransport*,
             const AsyncSocketObserverInterface::PrewriteState& state,
             AsyncSocketObserverInterface::PrewriteRequestContainer&
                 container) {
            AsyncSocketObserverInterface::PrewriteRequest request;
            if (state.startOffset == 0) {
              request.maybeOffsetToSplitWrite = 0;
              request.writeFlagsToAddAtOffset |= WriteFlags::TIMESTAMP_TX;
            }
            request.writeFlagsToAdd |= WriteFlags::TIMESTAMP_WRITE;

            container.addRequest(request);
          }));

  // application does a write with ACK and CORK set
  clientConn.writeAtClientReadAtServerReflectReadAtClient(
      kOneHundredCharacterVec, WriteFlags::CORK | WriteFlags::TIMESTAMP_ACK);

  // make sure we have the merge
  //   first write, TX is added
  //   second write, CORK is passed through
  EXPECT_THAT(
      clientConn.getSendmsgInvocations(),
      ElementsAre(
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data()),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK), // set by split
              SendmsgInvocAncillaryFlagsEq(
                  WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK)),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK), // still set
              SendmsgInvocAncillaryFlagsEq(
                  dropWriteFromFlags(WriteFlags::TIMESTAMP_ACK)))));

  // verify WRITE events exist at the appropriate locations
  // we verify timestamp events are generated elsewhere
  EXPECT_THAT(
      filterToWriteEvents(observer->byteEvents),
      ElementsAre(
          ByteEventMatching(ByteEventType::WRITE, 0),
          ByteEventMatching(ByteEventType::WRITE, 99)));
}

/**
 * Test invalid offset for prewrite, ensure death via CHECK.
 */
TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverInvalidOffset) {
  auto clientConn = getClientConn();
  clientConn.connect();
  auto observer = clientConn.attachObserver(
      true /* enableByteEvents */, true /* enablePrewrite */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();

  ON_CALL(*observer, prewriteMock(_, _, _))
      .WillByDefault(testing::Invoke(
          [](AsyncTransport*,
             const AsyncSocketObserverInterface::PrewriteState& state,
             AsyncSocketObserverInterface::PrewriteRequestContainer&
                 container) {
            AsyncSocketObserverInterface::PrewriteRequest request;
            EXPECT_GT(200, state.endOffset);
            request.maybeOffsetToSplitWrite = 200; // invalid
            container.addRequest(request);
          }));

  // check will fail due to invalid offset
  EXPECT_DEATH(
      clientConn.writeAtClientReadAtServerReflectReadAtClient(
          kOneHundredCharacterVec, WriteFlags::NONE),
      ".*");
}

/**
 * Test prewrite with multiple iovec.
 */
TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverTwoIovec) {
  // two iovec, each with half of the kOneHundredCharacterVec
  std::vector<iovec> iovs;
  {
    iovec iov = {};
    iov.iov_base = const_cast<void*>(
        static_cast<const void*>((kOneHundredCharacterVec.data())));
    iov.iov_len = 50;
    iovs.push_back(iov);
  }
  {
    iovec iov = {};
    iov.iov_base = const_cast<void*>(
        static_cast<const void*>((kOneHundredCharacterVec.data() + 50)));
    iov.iov_len = 50;
    iovs.push_back(iov);
  }

  auto clientConn = getClientConn();
  clientConn.connect();
  auto observer = clientConn.attachObserver(
      true /* enableByteEvents */, true /* enablePrewrite */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();

  const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
      WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
  ON_CALL(*observer, prewriteMock(_, _, _))
      .WillByDefault(testing::Invoke(
          [](AsyncTransport*,
             const AsyncSocketObserverInterface::PrewriteState& state,
             AsyncSocketObserverInterface::PrewriteRequestContainer&
                 container) {
            AsyncSocketObserverInterface::PrewriteRequest request;
            if (state.startOffset == 0) {
              request.maybeOffsetToSplitWrite = 0;
            } else if (state.startOffset <= 49) {
              request.maybeOffsetToSplitWrite = 49;
            } else if (state.startOffset <= 99) {
              request.maybeOffsetToSplitWrite = 99;
            }

            request.writeFlagsToAddAtOffset = flags;
            container.addRequest(request);
          }));

  clientConn.writeAtClientReadAtServerReflectReadAtClient(
      iovs.data(), iovs.size(), WriteFlags::NONE);

  EXPECT_THAT(
      clientConn.getSendmsgInvocations(),
      ElementsAre(
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data()),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 49),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
              SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
              SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))));

  // verify WRITE events exist at the appropriate locations
  // we verify timestamp events are generated elsewhere
  EXPECT_THAT(
      filterToWriteEvents(observer->byteEvents),
      ElementsAre(
          ByteEventMatching(ByteEventType::WRITE, 0),
          ByteEventMatching(ByteEventType::WRITE, 49),
          ByteEventMatching(ByteEventType::WRITE, 99)));
}

/**
 * Test prewrite with large number of iovec to trigger malloc codepath.
 */
TEST_F(AsyncSocketByteEventTest, PrewriteSingleObserverManyIovec) {
  // make a long vector, 10000 bytes long
  auto tenThousandByteVec = get10KBOfData();
  ASSERT_THAT(tenThousandByteVec, SizeIs(10000));

  // put each byte in the vector into its own iovec
  std::vector<iovec> tenThousandIovec;
  for (size_t i = 0; i < tenThousandByteVec.size(); i++) {
    iovec iov = {};
    iov.iov_base = tenThousandByteVec.data() + i;
    iov.iov_len = 1;
    tenThousandIovec.push_back(iov);
  }

  auto clientConn = getClientConn();
  clientConn.connect();
  auto observer = clientConn.attachObserver(
      true /* enableByteEvents */, true /* enablePrewrite */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();

  const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
      WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
  ON_CALL(*observer, prewriteMock(_, _, _))
      .WillByDefault(testing::Invoke(
          [](AsyncTransport*,
             const AsyncSocketObserverInterface::PrewriteState& state,
             AsyncSocketObserverInterface::PrewriteRequestContainer&
                 container) {
            AsyncSocketObserverInterface::PrewriteRequest request;
            if (state.startOffset == 0) {
              request.maybeOffsetToSplitWrite = 0;
            } else if (state.startOffset <= 1000) {
              request.maybeOffsetToSplitWrite = 1000;
            } else if (state.startOffset <= 5000) {
              request.maybeOffsetToSplitWrite = 5000;
            }

            request.writeFlagsToAddAtOffset = flags;
            container.addRequest(request);
          }));

  clientConn.writeAtClientReadAtServerReflectReadAtClient(
      tenThousandIovec.data(), tenThousandIovec.size(), WriteFlags::NONE);

  EXPECT_THAT(
      clientConn.getSendmsgInvocations(),
      AllOf(
          Contains(AllOf(
              SendmsgInvocHasIovLastByte(tenThousandByteVec.data()),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))),
          Contains(AllOf(
              SendmsgInvocHasIovLastByte(tenThousandByteVec.data() + 1000),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))),
          Contains(AllOf(
              SendmsgInvocHasIovLastByte(tenThousandByteVec.data() + 5000),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))),
          Contains(AllOf(
              SendmsgInvocHasIovLastByte(tenThousandByteVec.data() + 9999),
              SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
              SendmsgInvocAncillaryFlagsEq(WriteFlags::NONE)))));

  // verify WRITE events exist at the appropriate locations
  // we verify timestamp events are generated elsewhere
  //
  // should _not_ contain events for 99 as no prewrite for that
  EXPECT_THAT(
      filterToWriteEvents(observer->byteEvents),
      AllOf(
          Contains(ByteEventMatching(ByteEventType::WRITE, 0)),
          Contains(ByteEventMatching(ByteEventType::WRITE, 1000)),
          Contains(ByteEventMatching(ByteEventType::WRITE, 5000))));
}

TEST_F(AsyncSocketByteEventTest, PrewriteMultipleObservers) {
  auto clientConn = getClientConn();
  clientConn.connect();

  // five observers
  // observer1 - 4 have byte events and prewrite enabled
  // observer5 has byte events enabled
  // observer6 has neither byte events or prewrite
  auto observer1 = clientConn.attachObserver(
      true /* enableByteEvents */, true /* enablePrewrite */);
  auto observer2 = clientConn.attachObserver(
      true /* enableByteEvents */, true /* enablePrewrite */);
  auto observer3 = clientConn.attachObserver(
      true /* enableByteEvents */, true /* enablePrewrite */);
  auto observer4 = clientConn.attachObserver(
      true /* enableByteEvents */, true /* enablePrewrite */);
  auto observer5 = clientConn.attachObserver(
      true /* enableByteEvents */, false /* enablePrewrite */);
  auto observer6 = clientConn.attachObserver(
      false /* enableByteEvents */, false /* enablePrewrite */);

  clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();

  // observer 1 wants TX timestamps at 25, 50, 75
  ON_CALL(*observer1, prewriteMock(_, _, _))
      .WillByDefault(testing::Invoke(
          [](AsyncTransport*,
             const AsyncSocketObserverInterface::PrewriteState& state,
             AsyncSocketObserverInterface::PrewriteRequestContainer&
                 container) {
            AsyncSocketObserverInterface::PrewriteRequest request;
            if (state.startOffset <= 25) {
              request.maybeOffsetToSplitWrite = 25;
            } else if (state.startOffset <= 50) {
              request.maybeOffsetToSplitWrite = 50;
            } else if (state.startOffset <= 75) {
              request.maybeOffsetToSplitWrite = 75;
            }
            request.writeFlagsToAddAtOffset = WriteFlags::TIMESTAMP_TX;
            container.addRequest(request);
          }));

  // observer 2 wants ACK timestamps at 35, 65, 75
  ON_CALL(*observer2, prewriteMock(_, _, _))
      .WillByDefault(testing::Invoke(
          [](AsyncTransport*,
             const AsyncSocketObserverInterface::PrewriteState& state,
             AsyncSocketObserverInterface::PrewriteRequestContainer&
                 container) {
            AsyncSocketObserverInterface::PrewriteRequest request;
            if (state.startOffset <= 35) {
              request.maybeOffsetToSplitWrite = 35;
            } else if (state.startOffset <= 65) {
              request.maybeOffsetToSplitWrite = 65;
            } else if (state.startOffset <= 75) {
              request.maybeOffsetToSplitWrite = 75;
            }
            request.writeFlagsToAddAtOffset = WriteFlags::TIMESTAMP_ACK;
            container.addRequest(request);
          }));

  // observer 3 wants WRITE and SCHED flag on every write that occurs
  ON_CALL(*observer3, prewriteMock(_, _, _))
      .WillByDefault(testing::Invoke(
          [](AsyncTransport*,
             const AsyncSocketObserverInterface::PrewriteState& /* state */,
             AsyncSocketObserverInterface::PrewriteRequestContainer&
                 container) {
            AsyncSocketObserverInterface::PrewriteRequest request;
            request.writeFlagsToAdd =
                WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED;
            container.addRequest(request);
          }));

  // observer 4 has prewrite but makes no requests
  ON_CALL(*observer4, prewriteMock(_, _, _))
      .WillByDefault(testing::Invoke(
          [](AsyncTransport*,
             const AsyncSocketObserverInterface::PrewriteState& /* state */,
             AsyncSocketObserverInterface::
                 PrewriteRequestContainer& /* container */) {
            return; // do nothing
          }));

  // no calls for observer 5 or observer 6
  EXPECT_CALL(*observer5, prewriteMock(_, _, _)).Times(0);
  EXPECT_CALL(*observer6, prewriteMock(_, _, _)).Times(0);

  // write
  clientConn.writeAtClientReadAtServerReflectReadAtClient(
      kOneHundredCharacterVec, WriteFlags::NONE);

  EXPECT_THAT(
      clientConn.getSendmsgInvocations(),
      ElementsAre(
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 25),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(
                  WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX)),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 35),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(
                  WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_ACK)),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 50),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(
                  WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX)),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 65),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(
                  WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_ACK)),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 75),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(
                  WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX |
                  WriteFlags::TIMESTAMP_ACK)),
          AllOf(
              SendmsgInvocHasIovLastByte(kOneHundredCharacterVec.data() + 99),
              SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
              SendmsgInvocAncillaryFlagsEq(WriteFlags::TIMESTAMP_SCHED))));

  // verify WRITE events exist at the appropriate locations
  // we verify timestamp events are generated elsewhere
  for (const auto& observer : {observer1, observer2, observer3}) {
    EXPECT_THAT(
        filterToWriteEvents(observer->byteEvents),
        ElementsAre(
            ByteEventMatching(ByteEventType::WRITE, 25),
            ByteEventMatching(ByteEventType::WRITE, 35),
            ByteEventMatching(ByteEventType::WRITE, 50),
            ByteEventMatching(ByteEventType::WRITE, 65),
            ByteEventMatching(ByteEventType::WRITE, 75),
            ByteEventMatching(ByteEventType::WRITE, 99)));
  }
}

/**
 * Test prewrite with large write that enables testing of timestamps.
 *
 * We need to use a long vector to ensure that the kernel will not coalesce
 * the writes into a single SKB due to MSG_MORE.
 */
TEST_F(AsyncSocketByteEventTest, PrewriteTimestampedByteEvents) {
  // need a large block of data to ensure that MSG_MORE doesn't limit us
  const auto hundredKBVec = get1000KBOfData();
  ASSERT_THAT(hundredKBVec, SizeIs(1000000));

  auto clientConn = getClientConn();
  clientConn.connect();
  auto observer = clientConn.attachObserver(
      true /* enableByteEvents */, true /* enablePrewrite */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  clientConn.netOpsOnSendmsgRecordIovecsAndFlagsAndFwd();

  const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
      WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;
  ON_CALL(*observer, prewriteMock(_, _, _))
      .WillByDefault(testing::Invoke(
          [](AsyncTransport*,
             const AsyncSocketObserverInterface::PrewriteState& state,
             AsyncSocketObserverInterface::PrewriteRequestContainer&
                 container) {
            AsyncSocketObserverInterface::PrewriteRequest request;
            if (state.startOffset == 0) {
              request.maybeOffsetToSplitWrite = 0;
            } else if (state.startOffset <= 500000) {
              request.maybeOffsetToSplitWrite = 500000;
            } else {
              request.maybeOffsetToSplitWrite = 999999;
            }

            request.writeFlagsToAdd = flags;
            container.addRequest(request);
          }));

  clientConn.writeAtClientReadAtServerReflectReadAtClient(
      hundredKBVec, WriteFlags::NONE);

  EXPECT_THAT(
      clientConn.getSendmsgInvocations(),
      AllOf(
          Contains(AllOf(
              SendmsgInvocHasIovLastByte(hundredKBVec.data()),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))),
          Contains(AllOf(
              SendmsgInvocHasIovLastByte(hundredKBVec.data() + 500000),
              SendmsgInvocMsgFlagsEq(WriteFlags::CORK),
              SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags)))),
          Contains(AllOf(
              SendmsgInvocHasIovLastByte(hundredKBVec.data() + 999999),
              SendmsgInvocMsgFlagsEq(WriteFlags::NONE),
              SendmsgInvocAncillaryFlagsEq(dropWriteFromFlags(flags))))));

  // verify WRITE events exist at the appropriate locations
  EXPECT_THAT(
      filterToWriteEvents(observer->byteEvents),
      AllOf(
          Contains(ByteEventMatching(ByteEventType::WRITE, 0)),
          Contains(ByteEventMatching(ByteEventType::WRITE, 500000)),
          Contains(ByteEventMatching(ByteEventType::WRITE, 999999))));

  // verify SCHED, TX, and ACK events available at specified locations
  EXPECT_THAT(
      observer->byteEvents,
      AllOf(
          Contains(ByteEventMatching(ByteEventType::SCHED, 0)),
          Contains(ByteEventMatching(ByteEventType::TX, 0)),
          Contains(ByteEventMatching(ByteEventType::ACK, 0)),
          Contains(ByteEventMatching(ByteEventType::SCHED, 500000)),
          Contains(ByteEventMatching(ByteEventType::TX, 500000)),
          Contains(ByteEventMatching(ByteEventType::ACK, 500000)),
          Contains(ByteEventMatching(ByteEventType::SCHED, 999999)),
          Contains(ByteEventMatching(ByteEventType::TX, 999999)),
          Contains(ByteEventMatching(ByteEventType::ACK, 999999))));
}

/**
 * Test raw bytes written and bytes tried to write with prewrite.
 */
TEST_F(AsyncSocketByteEventTest, PrewriteRawBytesWrittenAndTriedToWrite) {
  auto clientConn = getClientConn();
  clientConn.connect();
  auto observer = clientConn.attachObserver(
      true /* enableByteEvents */, true /* enablePrewrite */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  struct ExpectedSendmsgInvocation {
    size_t expectedTotalIovLen{0};
    ssize_t returnVal{0}; // number of bytes written or error val
    folly::Optional<size_t> maybeWriteEventExpectedOffset{};
    folly::Optional<WriteFlags> maybeWriteEventExpectedFlags{};
  };

  const auto flags = WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK |
      WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_WRITE;

  // first write
  //
  // no splits triggered by observer
  //
  // sendmsg will incrementally accept the bytes so we can test the values of
  // maybeRawBytesWritten and maybeRawBytesTriedToWrite
  {
    // bytes written per sendmsg call: 20, 10, 50, -1 (EAGAIN), 11, 99
    const std::vector<ExpectedSendmsgInvocation> expectedSendmsgInvocations{
        // {
        //    expectedTotalIovLen, returnVal,
        //    maybeWriteEventExpectedOffset, maybeWriteEventExpectedFlags
        // },
        {100, 20, 19, flags},
        {80, 10, 29, flags},
        {70, 50, 79, flags},
        {20, -1, folly::none, flags},
        {20, 11, 90, flags},
        {9, 9, 99, flags}};

    // prewrite will be called, we request all events
    EXPECT_CALL(*observer, prewriteMock(_, _, _))
        .Times(expectedSendmsgInvocations.size())
        .WillRepeatedly(testing::Invoke(
            [](AsyncTransport*,
               const AsyncSocketObserverInterface::PrewriteState& /* state */,
               AsyncSocketObserverInterface::PrewriteRequestContainer&
                   container) {
              AsyncSocketObserverInterface::PrewriteRequest request = {};
              request.writeFlagsToAdd = flags;
              container.addRequest(request);
            }));

    // sendmsg will be called, we return # of bytes written
    {
      InSequence s;
      for (const auto& expectedInvocation : expectedSendmsgInvocations) {
        EXPECT_CALL(
            *(clientConn.getNetOpsDispatcher()),
            sendmsg(
                _,
                Pointee(SendmsgMsghdrHasTotalIovLen(
                    expectedInvocation.expectedTotalIovLen)),
                _))
            .WillOnce(::testing::InvokeWithoutArgs([expectedInvocation]() {
              if (expectedInvocation.returnVal < 0) {
                errno = EAGAIN; // returning error, set EAGAIN
              }
              return expectedInvocation.returnVal;
            }));
      }
    }

    // write
    // writes will be intercepted, so we don't need to read at other end
    WriteCallback wcb;
    clientConn.getRawSocket()->write(
        &wcb,
        kOneHundredCharacterVec.data(),
        kOneHundredCharacterVec.size(),
        WriteFlags::NONE);
    while (STATE_WAITING == wcb.state) {
      clientConn.getRawSocket()->getEventBase()->loopOnce();
    }
    ASSERT_EQ(STATE_SUCCEEDED, wcb.state);

    // check write events
    for (const auto& expectedInvocation : expectedSendmsgInvocations) {
      if (expectedInvocation.returnVal < 0) {
        // should be no WriteEvent since the return value was an error
        continue;
      }

      ASSERT_TRUE(expectedInvocation.maybeWriteEventExpectedOffset.has_value());
      const auto& expectedOffset =
          *expectedInvocation.maybeWriteEventExpectedOffset;

      auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
          expectedOffset, ByteEventType::WRITE);
      ASSERT_TRUE(maybeByteEvent.has_value());
      auto& byteEvent = maybeByteEvent.value();

      EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
      EXPECT_EQ(expectedOffset, byteEvent.offset);
      EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
      EXPECT_LT(
          std::chrono::steady_clock::now() - std::chrono::seconds(60),
          byteEvent.ts);

      EXPECT_EQ(
          expectedInvocation.maybeWriteEventExpectedFlags,
          byteEvent.maybeWriteFlags);
      EXPECT_TRUE(byteEvent.schedTimestampRequestedOnWrite());
      EXPECT_TRUE(byteEvent.txTimestampRequestedOnWrite());
      EXPECT_TRUE(byteEvent.ackTimestampRequestedOnWrite());

      EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
      EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());

      // what we really want to test
      EXPECT_EQ(
          folly::to_unsigned(expectedInvocation.returnVal),
          byteEvent.maybeRawBytesWritten);
      EXPECT_EQ(
          expectedInvocation.expectedTotalIovLen,
          byteEvent.maybeRawBytesTriedToWrite);
    }
  }

  // everything should have occurred by now
  clientConn.netOpsVerifyAndClearExpectations();

  // second write
  //
  // start offset is 100
  //
  // split at 150th byte triggered by observer
  //
  // sendmsg will incrementally accept the bytes so we can test the values of
  // maybeRawBytesWritten and maybeRawBytesTriedToWrite
  {
    // due to the split at the 150th byte, we expect sendmsg invocation to
    // only be called with bytes 100 -> 150 until after the 150th byte has been
    // written; in addition, the socket only accepts 20 of the 50 bytes the
    // first write.
    //
    // bytes written per sendmsg call: 20, 30, 50
    const std::vector<ExpectedSendmsgInvocation> expectedSendmsgInvocations{
        {50, 20, 119, flags | WriteFlags::CORK},
        {30, 30, 149, flags | WriteFlags::CORK},
        {50, 50, 199, flags}};

    // prewrite will be called, split at 50th byte (offset = 49)
    EXPECT_CALL(*observer, prewriteMock(_, _, _))
        .Times(expectedSendmsgInvocations.size())
        .WillRepeatedly(testing::Invoke(
            [](AsyncTransport*,
               const AsyncSocketObserverInterface::PrewriteState& state,
               AsyncSocketObserverInterface::PrewriteRequestContainer&
                   container) {
              AsyncSocketObserverInterface::PrewriteRequest request;
              if (state.startOffset <= 149) {
                request.maybeOffsetToSplitWrite = 149; // start offset = 100
              }
              request.writeFlagsToAdd = flags;
              container.addRequest(request);
            }));

    // sendmsg will be called, we return # of bytes written
    {
      InSequence s;
      for (const auto& expectedInvocation : expectedSendmsgInvocations) {
        EXPECT_CALL(
            *(clientConn.getNetOpsDispatcher()),
            sendmsg(
                _,
                Pointee(SendmsgMsghdrHasTotalIovLen(
                    expectedInvocation.expectedTotalIovLen)),
                _))
            .WillOnce(::testing::InvokeWithoutArgs([expectedInvocation]() {
              return expectedInvocation.returnVal;
            }));
      }
    }

    // write
    // writes will be intercepted, so we don't need to read at other end
    WriteCallback wcb;
    clientConn.getRawSocket()->write(
        &wcb,
        kOneHundredCharacterVec.data(),
        kOneHundredCharacterVec.size(),
        WriteFlags::NONE);
    while (STATE_WAITING == wcb.state) {
      clientConn.getRawSocket()->getEventBase()->loopOnce();
    }
    ASSERT_EQ(STATE_SUCCEEDED, wcb.state);

    // check write events
    for (const auto& expectedInvocation : expectedSendmsgInvocations) {
      ASSERT_TRUE(expectedInvocation.maybeWriteEventExpectedOffset.has_value());
      const auto& expectedOffset =
          *expectedInvocation.maybeWriteEventExpectedOffset;

      auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
          expectedOffset, ByteEventType::WRITE);
      ASSERT_TRUE(maybeByteEvent.has_value());
      auto& byteEvent = maybeByteEvent.value();

      EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
      EXPECT_EQ(expectedOffset, byteEvent.offset);
      EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
      EXPECT_LT(
          std::chrono::steady_clock::now() - std::chrono::seconds(60),
          byteEvent.ts);

      EXPECT_EQ(
          expectedInvocation.maybeWriteEventExpectedFlags,
          byteEvent.maybeWriteFlags);
      EXPECT_TRUE(byteEvent.schedTimestampRequestedOnWrite());
      EXPECT_TRUE(byteEvent.txTimestampRequestedOnWrite());
      EXPECT_TRUE(byteEvent.ackTimestampRequestedOnWrite());

      EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
      EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());

      // what we really want to test
      EXPECT_EQ(
          folly::to_unsigned(expectedInvocation.returnVal),
          byteEvent.maybeRawBytesWritten);
      EXPECT_EQ(
          expectedInvocation.expectedTotalIovLen,
          byteEvent.maybeRawBytesTriedToWrite);
    }
  }
}

TEST_F(AsyncSocketByteEventTest, GetTcpInfo_SocketStates) {
  const folly::TcpInfo::LookupOptions options = {};

  auto clientConn = getClientConn();

  // not open
  auto expectedTcpInfo = clientConn.getRawSocket()->getTcpInfo(options);
  EXPECT_FALSE(expectedTcpInfo.hasValue());

  // connected
  clientConn.connect();
  expectedTcpInfo = clientConn.getRawSocket()->getTcpInfo(options);
  EXPECT_TRUE(expectedTcpInfo.hasValue());

  // connected then closed
  clientConn.getRawSocket()->close();
  expectedTcpInfo = clientConn.getRawSocket()->getTcpInfo(options);
  EXPECT_FALSE(expectedTcpInfo.hasValue());
}

/**
 * Enable byte events and have offset correction immediately succeed.
 *
 * bytesSent and sendBufBytes stay the same and thus offset correction completes
 * on the first attempt.
 */
TEST_F(
    AsyncSocketByteEventTest,
    EnableByteEvents_OffsetCorrection_ValuesStaySame) {
  std::shared_ptr<MockTcpInfoDispatcher> mockTcpInfoDispatcher =
      std::make_shared<MockTcpInfoDispatcher>();

  folly::TcpInfo::tcp_info tInfoBefore = {};
  folly::TcpInfo::tcp_info tInfoAfter = {};
  tInfoBefore.tcpi_bytes_sent = 35;
  tInfoAfter.tcpi_bytes_sent = 35;

  folly::TcpInfo wrappedTcpInfoBefore{tInfoBefore};
  folly::TcpInfo wrappedTcpInfoAfter{tInfoAfter};

  wrappedTcpInfoBefore.setSendBufInUseBytes(0);
  wrappedTcpInfoAfter.setSendBufInUseBytes(0);

  auto clientConn = getClientConn();
  clientConn.netOpsExpectNoTimestampingSetSockOpt();
  clientConn.connect();
  clientConn.netOpsVerifyAndClearExpectations();

  clientConn.setMockTcpInfoDispatcher(mockTcpInfoDispatcher);

  {
    InSequence s;

    EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
        .WillOnce(Return(wrappedTcpInfoBefore))
        .RetiresOnSaturation();

    EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
        .WillOnce(Return(wrappedTcpInfoAfter))
        .RetiresOnSaturation();
  }

  auto observer = clientConn.attachObserver(true /* enableByteEvents */);

  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();
}

/**
 * Enable byte events and have offset correction repeat due to sendBufInUseBytes
 * changing in between calls to the kernel trying to enable timestamping.
 *
 * The operation should be retried SO_MAX_ATTEMPTS_ENABLE_BYTEEVENTS times and
 * then fail.
 */
TEST_F(
    AsyncSocketByteEventTest,
    EnableByteEvents_OffsetCorrection_sendBufInUseBytesChangingFail) {
  std::shared_ptr<MockTcpInfoDispatcher> mockTcpInfoDispatcher =
      std::make_shared<MockTcpInfoDispatcher>();

  folly::TcpInfo::tcp_info tInfoBefore = {};
  folly::TcpInfo::tcp_info tInfoAfter = {};
  tInfoBefore.tcpi_bytes_sent = 35;
  tInfoAfter.tcpi_bytes_sent = 35;

  folly::TcpInfo wrappedTcpInfoBefore{tInfoBefore};
  folly::TcpInfo wrappedTcpInfoAfter{tInfoAfter};

  wrappedTcpInfoBefore.setSendBufInUseBytes(1);
  wrappedTcpInfoAfter.setSendBufInUseBytes(0);

  auto clientConn = getClientConn();
  clientConn.netOpsExpectNoTimestampingSetSockOpt();
  clientConn.connect();
  clientConn.netOpsVerifyAndClearExpectations();

  clientConn.setMockTcpInfoDispatcher(mockTcpInfoDispatcher);

  auto byteEventsEnabledAttempts = 0;

  {
    InSequence s;

    for (; byteEventsEnabledAttempts < SO_MAX_ATTEMPTS_ENABLE_BYTEEVENTS;
         byteEventsEnabledAttempts++) {
      EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
          .WillOnce(Return(wrappedTcpInfoBefore))
          .RetiresOnSaturation();

      EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
          .WillOnce(Return(wrappedTcpInfoAfter))
          .RetiresOnSaturation();
    }
  }

  auto observer = clientConn.attachObserver(true /* enableByteEvents */);

  EXPECT_EQ(byteEventsEnabledAttempts, SO_MAX_ATTEMPTS_ENABLE_BYTEEVENTS);
  EXPECT_EQ(0, observer->byteEventsEnabledCalled);
  EXPECT_EQ(1, observer->byteEventsUnavailableCalled);
  EXPECT_TRUE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();
}

/**
 * Enable byte events and have offset correction repeat due to sentBytes
 * changing in between calls to the kernel trying to enable timestamping.
 *
 * The operation should be retried SO_MAX_ATTEMPTS_ENABLE_BYTEEVENTS times and
 * then fail.
 */
TEST_F(
    AsyncSocketByteEventTest,
    EnableByteEvents_OffsetCorrection_sentBytesChangingFail) {
  std::shared_ptr<MockTcpInfoDispatcher> mockTcpInfoDispatcher =
      std::make_shared<MockTcpInfoDispatcher>();

  folly::TcpInfo::tcp_info tInfoBefore = {};
  folly::TcpInfo::tcp_info tInfoAfter = {};
  tInfoBefore.tcpi_bytes_sent = 35;
  tInfoAfter.tcpi_bytes_sent = 36;

  folly::TcpInfo wrappedTcpInfoBefore{tInfoBefore};
  folly::TcpInfo wrappedTcpInfoAfter{tInfoAfter};

  wrappedTcpInfoBefore.setSendBufInUseBytes(0);
  wrappedTcpInfoAfter.setSendBufInUseBytes(0);

  auto clientConn = getClientConn();
  clientConn.netOpsExpectNoTimestampingSetSockOpt();
  clientConn.connect();
  clientConn.netOpsVerifyAndClearExpectations();

  clientConn.setMockTcpInfoDispatcher(mockTcpInfoDispatcher);

  auto byteEventsEnabledAttempts = 0;

  {
    InSequence s;

    for (; byteEventsEnabledAttempts < SO_MAX_ATTEMPTS_ENABLE_BYTEEVENTS;
         byteEventsEnabledAttempts++) {
      EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
          .WillOnce(Return(wrappedTcpInfoBefore))
          .RetiresOnSaturation();

      EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
          .WillOnce(Return(wrappedTcpInfoAfter))
          .RetiresOnSaturation();
    }
  }

  auto observer = clientConn.attachObserver(true /* enableByteEvents */);

  EXPECT_EQ(byteEventsEnabledAttempts, SO_MAX_ATTEMPTS_ENABLE_BYTEEVENTS);
  EXPECT_EQ(0, observer->byteEventsEnabledCalled);
  EXPECT_EQ(1, observer->byteEventsUnavailableCalled);
  EXPECT_TRUE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();
}

/**
 * Enable byte events and have offset correction repeat due to sendBufInUseBytes
 * changing in between calls to the kernel trying to enable timestamping.
 *
 * The operation should be retried at most SO_MAX_ATTEMPTS_ENABLE_BYTEEVENTS
 * times and then succeed when sendBufInUseBytes does not change.
 */
TEST_F(
    AsyncSocketByteEventTest,
    EnableByteEvents_OffsetCorrection_sendBufInUseBytesChangingSuccess) {
  std::shared_ptr<MockTcpInfoDispatcher> mockTcpInfoDispatcher =
      std::make_shared<MockTcpInfoDispatcher>();

  folly::TcpInfo::tcp_info tInfoBefore = {};
  folly::TcpInfo::tcp_info tInfoAfter = {};
  tInfoBefore.tcpi_bytes_sent = 36;
  tInfoAfter.tcpi_bytes_sent = 36;

  folly::TcpInfo::tcp_info tInfoBefore2 = {};
  folly::TcpInfo::tcp_info tInfoAfter2 = {};
  tInfoBefore2.tcpi_bytes_sent = 36;
  tInfoAfter2.tcpi_bytes_sent = 36;

  folly::TcpInfo wrappedTcpInfoBefore{tInfoBefore};
  folly::TcpInfo wrappedTcpInfoAfter{tInfoAfter};
  folly::TcpInfo wrappedTcpInfoBefore2{tInfoBefore2};
  folly::TcpInfo wrappedTcpInfoAfter2{tInfoAfter2};

  wrappedTcpInfoBefore.setSendBufInUseBytes(1);
  wrappedTcpInfoAfter.setSendBufInUseBytes(0);
  wrappedTcpInfoBefore2.setSendBufInUseBytes(0);
  wrappedTcpInfoAfter2.setSendBufInUseBytes(0);

  auto clientConn = getClientConn();
  clientConn.netOpsExpectNoTimestampingSetSockOpt();
  clientConn.connect();
  clientConn.netOpsVerifyAndClearExpectations();

  clientConn.setMockTcpInfoDispatcher(mockTcpInfoDispatcher);

  auto byteEventsEnabledAttempts = 0;
  auto constexpr kRetriesUntilByteEventsSuccessful = 5;
  EXPECT_LE(
      kRetriesUntilByteEventsSuccessful, SO_MAX_ATTEMPTS_ENABLE_BYTEEVENTS);

  {
    InSequence s;

    for (; byteEventsEnabledAttempts < SO_MAX_ATTEMPTS_ENABLE_BYTEEVENTS;
         byteEventsEnabledAttempts++) {
      if (byteEventsEnabledAttempts == kRetriesUntilByteEventsSuccessful) {
        EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
            .WillOnce(Return(wrappedTcpInfoBefore2))
            .RetiresOnSaturation();

        EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
            .WillOnce(Return(wrappedTcpInfoAfter2))
            .RetiresOnSaturation();

        break;
      } else {
        EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
            .WillOnce(Return(wrappedTcpInfoBefore))
            .RetiresOnSaturation();

        EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
            .WillOnce(Return(wrappedTcpInfoAfter))
            .RetiresOnSaturation();
      }
    }
  }

  auto observer = clientConn.attachObserver(true /* enableByteEvents */);

  EXPECT_EQ(byteEventsEnabledAttempts, kRetriesUntilByteEventsSuccessful);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();
}

/**
 * Enable byte events and have offset correction repeat due to sentBytes
 * changing in between calls to the kernel trying to enable timestamping.
 *
 * The operation should be retried at most SO_MAX_ATTEMPTS_ENABLE_BYTEEVENTS
 * times and then succeed when sentBytes does not change.
 */
TEST_F(
    AsyncSocketByteEventTest,
    EnableByteEvents_OffsetCorrection_sentBytesChangingSuccess) {
  std::shared_ptr<MockTcpInfoDispatcher> mockTcpInfoDispatcher =
      std::make_shared<MockTcpInfoDispatcher>();

  folly::TcpInfo::tcp_info tInfoBefore = {};
  folly::TcpInfo::tcp_info tInfoAfter = {};
  tInfoBefore.tcpi_bytes_sent = 35;
  tInfoAfter.tcpi_bytes_sent = 36;

  folly::TcpInfo::tcp_info tInfoBefore2 = {};
  folly::TcpInfo::tcp_info tInfoAfter2 = {};
  tInfoBefore2.tcpi_bytes_sent = 36;
  tInfoAfter2.tcpi_bytes_sent = 36;

  folly::TcpInfo wrappedTcpInfoBefore{tInfoBefore};
  folly::TcpInfo wrappedTcpInfoAfter{tInfoAfter};
  folly::TcpInfo wrappedTcpInfoBefore2{tInfoBefore2};
  folly::TcpInfo wrappedTcpInfoAfter2{tInfoAfter2};

  wrappedTcpInfoBefore.setSendBufInUseBytes(0);
  wrappedTcpInfoAfter.setSendBufInUseBytes(0);
  wrappedTcpInfoBefore2.setSendBufInUseBytes(0);
  wrappedTcpInfoAfter2.setSendBufInUseBytes(0);

  auto clientConn = getClientConn();
  clientConn.netOpsExpectNoTimestampingSetSockOpt();
  clientConn.connect();
  clientConn.netOpsVerifyAndClearExpectations();

  clientConn.setMockTcpInfoDispatcher(mockTcpInfoDispatcher);

  auto byteEventsEnabledAttempts = 0;
  auto constexpr kRetriesUntilByteEventsSuccessful = 5;
  EXPECT_LE(
      kRetriesUntilByteEventsSuccessful, SO_MAX_ATTEMPTS_ENABLE_BYTEEVENTS);

  {
    InSequence s;

    for (; byteEventsEnabledAttempts < SO_MAX_ATTEMPTS_ENABLE_BYTEEVENTS;
         byteEventsEnabledAttempts++) {
      if (byteEventsEnabledAttempts == kRetriesUntilByteEventsSuccessful) {
        EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
            .WillOnce(Return(wrappedTcpInfoBefore2))
            .RetiresOnSaturation();

        EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
            .WillOnce(Return(wrappedTcpInfoAfter2))
            .RetiresOnSaturation();

        break;
      } else {
        EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
            .WillOnce(Return(wrappedTcpInfoBefore))
            .RetiresOnSaturation();

        EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
            .WillOnce(Return(wrappedTcpInfoAfter))
            .RetiresOnSaturation();
      }
    }
  }

  auto observer = clientConn.attachObserver(true /* enableByteEvents */);

  EXPECT_EQ(byteEventsEnabledAttempts, kRetriesUntilByteEventsSuccessful);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();
}

class AsyncSocketByteEventRawOffsetTest
    : public AsyncSocketByteEventTest,
      public testing::WithParamInterface<size_t> {
 public:
  // byte offset of the AsyncSocket when ByteEvents are enabled
  //
  // for some of the tests, the value returned by sendBufInUseBytes
  // will be greater than this value to simulate a case in which
  // bytes are written to the socket prior to the AsyncSocket being
  // initialized, and those bytes still not yet been acked.
  static constexpr size_t kRawByteOffsetWhenByteEventsEnabled = 20;

  // values returned by sendBufInUseBytes()
  static std::vector<size_t> getTestingValues() {
    std::vector<size_t> vals{/* Values for sendBufInUseBytes */
                             0,
                             1,
                             10,
                             kRawByteOffsetWhenByteEventsEnabled,
                             // simulate cases where bytes have already been
                             // written to the kernel socket prior to the
                             // AsyncSocket being initialized and are still
                             // in the sendbuf (either not sent, or not ACKed).
                             kRawByteOffsetWhenByteEventsEnabled + 1,
                             kRawByteOffsetWhenByteEventsEnabled + 10};
    return vals;
  }
};

INSTANTIATE_TEST_SUITE_P(
    ByteEventRawOffsets,
    AsyncSocketByteEventRawOffsetTest,
    ::testing::ValuesIn(AsyncSocketByteEventRawOffsetTest::getTestingValues()));

/**
 * Enable byte events with varying values of sendBufInUseBytes.
 *
 * This is an end-to-end test verifying proper delivery of timestamps with
 * different byte offset corrections. sendBufInUseBytes varies between zero
 * and a value greater than that reported by getRawBytesWritten(), with the
 * latter providing coverage of a case where bytes were written to the
 * kernel socket prior to the AsyncSocket being initialized and are still
 * in the sendbuf.
 */
TEST_P(AsyncSocketByteEventRawOffsetTest, EnableByteEvents_CheckRawByteOffset) {
  const auto flags = WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
      WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK;

  const auto bytesInSendBuf = GetParam();
  const std::vector<uint8_t> wbuf1(kRawByteOffsetWhenByteEventsEnabled, 'a');
  const std::vector<uint8_t> wbuf2(1, 'a');
  const std::vector<uint8_t> wbufBytesInSendBufOnEnable(bytesInSendBuf, 'a');

  std::shared_ptr<MockTcpInfoDispatcher> mockTcpInfoDispatcher =
      std::make_shared<MockTcpInfoDispatcher>();

  folly::TcpInfo::tcp_info tInfoBefore = {};
  folly::TcpInfo::tcp_info tInfoAfter = {};
  folly::TcpInfo wrappedTcpInfoBefore{tInfoBefore};
  folly::TcpInfo wrappedTcpInfoAfter{tInfoAfter};
  wrappedTcpInfoBefore.setSendBufInUseBytes(bytesInSendBuf);
  wrappedTcpInfoAfter.setSendBufInUseBytes(bytesInSendBuf);

  auto clientConn = getClientConn();
  clientConn.netOpsExpectNoTimestampingSetSockOpt();
  clientConn.connect();
  clientConn.netOpsVerifyAndClearExpectations();

  clientConn.setMockTcpInfoDispatcher(mockTcpInfoDispatcher);

  {
    InSequence s;

    EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
        .WillOnce(Return(wrappedTcpInfoBefore))
        .RetiresOnSaturation();

    EXPECT_CALL(*mockTcpInfoDispatcher, initFromFd(_, _, _, _))
        .WillOnce(Return(wrappedTcpInfoAfter))
        .RetiresOnSaturation();
  }

  // Write any bytes that we wanted to have sent through the AsyncSocket
  // prior to timestamps being enabled to adjust the rawByteOffset
  clientConn.writeAtClientReadAtServer(wbuf1, WriteFlags::NONE);

  // Enable timestamps
  //
  // AsyncSocket will record the value returned by TcpInfo::sendBufInUseBytes
  // when enabling timestamps to determine the correction factor for timestamp
  // byte offsets. This test controls the value returned by sendBufInUseBytes.
  clientConn.netOpsExpectTimestampingSetSockOpt();
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());
  clientConn.netOpsVerifyAndClearExpectations();

  // Write wbufBytesInSendBufOnEnable to the socket (bypassing AsyncSocket)
  //
  // We can't control the actual number of bytes in the socket sendbuf when
  // we enable timestamping in the previous step. Instead, we create the
  // scenario where there are bytes in the sendBuf that we need to correct for
  // by writing bytes directly to the network socket after enabling
  // timestamping. The number of bytes written is identical to the number of
  // bytes that we reported were in the buffer in the previous step, thereby
  // causing kernel timestamp byte offsets to be offset by this amount.
  clientConn.writeAtClientDirectlyToNetworkSocket(wbufBytesInSendBufOnEnable);

  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf2, flags);
  clientConn.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer->byteEvents, SizeIs(4));
  EXPECT_EQ(
      kRawByteOffsetWhenByteEventsEnabled,
      observer->maxOffsetForByteEventReceived(ByteEventType::WRITE).value());
  EXPECT_EQ(
      kRawByteOffsetWhenByteEventsEnabled,
      observer->maxOffsetForByteEventReceived(ByteEventType::SCHED).value());
  EXPECT_EQ(
      kRawByteOffsetWhenByteEventsEnabled,
      observer->maxOffsetForByteEventReceived(ByteEventType::TX).value());
  EXPECT_EQ(
      kRawByteOffsetWhenByteEventsEnabled,
      observer->maxOffsetForByteEventReceived(ByteEventType::ACK).value());

  // write again to check offsets
  clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(dropWriteFromFlags(flags));
  clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf2, flags);
  clientConn.netOpsVerifyAndClearExpectations();
  EXPECT_THAT(observer->byteEvents, SizeIs(8));
  EXPECT_EQ(
      kRawByteOffsetWhenByteEventsEnabled + 1,
      observer->maxOffsetForByteEventReceived(ByteEventType::WRITE));
  EXPECT_EQ(
      kRawByteOffsetWhenByteEventsEnabled + 1,
      observer->maxOffsetForByteEventReceived(ByteEventType::SCHED));
  EXPECT_EQ(
      kRawByteOffsetWhenByteEventsEnabled + 1,
      observer->maxOffsetForByteEventReceived(ByteEventType::TX));
  EXPECT_EQ(
      kRawByteOffsetWhenByteEventsEnabled + 1,
      observer->maxOffsetForByteEventReceived(ByteEventType::ACK));
}

struct AsyncSocketByteEventDetailsTestParams {
  struct WriteParams {
    WriteParams(uint64_t bufferSize, WriteFlags writeFlags)
        : bufferSize(bufferSize), writeFlags(writeFlags) {}
    uint64_t bufferSize{0};
    WriteFlags writeFlags{WriteFlags::NONE};
  };

  std::vector<WriteParams> writesWithParams;
};

class AsyncSocketByteEventDetailsTest
    : public AsyncSocketByteEventTest,
      public testing::WithParamInterface<
          AsyncSocketByteEventDetailsTestParams> {
 public:
  static std::vector<AsyncSocketByteEventDetailsTestParams> getTestingValues() {
    const std::array<WriteFlags, 9> writeFlagCombinations{
        // SCHED
        WriteFlags::TIMESTAMP_SCHED,
        // TX
        WriteFlags::TIMESTAMP_TX,
        // ACK
        WriteFlags::TIMESTAMP_ACK,
        // SCHED + TX + ACK
        WriteFlags::TIMESTAMP_SCHED | WriteFlags::TIMESTAMP_TX |
            WriteFlags::TIMESTAMP_ACK,
        // WRITE
        WriteFlags::TIMESTAMP_WRITE,
        // WRITE + SCHED
        WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED,
        // WRITE + TX
        WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_TX,
        // WRITE + ACK
        WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_ACK,
        // WRITE + SCHED + TX + ACK
        WriteFlags::TIMESTAMP_WRITE | WriteFlags::TIMESTAMP_SCHED |
            WriteFlags::TIMESTAMP_TX | WriteFlags::TIMESTAMP_ACK,
    };

    std::vector<AsyncSocketByteEventDetailsTestParams> vals;
    for (const auto& writeFlags : writeFlagCombinations) {
      // write 1 byte
      {
        AsyncSocketByteEventDetailsTestParams params;
        params.writesWithParams.emplace_back(1, writeFlags);
        vals.push_back(params);
      }

      // write 1 byte twice
      {
        AsyncSocketByteEventDetailsTestParams params;
        params.writesWithParams.emplace_back(1, writeFlags);
        params.writesWithParams.emplace_back(1, writeFlags);
        vals.push_back(params);
      }

      // write 10 bytes
      {
        AsyncSocketByteEventDetailsTestParams params;
        params.writesWithParams.emplace_back(10, writeFlags);
        vals.push_back(params);
      }

      // write 10 bytes twice
      {
        AsyncSocketByteEventDetailsTestParams params;
        params.writesWithParams.emplace_back(10, writeFlags);
        params.writesWithParams.emplace_back(10, writeFlags);
        vals.push_back(params);
      }
    }

    return vals;
  }
};

INSTANTIATE_TEST_SUITE_P(
    ByteEventDetailsTest,
    AsyncSocketByteEventDetailsTest,
    ::testing::ValuesIn(AsyncSocketByteEventDetailsTest::getTestingValues()));

/**
 * Inspect ByteEvent fields, including xTimestampRequested in WRITE events.
 */
TEST_P(AsyncSocketByteEventDetailsTest, CheckByteEventDetails) {
  auto params = GetParam();

  auto clientConn = getClientConn();
  clientConn.connect();
  auto observer = clientConn.attachObserver(true /* enableByteEvents */);
  EXPECT_EQ(1, observer->byteEventsEnabledCalled);
  EXPECT_EQ(0, observer->byteEventsUnavailableCalled);
  EXPECT_FALSE(observer->byteEventsUnavailableCalledEx.has_value());

  uint64_t expectedNumByteEvents = 0;
  for (const auto& writeParams : params.writesWithParams) {
    const std::vector<uint8_t> wbuf(writeParams.bufferSize, 'a');
    const auto flags = writeParams.writeFlags;
    clientConn.netOpsExpectSendmsgWithAncillaryTsFlags(
        dropWriteFromFlags(flags));
    clientConn.writeAtClientReadAtServerReflectReadAtClient(wbuf, flags);
    clientConn.netOpsVerifyAndClearExpectations();
    const auto expectedOffset =
        clientConn.getRawSocket()->getRawBytesWritten() - 1;

    // check WRITE
    if ((flags & WriteFlags::TIMESTAMP_WRITE) != WriteFlags::NONE) {
      expectedNumByteEvents++;

      auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
          expectedOffset, ByteEventType::WRITE);
      ASSERT_TRUE(maybeByteEvent.has_value());
      auto& byteEvent = maybeByteEvent.value();

      EXPECT_EQ(ByteEventType::WRITE, byteEvent.type);
      EXPECT_EQ(expectedOffset, byteEvent.offset);
      EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
      EXPECT_LT(
          std::chrono::steady_clock::now() - std::chrono::seconds(60),
          byteEvent.ts);

      EXPECT_EQ(flags, byteEvent.maybeWriteFlags);
      EXPECT_EQ(
          isSet(flags, WriteFlags::TIMESTAMP_SCHED),
          byteEvent.schedTimestampRequestedOnWrite());
      EXPECT_EQ(
          isSet(flags, WriteFlags::TIMESTAMP_TX),
          byteEvent.txTimestampRequestedOnWrite());
      EXPECT_EQ(
          isSet(flags, WriteFlags::TIMESTAMP_ACK),
          byteEvent.ackTimestampRequestedOnWrite());

      EXPECT_FALSE(byteEvent.maybeSoftwareTs.has_value());
      EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
    }

    // check SCHED, TX, ACK
    for (const auto& byteEventType :
         {ByteEventType::SCHED, ByteEventType::TX, ByteEventType::ACK}) {
      auto maybeByteEvent = observer->getByteEventReceivedWithOffset(
          expectedOffset, byteEventType);
      switch (byteEventType) {
        case ByteEventType::WRITE:
          FAIL();
        case ByteEventType::SCHED:
          if ((flags & WriteFlags::TIMESTAMP_SCHED) == WriteFlags::NONE) {
            EXPECT_FALSE(maybeByteEvent.has_value());
            continue;
          }
          break;
        case ByteEventType::TX:
          if ((flags & WriteFlags::TIMESTAMP_TX) == WriteFlags::NONE) {
            EXPECT_FALSE(maybeByteEvent.has_value());
            continue;
          }
          break;
        case ByteEventType::ACK:
          if ((flags & WriteFlags::TIMESTAMP_ACK) == WriteFlags::NONE) {
            EXPECT_FALSE(maybeByteEvent.has_value());
            continue;
          }
          break;
      }

      expectedNumByteEvents++;
      ASSERT_TRUE(maybeByteEvent.has_value());
      auto& byteEvent = maybeByteEvent.value();

      EXPECT_EQ(byteEventType, byteEvent.type);
      EXPECT_EQ(expectedOffset, byteEvent.offset);
      EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);
      EXPECT_LT(
          std::chrono::steady_clock::now() - std::chrono::seconds(60),
          byteEvent.ts);
      EXPECT_FALSE(byteEvent.maybeWriteFlags.has_value());
      // don't check *TimestampRequestedOnWrite* fields to avoid CHECK_DEATH,
      // already checked in CheckByteEventDetailsApplicationSetsFlags

      EXPECT_TRUE(byteEvent.maybeSoftwareTs.has_value());
      EXPECT_FALSE(byteEvent.maybeHardwareTs.has_value());
    }
  }

  // should have at least expectedNumByteEvents
  // may be more if writes were split up by kernel
  EXPECT_THAT(observer->byteEvents, SizeIs(Ge(expectedNumByteEvents)));
}

class AsyncSocketByteEventHelperTest : public ::testing::Test {
 protected:
  using ByteEventType = AsyncSocket::ByteEvent::Type;

  /**
   * Wrapper around a vector containing cmsg header + data.
   */
  class WrappedCMsg {
   public:
    explicit WrappedCMsg(std::vector<char>&& data) : data_(std::move(data)) {}

    operator const struct cmsghdr &() {
      return *reinterpret_cast<struct cmsghdr*>(data_.data());
    }

   protected:
    std::vector<char> data_;
  };

  /**
   * Wrapper around a vector containing cmsg header + data.
   */
  class WrappedSockExtendedErrTsCMsg : public WrappedCMsg {
   public:
    using WrappedCMsg::WrappedCMsg;

    // ts[0] -> software timestamp
    // ts[1] -> hardware timestamp transformed to userspace time (deprecated)
    // ts[2] -> hardware timestamp

    void setSoftwareTimestamp(
        const std::chrono::seconds seconds,
        const std::chrono::nanoseconds nanoseconds) {
      struct cmsghdr* cmsg{reinterpret_cast<cmsghdr*>(data_.data())};
      struct scm_timestamping* tss{
          reinterpret_cast<struct scm_timestamping*>(CMSG_DATA(cmsg))};
      tss->ts[0].tv_sec = seconds.count();
      tss->ts[0].tv_nsec = nanoseconds.count();
    }

    void setHardwareTimestamp(
        const std::chrono::seconds seconds,
        const std::chrono::nanoseconds nanoseconds) {
      struct cmsghdr* cmsg{reinterpret_cast<cmsghdr*>(data_.data())};
      struct scm_timestamping* tss{
          reinterpret_cast<struct scm_timestamping*>(CMSG_DATA(cmsg))};
      tss->ts[2].tv_sec = seconds.count();
      tss->ts[2].tv_nsec = nanoseconds.count();
    }
  };

  static std::vector<char> cmsgData(int level, int type, size_t len) {
    std::vector<char> data(CMSG_LEN(len), 0);
    struct cmsghdr* cmsg{reinterpret_cast<cmsghdr*>(data.data())};
    cmsg->cmsg_level = level;
    cmsg->cmsg_type = type;
    cmsg->cmsg_len = CMSG_LEN(len);
    return data;
  }

  static WrappedSockExtendedErrTsCMsg cmsgForSockExtendedErrTimestamping() {
    return WrappedSockExtendedErrTsCMsg(
        cmsgData(SOL_SOCKET, SO_TIMESTAMPING, sizeof(struct scm_timestamping)));
  }

  static WrappedCMsg cmsgForScmTimestamping(
      const uint32_t type, const uint32_t kernelByteOffset) {
    auto data = cmsgData(SOL_IP, IP_RECVERR, sizeof(struct sock_extended_err));
    struct cmsghdr* cmsg{reinterpret_cast<cmsghdr*>(data.data())};
    struct sock_extended_err* serr{
        reinterpret_cast<struct sock_extended_err*>(CMSG_DATA(cmsg))};
    serr->ee_errno = ENOMSG;
    serr->ee_origin = SO_EE_ORIGIN_TIMESTAMPING;
    serr->ee_info = type;
    serr->ee_data = kernelByteOffset;
    return WrappedCMsg(std::move(data));
  }
};

TEST_F(AsyncSocketByteEventHelperTest, ByteOffsetThenTs) {
  auto scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, 0);
  const auto softwareTsSec = std::chrono::seconds(59);
  const auto softwareTsNs = std::chrono::nanoseconds(11);
  auto serrTs = cmsgForSockExtendedErrTimestamping();
  serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);

  AsyncSocket::ByteEventHelper helper = {};
  helper.byteEventsEnabled = true;
  helper.rawBytesWrittenWhenByteEventsEnabled = 0;

  EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
  EXPECT_TRUE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
}

TEST_F(AsyncSocketByteEventHelperTest, TsThenByteOffset) {
  auto scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, 0);
  const auto softwareTsSec = std::chrono::seconds(59);
  const auto softwareTsNs = std::chrono::nanoseconds(11);
  auto serrTs = cmsgForSockExtendedErrTimestamping();
  serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);

  AsyncSocket::ByteEventHelper helper = {};
  helper.byteEventsEnabled = true;
  helper.rawBytesWrittenWhenByteEventsEnabled = 0;

  EXPECT_FALSE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
  EXPECT_TRUE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
}

TEST_F(AsyncSocketByteEventHelperTest, ByteEventsDisabled) {
  auto scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, 0);
  const auto softwareTsSec = std::chrono::seconds(59);
  const auto softwareTsNs = std::chrono::nanoseconds(11);
  auto serrTs = cmsgForSockExtendedErrTimestamping();
  serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);

  AsyncSocket::ByteEventHelper helper = {};
  helper.byteEventsEnabled = false;
  helper.rawBytesWrittenWhenByteEventsEnabled = 0;

  // fails because disabled
  EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
  EXPECT_FALSE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));

  // enable, try again to prove this works
  helper.byteEventsEnabled = true;
  EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
  EXPECT_TRUE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
}

TEST_F(AsyncSocketByteEventHelperTest, IgnoreUnsupportedEvent) {
  auto scmType =
      folly::netops::SCM_TSTAMP_ACK + 10; // imaginary new type of SCM event
  auto scmTs = cmsgForScmTimestamping(scmType, 0);
  const auto softwareTsSec = std::chrono::seconds(59);
  const auto softwareTsNs = std::chrono::nanoseconds(11);
  auto serrTs = cmsgForSockExtendedErrTimestamping();
  serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);

  AsyncSocket::ByteEventHelper helper = {};
  helper.byteEventsEnabled = true;
  helper.rawBytesWrittenWhenByteEventsEnabled = 0;

  // unsupported event is eaten
  EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
  EXPECT_FALSE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));

  // change type, try again to prove this works
  scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_ACK, 0);
  EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
  EXPECT_TRUE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
}

TEST_F(AsyncSocketByteEventHelperTest, ErrorDoubleScmCmsg) {
  auto scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, 0);

  AsyncSocket::ByteEventHelper helper = {};
  helper.byteEventsEnabled = true;
  helper.rawBytesWrittenWhenByteEventsEnabled = 0;
  EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
  EXPECT_THROW(
      helper.processCmsg(scmTs, 1 /* rawBytesWritten */),
      AsyncSocket::ByteEventHelper::Exception);
}

TEST_F(AsyncSocketByteEventHelperTest, ErrorDoubleSerrCmsg) {
  const auto softwareTsSec = std::chrono::seconds(59);
  const auto softwareTsNs = std::chrono::nanoseconds(11);
  auto serrTs = cmsgForSockExtendedErrTimestamping();
  serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);

  AsyncSocket::ByteEventHelper helper = {};
  helper.byteEventsEnabled = true;
  helper.rawBytesWrittenWhenByteEventsEnabled = 0;
  EXPECT_FALSE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
  EXPECT_THROW(
      helper.processCmsg(serrTs, 1 /* rawBytesWritten */),
      AsyncSocket::ByteEventHelper::Exception);
}

TEST_F(AsyncSocketByteEventHelperTest, ErrorExceptionSet) {
  auto scmTs = cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, 0);
  const auto softwareTsSec = std::chrono::seconds(59);
  const auto softwareTsNs = std::chrono::nanoseconds(11);
  auto serrTs = cmsgForSockExtendedErrTimestamping();
  serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);

  AsyncSocket::ByteEventHelper helper = {};
  helper.byteEventsEnabled = true;
  helper.rawBytesWrittenWhenByteEventsEnabled = 0;
  helper.maybeEx = AsyncSocketException(
      AsyncSocketException::AsyncSocketExceptionType::UNKNOWN, "");

  // fails due to existing exception
  EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
  EXPECT_FALSE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));

  // delete the exception, then repeat to prove exception was blocking
  helper.maybeEx = folly::none;
  EXPECT_FALSE(helper.processCmsg(scmTs, 1 /* rawBytesWritten */));
  EXPECT_TRUE(helper.processCmsg(serrTs, 1 /* rawBytesWritten */));
}

struct AsyncSocketByteEventHelperTimestampTestParams {
  AsyncSocketByteEventHelperTimestampTestParams(
      uint32_t scmType,
      AsyncSocket::ByteEvent::Type expectedByteEventType,
      bool includeSoftwareTs,
      bool includeHardwareTs)
      : scmType(scmType),
        expectedByteEventType(expectedByteEventType),
        includeSoftwareTs(includeSoftwareTs),
        includeHardwareTs(includeHardwareTs) {}
  uint32_t scmType{0};
  AsyncSocket::ByteEvent::Type expectedByteEventType;
  bool includeSoftwareTs{false};
  bool includeHardwareTs{false};
};

class AsyncSocketByteEventHelperTimestampTest
    : public AsyncSocketByteEventHelperTest,
      public testing::WithParamInterface<
          AsyncSocketByteEventHelperTimestampTestParams> {
 public:
  static std::vector<AsyncSocketByteEventHelperTimestampTestParams>
  getTestingValues() {
    std::vector<AsyncSocketByteEventHelperTimestampTestParams> vals;

    // software + hardware timestamps
    {
      vals.emplace_back(
          folly::netops::SCM_TSTAMP_SCHED, ByteEventType::SCHED, true, true);
      vals.emplace_back(
          folly::netops::SCM_TSTAMP_SND, ByteEventType::TX, true, true);
      vals.emplace_back(
          folly::netops::SCM_TSTAMP_ACK, ByteEventType::ACK, true, true);
    }

    // software ts only
    {
      vals.emplace_back(
          folly::netops::SCM_TSTAMP_SCHED, ByteEventType::SCHED, true, false);
      vals.emplace_back(
          folly::netops::SCM_TSTAMP_SND, ByteEventType::TX, true, false);
      vals.emplace_back(
          folly::netops::SCM_TSTAMP_ACK, ByteEventType::ACK, true, false);
    }

    // hardware ts only
    {
      vals.emplace_back(
          folly::netops::SCM_TSTAMP_SCHED, ByteEventType::SCHED, false, true);
      vals.emplace_back(
          folly::netops::SCM_TSTAMP_SND, ByteEventType::TX, false, true);
      vals.emplace_back(
          folly::netops::SCM_TSTAMP_ACK, ByteEventType::ACK, false, true);
    }

    return vals;
  }
};

INSTANTIATE_TEST_SUITE_P(
    ByteEventTimestampTest,
    AsyncSocketByteEventHelperTimestampTest,
    ::testing::ValuesIn(
        AsyncSocketByteEventHelperTimestampTest::getTestingValues()));

/**
 * Check timestamp parsing for software and hardware timestamps.
 */
TEST_P(AsyncSocketByteEventHelperTimestampTest, CheckEventTimestamps) {
  const auto softwareTsSec = std::chrono::seconds(59);
  const auto softwareTsNs = std::chrono::nanoseconds(11);
  const auto hardwareTsSec = std::chrono::seconds(79);
  const auto hardwareTsNs = std::chrono::nanoseconds(31);

  auto params = GetParam();
  auto scmTs = cmsgForScmTimestamping(params.scmType, 0);
  auto serrTs = cmsgForSockExtendedErrTimestamping();
  if (params.includeSoftwareTs) {
    serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);
  }
  if (params.includeHardwareTs) {
    serrTs.setHardwareTimestamp(hardwareTsSec, hardwareTsNs);
  }

  AsyncSocket::ByteEventHelper helper = {};
  helper.byteEventsEnabled = true;
  helper.rawBytesWrittenWhenByteEventsEnabled = 0;
  folly::Optional<AsyncSocket::ByteEvent> maybeByteEvent;
  maybeByteEvent = helper.processCmsg(serrTs, 1 /* rawBytesWritten */);
  EXPECT_FALSE(maybeByteEvent.has_value());
  maybeByteEvent = helper.processCmsg(scmTs, 1 /* rawBytesWritten */);

  // common checks
  ASSERT_TRUE(maybeByteEvent.has_value());
  const auto& byteEvent = *maybeByteEvent;
  EXPECT_EQ(0, byteEvent.offset);
  EXPECT_GE(std::chrono::steady_clock::now(), byteEvent.ts);

  EXPECT_EQ(params.expectedByteEventType, byteEvent.type);
  if (params.includeSoftwareTs) {
    EXPECT_EQ(softwareTsSec + softwareTsNs, byteEvent.maybeSoftwareTs);
  }
  if (params.includeHardwareTs) {
    EXPECT_EQ(hardwareTsSec + hardwareTsNs, byteEvent.maybeHardwareTs);
  }
}

struct AsyncSocketByteEventHelperOffsetTestParams {
  uint64_t rawBytesWrittenWhenByteEventsEnabled{0};
  uint64_t byteTimestamped;
  uint64_t rawBytesWrittenWhenTimestampReceived;
};

class AsyncSocketByteEventHelperOffsetTest
    : public AsyncSocketByteEventHelperTest,
      public testing::WithParamInterface<
          AsyncSocketByteEventHelperOffsetTestParams> {
 public:
  static std::vector<AsyncSocketByteEventHelperOffsetTestParams>
  getTestingValues() {
    std::vector<AsyncSocketByteEventHelperOffsetTestParams> vals;
    const std::array<uint64_t, 5> rawBytesWrittenWhenByteEventsEnabledVals{
        0, 1, 100, 4294967295, 4294967296};
    for (const auto& rawBytesWrittenWhenByteEventsEnabled :
         rawBytesWrittenWhenByteEventsEnabledVals) {
      auto addParams = [&](auto params) {
        // check if case is valid based on rawBytesWrittenWhenByteEventsEnabled
        if (rawBytesWrittenWhenByteEventsEnabled <= params.byteTimestamped) {
          vals.push_back(params);
        }
      };

      // case 1
      // bytes sent on receipt of timestamp == byte timestamped
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 0;
        params.rawBytesWrittenWhenTimestampReceived = 0;
        addParams(params);
      }
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 1;
        params.rawBytesWrittenWhenTimestampReceived = 1;
        addParams(params);
      }
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 101;
        params.rawBytesWrittenWhenTimestampReceived = 101;
        addParams(params);
      }

      // bytes sent on receipt of timestamp > byte timestamped
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 1;
        params.rawBytesWrittenWhenTimestampReceived = 2;
        addParams(params);
      }
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 101;
        params.rawBytesWrittenWhenTimestampReceived = 102;
        addParams(params);
      }

      // case 2
      // bytes sent on receipt of timestamp == byte timestamped, boundary test
      // (boundary is at 2^32)
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 4294967294;
        params.rawBytesWrittenWhenTimestampReceived = 4294967294;
        addParams(params);
      }
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 4294967295;
        params.rawBytesWrittenWhenTimestampReceived = 4294967295;
        addParams(params);
      }
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 4294967296;
        params.rawBytesWrittenWhenTimestampReceived = 4294967296;
        addParams(params);
      }
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 4294967297;
        params.rawBytesWrittenWhenTimestampReceived = 4294967297;
        addParams(params);
      }
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 4294967298;
        params.rawBytesWrittenWhenTimestampReceived = 4294967298;
        addParams(params);
      }

      // case 3
      // bytes sent on receipt of timestamp > byte timestamped, boundary test
      // (boundary is at 2^32)
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 4294967293;
        params.rawBytesWrittenWhenTimestampReceived = 4294967294;
        addParams(params);
      }
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 4294967294;
        params.rawBytesWrittenWhenTimestampReceived = 4294967295;
        addParams(params);
      }
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 4294967295;
        params.rawBytesWrittenWhenTimestampReceived = 4294967296;
        addParams(params);
      }
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 4294967296;
        params.rawBytesWrittenWhenTimestampReceived = 4294967297;
        addParams(params);
      }

      // case 4
      // bytes sent on receipt of timestamp > byte timestamped, wrap test
      // (boundary is at 2^32)
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 4294967275;
        params.rawBytesWrittenWhenTimestampReceived = 4294967305;
        addParams(params);
      }
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 4294967295;
        params.rawBytesWrittenWhenTimestampReceived = 4294967296;
        addParams(params);
      }
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 4294967285;
        params.rawBytesWrittenWhenTimestampReceived = 4294967305;
        addParams(params);
      }

      // case 5
      // special case when timestamp enabled when bytes transferred > (2^32)
      // bytes sent on receipt of timestamp == byte timestamped, boundary test
      // (boundary is at 2^32)
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 6442450943;
        params.rawBytesWrittenWhenTimestampReceived = 6442450943;
        addParams(params);
      }

      // case 6
      // special case when timestamp enabled when bytes transferred > (2^32)
      // bytes sent on receipt of timestamp > byte timestamped, boundary test
      // (boundary is at 2^32)
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 6442450943;
        params.rawBytesWrittenWhenTimestampReceived = 6442450944;
        addParams(params);
      }

      // case 7
      // special case when timestamp enabled when bytes transferred > (2^32)
      // bytes sent on receipt of timestamp > byte timestamped, wrap test
      // (boundary is at 2^32)
      {
        AsyncSocketByteEventHelperOffsetTestParams params;
        params.rawBytesWrittenWhenByteEventsEnabled =
            rawBytesWrittenWhenByteEventsEnabled;
        params.byteTimestamped = 6442450943;
        params.rawBytesWrittenWhenTimestampReceived = 8589934591;
        addParams(params);
      }
    }

    return vals;
  }
};

INSTANTIATE_TEST_SUITE_P(
    ByteEventOffsetTest,
    AsyncSocketByteEventHelperOffsetTest,
    ::testing::ValuesIn(
        AsyncSocketByteEventHelperOffsetTest::getTestingValues()));

/**
 * Check byte offset handling, including boundary cases.
 *
 * See AsyncSocket::ByteEventHelper::processCmsg for details.
 */
TEST_P(AsyncSocketByteEventHelperOffsetTest, CheckCalculatedOffset) {
  auto params = GetParam();

  // because we use SOF_TIMESTAMPING_OPT_ID, byte offsets delivered from the
  // kernel are offset (relative to bytes written by AsyncSocket) by the number
  // of bytes AsyncSocket had written to the socket when enabling timestamps
  //
  // here we calculate what the kernel offset would be for the given byte offset
  const uint64_t bytesPerOffsetWrap =
      static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1;

  auto kernelByteOffset =
      params.byteTimestamped - params.rawBytesWrittenWhenByteEventsEnabled;
  if (kernelByteOffset > 0) {
    kernelByteOffset = kernelByteOffset % bytesPerOffsetWrap;
  }

  auto scmTs =
      cmsgForScmTimestamping(folly::netops::SCM_TSTAMP_SND, kernelByteOffset);
  const auto softwareTsSec = std::chrono::seconds(59);
  const auto softwareTsNs = std::chrono::nanoseconds(11);
  auto serrTs = cmsgForSockExtendedErrTimestamping();
  serrTs.setSoftwareTimestamp(softwareTsSec, softwareTsNs);

  AsyncSocket::ByteEventHelper helper = {};
  helper.byteEventsEnabled = true;
  helper.rawBytesWrittenWhenByteEventsEnabled =
      params.rawBytesWrittenWhenByteEventsEnabled;

  EXPECT_FALSE(helper.processCmsg(
      scmTs,
      params.rawBytesWrittenWhenTimestampReceived /* rawBytesWritten */));
  const auto maybeByteEvent = helper.processCmsg(
      serrTs,
      params.rawBytesWrittenWhenTimestampReceived /* rawBytesWritten */);
  ASSERT_TRUE(maybeByteEvent.has_value());
  const auto& byteEvent = *maybeByteEvent;

  EXPECT_EQ(params.byteTimestamped, byteEvent.offset);
  EXPECT_EQ(softwareTsSec + softwareTsNs, byteEvent.maybeSoftwareTs);
}

#endif // FOLLY_HAVE_SO_TIMESTAMPING

TEST(AsyncSocket, LifecycleCtorCallback) {
  EventBase evb;
  // create socket and verify that w/o a ctor callback, nothing happens
  auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
  EXPECT_EQ(socket1->getLifecycleObservers().size(), 0);

  // Then register a ctor callback that registers a mock lifecycle observer
  // NB: use nicemock instead of strict b/c the actual lifecycle testing
  // is done below and this simplifies the test
  auto lifecycleCB =
      std::make_shared<NiceMock<MockAsyncSocketLifecycleObserver>>();
  auto lifecycleRawPtr = lifecycleCB.get();
  // verify the first part of the lifecycle was processed
  ConstructorCallbackList<AsyncSocket>::addCallback(
      [lifecycleRawPtr](AsyncSocket* s) {
        s->addLifecycleObserver(lifecycleRawPtr);
      });
  auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
  EXPECT_EQ(socket2->getLifecycleObservers().size(), 1);
  EXPECT_THAT(
      socket2->getLifecycleObservers(),
      UnorderedElementsAre(lifecycleCB.get()));
  Mock::VerifyAndClearExpectations(lifecycleCB.get());
}

TEST(AsyncSocket, LifecycleObserverDetachAndAttachEvb) {
  auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  EventBase evb;
  EventBase evb2;
  auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
  EXPECT_CALL(*cb, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb.get());
  EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
  Mock::VerifyAndClearExpectations(cb.get());

  // Detach the evb and attach a new evb2
  EXPECT_CALL(*cb, evbDetachMock(socket.get(), &evb));
  socket->detachEventBase();
  EXPECT_EQ(nullptr, socket->getEventBase());
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_CALL(*cb, evbAttachMock(socket.get(), &evb2));
  socket->attachEventBase(&evb2);
  EXPECT_EQ(&evb2, socket->getEventBase());
  Mock::VerifyAndClearExpectations(cb.get());

  // detach the new evb2 and re-attach the old evb.
  EXPECT_CALL(*cb, evbDetachMock(socket.get(), &evb2));
  socket->detachEventBase();
  EXPECT_EQ(nullptr, socket->getEventBase());
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_CALL(*cb, evbAttachMock(socket.get(), &evb));
  socket->attachEventBase(&evb);
  EXPECT_EQ(&evb, socket->getEventBase());
  Mock::VerifyAndClearExpectations(cb.get());

  InSequence s;
  EXPECT_CALL(*cb, destroyMock(socket.get()));
  socket = nullptr;
  Mock::VerifyAndClearExpectations(cb.get());
}

TEST(AsyncSocket, LifecycleObserverAttachThenDestroySocket) {
  auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  TestServer server;

  EventBase evb;
  auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
  EXPECT_CALL(*cb, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb.get());
  EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_CALL(*cb, connectAttemptMock(socket.get()));
  EXPECT_CALL(*cb, fdAttachMock(socket.get()));
  EXPECT_CALL(*cb, connectSuccessMock(socket.get()));
  socket->connect(nullptr, server.getAddress(), 30);
  evb.loop();
  Mock::VerifyAndClearExpectations(cb.get());

  InSequence s;
  EXPECT_CALL(*cb, closeMock(socket.get()));
  EXPECT_CALL(*cb, destroyMock(socket.get()));
  socket = nullptr;
  Mock::VerifyAndClearExpectations(cb.get());
}

TEST(AsyncSocket, LifecycleObserverAttachThenConnectError) {
  auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  // port =1 is unreachble on localhost
  folly::SocketAddress unreachable{"::1", 1};

  EventBase evb;
  auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
  EXPECT_CALL(*cb, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb.get());
  EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
  Mock::VerifyAndClearExpectations(cb.get());

  // the current state machine calls AsyncSocket::invokeConnectionError() twice
  // for this use-case...
  EXPECT_CALL(*cb, connectAttemptMock(socket.get()));
  EXPECT_CALL(*cb, fdAttachMock(socket.get()));
  EXPECT_CALL(*cb, connectErrorMock(socket.get(), _)).Times(2);
  EXPECT_CALL(*cb, closeMock(socket.get()));
  socket->connect(nullptr, unreachable, 1);
  evb.loop();
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_CALL(*cb, destroyMock(socket.get()));
  socket = nullptr;
  Mock::VerifyAndClearExpectations(cb.get());
}

TEST(AsyncSocket, LifecycleObserverMultipleAttachThenDestroySocket) {
  auto cb1 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  auto cb2 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  TestServer server;

  EventBase evb;
  auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
  EXPECT_CALL(*cb1, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb1.get());
  EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));
  Mock::VerifyAndClearExpectations(cb1.get());
  Mock::VerifyAndClearExpectations(cb2.get());

  EXPECT_CALL(*cb2, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb2.get());
  EXPECT_THAT(
      socket->getLifecycleObservers(),
      UnorderedElementsAre(cb1.get(), cb2.get()));
  Mock::VerifyAndClearExpectations(cb1.get());
  Mock::VerifyAndClearExpectations(cb2.get());

  InSequence s;
  EXPECT_CALL(*cb1, connectAttemptMock(socket.get()));
  EXPECT_CALL(*cb2, connectAttemptMock(socket.get()));
  EXPECT_CALL(*cb1, fdAttachMock(socket.get()));
  EXPECT_CALL(*cb2, fdAttachMock(socket.get()));
  EXPECT_CALL(*cb1, connectSuccessMock(socket.get()));
  EXPECT_CALL(*cb2, connectSuccessMock(socket.get()));
  socket->connect(nullptr, server.getAddress(), 30);
  evb.loop();
  Mock::VerifyAndClearExpectations(cb1.get());
  Mock::VerifyAndClearExpectations(cb2.get());

  EXPECT_CALL(*cb1, closeMock(socket.get()));
  EXPECT_CALL(*cb2, closeMock(socket.get()));
  EXPECT_CALL(*cb1, destroyMock(socket.get()));
  EXPECT_CALL(*cb2, destroyMock(socket.get()));
  socket = nullptr;
  Mock::VerifyAndClearExpectations(cb1.get());
  Mock::VerifyAndClearExpectations(cb2.get());
}

TEST(AsyncSocket, LifecycleObserverAttachRemove) {
  EventBase evb;
  auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));

  auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  EXPECT_CALL(*cb, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb.get());
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
  EXPECT_CALL(*cb, observerDetachMock(socket.get()));
  EXPECT_TRUE(socket->removeLifecycleObserver(cb.get()));
  EXPECT_THAT(socket->getLifecycleObservers(), IsEmpty());
  Mock::VerifyAndClearExpectations(cb.get());
}

TEST(AsyncSocket, LifecycleObserverAttachRemoveMultiple) {
  EventBase evb;
  auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));

  auto cb1 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  EXPECT_CALL(*cb1, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb1.get());
  Mock::VerifyAndClearExpectations(cb1.get());
  EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));

  auto cb2 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  EXPECT_CALL(*cb2, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb2.get());
  Mock::VerifyAndClearExpectations(cb2.get());
  EXPECT_THAT(
      socket->getLifecycleObservers(),
      UnorderedElementsAre(cb1.get(), cb2.get()));

  EXPECT_CALL(*cb1, observerDetachMock(socket.get()));
  EXPECT_TRUE(socket->removeLifecycleObserver(cb1.get()));
  Mock::VerifyAndClearExpectations(cb1.get());
  EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb2.get()));

  EXPECT_CALL(*cb2, observerDetachMock(socket.get()));
  EXPECT_TRUE(socket->removeLifecycleObserver(cb2.get()));
  Mock::VerifyAndClearExpectations(cb2.get());
  EXPECT_THAT(socket->getLifecycleObservers(), IsEmpty());
}

TEST(AsyncSocket, LifecycleObserverAttachRemoveMultipleReverse) {
  EventBase evb;
  auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));

  auto cb1 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  EXPECT_CALL(*cb1, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb1.get());
  Mock::VerifyAndClearExpectations(cb1.get());
  EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));

  auto cb2 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  EXPECT_CALL(*cb2, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb2.get());
  Mock::VerifyAndClearExpectations(cb2.get());
  EXPECT_THAT(
      socket->getLifecycleObservers(),
      UnorderedElementsAre(cb1.get(), cb2.get()));

  EXPECT_CALL(*cb2, observerDetachMock(socket.get()));
  EXPECT_TRUE(socket->removeLifecycleObserver(cb2.get()));
  Mock::VerifyAndClearExpectations(cb2.get());
  EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));

  EXPECT_CALL(*cb1, observerDetachMock(socket.get()));
  EXPECT_TRUE(socket->removeLifecycleObserver(cb1.get()));
  Mock::VerifyAndClearExpectations(cb1.get());
  EXPECT_THAT(socket->getLifecycleObservers(), IsEmpty());
}

TEST(AsyncSocket, LifecycleObserverRemoveMissing) {
  auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  EventBase evb;
  auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
  EXPECT_FALSE(socket->removeLifecycleObserver(cb.get()));
}

TEST(AsyncSocket, LifecycleObserverMultipleAttachThenRemove) {
  auto cb1 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  auto cb2 = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  TestServer server;

  EventBase evb;
  auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
  EXPECT_CALL(*cb1, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb1.get());
  EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));
  Mock::VerifyAndClearExpectations(cb1.get());
  Mock::VerifyAndClearExpectations(cb2.get());

  EXPECT_CALL(*cb2, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb2.get());
  EXPECT_THAT(
      socket->getLifecycleObservers(),
      UnorderedElementsAre(cb1.get(), cb2.get()));
  Mock::VerifyAndClearExpectations(cb1.get());
  Mock::VerifyAndClearExpectations(cb2.get());

  EXPECT_CALL(*cb2, observerDetachMock(socket.get()));
  EXPECT_TRUE(socket->removeLifecycleObserver(cb2.get()));
  EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb1.get()));
  Mock::VerifyAndClearExpectations(cb1.get());
  Mock::VerifyAndClearExpectations(cb2.get());

  EXPECT_CALL(*cb1, observerDetachMock(socket.get()));
  socket->removeLifecycleObserver(cb1.get());
  EXPECT_THAT(socket->getLifecycleObservers(), IsEmpty());
  Mock::VerifyAndClearExpectations(cb1.get());
  Mock::VerifyAndClearExpectations(cb2.get());
}

TEST(AsyncSocket, LifecycleObserverDetach) {
  auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  TestServer server;

  EventBase evb;
  auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
  EXPECT_CALL(*cb, observerAttachMock(socket1.get()));
  socket1->addLifecycleObserver(cb.get());
  EXPECT_THAT(socket1->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_CALL(*cb, connectAttemptMock(socket1.get()));
  EXPECT_CALL(*cb, fdAttachMock(socket1.get()));
  EXPECT_CALL(*cb, connectSuccessMock(socket1.get()));
  socket1->connect(nullptr, server.getAddress(), 30);
  evb.loop();
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_CALL(*cb, fdDetachMock(socket1.get()));
  auto fd = socket1->detachNetworkSocket();
  Mock::VerifyAndClearExpectations(cb.get());

  // create socket2, then immediately destroy it, should get no callbacks
  auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(&evb, fd));
  socket2 = nullptr;

  // finally, destroy socket1
  EXPECT_CALL(*cb, destroyMock(socket1.get()));
}

TEST(AsyncSocket, LifecycleObserverMoveResubscribe) {
  auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  TestServer server;

  EventBase evb;
  auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
  EXPECT_CALL(*cb, observerAttachMock(socket1.get()));
  socket1->addLifecycleObserver(cb.get());
  EXPECT_THAT(socket1->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_CALL(*cb, connectAttemptMock(socket1.get()));
  EXPECT_CALL(*cb, fdAttachMock(socket1.get()));
  EXPECT_CALL(*cb, connectSuccessMock(socket1.get()));
  socket1->connect(nullptr, server.getAddress(), 30);
  evb.loop();
  Mock::VerifyAndClearExpectations(cb.get());

  AsyncSocket* socket2PtrCapturedmoved = nullptr;
  {
    InSequence s;
    EXPECT_CALL(*cb, fdDetachMock(socket1.get()));
    EXPECT_CALL(*cb, moveMock(socket1.get(), Not(socket1.get())))
        .WillOnce(Invoke(
            [&socket2PtrCapturedmoved, &cb](auto oldSocket, auto newSocket) {
              socket2PtrCapturedmoved = newSocket;
              EXPECT_CALL(*cb, observerDetachMock(oldSocket));
              EXPECT_CALL(*cb, observerAttachMock(newSocket));
              EXPECT_TRUE(oldSocket->removeLifecycleObserver(cb.get()));
              EXPECT_THAT(oldSocket->getLifecycleObservers(), IsEmpty());
              newSocket->addLifecycleObserver(cb.get());
              EXPECT_THAT(
                  newSocket->getLifecycleObservers(),
                  UnorderedElementsAre(cb.get()));
            }));
  }
  auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(std::move(socket1)));
  Mock::VerifyAndClearExpectations(cb.get());
  EXPECT_EQ(socket2.get(), socket2PtrCapturedmoved);

  {
    InSequence s;
    EXPECT_CALL(*cb, closeMock(socket2.get()));
    EXPECT_CALL(*cb, destroyMock(socket2.get()));
  }
  socket2 = nullptr;
}

TEST(AsyncSocket, LifecycleObserverMoveDoNotResubscribe) {
  auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  TestServer server;

  EventBase evb;
  auto socket1 = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
  EXPECT_CALL(*cb, observerAttachMock(socket1.get()));
  socket1->addLifecycleObserver(cb.get());
  EXPECT_THAT(socket1->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_CALL(*cb, connectAttemptMock(socket1.get()));
  EXPECT_CALL(*cb, fdAttachMock(socket1.get()));
  EXPECT_CALL(*cb, connectSuccessMock(socket1.get()));
  socket1->connect(nullptr, server.getAddress(), 30);
  evb.loop();
  Mock::VerifyAndClearExpectations(cb.get());

  // close will not be called on socket1 because the fd is detached
  AsyncSocket* socket2PtrCapturedMoved = nullptr;
  InSequence s;
  EXPECT_CALL(*cb, fdDetachMock(socket1.get()));
  EXPECT_CALL(*cb, moveMock(socket1.get(), Not(socket1.get())))
      .WillOnce(Invoke(
          [&socket2PtrCapturedMoved](auto /* oldSocket */, auto newSocket) {
            socket2PtrCapturedMoved = newSocket;
          }));
  EXPECT_CALL(*cb, destroyMock(socket1.get()));
  auto socket2 = AsyncSocket::UniquePtr(new AsyncSocket(std::move(socket1)));
  Mock::VerifyAndClearExpectations(cb.get());
  EXPECT_EQ(socket2.get(), socket2PtrCapturedMoved);
}

TEST(AsyncSocket, LifecycleObserverDetachCallbackImmediately) {
  auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  TestServer server;

  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  EXPECT_CALL(*cb, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb.get());
  EXPECT_THAT(socket->getLifecycleObservers(), UnorderedElementsAre(cb.get()));
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_CALL(*cb, observerDetachMock(socket.get()));
  EXPECT_TRUE(socket->removeLifecycleObserver(cb.get()));
  EXPECT_THAT(socket->getLifecycleObservers(), IsEmpty());
  Mock::VerifyAndClearExpectations(cb.get());

  // keep going to ensure no further callbacks
  socket->connect(nullptr, server.getAddress(), 30);
  evb.loop();
}

TEST(AsyncSocket, LifecycleObserverDetachCallbackAfterConnect) {
  auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  TestServer server;

  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  EXPECT_CALL(*cb, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb.get());
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_CALL(*cb, connectAttemptMock(socket.get()));
  EXPECT_CALL(*cb, fdAttachMock(socket.get()));
  EXPECT_CALL(*cb, connectSuccessMock(socket.get()));
  socket->connect(nullptr, server.getAddress(), 30);
  evb.loop();
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_CALL(*cb, observerDetachMock(socket.get()));
  EXPECT_TRUE(socket->removeLifecycleObserver(cb.get()));
  Mock::VerifyAndClearExpectations(cb.get());
}

TEST(AsyncSocket, LifecycleObserverDetachCallbackAfterClose) {
  auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  TestServer server;

  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  EXPECT_CALL(*cb, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb.get());
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_CALL(*cb, connectAttemptMock(socket.get()));
  EXPECT_CALL(*cb, fdAttachMock(socket.get()));
  EXPECT_CALL(*cb, connectSuccessMock(socket.get()));
  socket->connect(nullptr, server.getAddress(), 30);
  evb.loop();
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_CALL(*cb, closeMock(socket.get()));
  socket->closeNow();
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_CALL(*cb, observerDetachMock(socket.get()));
  EXPECT_TRUE(socket->removeLifecycleObserver(cb.get()));
  Mock::VerifyAndClearExpectations(cb.get());
}

TEST(AsyncSocket, LifecycleObserverDetachCallbackcloseDuringDestroy) {
  auto cb = std::make_unique<StrictMock<MockAsyncSocketLifecycleObserver>>();
  TestServer server;

  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  EXPECT_CALL(*cb, observerAttachMock(socket.get()));
  socket->addLifecycleObserver(cb.get());
  Mock::VerifyAndClearExpectations(cb.get());

  EXPECT_CALL(*cb, connectAttemptMock(socket.get()));
  EXPECT_CALL(*cb, fdAttachMock(socket.get()));
  EXPECT_CALL(*cb, connectSuccessMock(socket.get()));
  socket->connect(nullptr, server.getAddress(), 30);
  evb.loop();
  Mock::VerifyAndClearExpectations(cb.get());

  InSequence s;
  EXPECT_CALL(*cb, closeMock(socket.get()))
      .WillOnce(Invoke([&cb](auto callbackSocket) {
        EXPECT_TRUE(callbackSocket->removeLifecycleObserver(cb.get()));
      }));
  EXPECT_CALL(*cb, observerDetachMock(socket.get()));
  socket = nullptr;
  Mock::VerifyAndClearExpectations(cb.get());
}

TEST(AsyncSocket, PreReceivedData) {
  TestServer server;

  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  socket->connect(nullptr, server.getAddress(), 30);
  evb.loop();

  socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));

  auto acceptedSocket = server.acceptAsync(&evb);

  ReadCallback peekCallback(2);
  ReadCallback readCallback;
  peekCallback.dataAvailableCallback = [&]() {
    peekCallback.verifyData("he", 2);
    acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("h"));
    acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("e"));
    acceptedSocket->setReadCB(nullptr);
    acceptedSocket->setReadCB(&readCallback);
  };
  readCallback.dataAvailableCallback = [&]() {
    if (readCallback.dataRead() == 5) {
      readCallback.verifyData("hello", 5);
      acceptedSocket->setReadCB(nullptr);
    }
  };

  acceptedSocket->setReadCB(&peekCallback);

  evb.loop();
}

TEST(AsyncSocket, PreReceivedDataOnly) {
  TestServer server;

  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  socket->connect(nullptr, server.getAddress(), 30);
  evb.loop();

  socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));

  auto acceptedSocket = server.acceptAsync(&evb);

  ReadCallback peekCallback;
  ReadCallback readCallback;
  peekCallback.dataAvailableCallback = [&]() {
    peekCallback.verifyData("hello", 5);
    acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
    EXPECT_TRUE(acceptedSocket->readable());
    acceptedSocket->setReadCB(&readCallback);
  };
  readCallback.dataAvailableCallback = [&]() {
    readCallback.verifyData("hello", 5);
    acceptedSocket->setReadCB(nullptr);
  };

  acceptedSocket->setReadCB(&peekCallback);

  evb.loop();
}

TEST(AsyncSocket, PreReceivedDataPartial) {
  TestServer server;

  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  socket->connect(nullptr, server.getAddress(), 30);
  evb.loop();

  socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));

  auto acceptedSocket = server.acceptAsync(&evb);

  ReadCallback peekCallback;
  ReadCallback smallReadCallback(3);
  ReadCallback normalReadCallback;
  peekCallback.dataAvailableCallback = [&]() {
    peekCallback.verifyData("hello", 5);
    acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
    acceptedSocket->setReadCB(&smallReadCallback);
  };
  smallReadCallback.dataAvailableCallback = [&]() {
    smallReadCallback.verifyData("hel", 3);
    acceptedSocket->setReadCB(&normalReadCallback);
  };
  normalReadCallback.dataAvailableCallback = [&]() {
    normalReadCallback.verifyData("lo", 2);
    acceptedSocket->setReadCB(nullptr);
  };

  acceptedSocket->setReadCB(&peekCallback);

  evb.loop();
}

TEST(AsyncSocket, PreReceivedDataTakeover) {
  TestServer server;

  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
  socket->connect(nullptr, server.getAddress(), 30);
  evb.loop();

  socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));

  auto fd = server.acceptFD();
  SocketAddress peerAddress;
  peerAddress.setFromPeerAddress(fd);
  auto acceptedSocket =
      AsyncSocket::UniquePtr(new AsyncSocket(&evb, fd, 0, &peerAddress));
  AsyncSocket::UniquePtr takeoverSocket;

  ReadCallback peekCallback(3);
  ReadCallback readCallback;
  peekCallback.dataAvailableCallback = [&]() {
    peekCallback.verifyData("hel", 3);
    acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
    acceptedSocket->setReadCB(nullptr);
    takeoverSocket =
        AsyncSocket::UniquePtr(new AsyncSocket(std::move(acceptedSocket)));
    takeoverSocket->setReadCB(&readCallback);
  };
  readCallback.dataAvailableCallback = [&]() {
    readCallback.verifyData("hello", 5);
    takeoverSocket->setReadCB(nullptr);
  };

  acceptedSocket->setReadCB(&peekCallback);

  evb.loop();
  // Verify we can still get the peer address after the peer socket is reset.
  socket->closeWithReset();
  evb.loopOnce();
  SocketAddress socketPeerAddress;
  takeoverSocket->getPeerAddress(&socketPeerAddress);
  EXPECT_EQ(socketPeerAddress, peerAddress);
}

#ifdef MSG_NOSIGNAL
TEST(AsyncSocketTest, SendMessageFlags) {
  TestServer server;
  TestSendMsgParamsCallback sendMsgCB(
      MSG_DONTWAIT | MSG_NOSIGNAL | MSG_MORE, 0, nullptr);

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);

  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 30);
  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();

  evb.loop();
  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);

  // Set SendMsgParamsCallback
  socket->setSendMsgParamCB(&sendMsgCB);
  ASSERT_EQ(socket->getSendMsgParamsCB(), &sendMsgCB);

  // Write the first portion of data. This data is expected to be
  // sent out immediately.
  std::vector<uint8_t> buf(128, 'a');
  WriteCallback wcb;
  sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL);
  socket->write(&wcb, buf.data(), buf.size());
  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
  ASSERT_TRUE(sendMsgCB.queriedFlags_);
  ASSERT_FALSE(sendMsgCB.queriedData_);

  // Using different flags for the second write operation.
  // MSG_MORE flag is expected to delay sending this
  // data to the wire.
  sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_MORE);
  socket->write(&wcb, buf.data(), buf.size());
  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
  ASSERT_TRUE(sendMsgCB.queriedFlags_);
  ASSERT_FALSE(sendMsgCB.queriedData_);

  // Make sure the accepted socket saw only the data from
  // the first write request.
  std::vector<uint8_t> readbuf(2 * buf.size());
  uint32_t bytesRead = acceptedSocket->read(readbuf.data(), readbuf.size());
  ASSERT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
  ASSERT_EQ(bytesRead, buf.size());

  // Make sure the server got a connection and received the data
  acceptedSocket->close();
  socket->close();

  ASSERT_TRUE(socket->isClosedBySelf());
  ASSERT_FALSE(socket->isClosedByPeer());
}

TEST(AsyncSocketTest, SendMessageAncillaryData) {
  NetworkSocket fds[2];
  EXPECT_EQ(netops::socketpair(AF_UNIX, SOCK_STREAM, 0, fds), 0);

  // "Client" socket
  auto cfd = fds[0];
  ASSERT_NE(cfd, NetworkSocket());

  // "Server" socket
  auto sfd = fds[1];
  ASSERT_NE(sfd, NetworkSocket());
  SCOPE_EXIT {
    netops::close(sfd);
  };

  // Instantiate AsyncSocket object for the connected socket
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb, cfd);

  // Open a temporary file and write a magic string to it
  // We'll transfer the file handle to test the message parameters
  // callback logic.
  TemporaryFile file(
      StringPiece(), fs::path(), TemporaryFile::Scope::UNLINK_IMMEDIATELY);
  int tmpfd = file.fd();
  ASSERT_NE(tmpfd, -1) << "Failed to open a temporary file";
  std::string magicString("Magic string");
  ASSERT_EQ(
      write(tmpfd, magicString.c_str(), magicString.length()),
      magicString.length());

  // Send message
  union {
    // Space large enough to hold an 'int'
    char control[CMSG_SPACE(sizeof(int))];
    struct cmsghdr cmh;
  } s_u;
  s_u.cmh.cmsg_len = CMSG_LEN(sizeof(int));
  s_u.cmh.cmsg_level = SOL_SOCKET;
  s_u.cmh.cmsg_type = SCM_RIGHTS;
  memcpy(CMSG_DATA(&s_u.cmh), &tmpfd, sizeof(int));

  // Set up the callback providing message parameters
  TestSendMsgParamsCallback sendMsgCB(
      MSG_DONTWAIT | MSG_NOSIGNAL, sizeof(s_u.control), s_u.control);
  socket->setSendMsgParamCB(&sendMsgCB);

  // We must transmit at least 1 byte of real data in order
  // to send ancillary data
  int s_data = 12345;
  WriteCallback wcb;
  auto ioBuf = folly::IOBuf::wrapBuffer(&s_data, sizeof(s_data));
  sendMsgCB.expectedTag_ = folly::AsyncSocket::WriteRequestTag{
      ioBuf.get()}; // Also test write tagging.
  ASSERT_FALSE(sendMsgCB.tagLastWritten_.has_value());
  socket->writeChain(&wcb, std::move(ioBuf));
  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
  ASSERT_TRUE(sendMsgCB.queriedData_); // Did the tag check run?
  ASSERT_EQ(sendMsgCB.expectedTag_, *sendMsgCB.tagLastWritten_);

  // Receive the message
  union {
    // Space large enough to hold an 'int'
    char control[CMSG_SPACE(sizeof(int))];
    struct cmsghdr cmh;
  } r_u;
  struct msghdr msgh;
  struct iovec iov;
  int r_data = 0;

  msgh.msg_control = r_u.control;
  msgh.msg_controllen = sizeof(r_u.control);
  msgh.msg_name = nullptr;
  msgh.msg_namelen = 0;
  msgh.msg_iov = &iov;
  msgh.msg_iovlen = 1;
  iov.iov_base = &r_data;
  iov.iov_len = sizeof(r_data);

  // Receive data
  ASSERT_NE(netops::recvmsg(sfd, &msgh, 0), -1) << "recvmsg failed: " << errno;

  // Validate the received message
  ASSERT_EQ(r_u.cmh.cmsg_len, CMSG_LEN(sizeof(int)));
  ASSERT_EQ(r_u.cmh.cmsg_level, SOL_SOCKET);
  ASSERT_EQ(r_u.cmh.cmsg_type, SCM_RIGHTS);
  ASSERT_EQ(r_data, s_data);
  int fd = 0;
  memcpy(&fd, CMSG_DATA(&r_u.cmh), sizeof(int));
  ASSERT_NE(fd, 0);
  SCOPE_EXIT {
    close(fd);
  };

  std::vector<uint8_t> transferredMagicString(magicString.length() + 1, 0);

  // Reposition to the beginning of the file
  ASSERT_EQ(0, lseek(fd, 0, SEEK_SET));

  // Read the magic string back, and compare it with the original
  ASSERT_EQ(
      magicString.length(),
      read(fd, transferredMagicString.data(), transferredMagicString.size()));
  ASSERT_TRUE(std::equal(
      magicString.begin(), magicString.end(), transferredMagicString.begin()));
}

namespace {

// Child classes of AsyncSocket (e.g. AsyncFdSocket) want to be able to
// fail reads from the read ancillary data or regular read callback. Test this.
struct FailableSocket : public AsyncSocket {
  FailableSocket(EventBase* evb, NetworkSocket fd) : AsyncSocket(evb, fd) {}
  void testFailRead() {
    AsyncSocketException ex(
        AsyncSocketException::INTERNAL_ERROR, "FailableSocket::testFailRead");
    AsyncSocket::failRead(__func__, ex);
  }
};

class TruncateAncillaryDataAndCallFn
    : public folly::AsyncSocket::ReadAncillaryDataCallback {
 public:
  explicit TruncateAncillaryDataAndCallFn(VoidCallback cob)
      : callback_(std::move(cob)) {}

  void ancillaryData(struct msghdr& msg) noexcept override {
    sawCtrunc_ = sawCtrunc_ || (msg.msg_flags & MSG_CTRUNC);
    callback_();
  }
  folly::MutableByteRange getAncillaryDataCtrlBuffer() override {
    return folly::MutableByteRange(ancillaryDataCtrlBuffer_);
  }

  bool sawCtrunc_{false};

 private:
  VoidCallback callback_;
  // Empty to trigger MSG_CTRUNC
  std::array<uint8_t, 0> ancillaryDataCtrlBuffer_;
};

// Returns the error string from the read callback (can be "none")
std::string testTruncateAncillaryDataAndCall(
    std::function<void(FailableSocket*)> fn,
    std::function<void(FailableSocket*)> postConditionCheck) {
  NetworkSocket fds[2];
  CHECK_EQ(netops::socketpair(AF_UNIX, SOCK_STREAM, 0, fds), 0);

  EventBase evb;
  std::shared_ptr<AsyncSocket> sendSock = AsyncSocket::newSocket(&evb, fds[0]);
  ReadCallback rcb; // outlives socket since ~AsyncSocket calls rcb.readEOF
  FailableSocket recvSock(&evb, fds[1]);

  TruncateAncillaryDataAndCallFn ancillaryCob{[&]() { fn(&recvSock); }};
  recvSock.setReadAncillaryDataCB(&ancillaryCob);

  // Send the stderr FD with ancillary data
  int tmpfd = 2;
  union { // `man cmsg` suggests this idiom for a "large enough" `cmsghdr`
    char buf[CMSG_SPACE(sizeof(tmpfd))];
    struct cmsghdr cmh;
  } u;
  u.cmh.cmsg_len = CMSG_LEN(sizeof(tmpfd));
  u.cmh.cmsg_level = SOL_SOCKET;
  u.cmh.cmsg_type = SCM_RIGHTS;
  memcpy(CMSG_DATA(&u.cmh), &tmpfd, sizeof(tmpfd));

  TestSendMsgParamsCallback sendMsgCB(
      MSG_DONTWAIT | MSG_NOSIGNAL, sizeof(u.buf), u.buf);
  sendSock->setSendMsgParamCB(&sendMsgCB);

  // Transmit at least 1 byte of real data to send ancillary data
  int s_data = 12345;
  WriteCallback wcb;
  sendSock->write(&wcb, &s_data, sizeof(s_data));
  CHECK_EQ(wcb.state, STATE_SUCCEEDED);

  // The FD will be discarded (MSG_CTRUNC) since our ancillary data callback
  // deliberately misconfigures the `recvmsg`.
  recvSock.setReadCB(&rcb);
  CHECK(!ancillaryCob.sawCtrunc_);
  evb.loopOnce();

  // Ensure that `ancillaryData()` actually ran, and saw the error condition.
  CHECK(ancillaryCob.sawCtrunc_);
  postConditionCheck(&recvSock);

  return rcb.exception.what();
}

} // namespace

// These tests do double-duty:
//  - show that `ReadAncillaryDataCallback` can safely close or fail a socket
//  - exercise getting & handling `MSG_CTRUNC`

TEST(AsyncSocketTest, ReceiveTruncatedAncillaryDataAndFail) {
  EXPECT_THAT(
      testTruncateAncillaryDataAndCall(
          [](FailableSocket* sock) { sock->testFailRead(); },
          [](FailableSocket* sock) { ASSERT_TRUE(sock->error()); }),
      testing::HasSubstr("FailableSocket::testFailRead"));
}

TEST(AsyncSocketTest, ReceiveTruncatedAncillaryDataAndClose) {
  EXPECT_THAT(
      testTruncateAncillaryDataAndCall(
          [](FailableSocket* sock) { sock->close(); },
          [](FailableSocket* sock) { ASSERT_TRUE(sock->isClosedBySelf()); }),
      testing::HasSubstr("AsyncSocketException: none, type =")); // no error
}

TEST(AsyncSocketTest, ReceiveTruncatedAncillaryDataUnhandled) {
  // Since this `ancillaryData` fails to check MSG_CTRUNG, the last-ditch
  // check in `AsyncSocket::processNormalRead` will fire.
  EXPECT_THAT(
      testTruncateAncillaryDataAndCall(
          [](FailableSocket*) {},
          [](FailableSocket* sock) { ASSERT_TRUE(sock->error()); }),
      testing::HasSubstr("recvmsg() got MSG_CTRUNC"));
}

TEST(AsyncSocketTest, UnixDomainSocketErrMessageCB) {
  // In the latest stable kernel 4.14.3 as of 2017-12-04, Unix Domain
  // Socket (UDS) does not support MSG_ERRQUEUE. So
  // recvmsg(MSG_ERRQUEUE) will read application data from UDS which
  // breaks application message flow.  To avoid this problem,
  // AsyncSocket currently disables setErrMessageCB for UDS.
  //
  // This tests two things for UDS
  // 1. setErrMessageCB fails
  // 2. recvmsg(MSG_ERRQUEUE) reads application data
  //
  // Feel free to remove this test if UDS supports MSG_ERRQUEUE in the future.

  NetworkSocket fd[2];
  EXPECT_EQ(netops::socketpair(AF_UNIX, SOCK_STREAM, 0, fd), 0);
  ASSERT_NE(fd[0], NetworkSocket());
  ASSERT_NE(fd[1], NetworkSocket());
  SCOPE_EXIT {
    netops::close(fd[1]);
  };

  EXPECT_EQ(netops::set_socket_non_blocking(fd[0]), 0);
  EXPECT_EQ(netops::set_socket_non_blocking(fd[1]), 0);

  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb, fd[0]);

  // setErrMessageCB should fail for unix domain socket
  TestErrMessageCallback errMsgCB;
  ASSERT_NE(&errMsgCB, nullptr);
  socket->setErrMessageCB(&errMsgCB);
  ASSERT_EQ(socket->getErrMessageCallback(), nullptr);

#ifdef FOLLY_HAVE_MSG_ERRQUEUE
  // The following verifies that MSG_ERRQUEUE does not work for UDS,
  // and recvmsg reads application data
  union {
    // Space large enough to hold an 'int'
    char control[CMSG_SPACE(sizeof(int))];
    struct cmsghdr cmh;
  } r_u;
  struct msghdr msgh;
  struct iovec iov;
  int recv_data = 0;

  msgh.msg_control = r_u.control;
  msgh.msg_controllen = sizeof(r_u.control);
  msgh.msg_name = nullptr;
  msgh.msg_namelen = 0;
  msgh.msg_iov = &iov;
  msgh.msg_iovlen = 1;
  iov.iov_base = &recv_data;
  iov.iov_len = sizeof(recv_data);

  // there is no data, recvmsg should fail
  EXPECT_EQ(netops::recvmsg(fd[1], &msgh, MSG_ERRQUEUE), -1);
  EXPECT_TRUE(errno == EAGAIN || errno == EWOULDBLOCK);

  // provide some application data, error queue should be empty if it exists
  // However, UDS reads application data as error message
  int test_data = 123456;
  WriteCallback wcb;
  socket->write(&wcb, &test_data, sizeof(test_data));
  recv_data = 0;
  ASSERT_NE(netops::recvmsg(fd[1], &msgh, MSG_ERRQUEUE), -1);
  ASSERT_EQ(recv_data, test_data);
#endif // FOLLY_HAVE_MSG_ERRQUEUE
}

TEST(AsyncSocketTest, V6TosReflectTest) {
  EventBase eventBase;

  // Create a server socket
  std::shared_ptr<AsyncServerSocket> serverSocket(
      AsyncServerSocket::newSocket(&eventBase));
  folly::IPAddress ip("::1");
  std::vector<folly::IPAddress> serverIp;
  serverIp.push_back(ip);
  serverSocket->bind(serverIp, 0);
  serverSocket->listen(16);
  folly::SocketAddress serverAddress;
  serverSocket->getAddress(&serverAddress);

  // Enable TOS reflect
  serverSocket->setTosReflect(true);

  // Add a callback to accept one connection then stop the loop
  TestAcceptCallback acceptCallback;
  acceptCallback.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
      });
  acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
    serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
  });
  serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
  serverSocket->startAccepting();

  // Create a client socket, setsockopt() the TOS before connecting
  auto clientThread = [](std::shared_ptr<AsyncSocket>& clientSock,
                         ConnCallback* ccb,
                         EventBase* evb,
                         folly::SocketAddress sAddr) {
    clientSock = AsyncSocket::newSocket(evb);
    SocketOptionKey v6Opts = {IPPROTO_IPV6, IPV6_TCLASS};
    SocketOptionMap optionMap;
    optionMap.insert({v6Opts, 0x2c});
    SocketAddress bindAddr("0.0.0.0", 0);
    clientSock->connect(ccb, sAddr, 30, optionMap, bindAddr);
  };

  std::shared_ptr<AsyncSocket> socket(nullptr);
  ConnCallback cb;
  clientThread(socket, &cb, &eventBase, serverAddress);

  eventBase.loop();

  // Verify if the connection is accepted and if the accepted socket has
  // setsockopt on the TOS for the same value that was on the client socket
  auto fd = acceptCallback.getEvents()->at(1).fd;
  ASSERT_NE(fd, NetworkSocket());
  int value;
  socklen_t valueLength = sizeof(value);
  int rc =
      netops::getsockopt(fd, IPPROTO_IPV6, IPV6_TCLASS, &value, &valueLength);
  ASSERT_EQ(rc, 0);
  ASSERT_EQ(value, 0x2c);

  // Additional Test for ConnectCallback without bindAddr
  serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
  serverSocket->startAccepting();

  auto newClientSock = AsyncSocket::newSocket(&eventBase);
  TestConnectCallback callback;
  // connect call will not set this SO_REUSEADDR if we do not
  // pass the bindAddress in its call; so we can safely verify this.
  newClientSock->connect(&callback, serverAddress, 30);

  // Collect events
  eventBase.loop();

  auto acceptedFd = acceptCallback.getEvents()->at(1).fd;
  ASSERT_NE(acceptedFd, NetworkSocket());
  int reuseAddrVal;
  socklen_t reuseAddrValLen = sizeof(reuseAddrVal);
  // Get the socket created underneath connect call of AsyncSocket
  auto usedSockFd = newClientSock->getNetworkSocket();
  int getOptRet = netops::getsockopt(
      usedSockFd, SOL_SOCKET, SO_REUSEADDR, &reuseAddrVal, &reuseAddrValLen);
  ASSERT_EQ(getOptRet, 0);
  ASSERT_EQ(reuseAddrVal, 1 /* configured through preConnect*/);
}

TEST(AsyncSocketTest, V4TosReflectTest) {
  EventBase eventBase;

  // Create a server socket
  std::shared_ptr<AsyncServerSocket> serverSocket(
      AsyncServerSocket::newSocket(&eventBase));
  folly::IPAddress ip("127.0.0.1");
  std::vector<folly::IPAddress> serverIp;
  serverIp.push_back(ip);
  serverSocket->bind(serverIp, 0);
  serverSocket->listen(16);
  folly::SocketAddress serverAddress;
  serverSocket->getAddress(&serverAddress);

  // Enable TOS reflect
  serverSocket->setTosReflect(true);

  // Add a callback to accept one connection then stop the loop
  TestAcceptCallback acceptCallback;
  acceptCallback.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
      });
  acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
    serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
  });
  serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
  serverSocket->startAccepting();

  // Create a client socket, setsockopt() the TOS before connecting
  auto clientThread = [](std::shared_ptr<AsyncSocket>& clientSock,
                         ConnCallback* ccb,
                         EventBase* evb,
                         folly::SocketAddress sAddr) {
    clientSock = AsyncSocket::newSocket(evb);
    SocketOptionKey v4Opts = {IPPROTO_IP, IP_TOS};
    SocketOptionMap optionMap;
    optionMap.insert({v4Opts, 0x2c});
    SocketAddress bindAddr("0.0.0.0", 0);
    clientSock->connect(ccb, sAddr, 30, optionMap, bindAddr);
  };

  std::shared_ptr<AsyncSocket> socket(nullptr);
  ConnCallback cb;
  clientThread(socket, &cb, &eventBase, serverAddress);

  eventBase.loop();

  // Verify if the connection is accepted and if the accepted socket has
  // setsockopt on the TOS for the same value that was on the client socket
  auto fd = acceptCallback.getEvents()->at(1).fd;
  ASSERT_NE(fd, NetworkSocket());
  int value;
  socklen_t valueLength = sizeof(value);
  int rc = netops::getsockopt(fd, IPPROTO_IP, IP_TOS, &value, &valueLength);
  ASSERT_EQ(rc, 0);
  ASSERT_EQ(value, 0x2c);
}

TEST(AsyncSocketTest, V6AcceptedTosTest) {
  EventBase eventBase;

  // This test verifies if the ListenerTos set on a socket is
  // propagated properly to accepted socket connections

  // Create a server socket
  std::shared_ptr<AsyncServerSocket> serverSocket(
      AsyncServerSocket::newSocket(&eventBase));
  folly::IPAddress ip("::1");
  std::vector<folly::IPAddress> serverIp;
  serverIp.push_back(ip);
  serverSocket->bind(serverIp, 0);
  serverSocket->listen(16);
  folly::SocketAddress serverAddress;
  serverSocket->getAddress(&serverAddress);

  // Set listener TOS to 0x74 i.e. dscp 29
  serverSocket->setListenerTos(0x74);

  // Add a callback to accept one connection then stop the loop
  TestAcceptCallback acceptCallback;
  acceptCallback.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
      });
  acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
    serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
  });
  serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
  serverSocket->startAccepting();

  // Create a client socket, setsockopt() the TOS before connecting
  auto clientThread = [](std::shared_ptr<AsyncSocket>& clientSock,
                         ConnCallback* ccb,
                         EventBase* evb,
                         folly::SocketAddress sAddr) {
    clientSock = AsyncSocket::newSocket(evb);
    SocketOptionKey v6Opts = {IPPROTO_IPV6, IPV6_TCLASS};
    SocketOptionMap optionMap;
    optionMap.insert({v6Opts, 0x2c});
    SocketAddress bindAddr("0.0.0.0", 0);
    clientSock->connect(ccb, sAddr, 30, optionMap, bindAddr);
  };

  std::shared_ptr<AsyncSocket> socket(nullptr);
  ConnCallback cb;
  clientThread(socket, &cb, &eventBase, serverAddress);

  eventBase.loop();

  // Verify if the connection is accepted and if the accepted socket has
  // setsockopt on the TOS for the same value that the listener was set to
  auto fd = acceptCallback.getEvents()->at(1).fd;
  ASSERT_NE(fd, NetworkSocket());
  int value;
  socklen_t valueLength = sizeof(value);
  int rc =
      netops::getsockopt(fd, IPPROTO_IPV6, IPV6_TCLASS, &value, &valueLength);
  ASSERT_EQ(rc, 0);
  ASSERT_EQ(value, 0x74);
}

TEST(AsyncSocketTest, V4AcceptedTosTest) {
  EventBase eventBase;

  // This test verifies if the ListenerTos set on a socket is
  // propagated properly to accepted socket connections

  // Create a server socket
  std::shared_ptr<AsyncServerSocket> serverSocket(
      AsyncServerSocket::newSocket(&eventBase));
  folly::IPAddress ip("127.0.0.1");
  std::vector<folly::IPAddress> serverIp;
  serverIp.push_back(ip);
  serverSocket->bind(serverIp, 0);
  serverSocket->listen(16);
  folly::SocketAddress serverAddress;
  serverSocket->getAddress(&serverAddress);

  // Set listener TOS to 0x74 i.e. dscp 29
  serverSocket->setListenerTos(0x74);

  // Add a callback to accept one connection then stop the loop
  TestAcceptCallback acceptCallback;
  acceptCallback.setConnectionAcceptedFn(
      [&](NetworkSocket /* fd */, const folly::SocketAddress& /* addr */) {
        serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
      });
  acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
    serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
  });
  serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
  serverSocket->startAccepting();

  // Create a client socket, setsockopt() the TOS before connecting
  auto clientThread = [](std::shared_ptr<AsyncSocket>& clientSock,
                         ConnCallback* ccb,
                         EventBase* evb,
                         folly::SocketAddress sAddr) {
    clientSock = AsyncSocket::newSocket(evb);
    SocketOptionKey v4Opts = {IPPROTO_IP, IP_TOS};
    SocketOptionMap optionMap;
    optionMap.insert({v4Opts, 0x2c});
    SocketAddress bindAddr("0.0.0.0", 0);
    clientSock->connect(ccb, sAddr, 30, optionMap, bindAddr);
  };

  std::shared_ptr<AsyncSocket> socket(nullptr);
  ConnCallback cb;
  clientThread(socket, &cb, &eventBase, serverAddress);

  eventBase.loop();

  // Verify if the connection is accepted and if the accepted socket has
  // setsockopt on the TOS for the same value that the listener was set to
  auto fd = acceptCallback.getEvents()->at(1).fd;
  ASSERT_NE(fd, NetworkSocket());
  int value;
  socklen_t valueLength = sizeof(value);
  int rc = netops::getsockopt(fd, IPPROTO_IP, IP_TOS, &value, &valueLength);
  ASSERT_EQ(rc, 0);
  ASSERT_EQ(value, 0x74);
}
#endif

#if defined(__linux__)
TEST(AsyncSocketTest, getBufInUse) {
  EventBase eventBase;
  std::shared_ptr<AsyncServerSocket> server(
      AsyncServerSocket::newSocket(&eventBase));
  server->bind(0);
  server->listen(5);

  std::shared_ptr<AsyncSocket> client = AsyncSocket::newSocket(&eventBase);
  client->connect(nullptr, server->getAddress());

  NetworkSocket servfd = server->getNetworkSocket();
  NetworkSocket accepted;
  uint64_t maxTries = 5;

  do {
    std::this_thread::yield();
    eventBase.loop();
    accepted = netops::accept(servfd, nullptr, nullptr);
  } while (accepted == NetworkSocket() && --maxTries);

  // Exhaustion number of tries to accept client connection, good bye
  ASSERT_TRUE(accepted != NetworkSocket());

  auto clientAccepted = AsyncSocket::newSocket(nullptr, accepted);

  // Use minimum receive buffer size
  clientAccepted->setRecvBufSize(0);

  // Use maximum send buffer size
  client->setSendBufSize((unsigned)-1);

  std::string testData;
  for (int i = 0; i < 10000; ++i) {
    testData += "0123456789";
  }

  client->write(nullptr, (const void*)testData.c_str(), testData.size());

  std::this_thread::yield();
  eventBase.loop();

  size_t recvBufSize = clientAccepted->getRecvBufInUse();
  size_t sendBufSize = client->getSendBufInUse();

  EXPECT_EQ((recvBufSize + sendBufSize), testData.size());
  EXPECT_GT(recvBufSize, 0);
  EXPECT_GT(sendBufSize, 0);
}
#endif

TEST(AsyncSocketTest, QueueTimeout) {
  // Create a new AsyncServerSocket
  EventBase eventBase;
  std::shared_ptr<AsyncServerSocket> serverSocket(
      AsyncServerSocket::newSocket(&eventBase));
  serverSocket->bind(0);
  serverSocket->listen(16);
  folly::SocketAddress serverAddress;
  serverSocket->getAddress(&serverAddress);

  constexpr auto kConnectionTimeout = milliseconds(10);
  serverSocket->setQueueTimeout(kConnectionTimeout);

  TestAcceptCallback acceptCb;
  acceptCb.setConnectionAcceptedFn(
      [&, called = false](auto&&...) mutable {
        ASSERT_FALSE(called)
            << "Only the first connection should have been dequeued";
        called = true;
        // Allow plenty of time for the AsyncSocketServer's event loop to run.
        // This should leave no doubt that the acceptor thread has enough time
        // to dequeue. If the dequeue succeeds, then our expiry code is broken.
        static constexpr auto kEventLoopTime = kConnectionTimeout * 5;
        eventBase.runInEventBaseThread([&]() {
          eventBase.tryRunAfterDelay(
              [&]() { serverSocket->removeAcceptCallback(&acceptCb, nullptr); },
              milliseconds(kEventLoopTime).count());
        });
        // After the first message is enqueued, sleep long enough so that the
        // second message expires before it has a chance to dequeue.
        std::this_thread::sleep_for(kConnectionTimeout);
      });
  ScopedEventBaseThread acceptThread("ioworker_test");

  TestConnectionEventCallback connectionEventCb;
  serverSocket->setConnectionEventCallback(&connectionEventCb);
  serverSocket->addAcceptCallback(&acceptCb, acceptThread.getEventBase());
  serverSocket->startAccepting();

  std::shared_ptr<AsyncSocket> clientSocket1(
      AsyncSocket::newSocket(&eventBase, serverAddress));
  std::shared_ptr<AsyncSocket> clientSocket2(
      AsyncSocket::newSocket(&eventBase, serverAddress));

  // Loop until we are stopped
  eventBase.loop();

  EXPECT_EQ(connectionEventCb.getConnectionEnqueuedForAcceptCallback(), 2);
  // Since the second message is expired, it should NOT be dequeued
  EXPECT_EQ(connectionEventCb.getConnectionDequeuedByAcceptCallback(), 1);
}

class TestRXTimestampsCallback
    : public folly::AsyncSocket::ReadAncillaryDataCallback {
 public:
  explicit TestRXTimestampsCallback(AsyncSocket* sock) : socket_(sock) {}

  void ancillaryData(struct msghdr& msgh) noexcept override {
    if (closeSocket_) {
      socket_->close();
      return;
    }

    struct cmsghdr* cmsg;
    for (cmsg = CMSG_FIRSTHDR(&msgh); cmsg != nullptr;
         cmsg = CMSG_NXTHDR(&msgh, cmsg)) {
      if (cmsg->cmsg_level != SOL_SOCKET ||
          cmsg->cmsg_type != SO_TIMESTAMPING) {
        continue;
      }
      callCount_++;
      timespec* ts = (struct timespec*)CMSG_DATA(cmsg);
      actualRxTimestampSec_ = ts[0].tv_sec;
    }
  }
  folly::MutableByteRange getAncillaryDataCtrlBuffer() override {
    return folly::MutableByteRange(ancillaryDataCtrlBuffer_);
  }

  uint32_t callCount_{0};
  long actualRxTimestampSec_{0};
  bool closeSocket_{false};

 private:
  AsyncSocket* socket_;
  std::array<uint8_t, 1024> ancillaryDataCtrlBuffer_;
};

/**
 * Test read ancillary data callback
 */
TEST(AsyncSocketTest, readAncillaryData) {
  TestServer server;

  // connect()
  EventBase evb;
  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);

  ConnCallback ccb;
  socket->connect(&ccb, server.getAddress(), 1);
  LOG(INFO) << "Client socket fd=" << socket->getNetworkSocket();

  // Enable rx timestamp notifications
  ASSERT_NE(socket->getNetworkSocket(), NetworkSocket());
  int flags = folly::netops::SOF_TIMESTAMPING_SOFTWARE |
      folly::netops::SOF_TIMESTAMPING_RX_SOFTWARE |
      folly::netops::SOF_TIMESTAMPING_RX_HARDWARE;
  SocketOptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
  EXPECT_EQ(tstampingOpt.apply(socket->getNetworkSocket(), flags), 0);

  // Accept the connection.
  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
  LOG(INFO) << "Server socket fd=" << acceptedSocket->getNetworkSocket();

  // Wait for connection
  evb.loop();
  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);

  TestRXTimestampsCallback rxcb{socket.get()};

  // Set read callback
  ReadCallback rcb(100);
  socket->setReadCB(&rcb);

  // Get the timestamp when the message was write
  struct timespec currentTime;
  clock_gettime(CLOCK_REALTIME, &currentTime);
  long writeTimestampSec = currentTime.tv_sec;

  // write bytes from server (acceptedSocket) to client (socket).
  std::vector<uint8_t> wbuf(128, 'a');
  acceptedSocket->write(wbuf.data(), wbuf.size());

  // Wait for reading to complete.
  evb.loopOnce();
  ASSERT_NE(rcb.buffers.size(), 0);

  // Verify that if the callback is not set, it will not be called
  ASSERT_EQ(rxcb.callCount_, 0);

  // Set up rx timestamp callbacks
  socket->setReadAncillaryDataCB(&rxcb);
  acceptedSocket->write(wbuf.data(), wbuf.size());

  // Wait for reading to complete.
  evb.loopOnce();
  ASSERT_NE(rcb.buffers.size(), 0);

  // Verify that after setting callback, the callback was called
  ASSERT_GT(rxcb.callCount_, 0);
  // Compare the received timestamp is within an expected range
  clock_gettime(CLOCK_REALTIME, &currentTime);
  ASSERT_TRUE(rxcb.actualRxTimestampSec_ <= currentTime.tv_sec);
  ASSERT_TRUE(rxcb.actualRxTimestampSec_ >= writeTimestampSec);

  // Check that the callback can close the socket.
  rxcb.closeSocket_ = true;
  ASSERT_FALSE(socket->isClosedBySelf());
  acceptedSocket->write(wbuf.data(), wbuf.size());
  evb.loopOnce();
  ASSERT_TRUE(socket->isClosedBySelf());
}

class AsyncSocketWriteCallbackTest : public ::testing::Test {
 protected:
  using MockDispatcher = ::testing::NiceMock<netops::test::MockDispatcher>;
  void SetUp() override {
    socket_ = AsyncSocket::newSocket(&evb_);
    socket_->setOverrideNetOpsDispatcher(netOpsDispatcher_);
    netOpsDispatcher_->forwardToDefaultImpl();

    socket_->connect(nullptr, server_.getAddress());
  }

  void netOpsOnSendmsg() {
    ON_CALL(*netOpsDispatcher_, sendmsg(_, _, _))
        .WillByDefault(::testing::Invoke(
            [this](NetworkSocket s, const msghdr* message, int flags) {
              sendMsgInvocations_++;
              return netops::Dispatcher::getDefaultInstance()->sendmsg(
                  s, message, flags);
            }));
  }

  // simulate spliting a write into two parts by returning less than the amount
  // of bytes that was written if this is the first invocation of sendMsg
  void netOpsOnSendmsgPartial() {
    ON_CALL(*netOpsDispatcher_, sendmsg(_, _, _))
        .WillByDefault(::testing::Invoke(
            [this](NetworkSocket s, const msghdr* message, int flags) {
              sendMsgInvocations_++;
              auto totalWritten =
                  netops::Dispatcher::getDefaultInstance()->sendmsg(
                      s, message, flags);
              if (splitNextWrite_) {
                splitNextWrite_ = false;
                return totalWritten - 1;
              } else {
                splitNextWrite_ = true;
                return totalWritten;
              }
            }));
  }

  // simulate a failed write by returning -1 on sendMsg
  void netOpsOnSendmsgFail() {
    ON_CALL(*netOpsDispatcher_, sendmsg(_, _, _))
        .WillByDefault(::testing::Invoke(
            [this](NetworkSocket s, const msghdr* message, int flags) {
              sendMsgInvocations_++;
              netops::Dispatcher::getDefaultInstance()->sendmsg(
                  s, message, flags);
              return -1;
            }));
  }

  WriteCallback writeCallback1_;
  WriteCallback writeCallback2_;
  TestServer server_;
  std::shared_ptr<AsyncSocket> socket_;
  folly::EventBase evb_;
  std::shared_ptr<MockDispatcher> netOpsDispatcher_{
      std::make_shared<MockDispatcher>()};
  size_t sendMsgInvocations_{0};
  bool splitNextWrite_{false};
};

/**
 * Call write once successfully and expect `writeStarting` to be called once.
 */
TEST_F(AsyncSocketWriteCallbackTest, WriteStartingTests_WriteOnceSuccess) {
  const std::vector<uint8_t> wbuf(20, 'a');
  iovec op = {};
  op.iov_base = const_cast<void*>(static_cast<const void*>(wbuf.data()));
  op.iov_len = wbuf.size();
  WriteFlags flags = WriteFlags::NONE;

  netOpsOnSendmsg();

  ASSERT_THAT(writeCallback1_.writeStartingInvocations, Eq(0));
  ASSERT_THAT(writeCallback2_.writeStartingInvocations, Eq(0));
  socket_->writev(&writeCallback1_, &op, 1, flags);
  while (writeCallback1_.state == STATE_WAITING) {
    socket_->getEventBase()->loopOnce();
  }
  ASSERT_EQ(writeCallback1_.state, STATE_SUCCEEDED);
  ASSERT_EQ(writeCallback2_.state, STATE_WAITING);
  EXPECT_EQ(writeCallback1_.writeStartingInvocations, 1);
  EXPECT_EQ(writeCallback2_.writeStartingInvocations, 0);
  EXPECT_EQ(sendMsgInvocations_, 1);
}

/**
 * Call write once but do not write all bytes the first time; expect
 * `writeStarting` to be called once.
 */
TEST_F(AsyncSocketWriteCallbackTest, WriteStartingTests_WriteOnceIncomplete) {
  const std::vector<uint8_t> wbuf(20, 'a');
  iovec op = {};
  op.iov_base = const_cast<void*>(static_cast<const void*>(wbuf.data()));
  op.iov_len = wbuf.size();
  WriteFlags flags = WriteFlags::NONE;

  // make sure there are no pending WriteRequests
  socket_->getEventBase()->loopOnce();

  splitNextWrite_ = true;
  netOpsOnSendmsgPartial();

  ASSERT_THAT(writeCallback1_.writeStartingInvocations, Eq(0));
  socket_->writev(&writeCallback1_, &op, 1, flags);
  while (writeCallback1_.state == STATE_WAITING) {
    socket_->getEventBase()->loopOnce();
  }

  ASSERT_EQ(writeCallback1_.state, STATE_SUCCEEDED);
  EXPECT_EQ(writeCallback1_.writeStartingInvocations, 1);
  EXPECT_EQ(sendMsgInvocations_, 2);
}

/**
 * Call write twice successfully and expect `writeStarting` to be called twice.
 */
TEST_F(AsyncSocketWriteCallbackTest, WriteStartingTests_WriteTwiceSuccess) {
  const std::vector<uint8_t> wbuf(20, 'a');
  iovec op = {};
  op.iov_base = const_cast<void*>(static_cast<const void*>(wbuf.data()));
  op.iov_len = wbuf.size();
  WriteFlags flags = WriteFlags::NONE;

  netOpsOnSendmsg();

  ASSERT_THAT(writeCallback1_.writeStartingInvocations, Eq(0));
  socket_->writev(&writeCallback1_, &op, 1, flags);
  socket_->writev(&writeCallback1_, &op, 1, flags);
  while (writeCallback1_.state == STATE_WAITING) {
    socket_->getEventBase()->loopOnce();
  }
  ASSERT_EQ(writeCallback1_.state, STATE_SUCCEEDED);
  EXPECT_EQ(writeCallback1_.writeStartingInvocations, 2);
  EXPECT_EQ(sendMsgInvocations_, 2);
}

/**
 * Call write twice, with the first write incomplete; expect `writeStarting` to
 * be called twice
 */
TEST_F(
    AsyncSocketWriteCallbackTest,
    WriteStartingTests_WriteTwiceIncompleteThenSuccess) {
  const std::vector<uint8_t> wbuf(20, 'a');
  iovec op = {};
  op.iov_base = const_cast<void*>(static_cast<const void*>(wbuf.data()));
  op.iov_len = wbuf.size();
  WriteFlags flags = WriteFlags::NONE;

  ASSERT_THAT(writeCallback1_.writeStartingInvocations, Eq(0));

  // We split the first write in two parts. The first part is written
  // immediately and a WriteRequest is created to write the bytes from the
  // second part
  splitNextWrite_ = true;
  netOpsOnSendmsgPartial();
  socket_->writev(&writeCallback1_, &op, 1, flags);
  while (writeCallback1_.state == STATE_WAITING) {
    socket_->getEventBase()->loopOnce();
  }
  ASSERT_EQ(writeCallback1_.state, STATE_SUCCEEDED);
  EXPECT_EQ(writeCallback1_.writeStartingInvocations, 1);

  // We do not split the second write
  netOpsOnSendmsg();
  socket_->writev(&writeCallback2_, &op, 1, flags);
  while (writeCallback2_.state == STATE_WAITING) {
    socket_->getEventBase()->loopOnce();
  }
  ASSERT_EQ(writeCallback1_.state, STATE_SUCCEEDED);
  EXPECT_EQ(writeCallback1_.writeStartingInvocations, 1);
  ASSERT_EQ(writeCallback2_.state, STATE_SUCCEEDED);
  EXPECT_EQ(writeCallback2_.writeStartingInvocations, 1);
  EXPECT_EQ(sendMsgInvocations_, 3);
}

/**
 * Call write twice, both times incomplete; expect `writeStarting` to be called
 * twice.
 */
TEST_F(
    AsyncSocketWriteCallbackTest,
    WriteStartingTests_WriteTwiceIncompleteThenIncomplete) {
  const std::vector<uint8_t> wbuf(20, 'a');
  iovec op = {};
  op.iov_base = const_cast<void*>(static_cast<const void*>(wbuf.data()));
  op.iov_len = wbuf.size();
  WriteFlags flags = WriteFlags::NONE;

  ASSERT_THAT(writeCallback1_.writeStartingInvocations, Eq(0));

  // We split the first write in two parts. The first part is written
  // immediately and a WriteRequest is created to write the bytes from the
  // second part
  splitNextWrite_ = true;
  netOpsOnSendmsgPartial();
  socket_->writev(&writeCallback1_, &op, 1, flags);
  socket_->getEventBase()->loopOnce();

  EXPECT_EQ(sendMsgInvocations_, 1);
  ASSERT_EQ(writeCallback1_.state, STATE_WAITING);
  EXPECT_EQ(writeCallback1_.writeStartingInvocations, 1);

  // We also split the second write. Since the WriteRequest queue is not empty,
  // a new WriteRequest for all bytes in the second write is created when writev
  // is called. The write will be split into two parts when we process this
  // request. The first part is written when the request is processed and
  // another WriteRequest will be created for the second part. This new
  // WriteRequest will be processed on the next event loop iteration.
  socket_->writev(&writeCallback2_, &op, 1, flags);
  socket_->getEventBase()->loopOnce();

  EXPECT_EQ(sendMsgInvocations_, 3);
  ASSERT_EQ(writeCallback1_.state, STATE_SUCCEEDED);
  EXPECT_EQ(writeCallback1_.writeStartingInvocations, 1);
  ASSERT_EQ(writeCallback2_.state, STATE_WAITING);
  EXPECT_EQ(writeCallback2_.writeStartingInvocations, 1);

  socket_->getEventBase()->loopOnce();

  ASSERT_EQ(writeCallback1_.state, STATE_SUCCEEDED);
  EXPECT_EQ(writeCallback1_.writeStartingInvocations, 1);
  ASSERT_EQ(writeCallback2_.state, STATE_SUCCEEDED);
  EXPECT_EQ(writeCallback2_.writeStartingInvocations, 1);
  EXPECT_EQ(sendMsgInvocations_, 4);
}

/**
 * Call write once and fail immediately; expect 'writeStarting` to not be
 * called.
 */
TEST_F(
    AsyncSocketWriteCallbackTest, WriteStartingTests_WriteOnceFailImmediately) {
  const std::vector<uint8_t> wbuf(20, 'a');
  iovec op = {};
  op.iov_base = const_cast<void*>(static_cast<const void*>(wbuf.data()));
  op.iov_len = wbuf.size();
  WriteFlags flags = WriteFlags::NONE;

  ASSERT_THAT(writeCallback1_.writeStartingInvocations, Eq(0));
  socket_->writev(&writeCallback1_, &op, 1, flags);
  socket_->shutdownWriteNow();
  ASSERT_EQ(writeCallback1_.state, STATE_FAILED);
  EXPECT_EQ(writeCallback1_.writeStartingInvocations, 0);
}

/**
 * Call write once and fail after `sendMsg` was called; expect 'writeStarting`
 * to be called once.
 */
TEST_F(AsyncSocketWriteCallbackTest, WriteStartingTests_WriteOnceFail) {
  const std::vector<uint8_t> wbuf(20, 'a');
  iovec op = {};
  op.iov_base = const_cast<void*>(static_cast<const void*>(wbuf.data()));
  op.iov_len = wbuf.size();
  WriteFlags flags = WriteFlags::NONE;

  netOpsOnSendmsgFail();
  writeCallback1_.errorCallback = std::bind(&AsyncSocket::close, socket_.get());

  ASSERT_THAT(writeCallback1_.writeStartingInvocations, Eq(0));
  socket_->writev(&writeCallback1_, &op, 1, flags);
  while (writeCallback1_.state == STATE_WAITING) {
    socket_->getEventBase()->loopOnce();
  }
  ASSERT_EQ(writeCallback1_.state, STATE_FAILED);
  EXPECT_EQ(writeCallback1_.writeStartingInvocations, 1);
}