folly/folly/channels/detail/ChannelBridge.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <folly/Try.h>
#include <folly/experimental/channels/detail/AtomicQueue.h>

namespace folly {
namespace channels {
namespace detail {

class ChannelBridgeBase {};

class IChannelCallback {
 public:
  virtual ~IChannelCallback() = default;

  virtual void consume(ChannelBridgeBase* bridge) = 0;

  virtual void canceled(ChannelBridgeBase* bridge) = 0;
};

using SenderQueue = typename folly::channels::detail::Queue<Unit>;

template <typename TValue>
using ReceiverQueue = typename folly::channels::detail::Queue<Try<TValue>>;

template <typename TValue>
class ChannelBridge : public ChannelBridgeBase {
 public:
  struct Deleter {
    void operator()(ChannelBridge<TValue>* ptr) { ptr->decref(); }
  };
  using Ptr = std::unique_ptr<ChannelBridge<TValue>, Deleter>;

  static Ptr create() { return Ptr(new ChannelBridge<TValue>()); }

  Ptr copy() {
    auto refCount = refCount_.fetch_add(1, std::memory_order_relaxed);
    DCHECK(refCount > 0);
    return Ptr(this);
  }

  // These should only be called from the sender thread

  template <typename U = TValue>
  void senderPush(U&& value) {
    receiverQueue_.push(
        Try<TValue>(std::forward<U>(value)),
        static_cast<ChannelBridgeBase*>(this));
  }

  bool senderWait(IChannelCallback* callback) {
    return senderQueue_.wait(callback, static_cast<ChannelBridgeBase*>(this));
  }

  IChannelCallback* cancelSenderWait() { return senderQueue_.cancelCallback(); }

  void senderClose() {
    if (!isSenderClosed()) {
      receiverQueue_.push(Try<TValue>(), static_cast<ChannelBridgeBase*>(this));
      senderQueue_.close(static_cast<ChannelBridgeBase*>(this));
    }
  }

  void senderClose(exception_wrapper ex) {
    if (!isSenderClosed()) {
      receiverQueue_.push(
          Try<TValue>(std::move(ex)), static_cast<ChannelBridgeBase*>(this));
      senderQueue_.close(static_cast<ChannelBridgeBase*>(this));
    }
  }

  bool isSenderClosed() { return senderQueue_.isClosed(); }

  SenderQueue senderGetValues() {
    return senderQueue_.getMessages(static_cast<ChannelBridgeBase*>(this));
  }

  // These should only be called from the receiver thread

  void receiverCancel() {
    if (!isReceiverCancelled()) {
      senderQueue_.push(Unit(), static_cast<ChannelBridgeBase*>(this));
      receiverQueue_.close(static_cast<ChannelBridgeBase*>(this));
    }
  }

  bool isReceiverCancelled() { return receiverQueue_.isClosed(); }

  bool receiverWait(IChannelCallback* callback) {
    return receiverQueue_.wait(callback, static_cast<ChannelBridgeBase*>(this));
  }

  IChannelCallback* cancelReceiverWait() {
    return receiverQueue_.cancelCallback();
  }

  ReceiverQueue<TValue> receiverGetValues() {
    return receiverQueue_.getMessages(static_cast<ChannelBridgeBase*>(this));
  }

 private:
  using ReceiverAtomicQueue = typename folly::channels::detail::
      AtomicQueue<IChannelCallback, Try<TValue>>;

  using SenderAtomicQueue =
      typename folly::channels::detail::AtomicQueue<IChannelCallback, Unit>;

  void decref() {
    if (refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) {
      delete this;
    }
  }

  ReceiverAtomicQueue receiverQueue_;
  SenderAtomicQueue senderQueue_;
  std::atomic<int8_t> refCount_{1};
};

template <typename TValue>
using ChannelBridgePtr = typename ChannelBridge<TValue>::Ptr;
} // namespace detail
} // namespace channels
} // namespace folly