folly/folly/channels/ProxyChannel-inl.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <folly/channels/ProxyChannel.h>
#include <folly/experimental/channels/detail/Utility.h>

namespace folly {
namespace channels {

template <typename ValueType>
ProxyChannel<ValueType>::ProxyChannel(TProcessor* processor)
    : processor_(processor) {}

template <typename ValueType>
ProxyChannel<ValueType>::ProxyChannel(ProxyChannel&& other) noexcept
    : processor_(std::exchange(other.processor_, nullptr)) {}

template <typename ValueType>
ProxyChannel<ValueType>& ProxyChannel<ValueType>::operator=(
    ProxyChannel&& other) noexcept {
  if (&other == this) {
    return *this;
  }
  if (processor_) {
    std::move(*this).close();
  }
  processor_ = std::exchange(other.processor_, nullptr);
  return *this;
}

template <typename ValueType>
ProxyChannel<ValueType>::~ProxyChannel() {
  if (processor_) {
    std::move(*this).close();
  }
}

template <typename ValueType>
ProxyChannel<ValueType>::operator bool() const {
  return processor_;
}

template <typename ValueType>
void ProxyChannel<ValueType>::setInputReceiver(Receiver<ValueType> receiver) {
  processor_->setInputReceiver(std::move(receiver));
}

template <typename ValueType>
void ProxyChannel<ValueType>::removeInputReceiver() {
  processor_->removeInputReceiver();
}

template <typename ValueType>
void ProxyChannel<ValueType>::close(folly::exception_wrapper&& ex) && {
  processor_->destroyHandle(
      ex ? detail::CloseResult(std::move(ex)) : detail::CloseResult());
  processor_ = nullptr;
}

namespace detail {

/**
 * This object does the proxying of values from the input receiver to the output
 * receiver.
 */
template <typename ValueType>
class ProxyChannelProcessor : public IChannelCallback {
 private:
  struct State {
    explicit State(ChannelBridgePtr<ValueType> _sender)
        : sender(std::move(_sender)) {}

    ChannelState getSenderState() {
      return detail::getSenderState(sender.get());
    }

    // The output sender for the proxy channel.
    ChannelBridgePtr<ValueType> sender;

    // The current input receiver for the proxy channel.
    ChannelBridge<ValueType>* receiver{nullptr};

    // The refcount for this proxy channel. The handle (if not yet destroyed),
    // the sender (if not yet cancelled), the current input receiver (if any),
    // and any previous input receivers not yet joined (if any) will contribute
    // to this refcount. It starts at 2, since a new ProxyChannel always has
    // one handle, one output receiver, and no input receivers.
    size_t refCount{2};
  };

  using WLockedStatePtr = typename folly::Synchronized<State>::WLockedPtr;

 public:
  ProxyChannelProcessor(
      Sender<ValueType> sender,
      folly::Executor::KeepAlive<folly::SequencedExecutor> executor)
      : executor_(std::move(executor)),
        state_(State(std::move(detail::senderGetBridge(sender)))) {
    auto state = state_.wlock();
    CHECK(state->sender->senderWait(this));
  }

  /**
   * Sets a new input receiver (removing the old input receiver, if any).
   */
  void setInputReceiver(Receiver<ValueType> receiver) {
    auto state = state_.wlock();
    if (state->getSenderState() != ChannelState::Active) {
      return;
    }
    auto [unbufferedReceiver, buffer] =
        detail::receiverUnbuffer(std::move(receiver));
    cancelInputReceiverIfExists(state);
    auto receiverPtr = unbufferedReceiver.release();
    state->receiver = receiverPtr;
    state->refCount++;
    processAllAvailableValues(std::move(state), receiverPtr, std::move(buffer));
  }

  /**
   * Removes the current input receiver.
   */
  void removeInputReceiver() {
    auto state = state_.wlock();
    if (state->getSenderState() != ChannelState::Active) {
      return;
    }
    cancelInputReceiverIfExists(state);
  }

  /**
   * Called when the user's ProxyChannel object is destroyed.
   */
  void destroyHandle(CloseResult closeResult) {
    processHandleDestroyed(state_.wlock(), std::move(closeResult));
  }

  /**
   * Called when one of the channels we are listening to has an update (either
   * a value from an input receiver or a cancellation from the output receiver).
   */
  void consume(ChannelBridgeBase* bridge) override {
    executor_->add([=, this]() {
      auto state = state_.wlock();
      if (bridge == state->sender.get()) {
        // The consumer of the output receiver has stopped consuming.
        state->sender->senderClose();
        processSenderCancelled(std::move(state));
      } else {
        // One or more values are now available from an input receiver.
        auto* receiver = static_cast<ChannelBridge<ValueType>*>(bridge);
        processAllAvailableValues(std::move(state), receiver);
      }
    });
  }

  /**
   * Called after we cancelled one of the channels we were listening to (either
   * the sender or an input receiver).
   */
  void canceled(ChannelBridgeBase* bridge) override {
    executor_->add([=, this]() {
      auto state = state_.wlock();
      if (bridge == state->sender.get()) {
        // We previously cancelled the sender due to an input receiver closure.
        // Process the cancellation for the sender.
        CHECK(state->getSenderState() == ChannelState::CancellationTriggered);
        processSenderCancelled(std::move(state));
      } else {
        // We previously cancelled this input receiver. Process the cancellation
        // for this input receiver.
        auto* receiver = static_cast<ChannelBridge<ValueType>*>(bridge);
        processReceiverCancelled(std::move(state), receiver, CloseResult());
      }
    });
  }

