folly/folly/fibers/AddTasks-inl.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <memory>
#include <vector>

namespace folly {
namespace fibers {

template <typename T>
TaskIterator<T>::TaskIterator(TaskIterator&& other) noexcept
    : context_(std::move(other.context_)), id_(other.id_), fm_(other.fm_) {}

template <typename T>
inline bool TaskIterator<T>::hasCompleted() const {
  return context_->tasksConsumed < context_->results.size();
}

template <typename T>
inline bool TaskIterator<T>::hasPending() const {
  return !context_.unique();
}

template <typename T>
inline bool TaskIterator<T>::hasNext() const {
  return hasPending() || hasCompleted();
}

template <typename T>
folly::Try<T> TaskIterator<T>::awaitNextResult() {
  assert(hasCompleted() || hasPending());
  reserve(1);

  size_t i = context_->tasksConsumed++;
  id_ = context_->results[i].first;
  return std::move(context_->results[i].second);
}

template <typename T>
inline T TaskIterator<T>::awaitNext() {
  return std::move(awaitNextResult().value());
}

template <>
inline void TaskIterator<void>::awaitNext() {
  awaitNextResult().value();
}

template <typename T>
inline void TaskIterator<T>::reserve(size_t n) {
  size_t tasksReady = context_->results.size() - context_->tasksConsumed;

  // we don't need to do anything if there are already n or more tasks complete
  // or if we have no tasks left to execute.
  if (!hasPending() || tasksReady >= n) {
    return;
  }

  n -= tasksReady;
  size_t tasksLeft = context_->totalTasks - context_->results.size();
  n = std::min(n, tasksLeft);

  await_async([this, n](Promise<void> promise) {
    context_->tasksToFulfillPromise = n;
    context_->promise.assign(std::move(promise));
  });
}

template <typename T>
inline size_t TaskIterator<T>::getTaskID() const {
  assert(id_ != static_cast<size_t>(-1));
  return id_;
}

template <typename T>
template <typename F>
void TaskIterator<T>::addTask(F&& func) {
  static_assert(
      std::is_convertible<invoke_result_t<F>, T>::value,
      "TaskIterator<T>: T must be convertible from func()'s return type");

  auto taskId = context_->totalTasks++;

  fm_.addTask(
      [taskId, context = context_, func = std::forward<F>(func)]() mutable {
        context->results.emplace_back(
            taskId, folly::makeTryWith(std::move(func)));

        // Check for awaiting iterator.
        if (context->promise.hasValue()) {
          if (--context->tasksToFulfillPromise == 0) {
            context->promise->setValue();
            context->promise.clear();
          }
        }
      });
}

template <class InputIterator>
TaskIterator<
    invoke_result_t<typename std::iterator_traits<InputIterator>::value_type>>
addTasks(InputIterator first, InputIterator last) {
  typedef invoke_result_t<
      typename std::iterator_traits<InputIterator>::value_type>
      ResultType;
  typedef TaskIterator<ResultType> IteratorType;

  IteratorType iterator;

  for (; first != last; ++first) {
    iterator.addTask(std::move(*first));
  }

  iterator.context_->results.reserve(iterator.context_->totalTasks);

  return iterator;
}
} // namespace fibers
} // namespace folly