#pragma once
#include <type_traits>
#if __has_include(<variant>)
#include <variant>
#endif
#include <folly/Portability.h>
#include <folly/Utility.h>
#if FOLLY_HAS_COROUTINES
#if __has_include(<coroutine>) && !defined(LLVM_COROUTINES) && \
(!defined(_LIBCPP_VERSION) || __cplusplus > 201703L)
#define FOLLY_USE_STD_COROUTINE …
#else
#define FOLLY_USE_STD_COROUTINE …
#endif
#if FOLLY_USE_STD_COROUTINE
#include <coroutine>
#else
#include <experimental/coroutine>
#endif
#endif
#if FOLLY_HAS_COROUTINES
namespace folly {
class exception_wrapper;
struct AsyncStackFrame;
}
namespace folly::coro {
#if FOLLY_USE_STD_COROUTINE
namespace impl = std;
#else
namespace impl = std::experimental;
#endif
using impl::coroutine_handle;
using impl::coroutine_traits;
using impl::noop_coroutine;
using impl::noop_coroutine_handle;
using impl::noop_coroutine_promise;
using impl::suspend_always;
using impl::suspend_never;
template <typename T = void>
class ready_awaitable {
static_assert(!std::is_void<T>::value, "base template unsuitable for void");
public:
explicit ready_awaitable(T value)
noexcept(noexcept(T(FOLLY_DECLVAL(T&&))))
: value_(static_cast<T&&>(value)) {}
bool await_ready() noexcept { return true; }
void await_suspend(coroutine_handle<>) noexcept {}
T await_resume() noexcept(noexcept(T(FOLLY_DECLVAL(T&&)))) {
return static_cast<T&&>(value_);
}
private:
T value_;
};
template <>
class ready_awaitable<void> {
public:
ready_awaitable() noexcept = default;
bool await_ready() noexcept { return true; }
void await_suspend(coroutine_handle<>) noexcept {}
void await_resume() noexcept {}
};
namespace detail {
struct await_suspend_return_coroutine_fn {
template <typename A, typename P>
coroutine_handle<> operator()(A& a, coroutine_handle<P> coro) const
noexcept(noexcept(a.await_suspend(coro))) {
using result = decltype(a.await_suspend(coro));
if constexpr (std::is_same<void, result>::value) {
a.await_suspend(coro);
return noop_coroutine();
} else if constexpr (std::is_same<bool, result>::value) {
return a.await_suspend(coro) ? noop_coroutine() : coro;
} else {
return a.await_suspend(coro);
}
}
};
inline constexpr await_suspend_return_coroutine_fn
await_suspend_return_coroutine{};
}
#if __has_include(<variant>)
template <typename... A>
class variant_awaitable : private std::variant<A...> {
private:
using base = std::variant<A...>;
template <typename Visitor>
auto visit(Visitor v) {
return std::visit(v, static_cast<base&>(*this));
}
public:
using base::base;
auto await_ready() noexcept(
(noexcept(FOLLY_DECLVAL(A&).await_ready()) && ...)) {
return visit([&](auto& a) { return a.await_ready(); });
}
template <typename P>
auto await_suspend(coroutine_handle<P> coro) noexcept(
(noexcept(FOLLY_DECLVAL(A&).await_suspend(coro)) && ...)) {
auto impl = detail::await_suspend_return_coroutine;
return visit([&](auto& a) { return impl(a, coro); });
}
auto await_resume() noexcept(
(noexcept(FOLLY_DECLVAL(A&).await_resume()) && ...)) {
return visit([&](auto& a) { return a.await_resume(); });
}
};
#endif
namespace detail {
struct detect_promise_return_object_eager_conversion_ {
struct promise_type {
struct return_object {
return_object(promise_type& p) noexcept : promise{&p} {
promise->object = this;
}
~return_object() {
if (promise) {
promise->object = nullptr;
}
}
promise_type* promise;
};
~promise_type() {
if (object) {
object->promise = nullptr;
}
}
suspend_never initial_suspend() const noexcept { return {}; }
suspend_never final_suspend() const noexcept { return {}; }
void unhandled_exception() {}
return_object get_return_object() noexcept { return {*this}; }
void return_void() {}
return_object* object = nullptr;
};
detect_promise_return_object_eager_conversion_(
promise_type::return_object const& o) noexcept
: eager{!!o.promise} {}
~detect_promise_return_object_eager_conversion_() {}
bool eager = false;
static detect_promise_return_object_eager_conversion_ go() noexcept {
FOLLY_PUSH_WARNING
#if defined(__clang__) && (__clang_major__ < 17 && __clang_major__ > 13)
FOLLY_CLANG_DISABLE_WARNING("-Wdeprecated-experimental-coroutine")
#endif
co_return;
FOLLY_POP_WARNING
}
};
}
inline bool detect_promise_return_object_eager_conversion() {
using coro = detail::detect_promise_return_object_eager_conversion_;
constexpr auto t = kMscVer && kMscVer < 1925;
constexpr auto f = (kGnuc && !kIsClang) || (kMscVer >= 1925);
return t ? true : f ? false : coro::go().eager;
}
class ExtendedCoroutineHandle;
class ExtendedCoroutinePromise {
public:
virtual coroutine_handle<> getHandle() = 0;
virtual std::pair<ExtendedCoroutineHandle, AsyncStackFrame*> getErrorHandle(
exception_wrapper&) = 0;
protected:
~ExtendedCoroutinePromise() = default;
};
class ExtendedCoroutineHandle {
public:
template <typename Promise>
ExtendedCoroutineHandle(
coroutine_handle<Promise> handle) noexcept
: basic_(handle), extended_(fromBasic(handle)) {}
ExtendedCoroutineHandle(coroutine_handle<> handle) noexcept
: basic_(handle) {}
ExtendedCoroutineHandle(ExtendedCoroutinePromise* ptr) noexcept
: basic_(ptr->getHandle()), extended_(ptr) {}
ExtendedCoroutineHandle() noexcept = default;
void resume() { basic_.resume(); }
void destroy() { basic_.destroy(); }
coroutine_handle<> getHandle() const noexcept { return basic_; }
ExtendedCoroutinePromise* getPromise() const noexcept { return extended_; }
std::pair<ExtendedCoroutineHandle, AsyncStackFrame*> getErrorHandle(
exception_wrapper& ex) {
if (extended_) {
return extended_->getErrorHandle(ex);
}
return {basic_, nullptr};
}
explicit operator bool() const noexcept { return !!basic_; }
private:
template <typename Promise>
static auto fromBasic(coroutine_handle<Promise> handle) noexcept {
if constexpr (std::is_convertible_v<Promise*, ExtendedCoroutinePromise*>) {
return static_cast<ExtendedCoroutinePromise*>(&handle.promise());
} else {
return nullptr;
}
}
coroutine_handle<> basic_;
ExtendedCoroutinePromise* extended_{nullptr};
};
template <typename Promise>
class ExtendedCoroutinePromiseImpl : public ExtendedCoroutinePromise {
public:
coroutine_handle<> getHandle() final {
return coroutine_handle<Promise>::from_promise(
*static_cast<Promise*>(this));
}
std::pair<ExtendedCoroutineHandle, AsyncStackFrame*> getErrorHandle(
exception_wrapper&) override {
return {getHandle(), nullptr};
}
protected:
~ExtendedCoroutinePromiseImpl() = default;
};
}
#endif