folly/folly/coro/Merge-inl.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <exception>
#include <memory>

#include <folly/CancellationToken.h>
#include <folly/Executor.h>
#include <folly/ScopeGuard.h>
#include <folly/experimental/coro/Baton.h>
#include <folly/experimental/coro/Mutex.h>
#include <folly/experimental/coro/Task.h>
#include <folly/experimental/coro/ViaIfAsync.h>
#include <folly/experimental/coro/WithCancellation.h>
#include <folly/experimental/coro/detail/Barrier.h>
#include <folly/experimental/coro/detail/BarrierTask.h>
#include <folly/experimental/coro/detail/CurrentAsyncFrame.h>
#include <folly/experimental/coro/detail/Helpers.h>

#if FOLLY_HAS_COROUTINES

namespace folly {
namespace coro {
namespace detail {

enum class CallbackRecordSelector { Invalid, Value, None, Error };

constexpr inline std::in_place_index_t<0> const callback_record_value{};
constexpr inline std::in_place_index_t<1> const callback_record_none{};
constexpr inline std::in_place_index_t<2> const callback_record_error{};

//
// CallbackRecord records the result of a single invocation of a callback.
//
// This is very related to Try and expected, but this also records None in
// addition to Value and Error results.
//
// When the callback supports multiple overloads of Value then T would be
// something like a variant<tuple<..>, ..>
//
// When the callback supports multiple overloads of Error then all the errors
// are coerced to folly::exception_wrapper
//
template <class T>
class CallbackRecord {
  static void clear(CallbackRecord* that) {
    auto selector =
        std::exchange(that->selector_, CallbackRecordSelector::Invalid);
    if (selector == CallbackRecordSelector::Value) {
      detail::deactivate(that->value_);
    } else if (selector == CallbackRecordSelector::Error) {
      detail::deactivate(that->error_);
    }
  }
  template <class OtherReference>
  static void convert_variant(
      CallbackRecord* that, const CallbackRecord<OtherReference>& other) {
    if (other.hasValue()) {
      detail::activate(that->value_, other.value_.get());
    } else if (other.hasError()) {
      detail::activate(that->error_, other.error_.get());
    }
    that->selector_ = other.selector_;
  }
  template <class OtherReference>
  static void convert_variant(
      CallbackRecord* that, CallbackRecord<OtherReference>&& other) {
    if (other.hasValue()) {
      detail::activate(that->value_, std::move(other.value_).get());
    } else if (other.hasError()) {
      detail::activate(that->error_, std::move(other.error_).get());
    }
    that->selector_ = other.selector_;
  }

 public:
  ~CallbackRecord() { clear(this); }

  CallbackRecord() noexcept : selector_(CallbackRecordSelector::Invalid) {}

  template <class V>
  CallbackRecord(const std::in_place_index_t<0>&, V&& v) noexcept(
      std::is_nothrow_constructible_v<T, V>)
      : CallbackRecord() {
    detail::activate(value_, std::forward<V>(v));
    selector_ = CallbackRecordSelector::Value;
  }
  explicit CallbackRecord(const std::in_place_index_t<1>&) noexcept
      : selector_(CallbackRecordSelector::None) {}
  CallbackRecord(
      const std::in_place_index_t<2>&, folly::exception_wrapper e) noexcept
      : CallbackRecord() {
    detail::activate(error_, std::move(e));
    selector_ = CallbackRecordSelector::Error;
  }

  CallbackRecord(CallbackRecord&& other) noexcept(
      std::is_nothrow_move_constructible_v<T>)
      : CallbackRecord() {
    convert_variant(this, std::move(other));
  }

  CallbackRecord& operator=(CallbackRecord&& other) noexcept(
      std::is_nothrow_move_constructible_v<T>) {
    if (&other != this) {
      clear(this);
      convert_variant(this, std::move(other));
    }
    return *this;
  }

  template <class U>
  CallbackRecord(CallbackRecord<U>&& other) noexcept(
      std::is_nothrow_constructible_v<T, U>)
      : CallbackRecord() {
    convert_variant(this, std::move(other));
  }

  bool hasNone() const noexcept {
    return selector_ == CallbackRecordSelector::None;
  }

  bool hasError() const noexcept {
    return selector_ == CallbackRecordSelector::Error;
  }

  decltype(auto) error() & {
    DCHECK(hasError());
    return error_.get();
  }

  decltype(auto) error() && {
    DCHECK(hasError());
    return std::move(error_).get();
  }

  decltype(auto) error() const& {
    DCHECK(hasError());
    return error_.get();
  }

  decltype(auto) error() const&& {
    DCHECK(hasError());
    return std::move(error_).get();
  }

  bool hasValue() const noexcept {
    return selector_ == CallbackRecordSelector::Value;
  }

  decltype(auto) value() & {
    DCHECK(hasValue());
    return value_.get();
  }

  decltype(auto) value() && {
    DCHECK(hasValue());
    return std::move(value_).get();
  }

  decltype(auto) value() const& {
    DCHECK(hasValue());
    return value_.get();
  }

  decltype(auto) value() const&& {
    DCHECK(hasValue());
    return std::move(value_).get();
  }

  explicit operator bool() const noexcept {
    return selector_ != CallbackRecordSelector::Invalid;
  }

