folly/folly/channels/MultiplexChannel-inl.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <folly/channels/MultiplexChannel.h>
#include <folly/channels/RateLimiter.h>
#include <folly/experimental/channels/detail/Utility.h>
#include <folly/experimental/coro/FutureUtil.h>
#include <folly/experimental/coro/Mutex.h>
#include <folly/experimental/coro/Promise.h>

namespace folly {
namespace channels {

template <typename MultiplexerType>
MultiplexedSubscriptions<MultiplexerType>::MultiplexedSubscriptions(
    SubscriptionMap& subscriptions)
    : subscriptions_(subscriptions) {}

template <typename MultiplexerType>
bool MultiplexedSubscriptions<MultiplexerType>::hasSubscription(
    const MultiplexedSubscriptions::KeyType& key) {
  return subscriptions_.contains(key) && !closedSubscriptionKeys_.contains(key);
}

template <typename MultiplexerType>
typename MultiplexedSubscriptions<MultiplexerType>::KeyContextType&
MultiplexedSubscriptions<MultiplexerType>::getKeyContext(
    const MultiplexedSubscriptions::KeyType& key) {
  ensureKeyExists(key);
  return std::get<KeyContextType>(subscriptions_.at(key));
}

template <typename MultiplexerType>
template <typename U>
void MultiplexedSubscriptions<MultiplexerType>::write(
    const MultiplexedSubscriptions::KeyType& key, U&& value) {
  ensureKeyExists(key);
  auto& sender =
      std::get<FanoutSender<OutputValueType>>(subscriptions_.at(key));
  sender.write(std::forward<U>(value));
}

template <typename MultiplexerType>
void MultiplexedSubscriptions<MultiplexerType>::close(
    const MultiplexedSubscriptions::KeyType& key, exception_wrapper ex) {
  ensureKeyExists(key);
  auto& sender =
      std::get<FanoutSender<OutputValueType>>(subscriptions_.at(key));
  if (ex) {
    std::move(sender).close(std::move(ex));
  } else {
    std::move(sender).close();
  }
  // We do not erase from the subscriptions_ map yet, because we do not want
  // to invalidate the view returned by getSubscriptionKeys.
  closedSubscriptionKeys_.insert(key);
}

template <typename MultiplexerType>
void MultiplexedSubscriptions<MultiplexerType>::ensureKeyExists(
    const KeyType& key) {
  if (!subscriptions_.contains(key) || closedSubscriptionKeys_.contains(key)) {
    throw std::runtime_error("Subscription with the given key does not exist.");
  }
}

template <typename MultiplexerType>
MultiplexChannel<MultiplexerType>::MultiplexChannel(TProcessor* processor)
    : processor_(processor) {}

template <typename MultiplexerType>
MultiplexChannel<MultiplexerType>::MultiplexChannel(
    MultiplexChannel&& other) noexcept
    : processor_(std::exchange(other.processor_, nullptr)) {}

template <typename MultiplexerType>
MultiplexChannel<MultiplexerType>& MultiplexChannel<MultiplexerType>::operator=(
    MultiplexChannel&& other) noexcept {
  if (&other == this) {
    return *this;
  }
  if (processor_) {
    std::move(*this).close();
  }
  processor_ = std::exchange(other.processor_, nullptr);
  return *this;
}

template <typename MultiplexerType>
MultiplexChannel<MultiplexerType>::~MultiplexChannel() {
  if (processor_ != nullptr) {
    std::move(*this).close(exception_wrapper());
  }
}

template <typename MultiplexerType>
MultiplexChannel<MultiplexerType>::operator bool() const {
  return processor_;
}

template <typename MultiplexerType>
Receiver<typename MultiplexChannel<MultiplexerType>::OutputValueType>
MultiplexChannel<MultiplexerType>::subscribe(
    KeyType key, SubscriptionArgType subscriptionArg) {
  return processor_->subscribe(std::move(key), std::move(subscriptionArg));
}

template <typename MultiplexerType>
folly::coro::Task<std::vector<std::pair<
    typename MultiplexChannel<MultiplexerType>::KeyType,
    typename MultiplexChannel<MultiplexerType>::KeyContextType>>>
MultiplexChannel<MultiplexerType>::clearUnusedSubscriptions() {
  co_return co_await processor_->clearUnusedSubscriptions();
}

template <typename MultiplexerType>
bool MultiplexChannel<MultiplexerType>::anySubscribers() const {
  return processor_->anySubscribers();
}

template <typename MultiplexerType>
void MultiplexChannel<MultiplexerType>::close(exception_wrapper ex) && {
  processor_->destroyHandle(
      ex ? detail::CloseResult(std::move(ex)) : detail::CloseResult());
  processor_ = nullptr;
}

namespace detail {

/**
 * This object fans out values from the input receiver to all output receivers.
 * The lifetime of this object is described by the following state machine.
 *
 * The input receiver can be in one of three conceptual states: Active,
 * CancellationTriggered, or CancellationProcessed (removed). When the input
 * receiver reaches the CancellationProcessed state AND the user's
 * MultiplexChannel object is deleted, this object is deleted.
 *
 * When an input receiver receives a value indicating that the channel has
 * been closed, the state of the input receiver transitions from Active directly
 * to CancellationProcessed (and this object will be deleted once the user
 * destroys their MultiplexChannel object).
 *
 * When the user destroys their MultiplexChannel object, the state of the input
 * receiver transitions from Active to CancellationTriggered. This object will
 * then be deleted once the input receiver transitions to the
 * CancellationProcessed state.
 */
template <typename MultiplexerType>
class MultiplexChannelProcessor : public IChannelCallback {
 private:
  using KeyType = typename detail::MultiplexerTraits<MultiplexerType>::KeyType;
  using KeyContextType =
      typename detail::MultiplexerTraits<MultiplexerType>::KeyContextType;
  using SubscriptionArgType =
      typename detail::MultiplexerTraits<MultiplexerType>::SubscriptionArgType;
  using InputValueType =
      typename detail::MultiplexerTraits<MultiplexerType>::InputValueType;
  using OutputValueType =
      typename detail::MultiplexerTraits<MultiplexerType>::OutputValueType;

