folly/folly/coro/RustAdaptors.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <folly/CancellationToken.h>
#include <folly/Executor.h>
#include <folly/Optional.h>
#include <folly/experimental/coro/AsyncGenerator.h>
#include <folly/experimental/coro/Task.h>
#include <folly/futures/Future.h>
#include <folly/synchronization/Baton.h>

#if FOLLY_HAS_COROUTINES

namespace folly {
namespace coro {

template <typename T>
class PollFuture final : private Executor {
 public:
  using Poll = Optional<lift_unit_t<T>>;
  using Waker = Function<void()>;

  explicit PollFuture(Task<T> task) {
    Executor* self = this;
    std::move(task)
        .scheduleOn(makeKeepAlive(self))
        .start(
            [&](Try<T>&& result) noexcept {
              // Rust doesn't support exceptions
              DCHECK(!result.hasException());
              if constexpr (!std::is_same_v<T, void>) {
                result_ = std::move(result).value();
              } else {
                result_ = unit;
              }
            },
            cancellationSource_.getToken());
  }

  explicit PollFuture(SemiFuture<lift_unit_t<T>> future) {
    Executor* self = this;
    std::move(future)
        .via(makeKeepAlive(self))
        .setCallback_([&](Executor::KeepAlive<>&&, Try<T>&& result) mutable {
          result_ = std::move(result).value();
        });
  }

  ~PollFuture() override {
    cancellationSource_.requestCancellation();
    if (keepAliveCount_.load(std::memory_order_relaxed) > 0) {
      folly::Baton<> b;
      while (!poll([&] { b.post(); })) {
        b.wait();
        b.reset();
      }
    }
  }

  Poll poll(Waker waker) {
    while (true) {
      std::queue<Func> funcs;
      {
        auto wQueueAndWaker = queueAndWaker_.wlock();
        if (wQueueAndWaker->funcs.empty()) {
          wQueueAndWaker->waker = std::move(waker);
          break;
        }

        std::swap(funcs, wQueueAndWaker->funcs);
      }

      while (!funcs.empty()) {
        funcs.front()();
        funcs.pop();
      }
    }

    if (keepAliveCount_.load(std::memory_order_relaxed) == 0) {
      return std::move(result_);
    }
    return none;
  }

 private:
  void add(Func func) override {
    auto waker = [&] {
      auto wQueueAndWaker = queueAndWaker_.wlock();
      wQueueAndWaker->funcs.push(std::move(func));
      return std::exchange(wQueueAndWaker->waker, {});
    }();
    if (waker) {
      waker();
    }
  }

  bool keepAliveAcquire() noexcept override {
    auto keepAliveCount =
        keepAliveCount_.fetch_add(1, std::memory_order_relaxed);
    DCHECK(keepAliveCount > 0);
    return true;
  }

  void keepAliveRelease() noexcept override {
    auto keepAliveCount = keepAliveCount_.load(std::memory_order_relaxed);
    do {
      DCHECK(keepAliveCount > 0);
      if (keepAliveCount == 1) {
        add([this] {
          // the final count *must* be released from this executor so that we
          // don't race with poll.
          keepAliveCount_.fetch_sub(1, std::memory_order_relaxed);
        });
        return;
      }
    } while (!keepAliveCount_.compare_exchange_weak(
        keepAliveCount,
        keepAliveCount - 1,
        std::memory_order_release,
        std::memory_order_relaxed));
  }

  struct QueueAndWaker {
    std::queue<Func> funcs;
    Waker waker;
  };
  Synchronized<QueueAndWaker> queueAndWaker_;
  std::atomic<ssize_t> keepAliveCount_{1};
  Optional<lift_unit_t<T>> result_;
  CancellationSource cancellationSource_;
};

template <typename T>
class PollStream {
 public:
  using Poll = Optional<Optional<T>>;
  using Waker = Function<void()>;

  explicit PollStream(AsyncGenerator<T> asyncGenerator)
      : asyncGenerator_(std::move(asyncGenerator)) {}

  Poll poll(Waker waker) {
    if (!nextFuture_) {
      nextFuture_.emplace(getNext());
    }

    auto nextPoll = nextFuture_->poll(std::move(waker));
    if (!nextPoll) {
      return none;
    }

    nextFuture_.reset();
    return nextPoll;
  }

 private:
  Task<Optional<T>> getNext() {
    auto next = co_await asyncGenerator_.next();
    if (next) {
      co_return std::move(next).value();
    }
    co_return none;
  }

  AsyncGenerator<T> asyncGenerator_;
  Optional<PollFuture<Optional<T>>> nextFuture_;
};

} // namespace coro
} // namespace folly

#endif // FOLLY_HAS_COROUTINES