/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <folly/Try.h>
#include <folly/executors/ManualExecutor.h>
#include <folly/experimental/coro/Coroutine.h>
#include <folly/experimental/coro/Task.h>
#include <folly/experimental/coro/Traits.h>
#include <folly/experimental/coro/ViaIfAsync.h>
#include <folly/experimental/coro/WithAsyncStack.h>
#include <folly/experimental/coro/detail/Malloc.h>
#include <folly/experimental/coro/detail/Traits.h>
#include <folly/fibers/Baton.h>
#include <folly/synchronization/Baton.h>
#include <folly/tracing/AsyncStack.h>
#include <cassert>
#include <exception>
#include <type_traits>
#include <utility>
#if FOLLY_HAS_COROUTINES
namespace folly {
namespace coro {
namespace detail {
template <typename T>
class BlockingWaitTask;
class BlockingWaitPromiseBase {
struct FinalAwaiter {
bool await_ready() noexcept { return false; }
template <typename Promise>
void await_suspend(coroutine_handle<Promise> coro) noexcept {
BlockingWaitPromiseBase& promise = coro.promise();
folly::deactivateAsyncStackFrame(promise.getAsyncFrame());
promise.baton_.post();
}
void await_resume() noexcept {}
};
public:
BlockingWaitPromiseBase() noexcept = default;
static void* operator new(std::size_t size) {
return ::folly_coro_async_malloc(size);
}
static void operator delete(void* ptr, std::size_t size) {
::folly_coro_async_free(ptr, size);
}
suspend_always initial_suspend() { return {}; }
FinalAwaiter final_suspend() noexcept { return {}; }
template <typename Awaitable>
decltype(auto) await_transform(Awaitable&& awaitable) {
return folly::coro::co_withAsyncStack(static_cast<Awaitable&&>(awaitable));
}
bool done() const noexcept { return baton_.ready(); }
void wait() noexcept { baton_.wait(); }
folly::AsyncStackFrame& getAsyncFrame() noexcept { return asyncFrame_; }
private:
folly::fibers::Baton baton_;
folly::AsyncStackFrame asyncFrame_;
};
template <typename T>
class BlockingWaitPromise final : public BlockingWaitPromiseBase {
public:
BlockingWaitPromise() noexcept = default;
~BlockingWaitPromise() = default;
BlockingWaitTask<T> get_return_object() noexcept;
void unhandled_exception() noexcept {
result_->emplaceException(folly::exception_wrapper{current_exception()});
}
template <
typename U = T,
std::enable_if_t<std::is_convertible<U, T>::value, int> = 0>
void return_value(U&& value) noexcept(
std::is_nothrow_constructible<T, U&&>::value) {
result_->emplace(static_cast<U&&>(value));
}
void setTry(folly::Try<T>* result) noexcept { result_ = &result; }
private:
folly::Try<T>* result_;
};
template <typename T>
class BlockingWaitPromise<T&> final : public BlockingWaitPromiseBase {
public:
BlockingWaitPromise() noexcept = default;
~BlockingWaitPromise() = default;
BlockingWaitTask<T&> get_return_object() noexcept;
void unhandled_exception() noexcept {
result_->emplaceException(folly::exception_wrapper{current_exception()});
}
auto yield_value(T&& value) noexcept {
result_->emplace(std::ref(value));
return final_suspend();
}
auto yield_value(T& value) noexcept {
result_->emplace(std::ref(value));
return final_suspend();
}
#if 0
void return_value(T& value) noexcept {
result_->emplace(std::ref(value));
}
#endif
void return_void() {
// This should never be reachable.
// The coroutine should either have suspended at co_yield or should have
// thrown an exception and skipped over the implicit co_return and
// gone straight to unhandled_exception().
std::abort();
}
void setTry(folly::Try<std::reference_wrapper<T>>* result) noexcept {
result_ = result;
}
private:
folly::Try<std::reference_wrapper<T>>* result_;
};
template <>
class BlockingWaitPromise<void> final : public BlockingWaitPromiseBase {
public:
BlockingWaitPromise() = default;
BlockingWaitTask<void> get_return_object() noexcept;
void return_void() noexcept {}
void unhandled_exception() noexcept {
result_->emplaceException(exception_wrapper{current_exception()});
}
void setTry(folly::Try<void>* result) noexcept { result_ = result; }
private:
folly::Try<void>* result_;
};
template <typename T>
class BlockingWaitTask {
public:
using promise_type = BlockingWaitPromise<T>;
using handle_t = coroutine_handle<promise_type>;
explicit BlockingWaitTask(handle_t coro) noexcept : coro_(coro) {}
BlockingWaitTask(BlockingWaitTask&& other) noexcept
: coro_(std::exchange(other.coro_, {})) {}
BlockingWaitTask& operator=(BlockingWaitTask&& other) noexcept = delete;
~BlockingWaitTask() {
if (coro_) {
coro_.destroy();
}
}
FOLLY_NOINLINE T get(folly::AsyncStackFrame& parentFrame) && {
folly::Try<detail::lift_lvalue_reference_t<T>> result;
auto& promise = coro_.promise();
promise.setTry(&result);
auto& asyncFrame = promise.getAsyncFrame();
asyncFrame.setParentFrame(parentFrame);
asyncFrame.setReturnAddress();
{
RequestContextScopeGuard guard{RequestContext::saveContext()};
folly::resumeCoroutineWithNewAsyncStackRoot(coro_);
}
promise.wait();
return std::move(result).value();
}
FOLLY_NOINLINE T getVia(
folly::DrivableExecutor* executor,
folly::AsyncStackFrame& parentFrame) && {
folly::Try<detail::lift_lvalue_reference_t<T>> result;
auto& promise = coro_.promise();
promise.setTry(&result);
auto& asyncFrame = promise.getAsyncFrame();
asyncFrame.setReturnAddress();
asyncFrame.setParentFrame(parentFrame);
executor->add(
[coro = coro_, rctx = RequestContext::saveContext()]() mutable {
RequestContextScopeGuard guard{std::move(rctx)};
folly::resumeCoroutineWithNewAsyncStackRoot(coro);
});
while (!promise.done()) {
executor->drive();
}
return std::move(result).value();
}
private:
handle_t coro_;
};
template <typename T>
inline BlockingWaitTask<T>
BlockingWaitPromise<T>::get_return_object() noexcept {
return BlockingWaitTask<T>{
coroutine_handle<BlockingWaitPromise<T>>::from_promise(*this)};
}
template <typename T>
inline BlockingWaitTask<T&>
BlockingWaitPromise<T&>::get_return_object() noexcept {
return BlockingWaitTask<T&>{
coroutine_handle<BlockingWaitPromise<T&>>::from_promise(*this)};
}
inline BlockingWaitTask<void>
BlockingWaitPromise<void>::get_return_object() noexcept {
return BlockingWaitTask<void>{
coroutine_handle<BlockingWaitPromise<void>>::from_promise(*this)};
}
template <
typename Awaitable,
typename Result = await_result_t<Awaitable>,
std::enable_if_t<!std::is_lvalue_reference<Result>::value, int> = 0>
auto makeBlockingWaitTask(Awaitable&& awaitable)
-> BlockingWaitTask<detail::decay_rvalue_reference_t<Result>> {
co_return co_await static_cast<Awaitable&&>(awaitable);
}
template <
typename Awaitable,
typename Result = await_result_t<Awaitable>,
std::enable_if_t<std::is_lvalue_reference<Result>::value, int> = 0>
auto makeBlockingWaitTask(Awaitable&& awaitable)
-> BlockingWaitTask<detail::decay_rvalue_reference_t<Result>> {
co_yield co_await static_cast<Awaitable&&>(awaitable);
}
template <
typename Awaitable,
typename Result = await_result_t<Awaitable>,
std::enable_if_t<std::is_void<Result>::value, int> = 0>
BlockingWaitTask<void> makeRefBlockingWaitTask(Awaitable&& awaitable) {
co_await static_cast<Awaitable&&>(awaitable);
}
template <
typename Awaitable,
typename Result = await_result_t<Awaitable>,
std::enable_if_t<!std::is_void<Result>::value, int> = 0>
auto makeRefBlockingWaitTask(Awaitable&& awaitable)
-> BlockingWaitTask<std::add_lvalue_reference_t<Result>> {
co_yield co_await static_cast<Awaitable&&>(awaitable);
}
class BlockingWaitExecutor final : public folly::DrivableExecutor {
public:
~BlockingWaitExecutor() override {
while (keepAliveCount_.load() > 0) {
drive();
}
}
void add(Func func) override {
bool empty;
{
auto wQueue = queue_.wlock();
empty = wQueue->empty();
wQueue->push_back(std::move(func));
}
if (empty) {
baton_.post();
}
}
void drive() override {
baton_.wait();
baton_.reset();
folly::fibers::runInMainContext([&]() {
std::vector<Func> funcs;
queue_.swap(funcs);
for (auto& func : funcs) {
std::exchange(func, nullptr)();
}
});
}
private:
bool keepAliveAcquire() noexcept override {
auto keepAliveCount =
keepAliveCount_.fetch_add(1, std::memory_order_relaxed);
DCHECK(keepAliveCount >= 0);
return true;
}
void keepAliveRelease() noexcept override {
auto keepAliveCount = keepAliveCount_.load(std::memory_order_relaxed);
do {
DCHECK(keepAliveCount > 0);
if (keepAliveCount == 1) {
add([this] {
// the final count *must* be released from this executor or else if we
// are mid-destructor we have a data race
keepAliveCount_.fetch_sub(1, std::memory_order_relaxed);
});
return;
}
} while (!keepAliveCount_.compare_exchange_weak(
keepAliveCount,
keepAliveCount - 1,
std::memory_order_release,
std::memory_order_relaxed));
}
folly::Synchronized<std::vector<Func>> queue_;
fibers::Baton baton_;
std::atomic<ssize_t> keepAliveCount_{0};
};
} // namespace detail
/// blocking_wait_fn
///
/// Awaits co_awaits the passed awaitable and blocks the current thread until
/// the await operation completes.
///
/// Useful for launching an asynchronous operation from the top-level main()
/// function or from unit-tests.
///
/// WARNING:
/// Avoid using this function within any code that might run on the thread
/// of an executor as this can potentially lead to deadlock if the operation
/// you are waiting on needs to do some work on that executor in order to
/// complete.
struct blocking_wait_fn {
template <typename Awaitable>
FOLLY_NOINLINE auto operator()(Awaitable&& awaitable) const
-> detail::decay_rvalue_reference_t<await_result_t<Awaitable>> {
folly::AsyncStackFrame frame;
frame.setReturnAddress();
folly::AsyncStackRoot stackRoot;
stackRoot.setNextRoot(folly::tryGetCurrentAsyncStackRoot());
stackRoot.setStackFrameContext();
stackRoot.setTopFrame(frame);
return static_cast<std::add_rvalue_reference_t<await_result_t<Awaitable>>>(
detail::makeRefBlockingWaitTask(static_cast<Awaitable&&>(awaitable))
.get(frame));
}
template <typename SemiAwaitable>
FOLLY_NOINLINE auto operator()(
SemiAwaitable&& awaitable, folly::DrivableExecutor* executor) const
-> detail::decay_rvalue_reference_t<semi_await_result_t<SemiAwaitable>> {
folly::AsyncStackFrame frame;
frame.setReturnAddress();
folly::AsyncStackRoot stackRoot;
stackRoot.setNextRoot(folly::tryGetCurrentAsyncStackRoot());
stackRoot.setStackFrameContext();
stackRoot.setTopFrame(frame);
return static_cast<
std::add_rvalue_reference_t<semi_await_result_t<SemiAwaitable>>>(
detail::makeRefBlockingWaitTask(
folly::coro::co_viaIfAsync(
folly::getKeepAliveToken(executor),
static_cast<SemiAwaitable&&>(awaitable)))
.getVia(executor, frame));
}
template <
typename SemiAwaitable,
std::enable_if_t<!is_awaitable_v<SemiAwaitable>, int> = 0>
auto operator()(SemiAwaitable&& awaitable) const
-> detail::decay_rvalue_reference_t<semi_await_result_t<SemiAwaitable>> {
std::exception_ptr eptr;
{
detail::BlockingWaitExecutor executor;
try {
return operator()(static_cast<SemiAwaitable&&>(awaitable), &executor);
} catch (...) {
eptr = current_exception();
}
}
std::rethrow_exception(eptr);
}
};
inline constexpr blocking_wait_fn blocking_wait{};
static constexpr blocking_wait_fn const& blockingWait =
blocking_wait; // backcompat
} // namespace coro
} // namespace folly
#endif // FOLLY_HAS_COROUTINES