folly/folly/channels/Transform-inl.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <folly/channels/Channel.h>
#include <folly/executors/SequencedExecutor.h>
#include <folly/experimental/channels/detail/Utility.h>
#include <folly/experimental/coro/AsyncGenerator.h>
#include <folly/experimental/coro/Task.h>

namespace folly {
namespace channels {

namespace detail {

/**
 * This object transforms values from the input receiver to the output receiver.
 * It is not an object that the user is aware of or holds a pointer to. The
 * lifetime of this object is described by the following state machine.
 *
 * Both the sender and receiver can be in one of three states: Active,
 * CancellationTriggered, or CancellationProcessed. When both the sender and
 * receiver reach the CancellationProcessed state, this object is deleted.
 *
 * When the input receiver receives a value indicating that the channel has been
 * closed, the state of the receiver transitions from Active directly to
 * CancellationProcessed and the state of the sender transitions from Active to
 * CancellationTriggered. Once we receive a callback indicating the sender's
 * cancellation signal has been received, the sender's state is transitioned
 * from CancellationTriggered to CancellationProcessed (and the object is
 * deleted).
 *
 * When the sender receives notification that the consumer of the output
 * receiver has stopped consuming, the state of the sender transitions from
 * Active directly to CancellationProcessed, and the state of the input receiver
 * transitions from Active to CancellationTriggered. Once we receive a callback
 * indicating that the input receiver's cancellation signal has been received,
 * the input receiver's state is transitioned from CancellationTriggered to
 * to CancellationProcessed (and the object is deleted).
 */
template <
    typename InputValueType,
    typename OutputValueType,
    typename TransformerType>
class TransformProcessorBase : public IChannelCallback {
 public:
  TransformProcessorBase(
      Sender<OutputValueType> sender, TransformerType transformer)
      : sender_(std::move(senderGetBridge(sender))),
        transformer_(std::move(transformer)) {}

  template <typename ReceiverType>
  void startTransform(ReceiverType receiver) {
    executeWhenReady([=, this, receiver = std::move(receiver)](
                         std::unique_ptr<RateLimiter::Token> token) mutable {
      runOperationWithSenderCancellation(
          transformer_.getExecutor(),
          this->sender_,
          false /* alreadyStartedWaiting */,
          this /* channelCallbackToRestore */,
          startTransformImpl(std::move(receiver)),
          std::move(token));
    });
  }

 protected:
  /**
   * Starts transforming values from the input receiver and sending the
   * resulting transformed values to the output receiver.
   *
   * @param inputReceiver: The input receiver to transform values from.
   */
  folly::coro::Task<void> startTransformImpl(
      Receiver<InputValueType> receiver) {
    auto [unbufferedInputReceiver, buffer] =
        detail::receiverUnbuffer(std::move(receiver));
    receiver_ = std::move(unbufferedInputReceiver);
    co_await processAllAvailableValues(std::move(buffer));
  }

  /**
   * This is called when one of the channels we are listening to has an update
   * (either a value from the input receiver or a cancellation signal from the
   * sender).
   */
  void consume(ChannelBridgeBase* bridge) override {
    executeWhenReady([=, this](std::unique_ptr<RateLimiter::Token> token) {
      if (bridge == receiver_.get()) {
        // We have received new values from the input receiver.
        CHECK_NE(getReceiverState(), ChannelState::CancellationProcessed);
        runOperationWithSenderCancellation(
            transformer_.getExecutor(),
            this->sender_,
            true /* alreadyStartedWaiting */,
            this /* channelCallbackToRestore */,
            processAllAvailableValues(),
            std::move(token));
      } else {
        CHECK_NE(getSenderState(), ChannelState::CancellationProcessed);
        // The consumer of the output receiver has stopped consuming.
        if (getSenderState() == ChannelState::Active) {
          sender_->senderClose();
        }
        processSenderCancelled();
      }
    });
  }

