folly/folly/channels/test/FanoutChannelTest.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <folly/channels/ConsumeChannel.h>
#include <folly/channels/FanoutChannel.h>
#include <folly/channels/test/ChannelTestUtil.h>
#include <folly/executors/ManualExecutor.h>
#include <folly/executors/SerialExecutor.h>
#include <folly/portability/GMock.h>
#include <folly/portability/GTest.h>

namespace folly {
namespace channels {

using namespace testing;

class FanoutChannelFixture : public Test {
 protected:
  FanoutChannelFixture() {}

  ~FanoutChannelFixture() { executor_.drain(); }

  template <typename T>
  using Callback = StrictMock<MockNextCallback<T>>;

  template <typename T>
  std::pair<ChannelCallbackHandle, Callback<T>*> processValues(
      Receiver<T> receiver) {
    auto callback = std::make_unique<Callback<T>>();
    auto callbackPtr = callback.get();
    auto handle = consumeChannelWithCallback(
        std::move(receiver),
        &executor_,
        [cbk = std::move(callback)](
            Try<T> resultTry) mutable -> folly::coro::Task<bool> {
          (*cbk)(std::move(resultTry));
          co_return true;
        });
    return std::make_pair(std::move(handle), callbackPtr);
  }

  StrictMock<MockNextCallback<std::string>> createCallback() {
    return StrictMock<MockNextCallback<std::string>>();
  }

