folly/folly/gen/ParallelMap-inl.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef FOLLY_GEN_PARALLELMAP_H_
#error This file may only be included from folly/gen/ParallelMap.h
#endif

#include <atomic>
#include <exception>
#include <thread>
#include <type_traits>
#include <utility>
#include <vector>

#include <folly/Expected.h>
#include <folly/MPMCPipeline.h>
#include <folly/functional/Invoke.h>
#include <folly/synchronization/EventCount.h>

namespace folly::gen::detail {

/**
 * PMap - Map in parallel (using threads). For producing a sequence of
 * values by passing each value from a source collection through a
 * predicate while running the predicate in parallel in different
 * threads.
 *
 * This type is usually used through the 'pmap' helper function:
 *
 *   auto squares = seq(1, 10) | pmap(fibonacci, 4) | sum;
 */
template <class Predicate>
class PMap : public Operator<PMap<Predicate>> {
  Predicate pred_;
  size_t nThreads_;

 public:
  PMap() = default;

  PMap(Predicate pred, size_t nThreads)
      : pred_(std::move(pred)), nThreads_(nThreads) {}

  template <
      class Value,
      class Source,
      class Input = typename std::decay<Value>::type,
      class Output =
          typename std::decay<invoke_result_t<Predicate, Value>>::type>
  class Generator
      : public GenImpl<Output, Generator<Value, Source, Input, Output>> {
    Source source_;
    Predicate pred_;
    const size_t nThreads_;

    using Result = folly::Expected<Output, std::exception_ptr>;
    class ExecutionPipeline {
      std::vector<std::thread> workers_;
      std::atomic<bool> done_{false};
      const Predicate& pred_;
      using Pipeline = MPMCPipeline<Input, Result>;
      Pipeline pipeline_;
      EventCount wake_;

      // Total active item count as observed by the single thread owning *this.
      size_t active_ = 0;

     public:
      ExecutionPipeline(const Predicate& pred, size_t nThreads)
          : pred_(pred), pipeline_(nThreads, nThreads) {
        workers_.reserve(nThreads);
        for (size_t i = 0; i < nThreads; i++) {
          workers_.push_back(std::thread([this] { this->predApplier(); }));
        }
      }

      ~ExecutionPipeline() {
        done_.store(true, std::memory_order_release);
        wake_.notifyAll();
        // Drain the pipeline in cases where all results were not consumed,
        // either due to an exception or do to termination requested by consumer
        // like take(n).
        for (Result result; readPendingResult(result);) {
        }
        for (auto& w : workers_) {
          w.join();
        }
      }

      bool write(Value&& value) {
        if (!pipeline_.write(std::forward<Value>(value))) {
          return false;
        }
        ++active_;
        wake_.notify();
        return true;
      }

      void blockingWrite(Value&& value) {
        pipeline_.blockingWrite(std::forward<Value>(value));
        ++active_;
        wake_.notify();
      }

      bool read(Result& result) {
        if (!pipeline_.read(result)) {
          return false;
        }
        --active_;
        return true;
      }

      bool readPendingResult(Result& result) {
        if (active_ == 0) {
          return false;
        }
        pipeline_.blockingRead(result);
        --active_;
        return true;
      }

     private:
      void predApplier() {
        // Each thread takes a value from the pipeline_, runs the
        // predicate and enqueues the result. The pipeline preserves
        // ordering. NOTE: don't use blockingReadStage<0> to read from
        // the pipeline_ as there may not be any: end-of-data is signaled
        // separately using done_/wake_.
        Input in;
        for (;;) {
          auto key = wake_.prepareWait();

          typename Pipeline::template Ticket<0> ticket;
          if (pipeline_.template readStage<0>(ticket, in)) {
            wake_.cancelWait();
            try {
              Output out = pred_(std::move(in));
              pipeline_.template blockingWriteStage<0>(ticket, std::move(out));
            } catch (...) {
              pipeline_.template blockingWriteStage<0>(
                  ticket, makeUnexpected(current_exception()));
            }
            continue;
          }

          if (done_.load(std::memory_order_acquire)) {
            wake_.cancelWait();
            break;
          }

          // Not done_, but no items in the queue.
          wake_.wait(key);
        }
      }
    };

    static Output&& getOutput(Result& result) {
      if (result.hasError()) {
        std::rethrow_exception(std::move(result).error());
      }
      return std::move(result).value();
    }

   public:
    Generator(Source source, const Predicate& pred, size_t nThreads)
        : source_(std::move(source)),
          pred_(pred),
          nThreads_(nThreads ? nThreads : sysconf(_SC_NPROCESSORS_ONLN)) {}

    template <class Body>
    void foreach(Body&& body) const {
      ExecutionPipeline pipeline(pred_, nThreads_);

      source_.foreach([&](Value value) {
        if (pipeline.write(std::forward<Value>(value))) {
          // input queue not yet full, saturate it before we process
          // anything downstream
          return;
        }

        // input queue full; drain ready items from the queue
        for (Result result; pipeline.read(result);) {
          body(getOutput(result));
        }

        // write the value we were going to write before we made room.
        pipeline.blockingWrite(std::forward<Value>(value));
      });
      // flush the output queue
      for (Result result; pipeline.readPendingResult(result);) {
        body(getOutput(result));
      }
    }

    template <class Handler>
    bool apply(Handler&& handler) const {
      ExecutionPipeline pipeline(pred_, nThreads_);

      if (!source_.apply([&](Value value) {
            if (pipeline.write(std::forward<Value>(value))) {
              // input queue not yet full, saturate it before we process
              // anything downstream
              return true;
            }

            // input queue full; drain ready items from the queue
            for (Result result; pipeline.read(result);) {
              if (!handler(getOutput(result))) {
                return false;
              }
            }

            // write the value we were going to write before we made room.
            pipeline.blockingWrite(std::forward<Value>(value));
            return true;
          })) {
        return false;
      }

      // flush the output queue
      for (Result result; pipeline.readPendingResult(result);) {
        if (!handler(getOutput(result))) {
          return false;
        }
      }

      return true;
    }

    static constexpr bool infinite = Source::infinite;
  };

  template <class Source, class Value, class Gen = Generator<Value, Source>>
  Gen compose(GenImpl<Value, Source>&& source) const {
    return Gen(std::move(source.self()), pred_, nThreads_);
  }

  template <class Source, class Value, class Gen = Generator<Value, Source>>
  Gen compose(const GenImpl<Value, Source>& source) const {
    return Gen(source.self(), pred_, nThreads_);
  }
};

} // namespace folly::gen::detail