folly/folly/channels/ChannelProcessor-inl.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <folly/channels/ChannelProcessor.h>
#include <folly/channels/ConsumeChannel.h>
#include <folly/channels/MergeChannel.h>
#include <folly/channels/Transform.h>
#include <folly/executors/SerialExecutor.h>
#include <folly/experimental/channels/detail/IntrusivePtr.h>

namespace folly {
namespace channels {
namespace detail {

template <typename KeyType>
class ChannelProcessorImpl {
 public:
  ChannelProcessorImpl(
      std::vector<folly::Executor::KeepAlive<folly::SequencedExecutor>>
          executors,
      std::shared_ptr<folly::channels::RateLimiter> rateLimiter,
      MergeChannel<KeyType, Unit> mergeChannel,
      Receiver<MergeChannelEvent<KeyType, Unit>> mergeChannelReceiver)
      : implState_(make_intrusive<ImplState>(
            std::move(executors), std::move(rateLimiter))),
        channels_(std::move(mergeChannel)),
        handle_(consumeChannelWithCallback(
            std::move(mergeChannelReceiver),
            implState_->executors[0],
            [](Try<MergeChannelEvent<KeyType, Unit>>)
                -> folly::coro::Task<bool> {
              // Do nothing
              co_return true;
            })) {}

  template <typename ReceiverType, typename OnUpdateFunc>
  void addChannel(KeyType key, ReceiverType receiver, OnUpdateFunc onUpdate) {
    using InputValueType = typename ReceiverType::ValueType;
    channels_.removeReceiver(key);
    channels_.addNewReceiver(
        std::move(key),
        transform(
            std::move(receiver),
            Transformer<InputValueType, OnUpdateFunc>(
                implState_, std::move(onUpdate))));
  }

  template <
      typename InitializeArg,
      typename InitializeFunc,
      typename OnUpdateFunc>
  void addResumableChannelWithState(
      KeyType key,
      InitializeArg initializeArg,
      InitializeFunc initialize,
      OnUpdateFunc onUpdate) {
    addResumableChannelWithState(
        std::move(key),
        std::move(initializeArg),
        std::move(initialize),
        std::move(onUpdate),
        NoChannelState());
  }

  template <
      typename InitializeArg,
      typename InitializeFunc,
      typename OnUpdateFunc,
      typename ChannelState>
  void addResumableChannelWithState(
      KeyType key,
      InitializeArg initializeArg,
      InitializeFunc initialize,
      OnUpdateFunc onUpdate,
      ChannelState channelState) {
    using ReceiverType = typename decltype(initialize(
        std::move(initializeArg), channelState))::StorageType;
    using InputValueType = typename ReceiverType::ValueType;
    channels_.removeReceiver(key);
    channels_.addNewReceiver(
        std::move(key),
        resumableTransform(
            std::move(initializeArg),
            ResumableTransformer<
                InitializeArg,
                InputValueType,
                InitializeFunc,
                OnUpdateFunc,
                ChannelState>(
                implState_,
                std::move(initialize),
                std::move(onUpdate),
                std::move(channelState))));
  }

  void removeChannel(const KeyType& keyType) {
    channels_.removeReceiver(keyType);
  }

 private:
  struct NoChannelState {};

  template <
      typename Function,
      typename ReturnType =
          typename std::invoke_result_t<Function>::StorageType>
  static folly::coro::Task<ReturnType> catchNonCoroException(Function func) {
    auto result = folly::makeTryWith(std::move(func));
    if (result.hasException()) {
      return folly::coro::makeErrorTask<ReturnType>(
          std::move(result.exception()));
    } else {
      return std::move(result.value());
    }
  }

  struct ImplState : public IntrusivePtrBase<ImplState> {
    ImplState(
        std::vector<folly::Executor::KeepAlive<folly::SequencedExecutor>>
            _executors,
        std::shared_ptr<folly::channels::RateLimiter> _rateLimiter)
        : executors(std::move(_executors)),
          rateLimiter(std::move(_rateLimiter)) {}

    std::vector<folly::Executor::KeepAlive<folly::SequencedExecutor>> executors;
    std::shared_ptr<folly::channels::RateLimiter> rateLimiter;
  };

