folly/folly/channels/detail/Utility.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <optional>
#include <folly/ExceptionWrapper.h>
#include <folly/Function.h>
#include <folly/ScopeGuard.h>
#include <folly/channels/Channel.h>
#include <folly/channels/RateLimiter.h>
#include <folly/executors/SequencedExecutor.h>
#include <folly/experimental/coro/Promise.h>
#include <folly/experimental/coro/Task.h>

namespace folly {
namespace channels {
namespace detail {

struct CloseResult {
  CloseResult() {}

  explicit CloseResult(exception_wrapper _exception)
      : exception(std::move(_exception)) {}

  std::optional<exception_wrapper> exception;
};

enum class ChannelState {
  Active,
  CancellationTriggered,
  CancellationProcessed
};

template <typename TSender>
ChannelState getSenderState(TSender* sender) {
  if (sender == nullptr) {
    return ChannelState::CancellationProcessed;
  } else if (sender->isSenderClosed()) {
    return ChannelState::CancellationTriggered;
  } else {
    return ChannelState::Active;
  }
}

template <typename TReceiver>
ChannelState getReceiverState(TReceiver* receiver) {
  if (receiver == nullptr) {
    return ChannelState::CancellationProcessed;
  } else if (receiver->isReceiverCancelled()) {
    return ChannelState::CancellationTriggered;
  } else {
    return ChannelState::Active;
  }
}

inline std::ostream& operator<<(std::ostream& os, ChannelState state) {
  switch (state) {
    case ChannelState::Active:
      return os << "Active";
    case ChannelState::CancellationTriggered:
      return os << "CancellationTriggered";
    case ChannelState::CancellationProcessed:
      return os << "CancellationProcessed";
    default:
      return os << "Should never be hit";
  }
}

/**
 * A cancellation callback that wraps an existing channel callback. When the
 * callback is fired, this object will trigger cancellation on its cancellation
 * source (in addition to firing the wrapped callback).
 */
template <typename TSender>
class SenderCancellationCallback : public IChannelCallback {
 public:
  explicit SenderCancellationCallback(
      TSender& sender,
      folly::Executor::KeepAlive<folly::SequencedExecutor> executor,
      IChannelCallback* channelCallback)
      : sender_(sender),
        executor_(std::move(executor)),
        channelCallback_(channelCallback),
        callbackToFire_(folly::coro::makePromiseContract<CallbackToFire>()) {
    if (channelCallback_ == nullptr) {
      // The sender was already canceled runOperationWithSenderCancellation was
      // even called. This means the cancelled callback already was fired, so
      // we will not set the callback to fire here.
      cancelSource_.requestCancellation();
      return;
    }
    CHECK(sender_);
    if (!sender_->senderWait(this)) {
      // The sender was cancelled after runOperationWithSenderCancellation was
      // called, but before we had a chance to start the operation. This means
      // that the cancelled callback was never called. We will therefore set it
      // to fire here, when the operation is complete.
      cancelSource_.requestCancellation();
      callbackToFire_.first.setValue(CallbackToFire::Consume);
    }
  }

  folly::coro::Task<void> onTaskCompleted() {
    if (!channelCallback_) {
      co_return;
    }
    auto callbackToFire = std::optional<CallbackToFire>();
    bool promiseSet = false;
    if (callbackToFire_.second.isReady()) {
      // The callback was fired.
      promiseSet = true;
      callbackToFire = co_await std::move(callbackToFire_.second);
    } else {
      // The callback has not yet been fired.
      if (!sender_->cancelSenderWait()) {
        // The sender has been cancelled, but the callback has not been called
        // yet. Wait for the callback to be called.
        promiseSet = true;
        callbackToFire = co_await std::move(callbackToFire_.second);
      } else if (!sender_->senderWait(channelCallback_)) {
        // The sender was cancelled between the call to cancelSenderWait and
        // the call to senderWait. This means that the cancelled callback was
        // never called. We will therefore set it to fire here.
        callbackToFire = CallbackToFire::Consume;
      }
    }
    if (!promiseSet) {
      // Set a default value here, so we don't need to waste time constructing a
      // broken promise exception when the promise is destructed. This value
      // will not be read.
      callbackToFire_.first.setValue(CallbackToFire::Consume);
    }
    if (callbackToFire.has_value()) {
      switch (callbackToFire.value()) {
        case CallbackToFire::Consume:
          channelCallback_->consume(sender_.get());
          co_return;
        case CallbackToFire::Canceled:
          channelCallback_->canceled(sender_.get());
          co_return;
      }
    }
    // The sender has not yet been cancelled, and we are now back in the state
    // where the sender is waiting on the user-provided callback. We are done.
  }

