/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cassert>
#include <exception>
#include <type_traits>
#include <utility>
#include <folly/experimental/coro/Coroutine.h>
#include <folly/experimental/coro/Invoke.h>
#include <folly/lang/Exception.h>
#if FOLLY_HAS_COROUTINES
namespace folly {
namespace coro {
template <typename T>
class Generator {
public:
class promise_type final {
public:
promise_type() noexcept
: m_value(nullptr),
m_exception(nullptr),
m_root(this),
m_parentOrLeaf(this) {}
promise_type(const promise_type&) = delete;
promise_type(promise_type&&) = delete;
auto get_return_object() noexcept { return Generator<T>{*this}; }
suspend_always initial_suspend() noexcept { return {}; }
suspend_always final_suspend() noexcept { return {}; }
void unhandled_exception() noexcept { m_exception = current_exception(); }
void return_void() noexcept {}
suspend_always yield_value(T& value) noexcept {
m_value = std::addressof(value);
return {};
}
suspend_always yield_value(T&& value) noexcept {
m_value = std::addressof(value);
return {};
}
auto yield_value(Generator&& generator) noexcept {
return yield_value(generator);
}
auto yield_value(Generator& generator) noexcept {
struct awaitable {
awaitable(promise_type* childPromise) : m_childPromise(childPromise) {}
bool await_ready() noexcept { return this->m_childPromise == nullptr; }
void await_suspend(coroutine_handle<promise_type>) noexcept {}
void await_resume() {
if (this->m_childPromise != nullptr) {
this->m_childPromise->throw_if_exception();
}
}
private:
promise_type* m_childPromise;
};
if (generator.m_promise != nullptr) {
m_root->m_parentOrLeaf = generator.m_promise;
generator.m_promise->m_root = m_root;
generator.m_promise->m_parentOrLeaf = this;
generator.m_promise->resume();
// NB: This branch looks like a (premature?) optimization for empty
// generators, and until proven otherwise in benchmarks, it may be
// advantageous to simply return `awaitable{generator.m_promise}`.
if (!generator.m_promise->is_complete() ||
generator.m_promise->m_exception != nullptr) {
return awaitable{generator.m_promise};
}
m_root->m_parentOrLeaf = this;
}
return awaitable{nullptr};
}
// Don't allow any use of 'co_await' inside the Generator
// coroutine.
template <typename U>
void await_transform(U&& value) = delete;
void destroy() noexcept {
coroutine_handle<promise_type>::from_promise(*this).destroy();
}
void throw_if_exception() {
if (m_exception != nullptr) {
std::rethrow_exception(std::move(m_exception));
}
}
bool is_complete() noexcept {
return coroutine_handle<promise_type>::from_promise(*this).done();
}
T& value() noexcept {
assert(this == m_root);
assert(!is_complete());
return *(m_parentOrLeaf->m_value);
}
void pull() noexcept {
assert(this == m_root);
assert(!m_parentOrLeaf->is_complete());
m_parentOrLeaf->resume();
while (m_parentOrLeaf != this && m_parentOrLeaf->is_complete()) {
m_parentOrLeaf = m_parentOrLeaf->m_parentOrLeaf;
m_parentOrLeaf->resume();
}
}
private:
void resume() noexcept {
coroutine_handle<promise_type>::from_promise(*this).resume();
}
std::add_pointer_t<T> m_value;
std::exception_ptr m_exception;
promise_type* m_root;
// If this is the promise of the root generator then this field
// is a pointer to the leaf promise.
// For non-root generators this is a pointer to the parent promise.
promise_type* m_parentOrLeaf;
};
Generator() noexcept : m_promise(nullptr) {}
Generator(promise_type& promise) noexcept : m_promise(&promise) {}
Generator(Generator&& other) noexcept : m_promise(other.m_promise) {
other.m_promise = nullptr;
}
Generator(const Generator& other) = delete;
Generator& operator=(const Generator& other) = delete;
~Generator() {
if (m_promise != nullptr) {
m_promise->destroy();
}
}
Generator& operator=(Generator&& other) noexcept {
if (this != &other) {
if (m_promise != nullptr) {
m_promise->destroy();
}
m_promise = other.m_promise;
other.m_promise = nullptr;
}
return *this;
}
class iterator {
public:
using iterator_category = std::input_iterator_tag;
// What type should we use for counting elements of a potentially infinite
// sequence?
using difference_type = std::ptrdiff_t;
using value_type = std::remove_reference_t<T>;
using reference = std::conditional_t<std::is_reference_v<T>, T, T&>;
using pointer = std::add_pointer_t<T>;
iterator() noexcept : m_promise(nullptr) {}
explicit iterator(promise_type* promise) noexcept : m_promise(promise) {}
bool operator==(const iterator& other) const noexcept {
return m_promise == other.m_promise;
}
bool operator!=(const iterator& other) const noexcept {
return m_promise != other.m_promise;
}
iterator& operator++() {
assert(m_promise != nullptr);
assert(!m_promise->is_complete());
m_promise->pull();
if (m_promise->is_complete()) {
auto* temp = m_promise;
m_promise = nullptr;
temp->throw_if_exception();
}
return *this;
}
void operator++(int) { (void)operator++(); }
reference operator*() const noexcept {
assert(m_promise != nullptr);
return static_cast<reference>(m_promise->value());
}
pointer operator->() const noexcept { return std::addressof(operator*()); }
private:
promise_type* m_promise;
};
iterator begin() {
if (m_promise != nullptr) {
m_promise->pull();
if (!m_promise->is_complete()) {
return iterator(m_promise);
}
m_promise->throw_if_exception();
}
return iterator(nullptr);
}
iterator end() noexcept { return iterator(nullptr); }
void swap(Generator& other) noexcept {
std::swap(m_promise, other.m_promise);
}
template <typename F, typename... A, typename F_, typename... A_>
friend Generator tag_invoke(
tag_t<co_invoke_fn>, tag_t<Generator, F, A...>, F_ f, A_... a) {
auto&& r = invoke(static_cast<F&&>(f), static_cast<A&&>(a)...);
for (auto&& v : r) {
co_yield std::move(v);
}
}
private:
friend class promise_type;
promise_type* m_promise;
};
template <typename T>
void swap(Generator<T>& a, Generator<T>& b) noexcept {
a.swap(b);
}
} // namespace coro
} // namespace folly
#endif // FOLLY_HAS_COROUTINES