  /**
   * This is called after we explicitly cancel one of the channels we are
   * listening to.
   */
  void canceled(ChannelBridgeBase* bridge) override {
    executeWhenReady([=, this](std::unique_ptr<RateLimiter::Token> token) {
      if (bridge == receiver_.get()) {
        // We previously cancelled the input receiver (because the consumer of
        // the output receiver stopped consuming). Process the cancellation for
        // the input receiver.
        CHECK_EQ(getReceiverState(), ChannelState::CancellationTriggered);
        runOperationWithSenderCancellation(
            transformer_.getExecutor(),
            this->sender_,
            true /* alreadyStartedWaiting */,
            this /* channelCallbackToRestore */,
            processReceiverCancelled(CloseResult()),
            std::move(token));
      } else {
        // We previously cancelled the sender due to the closure of the input
        // receiver. Process the cancellation for the sender.
        CHECK_EQ(getSenderState(), ChannelState::CancellationTriggered);
        processSenderCancelled();
      }
    });
  }

  /**
   * Processes all available values from the input receiver (starting from the
   * provided buffer, if present).
   *
   * If a value was received indicating that the input channel has been closed
   * (or if the transform function indicated that channel should be closed), we
   * will process cancellation for the input receiver.
   */
  folly::coro::Task<void> processAllAvailableValues(
      std::optional<ReceiverQueue<InputValueType>> buffer = std::nullopt) {
    auto closeResult = buffer.has_value()
        ? co_await processValues(std::move(buffer.value()))
        : std::nullopt;
    while (!closeResult.has_value()) {
      if (receiver_->receiverWait(this)) {
        // There are no more values available right now, but more values may
        // come in the future. We will stop processing for now, until we
        // re-start processing when the consume() callback is fired.
        break;
      }
      auto values = receiver_->receiverGetValues();
      CHECK(!values.empty());
      closeResult = co_await processValues(std::move(values));
    }
    if (closeResult.has_value()) {
      // The output receiver should be closed (either because the input receiver
      // was closed or the transform function desired the closure of the output
      // receiver).
      receiver_->receiverCancel();
      co_await processReceiverCancelled(std::move(closeResult.value()));
    }
  }

  /**
   * Processes the given set of values for the input receiver. If the output
   * receiver should be closed (either because the input receiver was closed or
   * the transform function desired the closure of the output receiver), a
   * CloseResult is returned containing the exception (if any) that should be
   * used to close the output receiver.
   */
  folly::coro::Task<std::optional<CloseResult>> processValues(
      ReceiverQueue<InputValueType> values) {
    auto cancelToken = co_await folly::coro::co_current_cancellation_token;
    while (!values.empty()) {
      auto inputResult = std::move(values.front());
      values.pop();
      bool inputClosed = !inputResult.hasValue();
      if (!inputResult.hasValue() && !inputResult.hasException()) {
        inputResult = Try<InputValueType>(OnClosedException());
      }
      auto outputGen = folly::makeTryWith([&]() {
        return transformer_.transformValue(std::move(inputResult));
      });
      if (!outputGen.hasValue()) {
        // The transform function threw an exception and was not a coroutine.
        // We will close the output receiver.
        co_return outputGen.template hasException<OnClosedException>()
            ? CloseResult()
            : CloseResult(std::move(outputGen.exception()));
      }
      while (true) {
        auto outputResult =
            co_await folly::coro::co_awaitTry(outputGen->next());
        if (!outputResult.hasException() && !outputResult->has_value()) {
          break;
        }
        if (cancelToken.isCancellationRequested()) {
          co_return CloseResult();
        }
        if (!outputResult.hasException()) {
          sender_->senderPush(std::move(outputResult->value()));
        } else {
          // The transform coroutine threw an exception. We will close the
          // output receiver.
          co_return outputResult.template hasException<OnClosedException>()
              ? CloseResult()
              : CloseResult(std::move(outputResult.exception()));
        }
      }
      if (inputClosed) {
        // The input receiver was closed, and the transform function did not
        // explicitly close the output receiver. We will therefore close it
        // anyway, as it does not make sense to keep it open when no future
        // values will arrive.
        co_return CloseResult();
      }
    }
    co_return std::nullopt;
  }

  /**
   * Process cancellation for the input receiver.
   */
  virtual folly::coro::Task<void> processReceiverCancelled(
      CloseResult closeResult, bool noRetriesAllowed = false) = 0;

  /**
   * Process cancellation for the sender.
   */
  void processSenderCancelled() {
    CHECK_EQ(getSenderState(), ChannelState::CancellationTriggered);
    sender_ = nullptr;
    if (getReceiverState() == ChannelState::Active) {
      receiver_->receiverCancel();
    }
    maybeDelete();
  }