 private:
  union {
    detail::ManualLifetime<T> value_;
    detail::ManualLifetime<folly::exception_wrapper> error_;
  };
  CallbackRecordSelector selector_;
};
} // namespace detail

template <typename Reference, typename Value>
AsyncGenerator<Reference, Value> merge(
    folly::Executor::KeepAlive<> executor,
    AsyncGenerator<AsyncGenerator<Reference, Value>> sources) {
  struct SharedState {
    explicit SharedState(folly::Executor::KeepAlive<> executor_)
        : executor(std::move(executor_)) {}

    const folly::Executor::KeepAlive<> executor;
    const folly::CancellationSource cancelSource;
    coro::Mutex mutex;
    coro::Baton recordPublished;
    coro::Baton recordConsumed;
    coro::Baton allTasksCompleted;
    detail::CallbackRecord<Reference> record;
  };

  auto makeConsumerTask =
      [](std::shared_ptr<SharedState> state,
         AsyncGenerator<AsyncGenerator<Reference, Value>> sources_)
      -> Task<void> {
    auto makeWorkerTask = [](std::shared_ptr<SharedState> state_,
                             AsyncGenerator<Reference, Value> generator)
        -> detail::DetachedBarrierTask {
      exception_wrapper ex;
      auto cancelToken = state_->cancelSource.getToken();
      try {
        while (auto item = co_await co_viaIfAsync(
                   state_->executor.get_alias(),
                   co_withCancellation(cancelToken, generator.next()))) {
          // We have a new value to emit in the merged stream.
          {
            auto lock = co_await co_viaIfAsync(
                state_->executor.get_alias(), state_->mutex.co_scoped_lock());

            if (cancelToken.isCancellationRequested()) {
              // Consumer has detached and doesn't want any more values.
              // Discard this value.
              break;
            }

            // Publish the value.
            state_->record = detail::CallbackRecord<Reference>{
                detail::callback_record_value, *std::move(item)};
            state_->recordPublished.post();

            // Wait until the consumer is finished with it.
            co_await co_viaIfAsync(
                state_->executor.get_alias(), state_->recordConsumed);
            state_->recordConsumed.reset();

            // Clear the result before releasing the lock.
            state_->record = {};
          }

          if (cancelToken.isCancellationRequested()) {
            break;
          }
        }
      } catch (...) {
        ex = exception_wrapper{current_exception()};
      }

      if (ex) {
        state_->cancelSource.requestCancellation();

        auto lock = co_await co_viaIfAsync(
            state_->executor.get_alias(), state_->mutex.co_scoped_lock());
        if (!state_->record.hasError()) {
          state_->record = detail::CallbackRecord<Reference>{
              detail::callback_record_error, std::move(ex)};
          state_->recordPublished.post();
        }
      };
    };

    detail::Barrier barrier{1};

    auto& asyncFrame = co_await detail::co_current_async_stack_frame;

    // Save the initial context and restore it after starting each task
    // as the task may have modified the context before suspending and we
    // want to make sure the next task is started with the same initial
    // context.
    const auto context = RequestContext::saveContext();

    exception_wrapper ex;
    try {
      while (auto item = co_await sources_.next()) {
        if (state->cancelSource.isCancellationRequested()) {
          break;
        }
        makeWorkerTask(state, *std::move(item)).start(&barrier, asyncFrame);
        RequestContext::setContext(context);
      }
    } catch (...) {
      ex = exception_wrapper{current_exception()};
    }

    if (ex) {
      state->cancelSource.requestCancellation();

      auto lock = co_await co_viaIfAsync(
          state->executor.get_alias(), state->mutex.co_scoped_lock());
      if (!state->record.hasError()) {
        state->record = detail::CallbackRecord<Reference>{
            detail::callback_record_error, std::move(ex)};
        state->recordPublished.post();
      }
    }

    // Wait for all worker tasks to finish consuming the entirety of their
    // input streams.
    co_await detail::UnsafeResumeInlineSemiAwaitable{barrier.arriveAndWait()};

    // Guaranteed there are no more concurrent producers trying to acquire
    // the mutex here.
    if (!state->record.hasError()) {
      // Stream not yet been terminated with an error.
      // Terminate the stream with the 'end()' signal.
      assert(!state->record.hasValue());
      state->record =
          detail::CallbackRecord<Reference>{detail::callback_record_none};
      state->recordPublished.post();
    }
  };

  auto state = std::make_shared<SharedState>(executor);

  SCOPE_EXIT {
    state->cancelSource.requestCancellation();
    // Make sure we resume the worker thread so that it has a chance to notice
    // that cancellation has been requested.
    state->recordConsumed.post();
  };

  // Start a task that consumes the stream of input streams.
  makeConsumerTask(state, std::move(sources))
      .scheduleOn(executor)
      .start(
          [state](auto&&) { state->allTasksCompleted.post(); },
          state->cancelSource.getToken());

  // Consume values produced by the input streams.
  while (true) {
    if (!state->recordPublished.ready()) {
      folly::CancellationCallback cb{
          co_await co_current_cancellation_token,
          [&] { state->cancelSource.requestCancellation(); }};
      co_await state->recordPublished;
    }
    state->recordPublished.reset();

    if (state->record.hasValue()) {
      // next value
      co_yield std::move(state->record).value();
      state->recordConsumed.post();
    } else {
      // We're closing the output stream. In the spirit of structured
      // concurrency, let's make sure to not leave any background tasks behind.
      co_await state->allTasksCompleted;

      if (state->record.hasError()) {
        std::move(state->record).error().throw_exception();
      } else {
        // none
        assert(state->record.hasNone());
        break;
      }
    }
  }
}

} // namespace coro
} // namespace folly

#endif // FOLLY_HAS_COROUTINES