/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/Portability.h>
#include <folly/CancellationToken.h>
#include <folly/ScopeGuard.h>
#include <folly/executors/CPUThreadPoolExecutor.h>
#include <folly/experimental/coro/AsyncGenerator.h>
#include <folly/experimental/coro/BlockingWait.h>
#include <folly/experimental/coro/Collect.h>
#include <folly/experimental/coro/CurrentExecutor.h>
#include <folly/experimental/coro/Invoke.h>
#include <folly/experimental/coro/Merge.h>
#include <folly/experimental/coro/Task.h>
#include <folly/portability/GTest.h>
#if FOLLY_HAS_COROUTINES
using namespace folly::coro;
using namespace std::chrono_literals;
class MergeTest : public testing::Test {};
TEST_F(MergeTest, SimpleMerge) {
blockingWait([]() -> Task<void> {
auto generator = merge(
co_await co_current_executor,
[]() -> AsyncGenerator<AsyncGenerator<int>> {
auto makeGenerator = [](int start, int count) -> AsyncGenerator<int> {
for (int i = start; i < start + count; ++i) {
co_yield i;
co_await co_reschedule_on_current_executor;
}
};
co_yield makeGenerator(0, 3);
co_yield makeGenerator(3, 2);
}());
const std::array<int, 5> expectedValues = {{0, 3, 1, 4, 2}};
auto item = co_await generator.next();
for (int expectedValue : expectedValues) {
CHECK(!!item);
CHECK_EQ(expectedValue, *item);
item = co_await generator.next();
}
CHECK(!item);
}());
}
TEST_F(MergeTest, TruncateStream) {
blockingWait([]() -> Task<void> {
int started = 0;
int completed = 0;
{
auto generator = merge(
co_await co_current_executor,
co_invoke([&]() -> AsyncGenerator<AsyncGenerator<int>> {
auto makeGenerator = [&]() -> AsyncGenerator<int> {
++started;
SCOPE_EXIT {
++completed;
};
co_yield 1;
co_await co_reschedule_on_current_executor;
co_yield 2;
};
co_yield co_invoke(makeGenerator);
co_yield co_invoke(makeGenerator);
co_yield co_invoke(makeGenerator);
}));
auto item = co_await generator.next();
CHECK_EQ(1, *item);
item = co_await generator.next();
CHECK_EQ(1, *item);
CHECK_EQ(3, started);
// Truncate the stream after consuming only 2 of the 6 values it
// would have produced.
}
// Spin the executor until the generators finish responding to cancellation.
for (int i = 0; completed != started && i < 10; ++i) {
co_await co_reschedule_on_current_executor;
}
CHECK_EQ(3, completed);
}());
}
TEST_F(MergeTest, TruncateStreamMultiThreaded) {
blockingWait([]() -> Task<void> {
std::atomic<int> completed = 0;
folly::Baton allCompleted;
{
auto generator = merge(
folly::getGlobalCPUExecutor(),
co_invoke([&]() -> AsyncGenerator<AsyncGenerator<int>> {
auto makeGenerator = [&]() -> AsyncGenerator<int> {
SCOPE_EXIT {
if (++completed == 3) {
allCompleted.post();
}
};
co_yield 1;
co_yield 2;
};
co_yield co_invoke(makeGenerator);
co_yield co_invoke(makeGenerator);
co_yield co_invoke(makeGenerator);
}));
auto item = co_await generator.next();
CHECK_EQ(1, *item);
co_await generator.next();
// Truncate the stream after consuming only 2 of the 6 values it
// would have produced.
}
CHECK(allCompleted.try_wait_for(1s));
}());
}
TEST_F(MergeTest, SequencesOfRValueReferences) {
blockingWait([]() -> Task<void> {
auto makeStreamOfStreams =
[]() -> AsyncGenerator<AsyncGenerator<std::vector<int>&&>> {
auto makeStreamOfVectors = []() -> AsyncGenerator<std::vector<int>&&> {
co_yield std::vector{1, 2, 3};
co_await co_reschedule_on_current_executor;
co_yield std::vector{2, 4, 6};
};
co_yield makeStreamOfVectors();
co_yield makeStreamOfVectors();
};
auto gen = merge(co_await co_current_executor, makeStreamOfStreams());
int resultCount = 0;
while (auto item = co_await gen.next()) {
++resultCount;
std::vector<int>&& v = *item;
CHECK_EQ(3, v.size());
}
CHECK_EQ(4, resultCount);
}());
}
TEST_F(MergeTest, SequencesOfLValueReferences) {
blockingWait([]() -> Task<void> {
auto makeStreamOfStreams =
[]() -> AsyncGenerator<AsyncGenerator<std::vector<int>&>> {
auto makeStreamOfVectors = []() -> AsyncGenerator<std::vector<int>&> {
std::vector<int> v{1, 2, 3};
co_yield v;
CHECK_EQ(4, v.size());
co_await co_reschedule_on_current_executor;
v.push_back(v.back());
co_yield v;
};
co_yield makeStreamOfVectors();
co_yield makeStreamOfVectors();
};
auto gen = merge(co_await co_current_executor, makeStreamOfStreams());
int resultCount = 0;
while (auto item = co_await gen.next()) {
++resultCount;
std::vector<int>& v = *item;
if (v.size() == 3) {
CHECK_EQ(1, v[0]);
CHECK_EQ(2, v[1]);
CHECK_EQ(3, v[2]);
v.push_back(7);
} else {
CHECK_EQ(5, v.size());
CHECK_EQ(1, v[0]);
CHECK_EQ(2, v[1]);
CHECK_EQ(3, v[2]);
CHECK_EQ(7, v[3]);
CHECK_EQ(7, v[4]);
}
}
CHECK_EQ(4, resultCount);
}());
}
template <typename Ref, typename Value = folly::remove_cvref_t<Ref>>
folly::coro::AsyncGenerator<Ref, Value> neverStream() {
folly::coro::Baton baton;
folly::CancellationCallback cb{
co_await folly::coro::co_current_cancellation_token,
[&] { baton.post(); }};
co_await baton;
}
TEST_F(MergeTest, CancellationTokenPropagatesToOuterFromConsumer) {
folly::coro::blockingWait([]() -> folly::coro::Task<void> {
folly::CancellationSource cancelSource;
bool suspended = false;
bool done = false;
co_await folly::coro::collectAll(
folly::coro::co_withCancellation(
cancelSource.getToken(),
[&]() -> folly::coro::Task<void> {
auto stream = merge(
co_await co_current_executor,
neverStream<AsyncGenerator<int>>());
suspended = true;
auto result = co_await stream.next();
CHECK(!result.has_value());
done = true;
}()),
[&]() -> folly::coro::Task<void> {
co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor;
CHECK(suspended);
CHECK(!done);
cancelSource.requestCancellation();
}());
CHECK(done);
}());
}
TEST_F(MergeTest, CancellationTokenPropagatesToInnerFromConsumer) {
folly::coro::blockingWait([]() -> folly::coro::Task<void> {
folly::CancellationSource cancelSource;
bool suspended = false;
bool done = false;
auto makeStreamOfStreams = []() -> AsyncGenerator<AsyncGenerator<int>> {
co_yield neverStream<int>();
co_yield neverStream<int>();
};
co_await folly::coro::collectAll(
folly::coro::co_withCancellation(
cancelSource.getToken(),
[&]() -> folly::coro::Task<void> {
auto stream =
merge(co_await co_current_executor, makeStreamOfStreams());
suspended = true;
auto result = co_await stream.next();
CHECK(!result.has_value());
done = true;
}()),
[&]() -> folly::coro::Task<void> {
co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor;
CHECK(suspended);
CHECK(!done);
cancelSource.requestCancellation();
}());
CHECK(done);
}());
}
// Check that by the time merged generator's next() returns an empty value
// (end of stream) or throws an exception all source generators are destroyed.
TEST_F(MergeTest, SourcesAreDestroyedBeforeEof) {
std::atomic<int> runningSourceGenerators = 0;
std::atomic<int> runningListGenerators = 0;
auto sourceGenerator =
[&](bool shouldThrow) -> folly::coro::AsyncGenerator<int> {
++runningSourceGenerators;
SCOPE_EXIT {
--runningSourceGenerators;
};
co_await folly::coro::co_reschedule_on_current_executor;
co_yield 42;
co_await folly::coro::co_reschedule_on_current_executor;
if (shouldThrow) {
throw std::runtime_error("test exception");
}
};
auto listGenerator = [&](bool shouldThrow)
-> folly::coro::AsyncGenerator<folly::coro::AsyncGenerator<int>> {
CHECK(runningListGenerators == 0);
++runningListGenerators;
SCOPE_EXIT {
/* sleep override */
std::this_thread::sleep_for(std::chrono::milliseconds(10));
--runningListGenerators;
};
for (int i = 0;; ++i) {
co_await folly::coro::co_reschedule_on_current_executor;
co_yield sourceGenerator(shouldThrow && (i % 2 == 1));
}
};
folly::CPUThreadPoolExecutor exec(4);
// Stream interrupted by cancellation.
auto future =
folly::coro::co_invoke([&]() -> folly::coro::Task<void> {
auto gen =
folly::coro::merge(&exec, listGenerator(/* shouldThrow */ false));
folly::CancellationSource cancelSource;
auto r = co_await folly::coro::co_withCancellation(
cancelSource.getToken(), gen.next());
CHECK(r.has_value());
CHECK_EQ(*r, 42);
CHECK_GT(
runningSourceGenerators.load() + runningListGenerators.load(), 0);
cancelSource.requestCancellation();
// Currently the merged generator discards items produced
// after cancellation. But this behavior is not important, and
// it would probably be equally fine to return them (but stop
// calling source generators for more), so this test accepts
// either behavior.
while (true) {
r = co_await folly::coro::co_withCancellation(
cancelSource.getToken(), gen.next());
if (!r.has_value()) {
break;
}
CHECK_EQ(*r, 42);
}
CHECK_EQ(runningSourceGenerators.load(), 0);
CHECK_EQ(runningListGenerators.load(), 0);
})
.scheduleOn(&exec)
.start();
std::move(future).get();
// Stream interrupted by exception.
future =
folly::coro::co_invoke([&]() -> folly::coro::Task<void> {
auto gen =
folly::coro::merge(&exec, listGenerator(/* shouldThrow */ true));
auto r = co_await gen.next();
CHECK(r.has_value());
CHECK_EQ(*r, 42);
CHECK_GT(
runningSourceGenerators.load() + runningListGenerators.load(), 0);
while (true) {
auto r2 = co_await folly::coro::co_awaitTry(gen.next());
if (!r2.hasValue()) {
CHECK(
r2.exception().what().find("test exception") !=
std::string::npos);
break;
}
CHECK(r2->has_value());
CHECK_EQ(r2->value(), 42);
}
CHECK_EQ(runningSourceGenerators.load(), 0);
CHECK_EQ(runningListGenerators.load(), 0);
})
.scheduleOn(&exec)
.start();
std::move(future).get();
}
TEST_F(MergeTest, DontLeakRequestContext) {
class TestData : public folly::RequestData {
public:
explicit TestData() noexcept {}
bool hasCallback() override { return false; }
static void set() {
folly::RequestContext::get()->setContextData(
"test", std::make_unique<TestData>());
}
static auto get() {
return folly::RequestContext::get()->getContextData("test");
}
};
blockingWait([]() -> Task<void> {
folly::RequestContextScopeGuard requestScope;
TestData::set();
auto initialContextData = TestData::get();
CHECK(initialContextData != nullptr);
auto generator = merge(
co_await co_current_executor,
co_invoke([&]() -> AsyncGenerator<AsyncGenerator<int>> {
auto makeGenerator = [&]() -> AsyncGenerator<int> {
for (int i = 0; i < 10; ++i) {
CHECK(TestData::get() == initialContextData);
folly::RequestContextScopeGuard childScope;
CHECK(TestData::get() == nullptr);
co_await co_reschedule_on_current_executor;
CHECK(TestData::get() == nullptr);
TestData::set();
auto newContextData = TestData::get();
CHECK(newContextData != nullptr);
CHECK(newContextData != initialContextData);
co_await co_reschedule_on_current_executor;
CHECK(TestData::get() == newContextData);
}
};
for (int i = 0; i < 5; ++i) {
co_yield makeGenerator();
}
}));
while (auto val = co_await generator.next()) {
}
CHECK(TestData::get() == initialContextData);
}());
}
#endif