 public:
  explicit MultiplexChannelProcessor(MultiplexerType multiplexer)
      : multiplexer_(std::move(multiplexer)),
        totalSubscriptions_(0),
        pendingAsyncCalls_(0) {}

  /**
   * Starts multiplexing values from the input receiver to to one or more keyed
   * subscriptions.
   */
  void start(Receiver<InputValueType> inputReceiver) {
    executeWithMutexWhenReady(
        [this,
         inputReceiver =
             std::move(inputReceiver)]() mutable -> folly::coro::Task<void> {
          co_await processStart(std::move(inputReceiver));
        });
  }

  Receiver<OutputValueType> subscribe(
      KeyType key, SubscriptionArgType subscriptionArg) {
    auto [receiver, sender] = Channel<OutputValueType>::create();
    totalSubscriptions_.fetch_add(1);
    executeWithMutexWhenReady(
        [this,
         key = std::move(key),
         subscriptionArg = std::move(subscriptionArg),
         sender_2 = std::move(sender)]() mutable -> folly::coro::Task<void> {
          co_await processNewSubscription(
              std::move(key), std::move(subscriptionArg), std::move(sender_2));
        });
    return std::move(receiver);
  }

  folly::coro::Task<std::vector<std::pair<KeyType, KeyContextType>>>
  clearUnusedSubscriptions() {
    auto [promise, future] = folly::coro::makePromiseContract<
        std::vector<std::pair<KeyType, KeyContextType>>>();
    executeWithMutexWhenReady(
        [this,
         promise_2 = std::move(promise)]() mutable -> folly::coro::Task<void> {
          co_await processClearUnusedSubscriptions(std::move(promise_2));
        });
    return folly::coro::toTask(std::move(future));
  }

  bool anySubscribers() { return totalSubscriptions_.load() > 0; }

  /**
   * This is called when the user's MultiplexChannel object has been destroyed.
   */
  void destroyHandle(CloseResult closeResult) {
    executeWithMutexWhenReady(
        [this,
         closeResult =
             std::move(closeResult)]() mutable -> folly::coro::Task<void> {
          co_await processHandleDestroyed(std::move(closeResult));
        });
  }

 private:
  /**
   * Called when the input receiver has an update.
   */
  void consume(ChannelBridgeBase*) override {
    executeWithMutexWhenReady([this]() -> folly::coro::Task<void> {
      co_await processAllAvailableValues();
    });
  }

  /**
   * Called after we cancelled this input receiver, due to the destruction of
   * the handle.
   */
  void canceled(ChannelBridgeBase*) override {
    executeWithMutexWhenReady([this]() -> folly::coro::Task<void> {
      auto closeResult = CloseResult(); // Declaring first due to GCC bug
      co_await processReceiverCancelled(std::move(closeResult));
    });
  }

  folly::coro::Task<void> processStart(Receiver<InputValueType> inputReceiver) {
    auto [unbufferedInputReceiver, buffer] =
        detail::receiverUnbuffer(std::move(inputReceiver));
    receiver_ = std::move(unbufferedInputReceiver);

    // Start processing new values that come in from the input receiver.
    co_await processAllAvailableValues(std::move(buffer));
  }

