/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <memory>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/test/BlockingSocket.h>
#include <folly/io/async/test/CallbackStateEnum.h>
#include <folly/io/async/test/ConnCallback.h>
#include <folly/net/NetOps.h>
#include <folly/net/NetworkSocket.h>
#include <folly/portability/Sockets.h>
namespace folly::test {
class WriteCallback : public folly::AsyncTransport::WriteCallback,
public folly::AsyncWriter::ReleaseIOBufCallback {
public:
explicit WriteCallback(bool enableReleaseIOBufCallback = false)
: state(STATE_WAITING),
bytesWritten(0),
numIoBufCount(0),
numIoBufBytes(0),
exception(folly::AsyncSocketException::UNKNOWN, "none"),
releaseIOBufCallback(enableReleaseIOBufCallback ? this : nullptr) {}
void writeSuccess() noexcept override {
state = STATE_SUCCEEDED;
if (successCallback) {
successCallback();
}
}
void writeErr(
size_t nBytesWritten,
const folly::AsyncSocketException& ex) noexcept override {
LOG(ERROR) << ex.what();
state = STATE_FAILED;
this->bytesWritten = nBytesWritten;
exception = ex;
if (errorCallback) {
errorCallback();
}
}
void writeStarting() noexcept override { writeStartingInvocations++; }
folly::AsyncWriter::ReleaseIOBufCallback* getReleaseIOBufCallback() noexcept
override {
return releaseIOBufCallback;
}
void releaseIOBuf(std::unique_ptr<folly::IOBuf> ioBuf) noexcept override {
numIoBufCount += ioBuf->countChainElements();
numIoBufBytes += ioBuf->computeChainDataLength();
}
StateEnum state;
std::atomic<size_t> bytesWritten;
std::atomic<size_t> numIoBufCount;
std::atomic<size_t> numIoBufBytes;
folly::AsyncSocketException exception;
VoidCallback successCallback;
VoidCallback errorCallback;
ReleaseIOBufCallback* releaseIOBufCallback;
size_t writeStartingInvocations{0};
};
class ReadCallback : public folly::AsyncTransport::ReadCallback {
public:
explicit ReadCallback(size_t _maxBufferSz = 4096)
: state(STATE_WAITING),
exception(folly::AsyncSocketException::UNKNOWN, "none"),
buffers(),
maxBufferSz(_maxBufferSz) {}
~ReadCallback() override {
for (auto& buffer : buffers) {
buffer.free();
}
currentBuffer.free();
}
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
if (!currentBuffer.buffer) {
currentBuffer.allocate(maxBufferSz);
}
*bufReturn = currentBuffer.buffer;
*lenReturn = currentBuffer.length;
}
void readDataAvailable(size_t len) noexcept override {
currentBuffer.length = len;
buffers.push_back(currentBuffer);
currentBuffer.reset();
if (dataAvailableCallback) {
dataAvailableCallback();
}
}
void readEOF() noexcept override { state = STATE_SUCCEEDED; }
void readErr(const folly::AsyncSocketException& ex) noexcept override {
state = STATE_FAILED;
exception = ex;
}
void verifyData(const char* expected, size_t expectedLen) const {
verifyData((const unsigned char*)expected, expectedLen);
}
void verifyData(const unsigned char* expected, size_t expectedLen) const {
size_t offset = 0;
for (size_t idx = 0; idx < buffers.size(); ++idx) {
const auto& buf = buffers[idx];
size_t cmpLen = std::min(buf.length, expectedLen - offset);
CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
CHECK_EQ(cmpLen, buf.length);
offset += cmpLen;
}
CHECK_EQ(offset, expectedLen);
}
void clearData() {
for (auto& buffer : buffers) {
buffer.free();
}
buffers.clear();
}
size_t dataRead() const {
size_t ret = 0;
for (const auto& buf : buffers) {
ret += buf.length;
}
return ret;
}
class Buffer {
public:
Buffer() : buffer(nullptr), length(0) {}
Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
void reset() {
buffer = nullptr;
length = 0;
}
void allocate(size_t len) {
assert(buffer == nullptr);
this->buffer = static_cast<char*>(malloc(len));
this->length = len;
}
void free() {
::free(buffer);
reset();
}
char* buffer;
size_t length;
};
StateEnum state;
folly::AsyncSocketException exception;
std::vector<Buffer> buffers;
Buffer currentBuffer;
VoidCallback dataAvailableCallback;
const size_t maxBufferSz;
};
class ReadvCallback : public folly::AsyncTransport::ReadCallback {
public:
ReadvCallback(size_t bufferSize, size_t len)
: state_(STATE_WAITING),
exception_(folly::AsyncSocketException::UNKNOWN, "none"),
queue_(folly::IOBufIovecBuilder::Options().setBlockSize(bufferSize)),
len_(len) {
setReadMode(folly::AsyncTransport::ReadCallback::ReadMode::ReadVec);
}
~ReadvCallback() override = default;
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
std::ignore = bufReturn;
std::ignore = lenReturn;
CHECK(false); // this should not be called
}
void getReadBuffers(folly::IOBufIovecBuilder::IoVecVec& iovs) override {
queue_.allocateBuffers(iovs, len_);
}
void readDataAvailable(size_t len) noexcept override {
auto tmp = queue_.extractIOBufChain(len);
if (!buf_) {
buf_ = std::move(tmp);
} else {
buf_->prependChain(std::move(tmp));
}
}
void reset() { buf_.reset(); }
void readEOF() noexcept override { state_ = STATE_SUCCEEDED; }
void readErr(const folly::AsyncSocketException& ex) noexcept override {
state_ = STATE_FAILED;
exception_ = ex;
}
void verifyData(const std::string& data) const {
CHECK(buf_);
auto r = buf_->coalesce();
std::string tmp;
tmp.assign(reinterpret_cast<const char*>(r.begin()), r.end() - r.begin());
CHECK_EQ(data, tmp);
}
std::unique_ptr<folly::IOBuf> buf_;
private:
StateEnum state_;
folly::AsyncSocketException exception_;
folly::IOBufIovecBuilder queue_;
const size_t len_;
};
class BufferCallback : public folly::AsyncTransport::BufferCallback {
public:
BufferCallback(folly::AsyncSocket* socket, size_t expectedBytes)
: socket_(socket),
expectedBytes_(expectedBytes),
buffered_(false),
bufferCleared_(false) {}
void onEgressBuffered() override {
size_t bytesWritten = socket_->getAppBytesWritten();
size_t bytesBuffered = socket_->getAppBytesBuffered();
CHECK_GT(bytesBuffered, 0);
CHECK_EQ(expectedBytes_, bytesWritten + bytesBuffered);
buffered_ = true;
}
void onEgressBufferCleared() override {
size_t bytesWritten = socket_->getAppBytesWritten();
size_t bytesBuffered = socket_->getAppBytesBuffered();
CHECK_EQ(0, bytesBuffered);
CHECK_EQ(expectedBytes_, bytesWritten);
bufferCleared_ = true;
}
bool hasBuffered() const { return buffered_; }
bool hasBufferCleared() const { return bufferCleared_; }
private:
folly::AsyncSocket* socket_{nullptr};
size_t expectedBytes_{0};
bool buffered_{false};
bool bufferCleared_{false};
};
class ZeroCopyReadCallback : public folly::AsyncTransport::ReadCallback {
public:
explicit ZeroCopyReadCallback(
folly::AsyncTransport::ReadCallback::ZeroCopyMemStore* memStore,
size_t _maxBufferSz = 4096)
: memStore_(memStore),
state(STATE_WAITING),
exception(folly::AsyncSocketException::UNKNOWN, "none"),
maxBufferSz(_maxBufferSz) {}
~ZeroCopyReadCallback() override { currentBuffer.free(); }
// zerocopy
folly::AsyncTransport::ReadCallback::ZeroCopyMemStore*
readZeroCopyEnabled() noexcept override {
return memStore_;
}
void getZeroCopyFallbackBuffer(
void** bufReturn, size_t* lenReturn) noexcept override {
if (!currentZeroCopyBuffer.buffer) {
currentZeroCopyBuffer.allocate(maxBufferSz);
}
*bufReturn = currentZeroCopyBuffer.buffer;
*lenReturn = currentZeroCopyBuffer.length;
}
void readZeroCopyDataAvailable(
std::unique_ptr<folly::IOBuf>&& zeroCopyData,
size_t additionalBytes) noexcept override {
auto ioBuf = std::move(zeroCopyData);
if (additionalBytes) {
auto tmp = folly::IOBuf::takeOwnership(
currentZeroCopyBuffer.buffer,
currentZeroCopyBuffer.length,
0,
additionalBytes);
currentZeroCopyBuffer.reset();
if (ioBuf) {
ioBuf->prependChain(std::move(tmp));
} else {
ioBuf = std::move(tmp);
}
}
if (!data_) {
data_ = std::move(ioBuf);
} else {
data_->prependChain(std::move(ioBuf));
}
}
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
if (!currentBuffer.buffer) {
currentBuffer.allocate(maxBufferSz);
}
*bufReturn = currentBuffer.buffer;
*lenReturn = currentBuffer.length;
}
void readDataAvailable(size_t len) noexcept override {
auto ioBuf = folly::IOBuf::takeOwnership(
currentBuffer.buffer, currentBuffer.length, 0, len);
currentBuffer.reset();
if (!data_) {
data_ = std::move(ioBuf);
} else {
data_->prependChain(std::move(ioBuf));
}
}
void readEOF() noexcept override { state = STATE_SUCCEEDED; }
void readErr(const folly::AsyncSocketException& ex) noexcept override {
state = STATE_FAILED;
exception = ex;
}
void verifyData(const std::string& expected) const {
verifyData((const unsigned char*)expected.data(), expected.size());
}
void verifyData(const unsigned char* expected, size_t expectedLen) const {
CHECK(!!data_);
auto len = data_->computeChainDataLength();
CHECK_EQ(len, expectedLen);
auto* buf = data_.get();
auto* current = buf;
size_t offset = 0;
do {
size_t cmpLen = std::min(current->length(), expectedLen - offset);
CHECK_EQ(cmpLen, current->length());
CHECK_EQ(memcmp(current->data(), expected + offset, cmpLen), 0);
offset += cmpLen;
current = current->next();
} while (current != buf);
std::ignore = expected;
CHECK_EQ(offset, expectedLen);
}
class Buffer {
public:
Buffer() = default;
Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
~Buffer() {
if (buffer) {
::free(buffer);
}
}
void reset() {
buffer = nullptr;
length = 0;
}
void allocate(size_t len) {
CHECK(buffer == nullptr);
buffer = static_cast<char*>(malloc(len));
length = len;
}
void free() {
::free(buffer);
reset();
}
char* buffer{nullptr};
size_t length{0};
};
folly::AsyncTransport::ReadCallback::ZeroCopyMemStore* memStore_;
StateEnum state;
folly::AsyncSocketException exception;
Buffer currentBuffer, currentZeroCopyBuffer;
VoidCallback dataAvailableCallback;
const size_t maxBufferSz;
std::unique_ptr<folly::IOBuf> data_;
};
class ReadVerifier {};
class TestSendMsgParamsCallback
: public folly::AsyncSocket::SendMsgParamsCallback {
public:
TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data)
: flags_(flags),
writeFlags_(folly::WriteFlags::NONE),
dataSize_(dataSize),
data_(data),
queriedFlags_(false),
queriedData_(false) {}
void reset(int flags) {
flags_ = flags;
writeFlags_ = folly::WriteFlags::NONE;
queriedFlags_ = false;
queriedData_ = false;
}
int getFlagsImpl(
folly::WriteFlags flags, int /*defaultFlags*/) noexcept override {
queriedFlags_ = true;
if (writeFlags_ == folly::WriteFlags::NONE) {
writeFlags_ = flags;
} else {
assert(flags == writeFlags_);
}
return flags_;
}
void getAncillaryData(
folly::WriteFlags flags,
void* data,
const folly::AsyncSocket::WriteRequestTag& tag,
const bool /* byteEventsEnabled */) noexcept override {
CHECK_EQ(tag, expectedTag_);
queriedData_ = true;
if (writeFlags_ == folly::WriteFlags::NONE) {
writeFlags_ = flags;
} else {
assert(flags == writeFlags_);
}
assert(data != nullptr);
memcpy(data, data_, dataSize_);
}
uint32_t getAncillaryDataSize(
folly::WriteFlags flags,
const folly::AsyncSocket::WriteRequestTag& tag,
const bool /* byteEventsEnabled */) noexcept override {
CHECK_EQ(tag, expectedTag_);
if (writeFlags_ == folly::WriteFlags::NONE) {
writeFlags_ = flags;
} else {
assert(flags == writeFlags_);
}
return dataSize_;
}
void wroteBytes(
const folly::AsyncSocket::WriteRequestTag& tag) noexcept override {
CHECK_EQ(tag, expectedTag_);
tagLastWritten_ = tag;
}
int flags_;
folly::WriteFlags writeFlags_;
uint32_t dataSize_;
void* data_;
bool queriedFlags_;
bool queriedData_;
folly::AsyncSocket::WriteRequestTag expectedTag_{
folly::AsyncSocket::WriteRequestTag::EmptyDummy()};
std::optional<folly::AsyncSocket::WriteRequestTag> tagLastWritten_;
};
class TestServer {
public:
// Create a TestServer.
// This immediately starts listening on an ephemeral port.
explicit TestServer(bool enableTFO = false, int bufSize = -1) : fd_() {
namespace fsp = folly::portability::sockets;
fd_ = folly::netops::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
if (fd_ == folly::NetworkSocket()) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::INTERNAL_ERROR,
"failed to create test server socket",
errno);
}
if (folly::netops::set_socket_non_blocking(fd_) != 0) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::INTERNAL_ERROR,
"failed to put test server socket in "
"non-blocking mode",
errno);
}
if (enableTFO) {
#if FOLLY_ALLOW_TFO
folly::detail::tfo_enable(fd_, 100);
#endif
}
struct addrinfo hints, *res;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_PASSIVE;
if (getaddrinfo(nullptr, "0", &hints, &res)) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::INTERNAL_ERROR,
"Attempted to bind address to socket with "
"bad getaddrinfo",
errno);
}
SCOPE_EXIT {
freeaddrinfo(res);
};
if (bufSize > 0) {
folly::netops::setsockopt(
fd_, SOL_SOCKET, SO_SNDBUF, &bufSize, sizeof(bufSize));
folly::netops::setsockopt(
fd_, SOL_SOCKET, SO_RCVBUF, &bufSize, sizeof(bufSize));
}
if (folly::netops::bind(fd_, res->ai_addr, res->ai_addrlen)) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::INTERNAL_ERROR,
"failed to bind to async server socket for port 10",
errno);
}
if (folly::netops::listen(fd_, 10) != 0) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::INTERNAL_ERROR,
"failed to listen on test server socket",
errno);
}
address_.setFromLocalAddress(fd_);
// The local address will contain 0.0.0.0.
// Change it to 127.0.0.1, so it can be used to connect to the server
address_.setFromIpPort("127.0.0.1", address_.getPort());
}
~TestServer() {
if (fd_ != folly::NetworkSocket()) {
folly::netops::close(fd_);
}
}
// Get the address for connecting to the server
const folly::SocketAddress& getAddress() const { return address_; }
folly::NetworkSocket acceptFD(int timeout = 50) {
folly::netops::PollDescriptor pfd;
pfd.fd = fd_;
pfd.events = POLLIN;
int ret = folly::netops::poll(&pfd, 1, timeout);
if (ret == 0) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::INTERNAL_ERROR,
"test server accept() timed out");
} else if (ret < 0) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::INTERNAL_ERROR,
"test server accept() poll failed",
errno);
}
auto acceptedFd = folly::netops::accept(fd_, nullptr, nullptr);
if (acceptedFd == folly::NetworkSocket()) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::INTERNAL_ERROR,
"test server accept() failed",
errno);
}
return acceptedFd;
}
std::shared_ptr<BlockingSocket> accept(int timeout = 50) {
auto fd = acceptFD(timeout);
return std::make_shared<BlockingSocket>(fd);
}
std::shared_ptr<folly::AsyncSocket> acceptAsync(
folly::EventBase* evb, int timeout = 50) {
auto fd = acceptFD(timeout);
return folly::AsyncSocket::newSocket(evb, fd);
}
/**
* Accept a connection, read data from it, and verify that it matches the
* data in the specified buffer.
*/
void verifyConnection(const char* buf, size_t len) {
// accept a connection
std::shared_ptr<BlockingSocket> acceptedSocket = accept();
// read the data and compare it to the specified buffer
std::unique_ptr<uint8_t[]> readbuf(new uint8_t[len]);
acceptedSocket->readAll(readbuf.get(), len);
CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
// make sure we get EOF next
uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
CHECK_EQ(bytesRead, 0);
}
private:
folly::NetworkSocket fd_;
folly::SocketAddress address_;
};
} // namespace folly::test