 protected:
  /**
   * Processes all available values from the current input receiver channel
   * (starting from the provided buffer, if present).
   *
   * If an value was received indicating that the input channel has been closed
   * we will process cancellation for the input receiver.
   */
  void processAllAvailableValues(
      WLockedStatePtr state,
      ChannelBridge<ValueType>* receiver,
      std::optional<ReceiverQueue<ValueType>> buffer = std::nullopt) {
    CHECK_NOTNULL(receiver);
    if (!receiver->isReceiverCancelled()) {
      CHECK_EQ(receiver, state->receiver);
    }
    auto closeResult = receiver->isReceiverCancelled()
        ? CloseResult()
        : (buffer.has_value() ? processValues(state, std::move(buffer.value()))
                              : std::nullopt);
    while (!closeResult.has_value()) {
      if (receiver->receiverWait(this)) {
        // There are no more values available right now. We will stop processing
        // until the channel fires the consume() callback (indicating that more
        // values are available).
        break;
      }
      auto values = receiver->receiverGetValues();
      CHECK(!values.empty());
      closeResult = processValues(state, std::move(values));
    }
    if (closeResult.has_value()) {
      // The receiver received a value indicating channel closure.
      receiver->receiverCancel();
      processReceiverCancelled(
          std::move(state), receiver, std::move(closeResult.value()));
    }
  }

  /**
   * Processes the given set of values for an input receiver. Returns a
   * CloseResult if the given channel was closed, so the caller can stop
   * attempting to process values from it.
   */
  std::optional<CloseResult> processValues(
      WLockedStatePtr& state, ReceiverQueue<ValueType> values) {
    while (!values.empty()) {
      auto inputResult = std::move(values.front());
      values.pop();
      if (inputResult.hasValue()) {
        // We have received a normal value from an input receiver. Write it to
        // the output receiver.
        state->sender->senderPush(std::move(inputResult.value()));
      } else {
        // The input receiver was closed.
        return inputResult.hasException()
            ? CloseResult(std::move(inputResult.exception()))
            : CloseResult();
      }
    }
    return std::nullopt;
  }

  /**
   * Processes the cancellation of an input receiver.
   */
  void processReceiverCancelled(
      WLockedStatePtr state,
      ChannelBridge<ValueType>* receiver,
      CloseResult closeResult) {
    CHECK(receiver->isReceiverCancelled());
    if (receiver == state->receiver &&
        state->getSenderState() == ChannelState::Active) {
      if (closeResult.exception.has_value()) {
        state->sender->senderClose(std::move(closeResult.exception.value()));
      } else {
        state->sender->senderClose();
      }
    }
    if (state->receiver == receiver) {
      state->receiver = nullptr;
    }
    (ChannelBridgePtr<ValueType>(receiver)); // Delete the receiver
    state->refCount--;
    maybeDelete(std::move(state));
  }

  /**
   * Processes the cancellation of the sender (indicating that the consumer of
   * the output receiver has stopped consuming). We will trigger cancellation
   * for the input receiver if it is not already cancelled.
   */
  void processSenderCancelled(WLockedStatePtr state) {
    CHECK(state->getSenderState() == ChannelState::CancellationTriggered);
    state->sender.reset();
    state->refCount--;
    cancelInputReceiverIfExists(state);
    maybeDelete(std::move(state));
  }

  /**
   * Processes the destruction of the user's ProxyChannel object.  We will
   * close the sender and trigger cancellation for the input receiver (if any).
   */
  void processHandleDestroyed(WLockedStatePtr state, CloseResult closeResult) {
    if (state->getSenderState() == ChannelState::Active) {
      if (closeResult.exception.has_value()) {
        state->sender->senderClose(std::move(closeResult.exception.value()));
      } else {
        state->sender->senderClose();
      }
    }
    cancelInputReceiverIfExists(state);
    state->refCount--;
    maybeDelete(std::move(state));
  }

  /**
   * Cancels the current input receiver if it exists.
   */
  void cancelInputReceiverIfExists(WLockedStatePtr& state) {
    if (state->receiver != nullptr) {
      CHECK(!state->receiver->isReceiverCancelled());
      state->receiver->receiverCancel();
      state->receiver = nullptr;
    }
  }

  /**
   * Deletes this object if we have already processed cancellation for the
   * sender, the current input receiver, and all previous input receivers, and
   * if the user's ProxyChannel object was destroyed.
   */
  void maybeDelete(WLockedStatePtr state) {
    if (state->refCount == 0) {
      CHECK_NULL(state->sender.get());
      CHECK_NULL(state->receiver);
      state.unlock();
      delete this;
    }
  }

  folly::Executor::KeepAlive<folly::SequencedExecutor> executor_;
  folly::Synchronized<State> state_;
};
} // namespace detail

template <typename ValueType>
std::pair<Receiver<ValueType>, ProxyChannel<ValueType>> createProxyChannel(
    folly::Executor::KeepAlive<folly::SequencedExecutor> executor) {
  auto [receiver, sender] = Channel<ValueType>::create();
  auto* processor = new detail::ProxyChannelProcessor<ValueType>(
      std::move(sender), std::move(executor));
  return std::make_pair(
      std::move(receiver), ProxyChannel<ValueType>(processor));
}
} // namespace channels
} // namespace folly