folly/folly/executors/MeteredExecutor-inl.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <limits>

#include <folly/io/async/AtomicNotificationQueue.h>

namespace folly {
namespace detail {

template <template <typename> class Atom>
MeteredExecutorImpl<Atom>::MeteredExecutorImpl(
    KeepAlive keepAlive, Options options)
    : options_(std::move(options)), kaInner_(std::move(keepAlive)) {
  CHECK_GE(options_.maxInQueue, 1);
  CHECK_LT(options_.maxInQueue, uint32_t(1) << 31);
}

template <template <typename> class Atom>
MeteredExecutorImpl<Atom>::MeteredExecutorImpl(
    std::unique_ptr<Executor> executor, Options options)
    : MeteredExecutorImpl(getKeepAliveToken(*executor), std::move(options)) {
  ownedExecutor_ = std::move(executor);
}

template <template <typename> class Atom>
std::unique_ptr<QueueObserver> MeteredExecutorImpl<Atom>::setupQueueObserver() {
  if (options_.enableQueueObserver) {
    std::string name = "unk";
    if (options_.name != "") {
      name = options_.name;
    }
    if (auto factory = folly::QueueObserverFactory::make(
            "mex." + name, options_.numPriorities)) {
      return factory->create(options_.priority);
    }
  }
  return nullptr;
}

template <template <typename> class Atom>
template <class F>
void MeteredExecutorImpl<Atom>::modifyState(F f) {
  uint64_t oldState = state_.load(std::memory_order_relaxed);
  uint64_t newState;
  do {
    newState = f(oldState);
    // Verify invariants: no more in-queue than allowed.
    DCHECK_LE(newState >> kInQueueShift, options_.maxInQueue);
    // No more in queue than pending tasks.
    DCHECK_LE(newState >> kInQueueShift, newState & kSizeMask);
  } while (!state_.compare_exchange_strong(
      oldState,
      newState,
      std::memory_order_seq_cst,
      std::memory_order_relaxed));
}

template <template <typename> class Atom>
void MeteredExecutorImpl<Atom>::add(Func func) {
  auto task = Task(std::move(func), RequestContext::saveContext());
  if (queueObserver_) {
    auto payload = queueObserver_->onEnqueued(task.requestContext());
    task.setQueueObserverPayload(payload);
  }

  queue_.enqueue(std::move(task));

  bool shouldScheduleWorker;
  modifyState([&](uint64_t state) {
    state += kSizeInc;
    CHECK_NE(state & kSizeMask, 0)
        << "Too many pending tasks in MeteredExecutor";
    if (!(state & kPausedBit) &&
        ((state >> kInQueueShift) < options_.maxInQueue)) {
      state += kInQueueInc;
      shouldScheduleWorker = true;
    } else {
      shouldScheduleWorker = false;
    }
    return state;
  });

  if (shouldScheduleWorker) {
    scheduleWorker();
  }
}

template <template <typename> class Atom>
bool MeteredExecutorImpl<Atom>::pause() {
  auto oldState = state_.fetch_or(kPausedBit, std::memory_order_relaxed);
  return !(oldState & kPausedBit);
}

template <template <typename> class Atom>
bool MeteredExecutorImpl<Atom>::resume() {
  bool wasPaused = false;
  size_t workersToSchedule = 0;
  modifyState([&](uint64_t state) {
    if (state & kPausedBit) {
      wasPaused = true;
    } else {
      wasPaused = false;
      return state;
    }
    // Workers may have aborted without consuming tasks, reschedule them.
    auto curSize = state & kSizeMask;
    auto curInQueue = state >> kInQueueShift;
    DCHECK_LE(curInQueue, options_.maxInQueue);
    DCHECK_LE(curInQueue, curSize);
    workersToSchedule =
        std::min(static_cast<uint64_t>(options_.maxInQueue), curSize) -
        curInQueue;
    state &= ~kPausedBit;
    state += workersToSchedule * kInQueueInc;
    return state;
  });

  if (!wasPaused) {
    return false;
  }

  for (size_t i = 0; i < workersToSchedule; ++i) {
    scheduleWorker();
  }
  return true;
}

template <template <typename> class Atom>
void MeteredExecutorImpl<Atom>::Task::run() && {
  folly::RequestContextScopeGuard rctxGuard{std::move(rctx_)};
  invokeCatchingExns("MeteredExecutor", std::exchange(func_, {}));
}

template <template <typename> class Atom>
void MeteredExecutorImpl<Atom>::worker() {
  bool shouldAbort = false;
  bool shouldRescheduleWorker = false;
  modifyState([&](uint64_t state) {
    if (state & kPausedBit) {
      shouldAbort = true;
      shouldRescheduleWorker = false;
    } else {
      shouldAbort = false;
      // More work to do than workers in queue, re-schedule the worker without
      // changing the in-queue count.
      shouldRescheduleWorker = (state & kSizeMask) > (state >> kInQueueShift);
      DCHECK_GT(state & kSizeMask, 0);
      state -= kSizeInc;
    }
    if (!shouldRescheduleWorker) {
      state -= kInQueueInc;
    }
    return state;
  });

  if (shouldAbort) {
    return;
  }
  if (shouldRescheduleWorker) {
    scheduleWorker();
  }

  Task task;
  CHECK(queue_.try_dequeue(task));
  std::move(task).run();
}

template <template <typename> class Atom>
void MeteredExecutorImpl<Atom>::scheduleWorker() {
  folly::RequestContextScopeGuard rctxGuard{nullptr};
  kaInner_->add([self = getKeepAliveToken(this)] { self->worker(); });
}

template <template <typename> class Atom>
MeteredExecutorImpl<Atom>::~MeteredExecutorImpl() {
  joinKeepAlive();
}

} // namespace detail
} // namespace folly