folly/folly/channels/detail/AtomicQueue.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <atomic>
#include <cassert>
#include <memory>
#include <utility>
#include <glog/logging.h>

#include <folly/lang/Assume.h>

namespace folly {
namespace channels {
namespace detail {

template <typename T>
class Queue {
 public:
  constexpr Queue() noexcept {}
  constexpr Queue(Queue&& other) noexcept
      : head_(std::exchange(other.head_, nullptr)) {}
  Queue& operator=(Queue&& other) noexcept {
    clear();
    std::swap(head_, other.head_);
    return *this;
  }
  ~Queue() { clear(); }

  bool empty() const noexcept { return !head_; }

  T& front() noexcept { return head_->value; }

  void pop() noexcept {
    std::unique_ptr<Node>(std::exchange(head_, head_->next));
  }

  void clear() {
    while (!empty()) {
      pop();
    }
  }

  explicit operator bool() const { return !empty(); }

  struct Node {
    explicit Node(T&& t) : value(std::move(t)) {}

    T value;
    Node* next{nullptr};
  };

  constexpr explicit Queue(Node* head) noexcept : head_(head) {}
  static Queue fromReversed(Node* tail) noexcept {
    // Reverse a linked list.
    Node* head{nullptr};
    while (tail) {
      head = std::exchange(tail, std::exchange(tail->next, head));
    }
    return Queue(head);
  }

  Node* head_{nullptr};
};

template <typename Consumer, typename Message>
class AtomicQueue {
 public:
  using MessageQueue = Queue<Message>;

  AtomicQueue() {}
  ~AtomicQueue() {
    auto storage = storage_.load(std::memory_order_acquire);
    auto type = static_cast<Type>(storage & kTypeMask);
    auto ptr = storage & kPointerMask;
    switch (type) {
      case Type::EMPTY:
      case Type::CLOSED:
        return;
      case Type::TAIL:
        MessageQueue::fromReversed(
            reinterpret_cast<typename MessageQueue::Node*>(ptr));
        return;
      case Type::CONSUMER:
      default:
        folly::assume_unreachable();
    }
  }
  AtomicQueue(const AtomicQueue&) = delete;
  AtomicQueue& operator=(const AtomicQueue&) = delete;

  template <typename... ConsumerArgs>
  void push(Message&& value, ConsumerArgs&&... consumerArgs) {
    std::unique_ptr<typename MessageQueue::Node> node(
        new typename MessageQueue::Node(std::move(value)));
    assert(!(reinterpret_cast<intptr_t>(node.get()) & kTypeMask));

    auto storage = storage_.load(std::memory_order_relaxed);
    while (true) {
      auto type = static_cast<Type>(storage & kTypeMask);
      auto ptr = storage & kPointerMask;
      switch (type) {
        case Type::EMPTY:
        case Type::TAIL:
          node->next = reinterpret_cast<typename MessageQueue::Node*>(ptr);
          if (storage_.compare_exchange_weak(
                  storage,
                  reinterpret_cast<intptr_t>(node.get()) |
                      static_cast<intptr_t>(Type::TAIL),
                  std::memory_order_release,
                  std::memory_order_relaxed)) {
            node.release();
            return;
          }
          break;
        case Type::CLOSED:
          return;
        case Type::CONSUMER:
          node->next = nullptr;
          if (storage_.compare_exchange_weak(
                  storage,
                  reinterpret_cast<intptr_t>(node.get()) |
                      static_cast<intptr_t>(Type::TAIL),
                  std::memory_order_acq_rel,
                  std::memory_order_relaxed)) {
            node.release();
            auto consumer = reinterpret_cast<Consumer*>(ptr);
            consumer->consume(std::forward<ConsumerArgs>(consumerArgs)...);
            return;
          }
          break;
        default:
          folly::assume_unreachable();
      }
    }
  }

  template <typename... ConsumerArgs>
  bool wait(Consumer* consumer, ConsumerArgs&&... consumerArgs) {
    assert(!(reinterpret_cast<intptr_t>(consumer) & kTypeMask));
    auto storage = storage_.load(std::memory_order_relaxed);
    while (true) {
      auto type = static_cast<Type>(storage & kTypeMask);
      switch (type) {
        case Type::EMPTY:
          if (storage_.compare_exchange_weak(
                  storage,
                  reinterpret_cast<intptr_t>(consumer) |
                      static_cast<intptr_t>(Type::CONSUMER),
                  std::memory_order_release,
                  std::memory_order_relaxed)) {
            return true;
          }
          break;
        case Type::CLOSED:
          consumer->canceled(std::forward<ConsumerArgs>(consumerArgs)...);
          return true;
        case Type::TAIL:
          return false;
        case Type::CONSUMER:
        default:
          folly::assume_unreachable();
      }
    }
  }

  template <typename... ConsumerArgs>
  void close(ConsumerArgs&&... consumerArgs) {
    auto storage = storage_.exchange(
        static_cast<intptr_t>(Type::CLOSED), std::memory_order_acquire);
    auto type = static_cast<Type>(storage & kTypeMask);
    auto ptr = storage & kPointerMask;
    switch (type) {
      case Type::EMPTY:
        return;
      case Type::TAIL:
        MessageQueue::fromReversed(
            reinterpret_cast<typename MessageQueue::Node*>(ptr));
        return;
      case Type::CONSUMER:
        reinterpret_cast<Consumer*>(ptr)->canceled(
            std::forward<ConsumerArgs>(consumerArgs)...);
        return;
      case Type::CLOSED:
      default:
        folly::assume_unreachable();
    }
  }

  bool isClosed() {
    auto type = static_cast<Type>(storage_ & kTypeMask);
    return type == Type::CLOSED;
  }

  template <typename... ConsumerArgs>
  MessageQueue getMessages(ConsumerArgs&&... consumerArgs) {
    auto storage = storage_.exchange(
        static_cast<intptr_t>(Type::EMPTY), std::memory_order_acquire);
    auto type = static_cast<Type>(storage & kTypeMask);
    auto ptr = storage & kPointerMask;
    switch (type) {
      case Type::TAIL:
        return MessageQueue::fromReversed(
            reinterpret_cast<typename MessageQueue::Node*>(ptr));
      case Type::EMPTY:
        return MessageQueue();
      case Type::CLOSED:
        // We accidentally re-opened the queue, so close it again.
        // This is only safe to do because isClosed() can't be called
        // concurrently with getMessages().
        close(std::forward<ConsumerArgs>(consumerArgs)...);
        return MessageQueue();
      case Type::CONSUMER:
      default:
        folly::assume_unreachable();
    }
  }

  Consumer* cancelCallback() {
    auto storage = storage_.load(std::memory_order_acquire);
    while (true) {
      auto type = static_cast<Type>(storage & kTypeMask);
      auto ptr = storage & kPointerMask;
      switch (type) {
        case Type::CONSUMER:
          if (storage_.compare_exchange_weak(
                  storage,
                  static_cast<intptr_t>(Type::EMPTY),
                  std::memory_order_relaxed,
                  std::memory_order_relaxed)) {
            return reinterpret_cast<Consumer*>(ptr);
          }
          break;
        case Type::TAIL:
        case Type::EMPTY:
        case Type::CLOSED:
        default:
          return nullptr;
      }
    }
  }

 private:
  enum class Type : intptr_t { EMPTY = 0, CONSUMER = 1, TAIL = 2, CLOSED = 3 };

  static constexpr intptr_t kTypeMask = 3;
  static constexpr intptr_t kPointerMask = ~kTypeMask;

  std::atomic<intptr_t> storage_{0};
};
} // namespace detail
} // namespace channels
} // namespace folly