  /**
   * Processes all available values from the input receiver (starting from the
   * provided buffer, if present).
   *
   * If an value was received indicating that the input channel has been closed
   * (or if the transform function indicated that channel should be closed), we
   * will process cancellation for the input receiver.
   */
  folly::coro::Task<void> processAllAvailableValues(
      std::optional<ReceiverQueue<InputValueType>> buffer = std::nullopt) {
    CHECK_NE(getReceiverState(), ChannelState::CancellationProcessed);
    auto closeResult = receiver_->isReceiverCancelled()
        ? CloseResult()
        : (buffer.has_value()
               ? co_await processValues(std::move(buffer.value()))
               : std::nullopt);
    while (!closeResult.has_value()) {
      if (receiver_->receiverWait(this)) {
        // There are no more values available right now. We will stop processing
        // until the channel fires the consume() callback (indicating that more
        // values are available).
        break;
      }
      auto values = receiver_->receiverGetValues();
      CHECK(!values.empty());
      closeResult = co_await processValues(std::move(values));
    }
    if (closeResult.has_value()) {
      // The receiver received a value indicating channel closure.
      receiver_->receiverCancel();
      co_await processReceiverCancelled(std::move(closeResult.value()));
    }
  }

  /**
   * Processes the given set of values for the input receiver. Returns a
   * CloseResult if channel was closed, so the caller can stop attempting to
   * process values from it.
   */
  folly::coro::Task<std::optional<CloseResult>> processValues(
      ReceiverQueue<InputValueType> values) {
    while (!values.empty()) {
      auto inputResult = std::move(values.front());
      values.pop();
      bool inputClosed = !inputResult.hasValue();
      auto subscriptions =
          MultiplexedSubscriptions<MultiplexerType>(subscriptions_);
      if (inputClosed && !inputResult.hasException()) {
        // The input channel was closed. We will send an OnClosedException to
        // onInputValue.
        inputResult = Try<InputValueType>(
            folly::make_exception_wrapper<OnClosedException>());
      }

      // Process the input value by calling onInputValue on the user's
      // multiplexer.
      auto onInputValueResult = co_await folly::coro::co_awaitTry(
          multiplexer_.onInputValue(std::move(inputResult), subscriptions));

      // If the user closed any subscriptions, erase them from the subscriptions
      // map.
      for (const auto& key : subscriptions.closedSubscriptionKeys_) {
        subscriptions_.erase(key);
      }
      if (!subscriptions.closedSubscriptionKeys_.empty()) {
        totalSubscriptions_.fetch_sub(
            subscriptions.closedSubscriptionKeys_.size());
        subscriptions.closedSubscriptionKeys_.clear();
      }

      if (inputClosed && onInputValueResult.hasValue()) {
        // The input channel was closed, but the onInputValue function did not
        // throw. We need to close all output receivers.
        onInputValueResult =
            Try<void>(folly::make_exception_wrapper<OnClosedException>());
      }
      if (!onInputValueResult.hasValue()) {
        co_return onInputValueResult.template hasException<OnClosedException>()
            ? CloseResult()
            : CloseResult(std::move(onInputValueResult.exception()));
      }
    }
    co_return std::nullopt;
  }

  /**
   * Processes the cancellation of the input receiver. We will close all
   * senders with the exception received from the input receiver (if any).
   */
  folly::coro::Task<void> processReceiverCancelled(CloseResult closeResult) {
    CHECK_EQ(getReceiverState(), ChannelState::CancellationTriggered);
    receiver_ = nullptr;
    closeAllSubscriptions(std::move(closeResult));
    co_return;
  }

  folly::coro::Task<void> processNewSubscription(
      KeyType key,
      SubscriptionArgType subscriptionArg,
      Sender<OutputValueType> newSender) {
    if (subscriptions_.contains(key)) {
      // We already had a subscription for this key.
      totalSubscriptions_.fetch_sub(1);
    }
    auto& [sender, context] = subscriptions_[key];
    auto initialValues =
        co_await folly::coro::co_awaitTry(multiplexer_.onNewSubscription(
            key, context, std::move(subscriptionArg)));
    if (initialValues.hasException()) {
      std::move(newSender).close(initialValues.exception());
      co_return;
    }
    for (auto& initialValue : initialValues.value()) {
      newSender.write(std::move(initialValue));
    }
    sender.subscribe(std::move(newSender));
  }

