folly/folly/fibers/AtomicBatchDispatcher-inl.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

namespace folly {
namespace fibers {

template <typename InputT, typename ResultT>
struct AtomicBatchDispatcher<InputT, ResultT>::DispatchBaton {
  DispatchBaton(DispatchFunctionT&& dispatchFunction)
      : expectedCount_(0), dispatchFunction_(std::move(dispatchFunction)) {}

  ~DispatchBaton() { fulfillPromises(); }

  void reserve(size_t numEntries) { optEntries_.reserve(numEntries); }

  void setExceptionWrapper(folly::exception_wrapper&& exWrapper) {
    exceptionWrapper_ = std::move(exWrapper);
  }

  void setExpectedCount(size_t expectedCount) {
    assert(expectedCount_ == 0 || !"expectedCount_ being set more than once");
    expectedCount_ = expectedCount;
    optEntries_.resize(expectedCount_);
  }

  Future<ResultT> getFutureResult(InputT&& input, size_t sequenceNumber) {
    if (sequenceNumber >= optEntries_.size()) {
      optEntries_.resize(sequenceNumber + 1);
    }
    folly::Optional<Entry>& optEntry = optEntries_[sequenceNumber];
    assert(!optEntry || !"Multiple inputs have the same token sequence number");
    optEntry = Entry(std::move(input));
    return optEntry->promise.getFuture();
  }

 private:
  void setExceptionResults(folly::exception_wrapper&& ew_) {
    auto ew = std::move(ew_);
    for (auto& optEntry : optEntries_) {
      if (optEntry) {
        optEntry->promise.setException(ew);
      }
    }
  }

  void fulfillPromises() {
    try {
      // If an error message is set, set all promises to exception with message
      if (exceptionWrapper_) {
        return setExceptionResults(std::move(exceptionWrapper_));
      }

      // Validate entries count same as expectedCount_
      assert(
          optEntries_.size() == expectedCount_ ||
          !"Entries vector did not have expected size");
      std::vector<size_t> vecTokensNotDispatched;
      for (size_t i = 0; i < expectedCount_; ++i) {
        if (!optEntries_[i]) {
          vecTokensNotDispatched.push_back(i);
        }
      }
      if (!vecTokensNotDispatched.empty()) {
        return setExceptionResults(
            make_exception_wrapper<ABDTokenNotDispatchedException>(
                detail::createABDTokenNotDispatchedExMsg(
                    vecTokensNotDispatched)));
      }

      // Create the inputs vector
      std::vector<InputT> inputs;
      inputs.reserve(expectedCount_);
      for (auto& optEntry : optEntries_) {
        inputs.emplace_back(std::move(optEntry->input));
      }

      // Call the user provided batch dispatch function to get all results
      // and make sure that we have the expected number of results returned
      auto results = dispatchFunction_(std::move(inputs));
      if (results.size() != expectedCount_) {
        return setExceptionResults(make_exception_wrapper<ABDUsageException>(
            detail::createUnexpectedNumResultsABDUsageExMsg(
                expectedCount_, results.size())));
      }

      // Fulfill the promises with the results from the batch dispatch
      for (size_t i = 0; i < expectedCount_; ++i) {
        optEntries_[i]->promise.setValue(std::move(results[i]));
      }
    } catch (...) {
      // Set exceptions thrown when executing the user provided dispatch func
      return setExceptionResults(exception_wrapper{current_exception()});
    }
  }

  struct Entry {
    InputT input;
    folly::Promise<ResultT> promise;

    Entry(Entry&& other) noexcept
        : input(std::move(other.input)), promise(std::move(other.promise)) {}

    Entry& operator=(Entry&& other) noexcept {
      input = std::move(other.input);
      promise = std::move(other.promise);
      return *this;
    }

    explicit Entry(InputT&& input_) : input(std::move(input_)) {}
  };

  size_t expectedCount_;
  DispatchFunctionT dispatchFunction_;
  std::vector<folly::Optional<Entry>> optEntries_;
  folly::exception_wrapper exceptionWrapper_;
};

template <typename InputT, typename ResultT>
AtomicBatchDispatcher<InputT, ResultT>::Token::Token(
    std::shared_ptr<DispatchBaton> baton, size_t sequenceNumber)
    : baton_(std::move(baton)), sequenceNumber_(sequenceNumber) {}

template <typename InputT, typename ResultT>
size_t AtomicBatchDispatcher<InputT, ResultT>::Token::sequenceNumber() const {
  return sequenceNumber_;
}

template <typename InputT, typename ResultT>
Future<ResultT> AtomicBatchDispatcher<InputT, ResultT>::Token::dispatch(
    InputT input) {
  auto baton = std::move(baton_);
  if (!baton) {
    throw ABDUsageException(
        "Dispatch called more than once on the same Token object");
  }
  return baton->getFutureResult(std::move(input), sequenceNumber_);
}

template <typename InputT, typename ResultT>
AtomicBatchDispatcher<InputT, ResultT>::AtomicBatchDispatcher(
    DispatchFunctionT&& dispatchFunc)
    : numTokensIssued_(0),
      baton_(std::make_shared<DispatchBaton>(std::move(dispatchFunc))) {}

template <typename InputT, typename ResultT>
AtomicBatchDispatcher<InputT, ResultT>::~AtomicBatchDispatcher() {
  if (baton_) {
    // Set error here rather than throw because we do not want to throw from
    // the destructor of AtomicBatchDispatcher
    baton_->setExceptionWrapper(
        folly::make_exception_wrapper<ABDCommitNotCalledException>());
    commit();
  }
}

template <typename InputT, typename ResultT>
void AtomicBatchDispatcher<InputT, ResultT>::reserve(size_t numEntries) {
  if (!baton_) {
    throw ABDUsageException("Cannot call reserve(....) after calling commit()");
  }
  baton_->reserve(numEntries);
}

template <typename InputT, typename ResultT>
auto AtomicBatchDispatcher<InputT, ResultT>::getToken() -> Token {
  if (!baton_) {
    throw ABDUsageException("Cannot issue more tokens after calling commit()");
  }
  return Token(baton_, numTokensIssued_++);
}

template <typename InputT, typename ResultT>
void AtomicBatchDispatcher<InputT, ResultT>::commit() {
  auto baton = std::move(baton_);
  if (!baton) {
    throw ABDUsageException(
        "Cannot call commit() more than once on the same dispatcher");
  }
  baton->setExpectedCount(numTokensIssued_);
}

template <typename InputT, typename ResultT>
AtomicBatchDispatcher<InputT, ResultT> createAtomicBatchDispatcher(
    folly::Function<std::vector<ResultT>(std::vector<InputT>&&)> dispatchFunc,
    size_t initialCapacity) {
  auto abd = AtomicBatchDispatcher<InputT, ResultT>(std::move(dispatchFunc));
  if (initialCapacity) {
    abd.reserve(initialCapacity);
  }
  return abd;
}

} // namespace fibers
} // namespace folly