#include <folly/Portability.h>
#include <functional>
#include <folly/io/coro/Transport.h>
#include <folly/io/coro/TransportCallbackBase.h>
#if FOLLY_HAS_COROUTINES
using namespace folly::coro;
namespace {
class ConnectCallback : public TransportCallbackBase,
public folly::AsyncSocket::ConnectCallback {
public:
explicit ConnectCallback(folly::AsyncSocket& socket)
: TransportCallbackBase(socket), socket_(socket) {}
private:
void cancel() noexcept override { socket_.cancelConnect(); }
void connectSuccess() noexcept override { post(); }
void connectErr(const folly::AsyncSocketException& ex) noexcept override {
storeException(ex);
post();
}
folly::AsyncSocket& socket_;
};
class ReadCallback : public TransportCallbackBase,
public folly::AsyncTransport::ReadCallback,
public folly::HHWheelTimer::Callback {
public:
ReadCallback(
folly::HHWheelTimer& timer,
folly::AsyncTransport& transport,
folly::MutableByteRange buf,
std::chrono::milliseconds timeout)
: TransportCallbackBase(transport), buf_{buf} {
if (timeout.count() > 0) {
timer.scheduleTimeout(this, timeout);
}
}
ReadCallback(
folly::HHWheelTimer& timer,
folly::AsyncTransport& transport,
folly::IOBufQueue* readBuf,
size_t minReadSize,
size_t newAllocationSize,
std::chrono::milliseconds timeout)
: TransportCallbackBase(transport),
readBuf_(readBuf),
minReadSize_(minReadSize),
newAllocationSize_(newAllocationSize) {
if (timeout.count() > 0) {
timer.scheduleTimeout(this, timeout);
}
}
size_t length{0};
bool eof{false};
private:
folly::MutableByteRange buf_;
folly::IOBufQueue* readBuf_{nullptr};
size_t minReadSize_{0};
size_t newAllocationSize_{0};
void cancel() noexcept override {
transport_.setReadCB(nullptr);
cancelTimeout();
}
bool isBufferMovable() noexcept override { return readBuf_; }
void readBufferAvailable(
std::unique_ptr<folly::IOBuf> readBuf) noexcept override {
CHECK(readBuf_);
readBuf_->append(std::move(readBuf));
post();
}
void getReadBuffer(void** buf, size_t* len) override {
if (readBuf_) {
auto rbuf = readBuf_->preallocate(minReadSize_, newAllocationSize_);
*buf = rbuf.first;
*len = rbuf.second;
} else {
VLOG(5) << "getReadBuffer, size: " << buf_.size();
*buf = buf_.begin() + length;
*len = buf_.size() - length;
}
}
void readDataAvailable(size_t len) noexcept override {
VLOG(5) << "readDataAvailable: " << len << " bytes";
length += len;
if (readBuf_) {
readBuf_->postallocate(len);
} else if (length == buf_.size()) {
transport_.setReadCB(nullptr);
cancelTimeout();
}
post();
}
void readEOF() noexcept override {
VLOG(5) << "readEOF()";
transport_.setReadCB(nullptr);
cancelTimeout();
eof = true;
post();
}
void readErr(const folly::AsyncSocketException& ex) noexcept override {
VLOG(5) << "readErr()";
transport_.setReadCB(nullptr);
cancelTimeout();
storeException(ex);
post();
}
void timeoutExpired() noexcept override {
VLOG(5) << "timeoutExpired()";
using Error = folly::AsyncSocketException::AsyncSocketExceptionType;
transport_.setReadCB(nullptr);
if (length == 0) {
error_ = folly::make_exception_wrapper<folly::AsyncSocketException>(
Error::TIMED_OUT, "Timed out waiting for data", errno);
post();
}
}
};
class WriteCallback : public TransportCallbackBase,
public folly::AsyncTransport::WriteCallback {
public:
explicit WriteCallback(folly::AsyncTransport& transport)
: TransportCallbackBase(transport) {}
~WriteCallback() override = default;
size_t bytesWritten{0};
std::optional<folly::AsyncSocketException> error;
private:
void cancel() noexcept override { transport_.closeWithReset(); }
void writeSuccess() noexcept override {
VLOG(5) << "writeSuccess";
post();
}
void writeErr(
size_t bytes, const folly::AsyncSocketException& ex) noexcept override {
VLOG(5) << "writeErr, wrote " << bytesWritten << " bytes";
bytesWritten = bytes;
error = ex;
post();
}
};
}
namespace folly {
namespace coro {
Task<Transport> Transport::newConnectedSocket(
folly::EventBase* evb,
const folly::SocketAddress& destAddr,
std::chrono::milliseconds connectTimeout,
const SocketOptionMap& options,
const SocketAddress& bindAddr,
const std::string& ifName) {
auto socket = AsyncSocket::newSocket(evb);
socket->setReadCB(nullptr);
ConnectCallback cb{*socket};
socket->connect(
&cb, destAddr, connectTimeout.count(), options, bindAddr, ifName);
auto waitRet = co_await co_awaitTry(cb.wait());
if (waitRet.hasException()) {
co_yield co_error(std::move(waitRet.exception()));
}
if (cb.error()) {
co_yield co_error(std::move(cb.error()));
}
co_return Transport(evb, std::move(socket));
}
Task<size_t> Transport::read(
folly::MutableByteRange buf, std::chrono::milliseconds timeout) {
if (deferredReadEOF_) {
deferredReadEOF_ = false;
co_return 0;
}
VLOG(5) << "Transport::read(), expecting max len " << buf.size();
ReadCallback cb{eventBase_->timer(), *transport_, buf, timeout};
transport_->setReadCB(&cb);
auto waitRet = co_await co_awaitTry(cb.wait());
if (cb.error()) {
co_yield co_error(std::move(cb.error()));
}
if (waitRet.hasException() &&
(!waitRet.tryGetExceptionObject<OperationCancelled>() ||
(!cb.eof && cb.length == 0))) {
co_yield co_error(std::move(waitRet.exception()));
}
transport_->setReadCB(nullptr);
deferredReadEOF_ = (cb.eof && cb.length > 0);
co_return cb.length;
}
Task<size_t> Transport::read(
folly::IOBufQueue& readBuf,
std::size_t minReadSize,
std::size_t newAllocationSize,
std::chrono::milliseconds timeout) {
if (deferredReadEOF_) {
deferredReadEOF_ = false;
co_return 0;
}
VLOG(5) << "Transport::read(), expecting minReadSize=" << minReadSize;
auto readBufStartLength = readBuf.chainLength();
ReadCallback cb{
eventBase_->timer(),
*transport_,
&readBuf,
minReadSize,
newAllocationSize,
timeout};
transport_->setReadCB(&cb);
auto waitRet = co_await co_awaitTry(cb.wait());
if (cb.error()) {
co_yield co_error(std::move(cb.error()));
}
if (waitRet.hasException() &&
(!waitRet.tryGetExceptionObject<OperationCancelled>() ||
(!cb.eof && cb.length == 0))) {
co_yield co_error(std::move(waitRet.exception()));
}
transport_->setReadCB(nullptr);
auto length = readBuf.chainLength() - readBufStartLength;
deferredReadEOF_ = (cb.eof && length > 0);
co_return length;
}
Task<folly::Unit> Transport::write(
folly::ByteRange buf,
std::chrono::milliseconds timeout,
folly::WriteFlags writeFlags,
WriteInfo* writeInfo) {
transport_->setSendTimeout(timeout.count());
WriteCallback cb{*transport_};
transport_->write(&cb, buf.begin(), buf.size(), writeFlags);
auto waitRet = co_await co_awaitTry(cb.wait());
if (waitRet.hasException()) {
if (writeInfo) {
writeInfo->bytesWritten = cb.bytesWritten;
}
co_yield co_error(std::move(waitRet.exception()));
}
if (cb.error) {
if (writeInfo) {
writeInfo->bytesWritten = cb.bytesWritten;
}
co_yield co_error(std::move(*cb.error));
}
co_return unit;
}
Task<folly::Unit> Transport::write(
folly::IOBufQueue& ioBufQueue,
std::chrono::milliseconds timeout,
folly::WriteFlags writeFlags,
WriteInfo* writeInfo) {
transport_->setSendTimeout(timeout.count());
WriteCallback cb{*transport_};
auto iovec = ioBufQueue.front()->getIov();
transport_->writev(&cb, iovec.data(), iovec.size(), writeFlags);
auto waitRet = co_await co_awaitTry(cb.wait());
if (waitRet.hasException()) {
if (writeInfo) {
writeInfo->bytesWritten = cb.bytesWritten;
}
co_yield co_error(std::move(waitRet.exception()));
}
if (cb.error) {
if (writeInfo) {
writeInfo->bytesWritten = cb.bytesWritten;
}
co_yield co_error(std::move(*cb.error));
}
co_return unit;
}
}
}
#endif