  folly::ManualExecutor executor_;
};

TEST_F(FanoutChannelFixture, ReceiveValue_FanoutBroadcastsValues) {
  struct LatestVersion {
    int version{-1};
    size_t numSubscribers{0};

    void update(int& newVersion, size_t newNumSubscribers) {
      version = newVersion;
      numSubscribers = newNumSubscribers;
    }
  };

  auto [inputReceiver, sender] = Channel<int>::create();
  auto fanoutChannel = createFanoutChannel(
      std::move(inputReceiver), &executor_, LatestVersion());

  EXPECT_FALSE(fanoutChannel.anySubscribers());

  auto [handle1, callback1] = processValues(fanoutChannel.subscribe(
      [](const auto&) { return toVector(100); } /* getInitialValues */));
  auto [handle2, callback2] = processValues(fanoutChannel.subscribe(
      [](const auto&) { return toVector(200); } /* getInitialValues */));

  EXPECT_TRUE(fanoutChannel.anySubscribers());
  EXPECT_CALL(*callback1, onValue(100));
  EXPECT_CALL(*callback2, onValue(200));
  executor_.drain();

  EXPECT_CALL(*callback1, onValue(1));
  EXPECT_CALL(*callback2, onValue(1));
  EXPECT_CALL(*callback1, onValue(2));
  EXPECT_CALL(*callback2, onValue(2));
  sender.write(1);
  sender.write(2);
  executor_.drain();

  auto [handle3, callback3] = processValues(
      fanoutChannel.subscribe([](const LatestVersion& latestVersion) {
        EXPECT_EQ(latestVersion.numSubscribers, 2);
        return toVector(latestVersion.version);
      } /* getInitialValues */));

  EXPECT_CALL(*callback3, onValue(2));
  executor_.drain();

  sender.write(3);
  EXPECT_CALL(*callback1, onValue(3));
  EXPECT_CALL(*callback2, onValue(3));
  EXPECT_CALL(*callback3, onValue(3));

  std::move(sender).close();
  EXPECT_CALL(*callback1, onClosed());
  EXPECT_CALL(*callback2, onClosed());
  EXPECT_CALL(*callback3, onClosed());
  executor_.drain();

  EXPECT_FALSE(fanoutChannel.anySubscribers());
}

TEST_F(FanoutChannelFixture, InputClosed_AllOutputReceiversClose) {
  auto [inputReceiver, sender] = Channel<int>::create();
  auto fanoutChannel =
      createFanoutChannel(std::move(inputReceiver), &executor_);

  auto [handle1, callback1] = processValues(fanoutChannel.subscribe());
  auto [handle2, callback2] = processValues(fanoutChannel.subscribe());

  EXPECT_CALL(*callback1, onValue(1));
  EXPECT_CALL(*callback2, onValue(1));
  EXPECT_CALL(*callback1, onClosed());
  EXPECT_CALL(*callback2, onClosed());

  executor_.drain();

  EXPECT_TRUE(fanoutChannel.anySubscribers());

  sender.write(1);
  executor_.drain();

  std::move(sender).close();
  executor_.drain();

  EXPECT_FALSE(fanoutChannel.anySubscribers());
}

TEST_F(FanoutChannelFixture, InputThrows_AllOutputReceiversGetException) {
  auto [inputReceiver, sender] = Channel<int>::create();
  auto fanoutChannel =
      createFanoutChannel(std::move(inputReceiver), &executor_);

  auto [handle1, callback1] = processValues(fanoutChannel.subscribe());
  auto [handle2, callback2] = processValues(fanoutChannel.subscribe());

  EXPECT_CALL(*callback1, onValue(1));
  EXPECT_CALL(*callback2, onValue(1));
  EXPECT_CALL(*callback1, onRuntimeError("std::runtime_error: Error"));
  EXPECT_CALL(*callback2, onRuntimeError("std::runtime_error: Error"));

  executor_.drain();

  EXPECT_TRUE(fanoutChannel.anySubscribers());

  sender.write(1);
  executor_.drain();

  std::move(sender).close(std::runtime_error("Error"));
  executor_.drain();

  EXPECT_FALSE(fanoutChannel.anySubscribers());
}

TEST_F(FanoutChannelFixture, ReceiversCancelled) {
  auto [inputReceiver, sender] = Channel<int>::create();
  auto fanoutChannel =
      createFanoutChannel(std::move(inputReceiver), &executor_);

  auto [handle1, callback1] = processValues(fanoutChannel.subscribe());
  auto [handle2, callback2] = processValues(fanoutChannel.subscribe());

  EXPECT_CALL(*callback1, onValue(1));
  EXPECT_CALL(*callback2, onValue(1));
  EXPECT_CALL(*callback1, onCancelled());
  EXPECT_CALL(*callback2, onValue(2));
  EXPECT_CALL(*callback2, onCancelled());

  executor_.drain();

  EXPECT_TRUE(fanoutChannel.anySubscribers());

  sender.write(1);
  executor_.drain();

  EXPECT_TRUE(fanoutChannel.anySubscribers());

  handle1.reset();
  sender.write(2);
  executor_.drain();

  EXPECT_TRUE(fanoutChannel.anySubscribers());

  handle2.reset();
  sender.write(3);
  executor_.drain();

  EXPECT_FALSE(fanoutChannel.anySubscribers());

  std::move(sender).close();
  executor_.drain();

  EXPECT_FALSE(fanoutChannel.anySubscribers());
}

TEST_F(FanoutChannelFixture, SubscribersClosed) {
  auto [inputReceiver, sender] = Channel<int>::create();
  auto fanoutChannel =
      createFanoutChannel(std::move(inputReceiver), &executor_);

  auto [handle1, callback1] = processValues(fanoutChannel.subscribe());
  auto [handle2, callback2] = processValues(fanoutChannel.subscribe());
  executor_.drain();

  EXPECT_TRUE(fanoutChannel.anySubscribers());

  EXPECT_CALL(*callback1, onValue(1));
  EXPECT_CALL(*callback2, onValue(1));
  sender.write(1);
  executor_.drain();

  EXPECT_TRUE(fanoutChannel.anySubscribers());

  EXPECT_CALL(*callback1, onClosed());
  EXPECT_CALL(*callback2, onClosed());
  fanoutChannel.closeSubscribers();
  executor_.drain();

  EXPECT_FALSE(fanoutChannel.anySubscribers());

  auto [handle3, callback3] = processValues(fanoutChannel.subscribe());
  executor_.drain();

  EXPECT_TRUE(fanoutChannel.anySubscribers());

  EXPECT_CALL(*callback3, onValue(2));
  sender.write(2);
  executor_.drain();

  EXPECT_CALL(*callback3, onClosed());
  std::move(fanoutChannel).close();
  executor_.drain();
}

TEST_F(FanoutChannelFixture, VectorBool) {
  auto [inputReceiver, sender] = Channel<bool>::create();
  auto fanoutChannel =
      createFanoutChannel(std::move(inputReceiver), &executor_);

  auto [handle1, callback1] = processValues(fanoutChannel.subscribe(
      [](const auto&) { return toVector(true); } /* getInitialValues */));
  auto [handle2, callback2] = processValues(fanoutChannel.subscribe(
      [](const auto&) { return toVector(false); } /* getInitialValues */));

  EXPECT_CALL(*callback1, onValue(true));
  EXPECT_CALL(*callback2, onValue(false));

  executor_.drain();

  EXPECT_TRUE(fanoutChannel.anySubscribers());

  EXPECT_CALL(*callback1, onValue(true));
  EXPECT_CALL(*callback2, onValue(true));
  EXPECT_CALL(*callback1, onValue(false));
  EXPECT_CALL(*callback2, onValue(false));

  EXPECT_CALL(*callback1, onClosed());
  EXPECT_CALL(*callback2, onClosed());

  sender.write(true);
  sender.write(false);
  executor_.drain();

  std::move(sender).close();
  executor_.drain();

  EXPECT_FALSE(fanoutChannel.anySubscribers());
}

class FanoutChannelFixtureStress : public Test {
 protected:
  FanoutChannelFixtureStress()
      : producer_(makeProducer()),
        consumers_(toVector(makeConsumer(), makeConsumer(), makeConsumer())) {}

