folly/folly/channels/test/ChannelTestUtil.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <folly/channels/ConsumeChannel.h>
#include <folly/executors/CPUThreadPoolExecutor.h>
#include <folly/executors/IOThreadPoolExecutor.h>
#include <folly/executors/SequencedExecutor.h>
#include <folly/experimental/coro/DetachOnCancel.h>
#include <folly/experimental/coro/Sleep.h>
#include <folly/futures/SharedPromise.h>
#include <folly/portability/GMock.h>

namespace folly {
namespace channels {

template <typename T, typename... Others>
std::vector<T> toVector(T firstItem, Others... items) {
  std::vector<T> itemsVector;
  itemsVector.push_back(std::move(firstItem));
  (void(itemsVector.push_back(std::move(items))), ...);
  return itemsVector;
}

template <typename Key, typename Mapped, typename... Others>
folly::F14FastMap<std::remove_const_t<Key>, Mapped> toMap(
    std::pair<Key, Mapped> firstPair, Others... items) {
  folly::F14FastMap<std::remove_const_t<Key>, Mapped> itemsMap;
  itemsMap.insert(std::move(firstPair));
  (void(itemsMap.insert(std::move(items))), ...);
  return itemsMap;
}

template <typename TValue>
class MockNextCallback {
 public:
  void operator()(Try<TValue> result) {
    if (result.hasValue()) {
      onValue(result.value());
    } else if (result.template hasException<folly::OperationCancelled>()) {
      onCancelled();
    } else if (result.template hasException<std::runtime_error>()) {
      onRuntimeError(result.exception().what().toStdString());
    } else if (result.hasException()) {
      LOG(FATAL) << "Unexpected exception: " << result.exception().what();
    } else {
      onClosed();
    }
  }

  MOCK_METHOD(void, onValue, (TValue));
  MOCK_METHOD(void, onClosed, ());
  MOCK_METHOD(void, onCancelled, ());
  MOCK_METHOD(void, onRuntimeError, (std::string));
};

enum class ConsumptionMode {
  CoroWithTry,
  CoroWithoutTry,
  CallbackWithHandle,
};

template <typename TValue>
class ChannelConsumerBase {
 public:
  explicit ChannelConsumerBase(ConsumptionMode mode) : mode_(mode) {
    continueConsuming_.setValue(true);
  }

  ChannelConsumerBase(ChannelConsumerBase&&) = default;
  ChannelConsumerBase& operator=(ChannelConsumerBase&&) = default;

  virtual ~ChannelConsumerBase() = default;

  virtual folly::Executor::KeepAlive<folly::SequencedExecutor>
  getExecutor() = 0;

  virtual void onNext(Try<TValue> result) = 0;

  void startConsuming(Receiver<TValue> receiver) {
    folly::coro::co_withCancellation(
        cancellationSource_.getToken(), processValuesCoro(std::move(receiver)))
        .scheduleOn(getExecutor())
        .start();
  }

  folly::coro::Task<void> processValuesCoro(Receiver<TValue> receiver) {
    if (mode_ == ConsumptionMode::CoroWithTry ||
        mode_ == ConsumptionMode::CoroWithoutTry) {
      do {
        Try<TValue> resultTry;
        if (mode_ == ConsumptionMode::CoroWithTry) {
          resultTry = co_await folly::coro::co_awaitTry(receiver.next());
        } else if (mode_ == ConsumptionMode::CoroWithoutTry) {
          try {
            auto result = co_await receiver.next();
            if (result.has_value()) {
              resultTry = Try<TValue>(result.value());
            } else {
              resultTry = Try<TValue>();
            }
          } catch (...) {
            resultTry = Try<TValue>(exception_wrapper(current_exception()));
          }
        } else {
          LOG(FATAL) << "Unknown consumption mode";
        }
        bool hasValue = resultTry.hasValue();
        onNext(std::move(resultTry));
        if (!hasValue) {
          co_return;
        }
      } while (co_await folly::coro::detachOnCancel(
          continueConsuming_.getSemiFuture()));
    } else if (mode_ == ConsumptionMode::CallbackWithHandle) {
      auto callbackHandle = consumeChannelWithCallback(
          std::move(receiver),
          getExecutor(),
          [=, this](Try<TValue> resultTry) -> folly::coro::Task<bool> {
            onNext(std::move(resultTry));
            co_return co_await folly::coro::detachOnCancel(
                continueConsuming_.getSemiFuture());
          });
      cancelCallback_ = std::make_unique<folly::CancellationCallback>(
          co_await folly::coro::co_current_cancellation_token,
          [=, handle = std::move(callbackHandle)]() mutable {
            handle.reset();
          });
    } else {
      LOG(FATAL) << "Unknown consumption mode";
    }
  }

