#include <folly/Portability.h>
#include <folly/experimental/coro/Baton.h>
#include <folly/io/coro/ServerSocket.h>
#if FOLLY_HAS_COROUTINES
using namespace folly::coro;
namespace {
class AcceptCallback : public folly::AsyncServerSocket::AcceptCallback {
public:
explicit AcceptCallback(
Baton& baton, std::shared_ptr<folly::AsyncServerSocket> socket)
: baton_{baton}, socket_(std::move(socket)) {}
~AcceptCallback() override = default;
int acceptFd{-1};
folly::exception_wrapper error;
private:
Baton& baton_;
std::shared_ptr<folly::AsyncServerSocket> socket_;
void connectionAccepted(
folly::NetworkSocket fdNetworkSocket,
const folly::SocketAddress& clientAddr,
AcceptInfo ) noexcept override {
VLOG(5) << "Connection accepted from: " << clientAddr.describe();
socket_->pauseAccepting();
socket_->removeAcceptCallback(this, nullptr);
acceptFd = fdNetworkSocket.toFd();
baton_.post();
}
void acceptError(folly::exception_wrapper ex) noexcept override {
VLOG(5) << "acceptError";
socket_->pauseAccepting();
socket_->removeAcceptCallback(this, nullptr);
error = std::move(ex);
acceptFd = -1;
baton_.post();
}
void acceptStarted() noexcept override { VLOG(5) << "acceptStarted"; }
void acceptStopped() noexcept override { VLOG(5) << "acceptStopped"; }
};
}
namespace folly {
namespace coro {
ServerSocket::ServerSocket(
std::shared_ptr<AsyncServerSocket> socket,
std::optional<SocketAddress> bindAddr,
uint32_t listenQueueDepth)
: socket_{socket} {
socket_->setReusePortEnabled(true);
if (bindAddr.has_value()) {
VLOG(1) << "ServerSocket binds on IP: " << bindAddr->describe();
socket_->bind(*bindAddr);
} else {
VLOG(1) << "ServerSocket binds on any addr, random port";
socket_->bind(0);
}
socket_->listen(listenQueueDepth);
}
Task<std::unique_ptr<Transport>> ServerSocket::accept() {
VLOG(5) << "accept() called";
co_await folly::coro::co_safe_point;
Baton baton;
AcceptCallback cb(baton, socket_);
socket_->addAcceptCallback(&cb, nullptr);
socket_->startAccepting();
auto cancelToken = co_await folly::coro::co_current_cancellation_token;
CancellationCallback cancellationCallback{
cancelToken, [&baton] { baton.post(); }};
co_await baton;
if (cancelToken.isCancellationRequested()) {
socket_->stopAccepting();
co_yield co_cancelled;
}
if (cb.error) {
co_yield co_error(std::move(cb.error));
}
co_return std::make_unique<Transport>(
socket_->getEventBase(),
AsyncSocket::newSocket(
socket_->getEventBase(), NetworkSocket::fromFd(cb.acceptFd)));
}
}
}
#endif