/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/channels/ConsumeChannel.h>
#include <folly/channels/MaxConcurrentRateLimiter.h>
#include <folly/channels/MultiplexChannel.h>
#include <folly/channels/test/ChannelTestUtil.h>
#include <folly/executors/ManualExecutor.h>
#include <folly/executors/SerialExecutor.h>
#include <folly/experimental/coro/BlockingWait.h>
#include <folly/portability/GMock.h>
#include <folly/portability/GTest.h>
namespace folly {
namespace channels {
using namespace testing;
using namespace std::string_literals;
class MultiplexChannelFixture : public Test {
protected:
MultiplexChannelFixture() {}
~MultiplexChannelFixture() { executor_.drain(); }
template <typename T>
using Callback = StrictMock<MockNextCallback<T>>;
template <typename T>
std::pair<ChannelCallbackHandle, Callback<T>*> processValues(
Receiver<T> receiver) {
auto callback = std::make_unique<Callback<T>>();
auto callbackPtr = callback.get();
auto handle = consumeChannelWithCallback(
std::move(receiver),
&executor_,
[cbk = std::move(callback)](
Try<T> resultTry) mutable -> folly::coro::Task<bool> {
(*cbk)(std::move(resultTry));
co_return true;
});
return std::make_pair(std::move(handle), callbackPtr);
}
folly::ManualExecutor executor_;
};
static constexpr int kCloseSubscription = -1;
struct TestInputValue {
folly::F14FastMap<std::string, int> values;
};
struct TestContext {
std::string contextValue;
};
struct TestSubscriptionArg {
int initialValue;
bool firstSubscriptionForKey;
bool throwException{false};
folly::SemiFuture<Unit> waitForSubscription{folly::makeSemiFuture()};
};
struct TestMultiplexer {
public:
TestMultiplexer(
folly::Executor::KeepAlive<folly::SequencedExecutor> executor,
std::shared_ptr<RateLimiter> rateLimiter = nullptr)
: executor_(std::move(executor)), rateLimiter_(std::move(rateLimiter)) {}
folly::Executor::KeepAlive<folly::SequencedExecutor> getExecutor() {
return executor_;
}
folly::coro::Task<std::vector<int>> onNewSubscription(
std::string key,
TestContext& subscriptionContext,
TestSubscriptionArg subscriptionArg) {
co_await std::move(subscriptionArg.waitForSubscription);
if (subscriptionArg.throwException) {
throw std::runtime_error("Error");
}
if (subscriptionArg.firstSubscriptionForKey) {
subscriptionContext.contextValue = key;
}
EXPECT_EQ(subscriptionContext.contextValue, key);
co_return toVector(subscriptionArg.initialValue);
}
folly::coro::Task<void> onInputValue(
Try<TestInputValue> inputValue,
MultiplexedSubscriptions<TestMultiplexer>& subscriptions) {
for (auto& [key, value] : inputValue.value().values) {
if (!subscriptions.hasSubscription(key)) {
continue;
}
if (value == kCloseSubscription) {
subscriptions.close(key, {} /* ex */);
} else {
subscriptions.write(key, value);
}
}
co_return;
}
std::shared_ptr<RateLimiter> getRateLimiter() { return rateLimiter_; }
private:
folly::Executor::KeepAlive<folly::SequencedExecutor> executor_;
std::shared_ptr<RateLimiter> rateLimiter_;
};
TEST_F(MultiplexChannelFixture, ReceiveValues_MultiplexesValues) {
auto [inputReceiver, inputSender] = Channel<TestInputValue>::create();
auto multiplexChannel = createMultiplexChannel(
TestMultiplexer(&executor_), std::move(inputReceiver));
EXPECT_FALSE(multiplexChannel.anySubscribers());
auto [handle1a, callback1a] = processValues(multiplexChannel.subscribe(
"one"s,
TestSubscriptionArg{
100 /* initialValue */, true /* firstSubscriptionForKey */}));
auto [handle1b, callback1b] = processValues(multiplexChannel.subscribe(
"one"s,
TestSubscriptionArg{
101 /* initialValue */, false /* firstSubscriptionForKey */}));
auto [handle2, callback2] = processValues(multiplexChannel.subscribe(
"two"s,
TestSubscriptionArg{
200 /* initialValue */, true /* firstSubscriptionForKey */}));
EXPECT_TRUE(multiplexChannel.anySubscribers());
EXPECT_CALL(*callback1a, onValue(100));
EXPECT_CALL(*callback1b, onValue(101));
EXPECT_CALL(*callback2, onValue(200));
executor_.drain();
EXPECT_CALL(*callback1a, onValue(110));
EXPECT_CALL(*callback1b, onValue(110));
inputSender.write(TestInputValue{toMap(std::make_pair("one"s, 110))});
executor_.drain();
EXPECT_CALL(*callback2, onValue(210));
inputSender.write(TestInputValue{toMap(std::make_pair("two"s, 210))});
executor_.drain();
EXPECT_CALL(*callback1a, onValue(120));
EXPECT_CALL(*callback1b, onValue(120));
EXPECT_CALL(*callback2, onValue(220));
inputSender.write(TestInputValue{toMap(
std::make_pair("one"s, 120),
std::make_pair("two"s, 220),
std::make_pair("three", 330))});
executor_.drain();
std::move(inputSender).close();
EXPECT_CALL(*callback1a, onClosed());
EXPECT_CALL(*callback1b, onClosed());
EXPECT_CALL(*callback2, onClosed());
executor_.drain();
EXPECT_FALSE(multiplexChannel.anySubscribers());
}
TEST_F(MultiplexChannelFixture, InputThrows_AllOutputReceiversGetException) {
auto [inputReceiver, inputSender] = Channel<TestInputValue>::create();
auto multiplexChannel = createMultiplexChannel(
TestMultiplexer(&executor_), std::move(inputReceiver));
auto [handle1a, callback1a] = processValues(multiplexChannel.subscribe(
"one"s,
TestSubscriptionArg{
100 /* initialValue */, true /* firstSubscriptionForKey */}));
auto [handle1b, callback1b] = processValues(multiplexChannel.subscribe(
"one"s,
TestSubscriptionArg{
101 /* initialValue */, false /* firstSubscriptionForKey */}));
auto [handle2, callback2] = processValues(multiplexChannel.subscribe(
"two"s,
TestSubscriptionArg{
200 /* initialValue */, true /* firstSubscriptionForKey */}));
EXPECT_CALL(*callback1a, onValue(100));
EXPECT_CALL(*callback1b, onValue(101));
EXPECT_CALL(*callback2, onValue(200));
executor_.drain();
std::move(inputSender)
.close(folly::make_exception_wrapper<std::runtime_error>("Error"));
EXPECT_CALL(*callback1a, onRuntimeError("std::runtime_error: Error"));
EXPECT_CALL(*callback1b, onRuntimeError("std::runtime_error: Error"));
EXPECT_CALL(*callback2, onRuntimeError("std::runtime_error: Error"));
executor_.drain();
}
TEST_F(MultiplexChannelFixture, ClearUnusedSubscriptions) {
auto [inputReceiver, inputSender] = Channel<TestInputValue>::create();
auto multiplexChannel = createMultiplexChannel(
TestMultiplexer(&executor_), std::move(inputReceiver));
EXPECT_FALSE(multiplexChannel.anySubscribers());
auto [handle1a, callback1a] = processValues(multiplexChannel.subscribe(
"one"s,
TestSubscriptionArg{
100 /* initialValue */, true /* firstSubscriptionForKey */}));
auto [handle1b, callback1b] = processValues(multiplexChannel.subscribe(
"one"s,
TestSubscriptionArg{
101 /* initialValue */, false /* firstSubscriptionForKey */}));
auto [handle2, callback2] = processValues(multiplexChannel.subscribe(
"two"s,
TestSubscriptionArg{
200 /* initialValue */, true /* firstSubscriptionForKey */}));
EXPECT_TRUE(multiplexChannel.anySubscribers());
EXPECT_CALL(*callback1a, onValue(100));
EXPECT_CALL(*callback1b, onValue(101));
EXPECT_CALL(*callback2, onValue(200));
executor_.drain();
auto clearedSubscriptions1Task = multiplexChannel.clearUnusedSubscriptions()
.scheduleOn(&executor_)
.start();
executor_.drain();
auto clearedSubscriptions1 =
folly::coro::blockingWait(std::move(clearedSubscriptions1Task));
EXPECT_TRUE(clearedSubscriptions1.empty());
EXPECT_CALL(*callback1a, onCancelled());
EXPECT_CALL(*callback2, onCancelled());
handle1a.reset();
handle2.reset();
executor_.drain();
auto clearedSubscriptions2Task = multiplexChannel.clearUnusedSubscriptions()
.scheduleOn(&executor_)
.start();
executor_.drain();
auto clearedSubscriptions2 =
folly::coro::blockingWait(std::move(clearedSubscriptions2Task));
EXPECT_EQ(clearedSubscriptions2.size(), 1);
EXPECT_EQ(clearedSubscriptions2[0].first, "two"s);
EXPECT_EQ(clearedSubscriptions2[0].second.contextValue, "two"s);
EXPECT_CALL(*callback1b, onCancelled());
handle1b.reset();
executor_.drain();
EXPECT_TRUE(multiplexChannel.anySubscribers());
auto clearedSubscriptions3Task = multiplexChannel.clearUnusedSubscriptions()
.scheduleOn(&executor_)
.start();
executor_.drain();
auto clearedSubscriptions3 =
folly::coro::blockingWait(std::move(clearedSubscriptions3Task));
EXPECT_EQ(clearedSubscriptions3.size(), 1);
EXPECT_EQ(clearedSubscriptions3[0].first, "one"s);
EXPECT_EQ(clearedSubscriptions3[0].second.contextValue, "one"s);
EXPECT_FALSE(multiplexChannel.anySubscribers());
}
TEST_F(MultiplexChannelFixture, OnNewSubscriptionThrows_OutputReceiverClosed) {
auto [inputReceiver, inputSender] = Channel<TestInputValue>::create();
auto multiplexChannel = createMultiplexChannel(
TestMultiplexer(&executor_), std::move(inputReceiver));
auto [handle1a, callback1a] = processValues(multiplexChannel.subscribe(
"one"s,
TestSubscriptionArg{
100 /* initialValue */,
true /* firstSubscriptionForKey */,
false /* throwException */}));
EXPECT_CALL(*callback1a, onValue(100));
executor_.drain();
auto [handle1b, callback1b] = processValues(multiplexChannel.subscribe(
"one"s,
TestSubscriptionArg{
101 /* initialValue */,
false /* firstSubscriptionForKey */,
true /* throwException */}));
EXPECT_CALL(*callback1b, onRuntimeError("std::runtime_error: Error"));
executor_.drain();
inputSender.write(TestInputValue{toMap(std::make_pair("one"s, 110))});
EXPECT_CALL(*callback1a, onValue(110));
executor_.drain();
std::move(inputSender).close();
EXPECT_CALL(*callback1a, onClosed());
executor_.drain();
}
TEST_F(MultiplexChannelFixture, HandleDestroyed) {
auto [inputReceiver, inputSender] = Channel<TestInputValue>::create();
auto multiplexChannel = createMultiplexChannel(
TestMultiplexer(&executor_), std::move(inputReceiver));
EXPECT_FALSE(multiplexChannel.anySubscribers());
auto [handle1a, callback1a] = processValues(multiplexChannel.subscribe(
"one"s,
TestSubscriptionArg{
100 /* initialValue */,
true /* firstSubscriptionForKey */,
false /* throwException */}));
EXPECT_CALL(*callback1a, onValue(100));
executor_.drain();
{ auto toDestroy = std::move(multiplexChannel); }
EXPECT_CALL(*callback1a, onClosed());
executor_.drain();
}
TEST_F(MultiplexChannelFixture, Subscribe_WithRateLimiter) {
auto rateLimiter = MaxConcurrentRateLimiter::create(1 /* maxConcurrent */);
auto [inputReceiver, inputSender] = Channel<TestInputValue>::create();
auto multiplexChannel = createMultiplexChannel(
TestMultiplexer(&executor_, std::move(rateLimiter)),
std::move(inputReceiver));
EXPECT_FALSE(multiplexChannel.anySubscribers());
auto promise1a = folly::Promise<Unit>();
auto [handle1a, callback1a] = processValues(multiplexChannel.subscribe(
"one"s,
TestSubscriptionArg{
100 /* initialValue */,
true /* firstSubscriptionForKey */,
false /* throwException */,
promise1a.getSemiFuture() /* waitForSubscription */}));
executor_.drain();
auto promise1b = folly::Promise<Unit>();
auto [handle1b, callback1b] = processValues(multiplexChannel.subscribe(
"one"s,
TestSubscriptionArg{
101 /* initialValue */,
true /* firstSubscriptionForKey */,
false /* throwException */,
promise1b.getSemiFuture() /* waitForSubscription */}));
executor_.drain();
EXPECT_CALL(*callback1a, onValue(100));
promise1a.setValue();
executor_.drain();
EXPECT_CALL(*callback1b, onValue(101));
promise1b.setValue();
executor_.drain();
std::move(inputSender).close();
EXPECT_CALL(*callback1a, onClosed());
EXPECT_CALL(*callback1b, onClosed());
executor_.drain();
}
class MultiplexChannelFixtureStress : public Test {
protected:
MultiplexChannelFixtureStress()
: producer_(makeProducer()),
consumers_(
toVector(makeConsumer(0), makeConsumer(1), makeConsumer(2))) {}
static std::unique_ptr<StressTestProducer<int>> makeProducer() {
return std::make_unique<StressTestProducer<int>>(
[value = 0]() mutable { return value++; });
}
static std::unique_ptr<StressTestConsumer<int>> makeConsumer(int remainder) {
return std::make_unique<StressTestConsumer<int>>(
ConsumptionMode::CallbackWithHandle,
[remainder, lastReceived = -1](int value) mutable {
if (lastReceived == -1) {
lastReceived = value;
EXPECT_EQ(lastReceived % 3, remainder);
} else {
lastReceived += 3;
EXPECT_EQ(value, lastReceived);
}
});
}
static void sleepFor(std::chrono::milliseconds duration) {
/* sleep override */
std::this_thread::sleep_for(duration);
}
static constexpr std::chrono::milliseconds kTestTimeout =
std::chrono::milliseconds{10};
std::unique_ptr<StressTestProducer<int>> producer_;
std::vector<std::unique_ptr<StressTestConsumer<int>>> consumers_;
};
struct NoContext {};
struct NoSubscriptionArg {};
struct TestMultiplexerStress {
public:
explicit TestMultiplexerStress(
folly::Executor::KeepAlive<folly::SequencedExecutor> executor)
: executor_(std::move(executor)) {}
folly::Executor::KeepAlive<folly::SequencedExecutor> getExecutor() {
return executor_;
}
std::shared_ptr<RateLimiter> getRateLimiter() {
return nullptr; // No rate limiting
}
folly::coro::Task<std::vector<int>> onNewSubscription(
int, NoContext&, NoSubscriptionArg) {
co_return std::vector<int>(); // No initial values
}
folly::coro::Task<void> onInputValue(
Try<int> inputValue,
MultiplexedSubscriptions<TestMultiplexerStress>& subscriptions) {
if (subscriptions.hasSubscription(inputValue.value() % 3)) {
subscriptions.write(inputValue.value() % 3, inputValue.value());
}
co_return;
}
private:
folly::Executor::KeepAlive<folly::SequencedExecutor> executor_;
};
TEST_F(MultiplexChannelFixtureStress, HandleClosed) {
auto [receiver, sender] = Channel<int>::create();
producer_->startProducing(std::move(sender), std::nullopt /* closeEx */);
folly::CPUThreadPoolExecutor multiplexChannelExecutor(1);
auto multiplexChannel = createMultiplexChannel(
TestMultiplexerStress(
folly::SerialExecutor::create(&multiplexChannelExecutor)),
std::move(receiver));
consumers_.at(0)->startConsuming(
multiplexChannel.subscribe(0 /* key */, NoSubscriptionArg()));
consumers_.at(1)->startConsuming(
multiplexChannel.subscribe(1 /* key */, NoSubscriptionArg()));
sleepFor(kTestTimeout / 3);
consumers_.at(2)->startConsuming(
multiplexChannel.subscribe(2 /* key */, NoSubscriptionArg()));
sleepFor(kTestTimeout / 3);
consumers_.at(0)->cancel();
EXPECT_EQ(consumers_.at(0)->waitForClose().get(), CloseType::Cancelled);
sleepFor(kTestTimeout / 3);
std::move(multiplexChannel).close();
EXPECT_EQ(consumers_.at(1)->waitForClose().get(), CloseType::NoException);
EXPECT_EQ(consumers_.at(2)->waitForClose().get(), CloseType::NoException);
}
TEST_F(MultiplexChannelFixtureStress, InputChannelClosed) {
auto [receiver, sender] = Channel<int>::create();
producer_->startProducing(std::move(sender), std::nullopt /* closeEx */);
folly::CPUThreadPoolExecutor multiplexChannelExecutor(1);
auto multiplexChannel = createMultiplexChannel(
TestMultiplexerStress(
folly::SerialExecutor::create(&multiplexChannelExecutor)),
std::move(receiver));
consumers_.at(0)->startConsuming(
multiplexChannel.subscribe(0 /* key */, NoSubscriptionArg()));
consumers_.at(1)->startConsuming(
multiplexChannel.subscribe(1 /* key */, NoSubscriptionArg()));
sleepFor(kTestTimeout / 3);
consumers_.at(2)->startConsuming(
multiplexChannel.subscribe(2 /* key */, NoSubscriptionArg()));
sleepFor(kTestTimeout / 3);
consumers_.at(0)->cancel();
EXPECT_EQ(consumers_.at(0)->waitForClose().get(), CloseType::Cancelled);
sleepFor(kTestTimeout / 3);
producer_->stopProducing();
EXPECT_EQ(consumers_.at(1)->waitForClose().get(), CloseType::NoException);
EXPECT_EQ(consumers_.at(2)->waitForClose().get(), CloseType::NoException);
}
} // namespace channels
} // namespace folly