#pragma once
#include <folly/experimental/coro/Coroutine.h>
#include <folly/experimental/coro/Traits.h>
#include <folly/functional/Invoke.h>
#include <folly/lang/Assume.h>
#include <folly/lang/CustomizationPoint.h>
#include <folly/tracing/AsyncStack.h>
#include <cassert>
#include <type_traits>
#include <utility>
#if FOLLY_HAS_COROUTINES
namespace folly::coro {
namespace detail {
class WithAsyncStackCoroutine {
public:
class promise_type {
public:
WithAsyncStackCoroutine get_return_object() noexcept {
return WithAsyncStackCoroutine{
coroutine_handle<promise_type>::from_promise(*this)};
}
suspend_always initial_suspend() noexcept { return {}; }
struct FinalAwaiter {
bool await_ready() noexcept { return false; }
void await_suspend(coroutine_handle<promise_type> h) noexcept {
auto& promise = h.promise();
folly::deactivateSuspendedLeaf(promise.getLeafFrame());
folly::resumeCoroutineWithNewAsyncStackRoot(
promise.continuation_, *promise.getLeafFrame().getParentFrame());
}
[[noreturn]] void await_resume() noexcept { folly::assume_unreachable(); }
};
FinalAwaiter final_suspend() noexcept { return {}; }
void return_void() noexcept {}
[[noreturn]] void unhandled_exception() noexcept {
folly::assume_unreachable();
}
folly::AsyncStackFrame& getLeafFrame() noexcept { return leafFrame; }
private:
friend WithAsyncStackCoroutine;
coroutine_handle<> continuation_;
folly::AsyncStackFrame leafFrame;
};
WithAsyncStackCoroutine() noexcept : coro_() {}
WithAsyncStackCoroutine(WithAsyncStackCoroutine&& other) noexcept
: coro_(std::exchange(other.coro_, {})) {}
~WithAsyncStackCoroutine() {
if (coro_) {
coro_.destroy();
}
}
WithAsyncStackCoroutine& operator=(WithAsyncStackCoroutine other) noexcept {
std::swap(coro_, other.coro_);
return *this;
}
static WithAsyncStackCoroutine create() { co_return; }
template <typename Promise>
coroutine_handle<promise_type> getWrapperHandleFor(
coroutine_handle<Promise> h, void* returnAddress) noexcept {
auto& promise = coro_.promise();
promise.continuation_ = h;
promise.getLeafFrame().setParentFrame(h.promise().getAsyncFrame());
promise.getLeafFrame().setReturnAddress(returnAddress);
return coro_;
}
folly::AsyncStackFrame& getLeafFrame() noexcept {
return coro_.promise().getLeafFrame();
}
private:
explicit WithAsyncStackCoroutine(coroutine_handle<promise_type> h) noexcept
: coro_(h) {}
coroutine_handle<promise_type> coro_;
};
template <typename Awaitable>
class WithAsyncStackAwaiter {
using Awaiter = awaiter_type_t<Awaitable>;
public:
explicit WithAsyncStackAwaiter(Awaitable&& awaitable)
: awaiter_(folly::coro::get_awaiter(static_cast<Awaitable&&>(awaitable))),
coroWrapper_(WithAsyncStackCoroutine::create()) {}
auto await_ready() noexcept(noexcept(std::declval<Awaiter&>().await_ready()))
-> decltype(std::declval<Awaiter&>().await_ready()) {
return awaiter_.await_ready();
}
template <typename Promise>
FOLLY_NOINLINE auto await_suspend(coroutine_handle<Promise> h) {
AsyncStackFrame& callerFrame = h.promise().getAsyncFrame();
AsyncStackRoot* stackRoot = callerFrame.getStackRoot();
assert(stackRoot != nullptr);
auto wrapperHandle =
coroWrapper_.getWrapperHandleFor(h, FOLLY_ASYNC_STACK_RETURN_ADDRESS());
folly::deactivateAsyncStackFrame(callerFrame);
folly::activateSuspendedLeaf(coroWrapper_.getLeafFrame());
using await_suspend_result_t =
decltype(awaiter_.await_suspend(wrapperHandle));
try {
if constexpr (std::is_same_v<await_suspend_result_t, bool>) {
if (!awaiter_.await_suspend(wrapperHandle)) {
folly::activateAsyncStackFrame(*stackRoot, callerFrame);
folly::deactivateSuspendedLeaf(coroWrapper_.getLeafFrame());
return false;
}
return true;
} else {
return awaiter_.await_suspend(wrapperHandle);
}
} catch (...) {
folly::activateAsyncStackFrame(*stackRoot, callerFrame);
folly::deactivateSuspendedLeaf(coroWrapper_.getLeafFrame());
throw;
}
}
auto await_resume() noexcept(
noexcept(std::declval<Awaiter&>().await_resume()))
-> decltype(std::declval<Awaiter&>().await_resume()) {
coroWrapper_ = WithAsyncStackCoroutine();
return awaiter_.await_resume();
}
template <typename Awaiter2 = Awaiter>
auto await_resume_try() noexcept(
noexcept(std::declval<Awaiter2&>().await_resume_try()))
-> decltype(std::declval<Awaiter2&>().await_resume_try()) {
coroWrapper_ = WithAsyncStackCoroutine();
return awaiter_.await_resume_try();
}
private:
awaiter_type_t<Awaitable> awaiter_;
WithAsyncStackCoroutine coroWrapper_;
};
template <typename Awaitable>
class WithAsyncStackAwaitable {
public:
explicit WithAsyncStackAwaitable(Awaitable&& awaitable)
: awaitable_(static_cast<Awaitable&&>(awaitable)) {}
WithAsyncStackAwaiter<Awaitable&> operator co_await() & {
return WithAsyncStackAwaiter<Awaitable&>{awaitable_};
}
WithAsyncStackAwaiter<Awaitable> operator co_await() && {
return WithAsyncStackAwaiter<Awaitable>{
static_cast<Awaitable&&>(awaitable_)};
}
private:
Awaitable awaitable_;
};
struct WithAsyncStackFunction {
template <
typename Awaitable,
std::enable_if_t<
folly::is_tag_invocable_v<WithAsyncStackFunction, Awaitable>,
int> = 0>
auto operator()(Awaitable&& awaitable) const noexcept(
folly::is_nothrow_tag_invocable_v<WithAsyncStackFunction, Awaitable>)
-> folly::tag_invoke_result_t<WithAsyncStackFunction, Awaitable> {
return folly::tag_invoke(
WithAsyncStackFunction{}, static_cast<Awaitable&&>(awaitable));
}
template <
typename Awaitable,
std::enable_if_t<
!folly::is_tag_invocable_v<WithAsyncStackFunction, Awaitable>,
int> = 0,
std::enable_if_t<folly::coro::is_awaitable_v<Awaitable>, int> = 0>
WithAsyncStackAwaitable<Awaitable> operator()(Awaitable&& awaitable) const
noexcept(std::is_nothrow_move_constructible_v<Awaitable>) {
return WithAsyncStackAwaitable<Awaitable>{
static_cast<Awaitable&&>(awaitable)};
}
};
}
template <typename Awaitable>
inline constexpr bool is_awaitable_async_stack_aware_v =
folly::is_tag_invocable_v<detail::WithAsyncStackFunction, Awaitable>;
FOLLY_DEFINE_CPO(detail::WithAsyncStackFunction, co_withAsyncStack)
}
#endif