#pragma once
#include <folly/ScopeGuard.h>
#include <folly/Try.h>
#include <folly/experimental/coro/Coroutine.h>
#include <folly/experimental/coro/WithAsyncStack.h>
#include <folly/experimental/coro/detail/Malloc.h>
#include <folly/lang/Assume.h>
#include <folly/tracing/AsyncStack.h>
#include <cassert>
#include <utility>
#if FOLLY_HAS_COROUTINES
namespace folly {
namespace coro {
namespace detail {
template <typename T>
class InlineTask;
class InlineTaskPromiseBase {
struct FinalAwaiter {
bool await_ready() noexcept { return false; }
template <typename Promise>
coroutine_handle<> await_suspend(coroutine_handle<Promise> h) noexcept {
InlineTaskPromiseBase& promise = h.promise();
return promise.continuation_;
}
void await_resume() noexcept {}
};
protected:
InlineTaskPromiseBase() noexcept = default;
InlineTaskPromiseBase(const InlineTaskPromiseBase&) = delete;
InlineTaskPromiseBase(InlineTaskPromiseBase&&) = delete;
InlineTaskPromiseBase& operator=(const InlineTaskPromiseBase&) = delete;
InlineTaskPromiseBase& operator=(InlineTaskPromiseBase&&) = delete;
public:
static void* operator new(std::size_t size) {
return ::folly_coro_async_malloc(size);
}
static void operator delete(void* ptr, std::size_t size) {
::folly_coro_async_free(ptr, size);
}
suspend_always initial_suspend() noexcept { return {}; }
auto final_suspend() noexcept { return FinalAwaiter{}; }
void set_continuation(coroutine_handle<> continuation) noexcept {
assert(!continuation_);
continuation_ = continuation;
}
private:
coroutine_handle<> continuation_;
};
template <typename T>
class InlineTaskPromise : public InlineTaskPromiseBase {
public:
static_assert(
std::is_move_constructible<T>::value,
"InlineTask<T> only supports types that are move-constructible.");
static_assert(
!std::is_rvalue_reference<T>::value, "InlineTask<T&&> is not supported");
InlineTaskPromise() noexcept = default;
~InlineTaskPromise() = default;
InlineTask<T> get_return_object() noexcept;
template <
typename Value = T,
std::enable_if_t<std::is_convertible<Value&&, T>::value, int> = 0>
void return_value(Value&& value) noexcept(
std::is_nothrow_constructible<T, Value&&>::value) {
result_.emplace(static_cast<Value&&>(value));
}
void unhandled_exception() noexcept {
result_.emplaceException(folly::exception_wrapper{current_exception()});
}
T result() { return std::move(result_).value(); }
private:
using StorageType = std::conditional_t<
std::is_lvalue_reference<T>::value,
std::reference_wrapper<std::remove_reference_t<T>>,
T>;
folly::Try<StorageType> result_;
};
template <>
class InlineTaskPromise<void> : public InlineTaskPromiseBase {
public:
InlineTaskPromise() noexcept = default;
InlineTask<void> get_return_object() noexcept;
void return_void() noexcept {}
void unhandled_exception() noexcept {
result_.emplaceException(folly::exception_wrapper{current_exception()});
}
void result() { return result_.value(); }
private:
folly::Try<void> result_;
};
template <typename T>
class InlineTask {
public:
using promise_type = detail::InlineTaskPromise<T>;
private:
using handle_t = coroutine_handle<promise_type>;
public:
InlineTask(InlineTask&& other) noexcept
: coro_(std::exchange(other.coro_, {})) {}
~InlineTask() {
if (coro_) {
coro_.destroy();
}
}
class Awaiter {
public:
~Awaiter() {
if (coro_) {
coro_.destroy();
}
}
bool await_ready() noexcept { return false; }
handle_t await_suspend(coroutine_handle<> awaitingCoroutine) noexcept {
assert(coro_ && !coro_.done());
coro_.promise().set_continuation(awaitingCoroutine);
return coro_;
}
T await_resume() {
auto destroyOnExit =
folly::makeGuard([this] { std::exchange(coro_, {}).destroy(); });
return coro_.promise().result();
}
private:
friend class InlineTask<T>;
explicit Awaiter(handle_t coro) noexcept : coro_(coro) {}
handle_t coro_;
};
Awaiter operator co_await() && {
assert(coro_ && !coro_.done());
return Awaiter{std::exchange(coro_, {})};
}
private:
friend class InlineTaskPromise<T>;
explicit InlineTask(handle_t coro) noexcept : coro_(coro) {}
handle_t coro_;
};
template <typename T>
inline InlineTask<T> InlineTaskPromise<T>::get_return_object() noexcept {
return InlineTask<T>{
coroutine_handle<InlineTaskPromise<T>>::from_promise(*this)};
}
inline InlineTask<void> InlineTaskPromise<void>::get_return_object() noexcept {
return InlineTask<void>{
coroutine_handle<InlineTaskPromise<void>>::from_promise(*this)};
}
struct InlineTaskDetached {
class promise_type {
struct FinalAwaiter {
bool await_ready() noexcept { return false; }
void await_suspend(coroutine_handle<promise_type> h) noexcept {
folly::deactivateAsyncStackFrame(h.promise().getAsyncFrame());
h.destroy();
}
[[noreturn]] void await_resume() noexcept { folly::assume_unreachable(); }
};
public:
static void* operator new(std::size_t size) {
return ::folly_coro_async_malloc(size);
}
static void operator delete(void* ptr, std::size_t size) {
::folly_coro_async_free(ptr, size);
}
promise_type() noexcept {
asyncFrame_.setParentFrame(folly::getDetachedRootAsyncStackFrame());
}
InlineTaskDetached get_return_object() noexcept {
return InlineTaskDetached{
coroutine_handle<promise_type>::from_promise(*this)};
}
suspend_always initial_suspend() noexcept { return {}; }
FinalAwaiter final_suspend() noexcept { return {}; }
void return_void() noexcept {}
[[noreturn]] void unhandled_exception() noexcept { std::terminate(); }
template <typename Awaitable>
decltype(auto) await_transform(Awaitable&& awaitable) {
return folly::coro::co_withAsyncStack(
static_cast<Awaitable&&>(awaitable));
}
folly::AsyncStackFrame& getAsyncFrame() noexcept { return asyncFrame_; }
private:
folly::AsyncStackFrame asyncFrame_;
};
InlineTaskDetached(InlineTaskDetached&& other) noexcept
: coro_(std::exchange(other.coro_, {})) {}
~InlineTaskDetached() {
if (coro_) {
coro_.destroy();
}
}
FOLLY_NOINLINE void start() noexcept {
start(FOLLY_ASYNC_STACK_RETURN_ADDRESS());
}
void start(void* returnAddress) noexcept {
coro_.promise().getAsyncFrame().setReturnAddress(returnAddress);
folly::resumeCoroutineWithNewAsyncStackRoot(std::exchange(coro_, {}));
}
private:
explicit InlineTaskDetached(coroutine_handle<promise_type> h) noexcept
: coro_(h) {}
coroutine_handle<promise_type> coro_;
};
}
}
}
#endif