#pragma once
#include <folly/Range.h>
#include <folly/SocketAddress.h>
#include <folly/experimental/coro/Task.h>
#include <folly/io/IOBufQueue.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncSocketException.h>
#if FOLLY_HAS_COROUTINES
namespace folly {
namespace coro {
class TransportIf {
public:
using ErrorCode = AsyncSocketException::AsyncSocketExceptionType;
virtual ~TransportIf() = default;
virtual EventBase* getEventBase() noexcept = 0;
virtual Task<size_t> read(
MutableByteRange buf, std::chrono::milliseconds timeout) = 0;
Task<size_t> read(
void* buf, size_t buflen, std::chrono::milliseconds timeout) {
return read(MutableByteRange((unsigned char*)buf, buflen), timeout);
}
virtual Task<size_t> read(
IOBufQueue& buf,
size_t minReadSize,
size_t newAllocationSize,
std::chrono::milliseconds timeout) = 0;
struct WriteInfo {
size_t bytesWritten{0};
};
virtual Task<Unit> write(
ByteRange buf,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
folly::WriteFlags writeFlags = folly::WriteFlags::NONE,
WriteInfo* writeInfo = nullptr) = 0;
virtual Task<Unit> write(
IOBufQueue& ioBufQueue,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
folly::WriteFlags writeFlags = folly::WriteFlags::NONE,
WriteInfo* writeInfo = nullptr) = 0;
virtual SocketAddress getLocalAddress() const noexcept = 0;
virtual SocketAddress getPeerAddress() const noexcept = 0;
virtual void close() = 0;
virtual void shutdownWrite() = 0;
virtual void closeWithReset() = 0;
virtual folly::AsyncTransport* getTransport() const = 0;
virtual const AsyncTransportCertificate* getPeerCertificate() const = 0;
};
class Transport : public TransportIf {
public:
Transport(
folly::EventBase* eventBase, folly::AsyncTransport::UniquePtr transport)
: eventBase_(eventBase), transport_(std::move(transport)) {}
Transport(Transport&&) = default;
Transport& operator=(Transport&&) = default;
static Task<Transport> newConnectedSocket(
EventBase* evb,
const SocketAddress& destAddr,
std::chrono::milliseconds connectTimeout,
const SocketOptionMap& options = emptySocketOptionMap,
const SocketAddress& bindAddr = AsyncSocketTransport::anyAddress(),
const std::string& ifName = "");
virtual EventBase* getEventBase() noexcept override { return eventBase_; }
Task<size_t> read(
MutableByteRange buf, std::chrono::milliseconds timeout) override;
Task<size_t> read(
IOBufQueue& buf,
size_t minReadSize,
size_t newAllocationSize,
std::chrono::milliseconds timeout) override;
Task<Unit> write(
ByteRange buf,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
folly::WriteFlags writeFlags = folly::WriteFlags::NONE,
WriteInfo* writeInfo = nullptr) override;
Task<folly::Unit> write(
IOBufQueue& ioBufQueue,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
folly::WriteFlags writeFlags = folly::WriteFlags::NONE,
WriteInfo* writeInfo = nullptr) override;
AsyncTransport* getTransport() const override { return transport_.get(); }
SocketAddress getLocalAddress() const noexcept override {
SocketAddress addr;
transport_->getLocalAddress(&addr);
return addr;
}
SocketAddress getPeerAddress() const noexcept override {
SocketAddress addr;
transport_->getPeerAddress(&addr);
return addr;
}
void shutdownWrite() noexcept override {
if (transport_) {
transport_->shutdownWrite();
}
}
void close() noexcept override {
if (transport_) {
transport_->close();
}
}
void closeWithReset() noexcept override {
if (transport_) {
transport_->closeWithReset();
}
}
const AsyncTransportCertificate* getPeerCertificate() const override {
return transport_->getPeerCertificate();
}
private:
Transport(const Transport&) = delete;
Transport& operator=(const Transport&) = delete;
EventBase* eventBase_;
AsyncTransport::UniquePtr transport_;
bool deferredReadEOF_{false};
};
}
}
#endif