  /**
   * Deletes this object if we have already processed cancellation for the
   * receiver and the sender.
   */
  void maybeDelete() {
    if (getReceiverState() == ChannelState::CancellationProcessed &&
        getSenderState() == ChannelState::CancellationProcessed) {
      delete this;
    }
  }

  ChannelState getReceiverState() {
    return detail::getReceiverState(receiver_.get());
  }

  ChannelState getSenderState() {
    return detail::getSenderState(sender_.get());
  }

  void executeWhenReady(
      folly::Function<void(std::unique_ptr<RateLimiter::Token>)> func) {
    auto rateLimiter = transformer_.getRateLimiter();
    if (rateLimiter != nullptr) {
      rateLimiter->executeWhenReady(
          std::move(func), transformer_.getExecutor());
    } else {
      transformer_.getExecutor()->add([func = std::move(func)]() mutable {
        func(std::unique_ptr<RateLimiter::Token>(nullptr));
      });
    }
  }

  ChannelBridgePtr<InputValueType> receiver_;
  ChannelBridgePtr<OutputValueType> sender_;
  TransformerType transformer_;
};

/**
 * This subclass is used for simple transformations triggered by a call to the
 * transform function (i.e. with a single input receiver and no initialization
 * function).
 */
template <
    typename InputValueType,
    typename OutputValueType,
    typename TransformerType>
class TransformProcessor : public TransformProcessorBase<
                               InputValueType,
                               OutputValueType,
                               TransformerType> {
 public:
  using Base =
      TransformProcessorBase<InputValueType, OutputValueType, TransformerType>;
  using Base::Base;

 private:
  /**
   * Process cancellation for the input receiver.
   */
  folly::coro::Task<void> processReceiverCancelled(
      CloseResult closeResult, bool /* noRetriesAllowed */) override {
    CHECK_EQ(this->getReceiverState(), ChannelState::CancellationTriggered);
    this->receiver_ = nullptr;
    if (this->getSenderState() == ChannelState::Active) {
      if (closeResult.exception.has_value()) {
        this->sender_->senderClose(std::move(closeResult.exception.value()));
      } else {
        this->sender_->senderClose();
      }
    }
    this->maybeDelete();
    co_return;
  }
};

/**
 * This subclass is used for resumable transformations triggered by a call to
 * the resumableTransform function.
 */
template <
    typename InitializeArg,
    typename InputValueType,
    typename OutputValueType,
    typename TransformerType>
class ResumableTransformProcessor : public TransformProcessorBase<
                                        InputValueType,
                                        OutputValueType,
                                        TransformerType> {
 public:
  using Base =
      TransformProcessorBase<InputValueType, OutputValueType, TransformerType>;
  using Base::Base;

  void initialize(InitializeArg initializeArg) {
    this->executeWhenReady(
        [=, this, initializeArg = std::move(initializeArg)](
            std::unique_ptr<RateLimiter::Token> token) mutable {
          runOperationWithSenderCancellation(
              this->transformer_.getExecutor(),
              this->sender_,
              false /* currentlyWaiting */,
              this /* channelCallbackToRestore */,
              initializeImpl(std::move(initializeArg)),
              std::move(token));
        });
  }

 private:
  /**
   * Runs the user-provided initialization function to get a set of initial
   * values and a receiver to continue transforming. This is called when the
   * resumableTransform is created, and again whenever the previous input
   * receiver closed without an exception.
   */
  folly::coro::Task<void> initializeImpl(InitializeArg initializeArg) {
    auto cancelToken = co_await folly::coro::co_current_cancellation_token;
    auto initializeResult = co_await folly::coro::co_awaitTry(
        this->transformer_.initializeTransform(std::move(initializeArg)));
    if (initializeResult.hasException()) {
      auto closeResult =
          initializeResult.template hasException<OnClosedException>()
          ? CloseResult()
          : CloseResult(std::move(initializeResult.exception()));
      co_await processReceiverCancelled(
          std::move(closeResult), true /* noRetriesAllowed */);
      co_return;
    }
    auto [initialValues, inputReceiver] = std::move(initializeResult.value());
    CHECK(inputReceiver)
        << "The initialize function of a resumableTransform returned an "
           "invalid receiver.";
    if (cancelToken.isCancellationRequested()) {
      // The sender was closed before we finished running the initialization
      // function. We will ignore the results from that function and proceed
      // to process cancellation for the receiver.
      co_await processReceiverCancelled(
          CloseResult(), true /* noRetriesAllowed */);
      co_return;
    }
    for (auto& initialValue : initialValues) {
      this->sender_->senderPush(std::move(initialValue));
    }
    co_await this->startTransformImpl(std::move(inputReceiver));
  }

  /**
   * Process cancellation for the input receiver.
   */
  folly::coro::Task<void> processReceiverCancelled(
      CloseResult closeResult, bool noRetriesAllowed) override {
    if (this->receiver_) {
      CHECK_EQ(this->getReceiverState(), ChannelState::CancellationTriggered);
      this->receiver_ = nullptr;
    }
    auto cancelToken = co_await folly::coro::co_current_cancellation_token;
    if (this->getSenderState() == ChannelState::Active &&
        !cancelToken.isCancellationRequested()) {
      if (!closeResult.exception.has_value()) {
        // We were closed without an exception. We will close the sender without
        // an exception.
        this->sender_->senderClose();
      } else if (
          noRetriesAllowed ||
          !closeResult.exception
               ->is_compatible_with<ReinitializeException<InitializeArg>>()) {
        // We were closed with an exception. We will close the sender with that
        // exception.
        this->sender_->senderClose(std::move(closeResult.exception.value()));
      } else {
        // We were closed with a ReinitializeException. We will re-run the
        // user's initialization function and resume the resumableTransform.
        auto* reinitializeEx =
            closeResult.exception
                ->get_exception<ReinitializeException<InitializeArg>>();
        co_await initializeImpl(std::move(reinitializeEx->initializeArg));
        co_return;
      }
    }
    this->maybeDelete();
  }
};

template <bool Enabled>
class RateLimiterHolder;

template <>
class RateLimiterHolder<true> {
 public:
  explicit RateLimiterHolder(std::shared_ptr<RateLimiter> rateLimiter)
      : rateLimiter_(std::move(rateLimiter)) {}