  static std::unique_ptr<StressTestProducer<int>> makeProducer() {
    return std::make_unique<StressTestProducer<int>>(
        [value = 0]() mutable { return value++; });
  }

  static std::unique_ptr<StressTestConsumer<int>> makeConsumer() {
    return std::make_unique<StressTestConsumer<int>>(
        ConsumptionMode::CallbackWithHandle,
        [lastReceived = -1](int value) mutable {
          if (lastReceived == -1) {
            lastReceived = value;
          } else {
            EXPECT_EQ(value, ++lastReceived);
          }
        });
  }

  static void sleepFor(std::chrono::milliseconds duration) {
    /* sleep override */
    std::this_thread::sleep_for(duration);
  }

  static constexpr std::chrono::milliseconds kTestTimeout =
      std::chrono::milliseconds{10};

  std::unique_ptr<StressTestProducer<int>> producer_;
  std::vector<std::unique_ptr<StressTestConsumer<int>>> consumers_;
};

TEST_F(FanoutChannelFixtureStress, HandleClosed) {
  auto [receiver, sender] = Channel<int>::create();
  producer_->startProducing(std::move(sender), std::nullopt /* closeEx */);

  folly::CPUThreadPoolExecutor fanoutChannelExecutor(1);
  auto fanoutChannel = createFanoutChannel(
      std::move(receiver),
      folly::SerialExecutor::create(&fanoutChannelExecutor));

  consumers_.at(0)->startConsuming(fanoutChannel.subscribe());
  consumers_.at(1)->startConsuming(fanoutChannel.subscribe());

  sleepFor(kTestTimeout / 3);

  consumers_.at(2)->startConsuming(fanoutChannel.subscribe());

  sleepFor(kTestTimeout / 3);

  consumers_.at(0)->cancel();
  EXPECT_EQ(consumers_.at(0)->waitForClose().get(), CloseType::Cancelled);

  sleepFor(kTestTimeout / 3);

  std::move(fanoutChannel).close();
  EXPECT_EQ(consumers_.at(1)->waitForClose().get(), CloseType::NoException);
  EXPECT_EQ(consumers_.at(2)->waitForClose().get(), CloseType::NoException);
}

TEST_F(FanoutChannelFixtureStress, InputChannelClosed) {
  auto [receiver, sender] = Channel<int>::create();
  producer_->startProducing(std::move(sender), std::nullopt /* closeEx */);

  folly::CPUThreadPoolExecutor fanoutChannelExecutor(1);
  auto fanoutChannel = createFanoutChannel(
      std::move(receiver),
      folly::SerialExecutor::create(&fanoutChannelExecutor));

  consumers_.at(0)->startConsuming(fanoutChannel.subscribe());
  consumers_.at(1)->startConsuming(fanoutChannel.subscribe());

  sleepFor(kTestTimeout / 3);

  consumers_.at(2)->startConsuming(fanoutChannel.subscribe());

  sleepFor(kTestTimeout / 3);

  consumers_.at(0)->cancel();
  EXPECT_EQ(consumers_.at(0)->waitForClose().get(), CloseType::Cancelled);

  sleepFor(kTestTimeout / 3);

  producer_->stopProducing();
  EXPECT_EQ(consumers_.at(1)->waitForClose().get(), CloseType::NoException);
  EXPECT_EQ(consumers_.at(2)->waitForClose().get(), CloseType::NoException);
}
} // namespace channels
} // namespace folly