  /**
   * Returns a cancellation token that will trigger when the sender
   */
  folly::CancellationToken getCancellationToken() {
    return cancelSource_.getToken();
  }

  /**
   * Requests cancellation, and triggers the consume function on the callback
   * if the callback was not previously triggered.
   */
  void consume(ChannelBridgeBase*) override {
    cancelSource_.requestCancellation();
    executor_->add([this]() {
      CHECK(!callbackToFire_.second.isReady());
      callbackToFire_.first.setValue(CallbackToFire::Consume);
    });
  }

  /**
   * Requests cancellation, and triggers the canceled function on the callback
   * if the callback was not previously triggered.
   */
  void canceled(ChannelBridgeBase*) override {
    cancelSource_.requestCancellation();
    executor_->add([this]() {
      CHECK(!callbackToFire_.second.isReady());
      callbackToFire_.first.setValue(CallbackToFire::Canceled);
    });
  }

 private:
  enum class CallbackToFire { Consume, Canceled };

  TSender& sender_;
  folly::Executor::KeepAlive<folly::SequencedExecutor> executor_;
  IChannelCallback* channelCallback_;
  folly::CancellationSource cancelSource_;
  std::pair<
      folly::coro::Promise<CallbackToFire>,
      folly::coro::Future<CallbackToFire>>
      callbackToFire_;
};

/**
 * Any object that produces an output receiver (transform, merge,
 * MergeChannel, etc) will listen for a cancellation signal from that output
 * receiver. Once the consumer of the output receiver stops consuming, a
 * callback will be called that triggers these objects to start cleaning
 * themselves up (and eventually destroy themselves).
 *
 * However, when one of these objects decides to run a user coroutine, they
 * would like that user coroutine to be able to get notified when that
 * cancellation signal is received. That allows the coroutine to stop any
 * long-running operations quickly, rather than running a long time when the
 * consumer of the output receiver no longer cares about the result.
 *
 * This function enables that behavior. It will run the provided operation
 * coroutine. While that coroutine is running, it will listen to cancellation
 * events from the output receiver (through its sender). If it receives a
 * cancellation signal from the sender, it will trigger cancellation of the
 * operation coroutine.
 *
 * Once the coroutine finishes, it will then call the given channel callback
 * to notify it of the cancellation event (the same way that callback would
 * have been notified if no coroutine had been started). It will also resume
 * waiting on the channel callback.
 *
 * @param executor: The executor to run the coroutine on.
 *
 * @param sender: The sender to use to listen for cancellation. If this is
 * null, we will assume that cancellation already occurred.
 *
 * @param alreadyStartedWaiting: Whether or not the caller already started
 * listening for a cancellation signal from the output receiver. If so, this
 * function will temporarily stop waiting with that callback (so it can listen
 * for the cancellation signal to stop the coroutine).
 *
 * @param channelCallbackToRestore: The channel callback to restore once the
 *  coroutine operation is complete.
 *
 * @param operation: The operation to run.
 *
 * @param token: The rate limiter token for this operation.
 */
template <typename TSender>
void runOperationWithSenderCancellation(
    folly::Executor::KeepAlive<folly::SequencedExecutor> executor,
    TSender& sender,
    bool alreadyStartedWaiting,
    IChannelCallback* channelCallbackToRestore,
    folly::coro::Task<void> operation,
    std::unique_ptr<RateLimiter::Token> token) noexcept {
  if (alreadyStartedWaiting && (!sender || !sender->cancelSenderWait())) {
    // The output receiver was cancelled before starting this operation
    // (indicating that the channel callback already ran).
    channelCallbackToRestore = nullptr;
  }
  folly::coro::co_invoke(
      [&sender,
       executor,
       channelCallbackToRestore,
       token = std::move(token),
       operation = std::move(operation)]() mutable -> folly::coro::Task<void> {
        auto senderCancellationCallback = SenderCancellationCallback(
            sender, executor, channelCallbackToRestore);
        auto result =
            co_await folly::coro::co_awaitTry(folly::coro::co_withCancellation(
                senderCancellationCallback.getCancellationToken(),
                std::move(operation)));
        if (result.hasException()) {
          LOG(FATAL) << fmt::format(
              "Unexpected exception when running coroutine operation with "
              "sender cancellation: {}",
              result.exception().what());
        }
        co_await senderCancellationCallback.onTaskCompleted();
      })
      .scheduleOn(executor)
      .start();
}
} // namespace detail
} // namespace channels
} // namespace folly