  std::shared_ptr<RateLimiter> getRateLimiter() { return rateLimiter_; }

 private:
  std::shared_ptr<RateLimiter> rateLimiter_;
};

template <>
class RateLimiterHolder<false> {
 public:
  explicit RateLimiterHolder(std::shared_ptr<RateLimiter> rateLimiter) {
    CHECK_NULL(rateLimiter.get());
  }

  std::shared_ptr<RateLimiter> getRateLimiter() { return nullptr; }
};

template <
    typename InputValueType,
    typename OutputValueType,
    typename TransformValueFunc,
    bool RateLimiterEnabled>
class DefaultTransformer : public RateLimiterHolder<RateLimiterEnabled> {
 public:
  DefaultTransformer(
      folly::Executor::KeepAlive<folly::SequencedExecutor> executor,
      TransformValueFunc transformValue,
      std::shared_ptr<RateLimiter> rateLimiter)
      : RateLimiterHolder<RateLimiterEnabled>(std::move(rateLimiter)),
        executor_(std::move(executor)),
        transformValue_(std::move(transformValue)) {}

  folly::Executor::KeepAlive<folly::SequencedExecutor> getExecutor() {
    return executor_;
  }

  auto transformValue(Try<InputValueType> inputValue) {
    return transformValue_(std::move(inputValue));
  }

