folly/folly/coro/detail/BarrierTask.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <folly/experimental/coro/Coroutine.h>
#include <folly/experimental/coro/WithAsyncStack.h>
#include <folly/experimental/coro/detail/Barrier.h>
#include <folly/experimental/coro/detail/Malloc.h>

#include <cassert>
#include <utility>

#if FOLLY_HAS_COROUTINES

namespace folly {
namespace coro {
namespace detail {

class BarrierTask {
 public:
  class promise_type {
    struct FinalAwaiter {
      bool await_ready() noexcept { return false; }

      coroutine_handle<> await_suspend(
          coroutine_handle<promise_type> h) noexcept {
        auto& promise = h.promise();
        assert(promise.barrier_ != nullptr);
        return promise.barrier_->arrive(promise.asyncFrame_);
      }

      void await_resume() noexcept {}
    };

   public:
    static void* operator new(std::size_t size) {
      return ::folly_coro_async_malloc(size);
    }

    static void operator delete(void* ptr, std::size_t size) {
      ::folly_coro_async_free(ptr, size);
    }

    BarrierTask get_return_object() noexcept {
      return BarrierTask{coroutine_handle<promise_type>::from_promise(*this)};
    }

    suspend_always initial_suspend() noexcept { return {}; }

    FinalAwaiter final_suspend() noexcept { return {}; }

    template <typename Awaitable>
    auto await_transform(Awaitable&& awaitable) {
      return folly::coro::co_withAsyncStack(
          static_cast<Awaitable&&>(awaitable));
    }

    void return_void() noexcept {}

    [[noreturn]] void unhandled_exception() noexcept { std::terminate(); }

    void setBarrier(Barrier* barrier) noexcept {
      assert(barrier_ == nullptr);
      barrier_ = barrier;
    }

    folly::AsyncStackFrame& getAsyncFrame() noexcept { return asyncFrame_; }

   private:
    folly::AsyncStackFrame asyncFrame_;
    Barrier* barrier_ = nullptr;
  };

 private:
  using handle_t = coroutine_handle<promise_type>;

  explicit BarrierTask(handle_t coro) noexcept : coro_(coro) {}

 public:
  BarrierTask(BarrierTask&& other) noexcept
      : coro_(std::exchange(other.coro_, {})) {}

  ~BarrierTask() {
    if (coro_) {
      coro_.destroy();
    }
  }

  BarrierTask& operator=(BarrierTask other) noexcept {
    swap(other);
    return *this;
  }

  void swap(BarrierTask& b) noexcept { std::swap(coro_, b.coro_); }

  FOLLY_NOINLINE void start(Barrier* barrier) noexcept {
    start(barrier, folly::getDetachedRootAsyncStackFrame());
  }

  FOLLY_NOINLINE void start(
      Barrier* barrier, folly::AsyncStackFrame& parentFrame) noexcept {
    assert(coro_);
    auto& calleeFrame = coro_.promise().getAsyncFrame();
    calleeFrame.setParentFrame(parentFrame);
    calleeFrame.setReturnAddress();
    coro_.promise().setBarrier(barrier);

    folly::resumeCoroutineWithNewAsyncStackRoot(coro_);
  }

 private:
  handle_t coro_;
};

class DetachedBarrierTask {
 public:
  class promise_type {
   public:
    promise_type() noexcept {
      asyncFrame_.setParentFrame(folly::getDetachedRootAsyncStackFrame());
    }

    DetachedBarrierTask get_return_object() noexcept {
      return DetachedBarrierTask{
          coroutine_handle<promise_type>::from_promise(*this)};
    }

    suspend_always initial_suspend() noexcept { return {}; }

    auto final_suspend() noexcept {
      struct awaiter {
        bool await_ready() noexcept { return false; }
        auto await_suspend(coroutine_handle<promise_type> h) noexcept {
          assert(h.promise().barrier_ != nullptr);
          auto continuation =
              h.promise().barrier_->arrive(h.promise().getAsyncFrame());

          // Due to a bug in MSVC versions up to and including 19.39, we observe
          // an extra call to the destructor of the task with an explicit call
          // to coroutine_handle::destroy. Furthermore, with versions
          // above 19.30, this causes a crash when named return value
          // optimization is enabled.
#if !(!defined(__clang__) && defined(_MSC_VER) && _MSC_VER <= 1939)
          h.destroy();
#endif

          return continuation;
        }
        void await_resume() noexcept {}
      };
      return awaiter{};
    }

    [[noreturn]] void unhandled_exception() noexcept { std::terminate(); }

    void return_void() noexcept {}

    template <typename Awaitable>
    auto await_transform(Awaitable&& awaitable) {
      return folly::coro::co_withAsyncStack(
          static_cast<Awaitable&&>(awaitable));
    }

    void setBarrier(Barrier* barrier) noexcept { barrier_ = barrier; }

    AsyncStackFrame& getAsyncFrame() noexcept { return asyncFrame_; }

   private:
    AsyncStackFrame asyncFrame_;
    Barrier* barrier_;
  };

 private:
  using handle_t = coroutine_handle<promise_type>;

  explicit DetachedBarrierTask(handle_t coro) : coro_(coro) {}

 public:
  DetachedBarrierTask(DetachedBarrierTask&& other) noexcept
      : coro_(std::exchange(other.coro_, {})) {}

  ~DetachedBarrierTask() {
    if (coro_) {
      coro_.destroy();
    }
  }

  FOLLY_NOINLINE void start(Barrier* barrier) && noexcept {
    std::move(*this).start(barrier, FOLLY_ASYNC_STACK_RETURN_ADDRESS());
  }

  FOLLY_NOINLINE void start(
      Barrier* barrier, folly::AsyncStackFrame& parentFrame) && noexcept {
    assert(coro_);
    coro_.promise().getAsyncFrame().setParentFrame(parentFrame);
    std::move(*this).start(barrier, FOLLY_ASYNC_STACK_RETURN_ADDRESS());
  }

  void start(Barrier* barrier, void* returnAddress) && noexcept {
    assert(coro_);
    assert(barrier != nullptr);
    barrier->add(1);
    auto coro = std::exchange(coro_, {});
    coro.promise().setBarrier(barrier);
    coro.promise().getAsyncFrame().setReturnAddress(returnAddress);
    folly::resumeCoroutineWithNewAsyncStackRoot(coro);
  }

 private:
  handle_t coro_;
};

} // namespace detail
} // namespace coro
} // namespace folly

#endif // FOLLY_HAS_COROUTINES