  folly::coro::Task<void> processClearUnusedSubscriptions(
      folly::coro::Promise<std::vector<std::pair<KeyType, KeyContextType>>>
          promise) {
    auto clearedSubscriptions =
        std::vector<std::pair<KeyType, KeyContextType>>();
    size_t subscriptionsToRemove = 0;
    for (auto it = subscriptions_.begin(); it != subscriptions_.end();) {
      auto& sender = std::get<FanoutSender<OutputValueType>>(it->second);
      if (!sender.anySubscribers()) {
        clearedSubscriptions.push_back(std::make_pair(
            it->first, std::move(std::get<KeyContextType>(it->second))));
        it = subscriptions_.erase(it);
        subscriptionsToRemove++;
      } else {
        ++it;
      }
    }

    totalSubscriptions_.fetch_sub(subscriptionsToRemove);
    promise.setValue(std::move(clearedSubscriptions));
    co_return;
  }

  /**
   * Processes the destruction of the user's MultiplexChannel object.  We will
   * cancel the receiver and trigger cancellation for all senders not already
   * cancelled.
   */
  folly::coro::Task<void> processHandleDestroyed(CloseResult closeResult) {
    handleDeleted_ = true;
    if (getReceiverState() == ChannelState::Active) {
      receiver_->receiverCancel();
    }
    closeAllSubscriptions(std::move(closeResult));
    co_return;
  }

  /**
   * Deletes this object if we have already processed cancellation for the
   * receiver and all senders, and if the user's MultiplexChannel object was
   * destroyed.
   */
  void maybeDelete(std::unique_lock<folly::coro::Mutex>& lock) {
    if (getReceiverState() == ChannelState::CancellationProcessed &&
        handleDeleted_ && pendingAsyncCalls_ == 0) {
      lock.unlock();
      delete this;
    }
  }

  void executeWithMutexWhenReady(
      folly::Function<folly::coro::Task<void>()> func) {
    pendingAsyncCalls_++;
    auto rateLimiter = multiplexer_.getRateLimiter();
    if (rateLimiter != nullptr) {
      rateLimiter->executeWhenReady(
          [this, func = std::move(func), executor = multiplexer_.getExecutor()](
              std::unique_ptr<RateLimiter::Token> token) mutable {
            folly::coro::co_invoke(
                [this,
                 token = std::move(token),
                 func = std::move(func)]() mutable -> folly::coro::Task<void> {
                  auto lock = co_await mutex_.co_scoped_lock();
                  co_await func();
                  pendingAsyncCalls_--;
                  maybeDelete(lock);
                })
                .scheduleOn(executor)
                .start();
          },
          multiplexer_.getExecutor());
    } else {
      folly::coro::co_invoke(
          [this, func = std::move(func)]() mutable -> folly::coro::Task<void> {
            auto lock = co_await mutex_.co_scoped_lock();
            co_await func();
            pendingAsyncCalls_--;
            maybeDelete(lock);
          })
          .scheduleOn(multiplexer_.getExecutor())
          .start();
    }
  }

  ChannelState getReceiverState() {
    return detail::getReceiverState(receiver_.get());
  }

  void closeAllSubscriptions(CloseResult closeResult) {
    for (auto& [key, subscription] : subscriptions_) {
      auto& sender = std::get<FanoutSender<OutputValueType>>(subscription);
      std::move(sender).close(
          closeResult.exception.has_value() ? closeResult.exception.value()
                                            : exception_wrapper());
    }
    totalSubscriptions_.fetch_sub(subscriptions_.size());
    subscriptions_.clear();
  }

  using SubscriptionMap = folly::F14FastMap<
      KeyType,
      std::tuple<FanoutSender<OutputValueType>, KeyContextType>>;

  coro::Mutex mutex_;

  // The above coro mutex must be acquired before accessing this state.
  ChannelBridgePtr<InputValueType> receiver_;
  SubscriptionMap subscriptions_;
  bool handleDeleted_{false};

  // The above coro mutex does not need to be acquired before accessing this
  // state.
  MultiplexerType multiplexer_;
  std::atomic<uint64_t> totalSubscriptions_; // Includes pending subscriptions
  std::atomic<uint64_t> pendingAsyncCalls_;
};
} // namespace detail

template <typename MultiplexerType, typename InputReceiverType>
MultiplexChannel<MultiplexerType> createMultiplexChannel(
    MultiplexerType multiplexer, InputReceiverType inputReceiver) {
  auto* processor = new detail::MultiplexChannelProcessor<MultiplexerType>(
      std::move(multiplexer));
  processor->start(std::move(inputReceiver));
  return MultiplexChannel<MultiplexerType>(processor);
}
} // namespace channels
} // namespace folly