  template <typename InputValueType, typename OnUpdateFunc>
  class Transformer : public std::tuple<OnUpdateFunc> {
   public:
    Transformer(intrusive_ptr<ImplState> implState, OnUpdateFunc onUpdate)
        : std::tuple<OnUpdateFunc>(std::move(onUpdate)),
          implState_(std::move(implState)) {}

    folly::Executor::KeepAlive<folly::SequencedExecutor> getExecutor() {
      return implState_->executors
          [std::hash<decltype(this)>()(this) % implState_->executors.size()];
    }

    std::shared_ptr<folly::channels::RateLimiter> getRateLimiter() {
      return implState_->rateLimiter;
    }

    folly::coro::AsyncGenerator<Unit&&> transformValue(
        Try<InputValueType> value) {
      auto result = co_await folly::coro::co_awaitTry(catchNonCoroException(
          [&] { return std::get<OnUpdateFunc>(*this)(std::move(value)); }));
      if (result.template hasException<folly::OperationCancelled>() ||
          result.template hasException<OnClosedException>()) {
        co_yield folly::coro::co_error(OnClosedException());
      } else if (result.hasException()) {
        LOG(FATAL) << fmt::format(
            "Encountered exception from callback when consuming channel of "
            "type {}: {}",
            typeid(InputValueType).name(),
            result.exception().what());
      }
    }

   private:
    intrusive_ptr<ImplState> implState_;
  };

  template <
      typename InitializeArg,
      typename InputValueType,
      typename InitializeFunc,
      typename OnUpdateFunc,
      typename ChannelState>
  class ResumableTransformer
      : public std::tuple<InitializeFunc, OnUpdateFunc, ChannelState> {
   public:
    ResumableTransformer(
        intrusive_ptr<ImplState> implState,
        InitializeFunc initialize,
        OnUpdateFunc onUpdate,
        ChannelState channelState)
        : std::tuple<InitializeFunc, OnUpdateFunc, ChannelState>(
              std::move(initialize),
              std::move(onUpdate),
              std::move(channelState)),
          implState_(std::move(implState)) {}

    folly::Executor::KeepAlive<folly::SequencedExecutor> getExecutor() {
      return implState_->executors
          [std::hash<decltype(this)>()(this) % implState_->executors.size()];
    }

    std::shared_ptr<folly::channels::RateLimiter> getRateLimiter() {
      return implState_->rateLimiter;
    }

    folly::coro::Task<std::pair<std::vector<Unit>, Receiver<InputValueType>>>
    initializeTransform(InitializeArg initializeArg) {
      auto result = co_await folly::coro::co_awaitTry(
          initialize(std::move(initializeArg)));
      if (result.template hasException<folly::OperationCancelled>() ||
          result.template hasException<OnClosedException>()) {
        co_yield folly::coro::co_error(OnClosedException());
      } else if (result.hasException()) {
        LOG(FATAL) << folly::sformat(
            "Encountered exception from callback when consuming channel of "
            "type {}: {}",
            typeid(InputValueType).name(),
            result.exception().what());
      }
      co_return std::make_pair(std::vector<Unit>(), std::move(result.value()));
    }

    folly::coro::AsyncGenerator<Unit&&> transformValue(
        Try<InputValueType> value) {
      auto result =
          co_await folly::coro::co_awaitTry(onUpdate(std::move(value)));
      if (result
              .template hasException<ReinitializeException<InitializeArg>>()) {
        co_yield folly::coro::co_error(std::move(result.exception()));
      } else if (
          result.template hasException<folly::OperationCancelled>() ||
          result.template hasException<OnClosedException>()) {
        co_yield folly::coro::co_error(OnClosedException());
      } else if (result.hasException()) {
        LOG(FATAL) << folly::sformat(
            "Encountered exception from callback when consuming channel of "
            "type {}: {}",
            typeid(InputValueType).name(),
            result.exception().what());
      }
    }

   private:
    folly::coro::Task<Receiver<InputValueType>> initialize(
        InitializeArg initializeArg) {
      if constexpr (std::is_same_v<ChannelState, NoChannelState>) {
        co_return co_await catchNonCoroException([&] {
          return std::get<InitializeFunc>(*this)(std::move(initializeArg));
        });
      } else {
        co_return co_await catchNonCoroException([&] {
          return std::get<InitializeFunc>(*this)(
              std::move(initializeArg), std::get<ChannelState>(*this));
        });
      }
    }

    folly::coro::Task<void> onUpdate(Try<InputValueType> value) {
      if constexpr (std::is_same_v<ChannelState, NoChannelState>) {
        co_await catchNonCoroException(
            [&] { return std::get<OnUpdateFunc>(*this)(std::move(value)); });
      } else {
        co_await catchNonCoroException([&] {
          return std::get<OnUpdateFunc>(*this)(
              std::move(value), std::get<ChannelState>(*this));
        });
      }
    }

    intrusive_ptr<ImplState> implState_;
  };