 protected:
  ConsumptionMode mode_;
  folly::CancellationSource cancellationSource_;
  folly::SharedPromise<bool> continueConsuming_;
  std::unique_ptr<folly::CancellationCallback> cancelCallback_;
};

enum class CloseType { NoException, Exception, Cancelled };

template <typename TValue>
class StressTestConsumer : public ChannelConsumerBase<TValue> {
 public:
  StressTestConsumer(
      ConsumptionMode mode, folly::Function<void(TValue)> onValue)
      : ChannelConsumerBase<TValue>(mode),
        executor_(std::make_unique<folly::IOThreadPoolExecutor>(1)),
        onValue_(std::move(onValue)) {}

  StressTestConsumer(StressTestConsumer&&) = delete;
  StressTestConsumer&& operator=(StressTestConsumer&&) = delete;

  ~StressTestConsumer() override {
    this->cancellationSource_.requestCancellation();
    if (!this->continueConsuming_.isFulfilled()) {
      this->continueConsuming_.setValue(false);
    }
    executor_.reset();
  }

  folly::Executor::KeepAlive<folly::SequencedExecutor> getExecutor() override {
    return executor_->getEventBase();
  }

  void onNext(Try<TValue> result) override {
    if (result.hasValue()) {
      onValue_(std::move(result.value()));
    } else if (result.template hasException<folly::OperationCancelled>()) {
      closedType_.setValue(CloseType::Cancelled);
    } else if (result.hasException()) {
      EXPECT_TRUE(result.template hasException<std::runtime_error>());
      closedType_.setValue(CloseType::Exception);
    } else {
      closedType_.setValue(CloseType::NoException);
    }
  }

  void cancel() { this->cancellationSource_.requestCancellation(); }

  folly::SemiFuture<CloseType> waitForClose() {
    return closedType_.getSemiFuture();
  }

 private:
  std::unique_ptr<folly::IOThreadPoolExecutor> executor_;
  folly::Function<void(TValue)> onValue_;
  folly::Promise<CloseType> closedType_;
};

template <typename TValue>
class StressTestProducer {
 public:
  explicit StressTestProducer(folly::Function<TValue()> getNextValue)
      : executor_(std::make_unique<folly::CPUThreadPoolExecutor>(1)),
        getNextValue_(std::move(getNextValue)) {}

  StressTestProducer(StressTestProducer&&) = delete;
  StressTestProducer&& operator=(StressTestProducer&&) = delete;

  ~StressTestProducer() {
    if (executor_) {
      stopProducing();
      executor_.reset();
    }
  }

  void startProducing(
      Sender<TValue> sender, std::optional<exception_wrapper> closeException) {
    auto produceTask = folly::coro::co_invoke(
        [=,
         this,
         sender = std::move(sender),
         ex = std::move(closeException)]() mutable -> folly::coro::Task<void> {
          for (int i = 1; !stopped_.load(std::memory_order_relaxed); i++) {
            if (i % 1000 == 0) {
              co_await folly::coro::sleep(std::chrono::milliseconds(100));
            }
            sender.write(getNextValue_());
          }
          if (ex.has_value()) {
            std::move(sender).close(std::move(ex.value()));
          } else {
            std::move(sender).close();
          }
          co_return;
        });
    std::move(produceTask).scheduleOn(executor_.get()).start();
  }

  void stopProducing() { stopped_.store(true); }

 private:
  std::unique_ptr<folly::CPUThreadPoolExecutor> executor_;
  folly::Function<TValue()> getNextValue_;
  std::atomic<bool> stopped_{false};
};
} // namespace channels
} // namespace folly