 private:
  folly::Executor::KeepAlive<folly::SequencedExecutor> executor_;
  TransformValueFunc transformValue_;
};

template <
    typename InitializeArg,
    typename InputValueType,
    typename OutputValueType,
    typename InitializeTransformFunc,
    typename TransformValueFunc,
    bool RateLimiterEnabled>
class DefaultResumableTransformer : public DefaultTransformer<
                                        InputValueType,
                                        OutputValueType,
                                        TransformValueFunc,
                                        RateLimiterEnabled> {
 public:
  using Base = DefaultTransformer<
      InputValueType,
      OutputValueType,
      TransformValueFunc,
      RateLimiterEnabled>;

  DefaultResumableTransformer(
      folly::Executor::KeepAlive<folly::SequencedExecutor> executor,
      InitializeTransformFunc initializeTransform,
      TransformValueFunc transformValue,
      std::shared_ptr<RateLimiter> rateLimiter)
      : Base(
            std::move(executor),
            std::move(transformValue),
            std::move(rateLimiter)),
        initializeTransform_(std::move(initializeTransform)) {}

  auto initializeTransform(InitializeArg initializeArg) {
    return initializeTransform_(std::move(initializeArg));
  }

 private:
  InitializeTransformFunc initializeTransform_;
};
} // namespace detail

template <
    typename ReceiverType,
    typename TransformValueFunc,
    typename InputValueType,
    typename OutputValueType>
Receiver<OutputValueType> transform(
    ReceiverType inputReceiver,
    folly::Executor::KeepAlive<folly::SequencedExecutor> executor,
    TransformValueFunc transformValue,
    std::shared_ptr<RateLimiter> rateLimiter) {
  if (rateLimiter != nullptr) {
    using TransformerType = detail::DefaultTransformer<
        InputValueType,
        OutputValueType,
        TransformValueFunc,
        true /* RateLimiterEnabled */>;
    return transform(
        std::move(inputReceiver),
        TransformerType(
            std::move(executor),
            std::move(transformValue),
            std::move(rateLimiter)));

  } else {
    using TransformerType = detail::DefaultTransformer<
        InputValueType,
        OutputValueType,
        TransformValueFunc,
        false /* RateLimiterEnabled */>;
    return transform(
        std::move(inputReceiver),
        TransformerType(
            std::move(executor),
            std::move(transformValue),
            nullptr /* rateLimiter */));
  }
}

template <
    typename ReceiverType,
    typename TransformerType,
    typename InputValueType,
    typename OutputValueType>
Receiver<OutputValueType> transform(
    ReceiverType inputReceiver, TransformerType transformer) {
  auto [outputReceiver, outputSender] = Channel<OutputValueType>::create();
  using TProcessor = detail::
      TransformProcessor<InputValueType, OutputValueType, TransformerType>;
  auto* processor =
      new TProcessor(std::move(outputSender), std::move(transformer));
  processor->startTransform(std::move(inputReceiver));
  return std::move(outputReceiver);
}

template <
    typename InitializeArg,
    typename InitializeTransformFunc,
    typename TransformValueFunc,
    typename ReceiverType,
    typename InputValueType,
    typename OutputValueType>
Receiver<OutputValueType> resumableTransform(
    folly::Executor::KeepAlive<folly::SequencedExecutor> executor,
    InitializeArg initializeArg,
    InitializeTransformFunc initializeTransform,
    TransformValueFunc transformValue,
    std::shared_ptr<RateLimiter> rateLimiter) {
  if (rateLimiter != nullptr) {
    using TransformerType = detail::DefaultResumableTransformer<
        InitializeArg,
        InputValueType,
        OutputValueType,
        InitializeTransformFunc,
        TransformValueFunc,
        true /* RateLimiterEnabled */>;
    return resumableTransform(
        std::move(initializeArg),
        TransformerType(
            std::move(executor),
            std::move(initializeTransform),
            std::move(transformValue),
            std::move(rateLimiter)));
  } else {
    using TransformerType = detail::DefaultResumableTransformer<
        InitializeArg,
        InputValueType,
        OutputValueType,
        InitializeTransformFunc,
        TransformValueFunc,
        false /* RateLimiterEnabled */>;
    return resumableTransform(
        std::move(initializeArg),
        TransformerType(
            std::move(executor),
            std::move(initializeTransform),
            std::move(transformValue),
            nullptr /* rateLimiter */));
  }
}

template <
    typename InitializeArg,
    typename TransformerType,
    typename ReceiverType,
    typename InputValueType,
    typename OutputValueType>
Receiver<OutputValueType> resumableTransform(
    InitializeArg initializeArg, TransformerType transformer) {
  auto [outputReceiver, outputSender] = Channel<OutputValueType>::create();
  using TProcessor = detail::ResumableTransformProcessor<
      InitializeArg,
      InputValueType,
      OutputValueType,
      TransformerType>;
  auto* processor =
      new TProcessor(std::move(outputSender), std::move(transformer));
  processor->initialize(std::move(initializeArg));
  return std::move(outputReceiver);
}

} // namespace channels
} // namespace folly