#include <folly/CancellationToken.h>
#include <folly/experimental/coro/Baton.h>
#include <folly/experimental/coro/WithCancellation.h>
#if FOLLY_HAS_COROUTINES
namespace folly::coro {
namespace detail {
template <bool>
struct DiscardImpl {
folly::coro::Baton baton;
exception_wrapper timeoutResult;
bool parentCancelled = false;
bool checkedTimeout = false;
};
template <>
struct DiscardImpl<false> {};
template <typename SemiAwaitable, typename Duration, bool discard>
Task<typename semi_await_try_result_t<SemiAwaitable>::element_type> timeoutImpl(
SemiAwaitable semiAwaitable, Duration timeoutDuration, Timekeeper* tk) {
CancellationSource cancelSource;
DiscardImpl<discard> impl;
auto sleepFuture =
folly::futures::sleep(timeoutDuration, tk).toUnsafeFuture();
sleepFuture.setCallback_(
[&, cancelSource](Executor::KeepAlive<>&&, Try<Unit>&& result) noexcept {
if constexpr (discard) {
if (result.hasException()) {
impl.timeoutResult = std::move(result.exception());
} else {
impl.timeoutResult = folly::make_exception_wrapper<FutureTimeout>();
}
impl.baton.post();
}
cancelSource.requestCancellation();
});
bool isSleepCancelled = false;
auto tryCancelSleep = [&]() noexcept {
if (!isSleepCancelled) {
isSleepCancelled = true;
sleepFuture.cancel();
}
};
std::optional<CancellationCallback> cancelCallback{
std::in_place, co_await co_current_cancellation_token, [&]() {
cancelSource.requestCancellation();
tryCancelSleep();
if constexpr (discard) {
impl.parentCancelled = true;
}
}};
exception_wrapper error;
try {
auto resultTry =
co_await folly::coro::co_awaitTry(folly::coro::co_withCancellation(
cancelSource.getToken(), std::move(semiAwaitable)));
cancelCallback.reset();
if constexpr (discard) {
if (!impl.parentCancelled && impl.baton.ready()) {
co_yield folly::coro::co_error(std::move(impl.timeoutResult));
}
impl.checkedTimeout = true;
}
tryCancelSleep();
if constexpr (discard) {
co_await impl.baton;
}
if (resultTry.hasException()) {
co_yield folly::coro::co_error(std::move(resultTry).exception());
}
co_return std::move(resultTry).value();
} catch (...) {
error = exception_wrapper{current_exception()};
}
assert(error);
cancelCallback.reset();
if constexpr (discard) {
if (!impl.checkedTimeout && !impl.parentCancelled && impl.baton.ready()) {
co_yield folly::coro::co_error(std::move(impl.timeoutResult));
}
}
tryCancelSleep();
if constexpr (discard) {
co_await impl.baton;
}
co_yield folly::coro::co_error(std::move(error));
}
}
template <typename SemiAwaitable, typename Duration>
Task<typename semi_await_try_result_t<SemiAwaitable>::element_type> timeout(
SemiAwaitable semiAwaitable, Duration timeoutDuration, Timekeeper* tk) {
return detail::timeoutImpl<SemiAwaitable, Duration, true>(
std::move(semiAwaitable), timeoutDuration, tk);
}
template <typename SemiAwaitable, typename Duration>
Task<typename semi_await_try_result_t<SemiAwaitable>::element_type>
timeoutNoDiscard(
SemiAwaitable semiAwaitable, Duration timeoutDuration, Timekeeper* tk) {
return detail::timeoutImpl<SemiAwaitable, Duration, false>(
std::move(semiAwaitable), timeoutDuration, tk);
}
}
#endif