  intrusive_ptr<ImplState> implState_;
  MergeChannel<KeyType, Unit> channels_;
  ChannelCallbackHandle handle_;
};
} // namespace detail

template <typename KeyType>
ChannelProcessor<KeyType>::ChannelProcessor(
    std::unique_ptr<detail::ChannelProcessorImpl<KeyType>> impl)
    : impl_(std::move(impl)) {}

template <typename KeyType>
ChannelProcessor<KeyType>::operator bool() const {
  return impl_ != nullptr;
}

template <typename KeyType>
template <typename ReceiverType, typename OnUpdateFunc>
void ChannelProcessor<KeyType>::addChannel(
    KeyType key, ReceiverType receiver, OnUpdateFunc onUpdate) {
  impl_->addChannel(std::move(key), std::move(receiver), std::move(onUpdate));
}

template <typename KeyType>
template <
    typename InitializeArg,
    typename InitializeFunc,
    typename OnUpdateFunc>
void ChannelProcessor<KeyType>::addResumableChannel(
    KeyType key,
    InitializeArg initializeArg,
    InitializeFunc initialize,
    OnUpdateFunc onUpdate) {
  impl_->addResumableChannel(
      std::move(key),
      std::move(initializeArg),
      std::move(initialize),
      std::move(onUpdate));
}

template <typename KeyType>
template <
    typename InitializeArg,
    typename InitializeFunc,
    typename OnUpdateFunc,
    typename ChannelState>
void ChannelProcessor<KeyType>::addResumableChannelWithState(
    KeyType key,
    InitializeArg initializeArg,
    InitializeFunc initialize,
    OnUpdateFunc onUpdate,
    ChannelState channelState) {
  impl_->addResumableChannelWithState(
      std::move(key),
      std::move(initializeArg),
      std::move(initialize),
      std::move(onUpdate),
      std::move(channelState));
}

template <typename KeyType>
void ChannelProcessor<KeyType>::removeChannel(const KeyType& keyType) {
  impl_->removeChannel(keyType);
}

template <typename KeyType>
void ChannelProcessor<KeyType>::close() && {
  impl_.reset();
}

template <typename KeyType>
ChannelProcessor<KeyType> createChannelProcessor(
    std::vector<folly::Executor::KeepAlive<folly::SequencedExecutor>> executors,
    std::shared_ptr<RateLimiter> rateLimiter) {
  CHECK_GT(executors.size(), 0);
  auto [mergeChannelReceiver, mergeChannel] =
      createMergeChannel<KeyType, Unit>(executors[0]);
  return ChannelProcessor<KeyType>(
      std::make_unique<detail::ChannelProcessorImpl<KeyType>>(
          std::move(executors),
          std::move(rateLimiter),
          std::move(mergeChannel),
          std::move(mergeChannelReceiver)));
}

template <typename KeyType>
ChannelProcessor<KeyType> createChannelProcessor(
    folly::Executor::KeepAlive<> executor,
    std::shared_ptr<RateLimiter> rateLimiter,
    size_t numSequencedExecutors) {
  CHECK_GT(numSequencedExecutors, 0);
  auto executors =
      std::vector<folly::Executor::KeepAlive<folly::SequencedExecutor>>();
  for (size_t i = 0; i < numSequencedExecutors; i++) {
    executors.push_back(folly::SerialExecutor::create(executor));
  }
  return createChannelProcessor<KeyType>(
      std::move(executors), std::move(rateLimiter));
}
} // namespace channels
} // namespace folly