folly/folly/io/async/AsyncUDPServerSocket.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <folly/Memory.h>
#include <folly/io/IOBufQueue.h>
#include <folly/io/async/AsyncUDPSocket.h>
#include <folly/io/async/EventBase.h>

namespace folly {

/**
 * UDP server socket
 *
 * It wraps a UDP socket waiting for packets and distributes them among
 * a set of event loops in round robin fashion.
 *
 * NOTE: At the moment it is designed to work with single packet protocols
 *       in mind. We distribute incoming packets among all the listeners in
 *       round-robin fashion. So, any protocol that expects to send/recv
 *       more than 1 packet will not work because they will end up with
 *       different event base to process.
 */
class AsyncUDPServerSocket : private AsyncUDPSocket::ReadCallback,
                             public AsyncSocketBase {
 public:
  class Callback {
   public:
    using OnDataAvailableParams =
        AsyncUDPSocket::ReadCallback::OnDataAvailableParams;
    /**
     * Invoked when we start reading data from socket. It is invoked in
     * each acceptors/listeners event base thread.
     */
    virtual void onListenStarted() noexcept = 0;

    /**
     * Invoked when the server socket is closed. It is invoked in each
     * acceptors/listeners event base thread.
     */
    virtual void onListenStopped() noexcept = 0;

    /**
     * Invoked when the server socket is paused. It is invoked in each
     * acceptors/listeners event base thread.
     */
    virtual void onListenPaused() noexcept {}

    /**
     * Invoked when the server socket is resumed. It is invoked in each
     * acceptors/listeners event base thread.
     */
    virtual void onListenResumed() noexcept {}

    /**
     * Invoked when the server socket can still read but need to inform the
     * callback object that it should not process read from new client address.
     * It is invoked in each acceptors/listeners event base thread.
     */
    virtual void onAcceptNewPeerPaused() noexcept {}

    /**
     * Invoked when need to inform the callback object that it can resume
     * process read from new client address. It is invoked in each
     * acceptors/listeners event base thread.
     */
    virtual void onAcceptNewPeerResumed() noexcept {}

    /**
     * Invoked when a new packet is received
     */
    virtual void onDataAvailable(
        std::shared_ptr<AsyncUDPSocket> socket,
        const folly::SocketAddress& addr,
        std::unique_ptr<folly::IOBuf> buf,
        bool truncated,
        OnDataAvailableParams) noexcept = 0;

    virtual ~Callback() = default;
  };

  enum class DispatchMechanism { RoundRobin, ClientAddressHash };

  /**
   * Create a new UDP server socket
   *
   * Note about packet size - We allocate buffer of packetSize_ size to read.
   * If packet are larger than this value, as per UDP protocol, remaining data
   * is dropped and you get `truncated = true` in onDataAvailable callback
   */
  explicit AsyncUDPServerSocket(
      EventBase* evb,
      size_t sz = 1500,
      DispatchMechanism dm = DispatchMechanism::RoundRobin)
      : evb_(evb), packetSize_(sz), dispatchMechanism_(dm), nextListener_(0) {}

  ~AsyncUDPServerSocket() override {
    if (socket_) {
      close();
    }
  }

  void bind(
      const folly::SocketAddress& addy,
      const SocketOptionMap& options = emptySocketOptionMap,
      const std::string& ifName = "") {
    CHECK(!socket_);

    socket_ = std::make_shared<AsyncUDPSocket>(evb_);
    socket_->setReusePort(reusePort_);
    socket_->setReuseAddr(reuseAddr_);
    socket_->setRecvTos(recvTos_);
    socket_->applyOptions(
        validateSocketOptions(
            options, addy.getFamily(), SocketOptionKey::ApplyPos::PRE_BIND),
        SocketOptionKey::ApplyPos::PRE_BIND);
    AsyncUDPSocket::BindOptions bindOptions;
    bindOptions.ifName = ifName;
    socket_->bind(addy, bindOptions);
    socket_->applyOptions(
        validateSocketOptions(
            options, addy.getFamily(), SocketOptionKey::ApplyPos::POST_BIND),
        SocketOptionKey::ApplyPos::POST_BIND);

    applyEventCallback();
  }

  void setReusePort(bool reusePort) { reusePort_ = reusePort; }

  void setReuseAddr(bool reuseAddr) { reuseAddr_ = reuseAddr; }

  void setRecvTos(bool recvTos) { recvTos_ = recvTos; }

  void setTosOrTrafficClass(uint8_t tosOrTclass) {
    CHECK(socket_);
    socket_->setTosOrTrafficClass(tosOrTclass);
  }

  folly::SocketAddress address() const {
    CHECK(socket_);
    return socket_->address();
  }

  void getAddress(SocketAddress* a) const override { *a = address(); }

  /**
   * Add a listener to the round robin list
   */
  void addListener(EventBase* evb, Callback* callback) {
    listeners_.emplace_back(evb, callback);
  }

  void listen() {
    CHECK(socket_) << "Need to bind before listening";

    for (auto& listener : listeners_) {
      auto callback = listener.second;

      listener.first->runInEventBaseThread(
          [callback]() mutable { callback->onListenStarted(); });
    }

    socket_->resumeRead(this);
  }

  NetworkSocket getNetworkSocket() const {
    CHECK(socket_) << "Need to bind before getting Network Socket";
    return socket_->getNetworkSocket();
  }

  const std::shared_ptr<AsyncUDPSocket>& getSocket() const { return socket_; }

  void close() {
    CHECK(socket_) << "Need to bind before closing";
    socket_->close();
    socket_.reset();
  }

  EventBase* getEventBase() const override { return evb_; }

  /**
   * Indicates if the current socket is accepting.
   */
  bool isAccepting() const { return socket_->isReading(); }

  /**
   * Pauses accepting datagrams on the underlying socket.
   */
  void pauseAccepting() {
    socket_->pauseRead();
    for (auto& listener : listeners_) {
      auto callback = listener.second;

      listener.first->runInEventBaseThread(
          [callback]() mutable { callback->onListenPaused(); });
    }
  }

  /**
   * Inform the callback object that it should not process read from new client
   * address.
   */
  void pauseAcceptingNewPeer() {
    for (auto& listener : listeners_) {
      auto callback = listener.second;

      listener.first->runInEventBaseThread(
          [callback]() mutable { callback->onAcceptNewPeerPaused(); });
    }
  }

  /**
   * Starts accepting datagrams once again.
   */
  void resumeAccepting() {
    socket_->resumeRead(this);
    for (auto& listener : listeners_) {
      auto callback = listener.second;

      listener.first->runInEventBaseThread(
          [callback]() mutable { callback->onListenResumed(); });
    }
  }

  /**
   * Inform the callback object that it can process read from new client address
   * now.
   */
  void resumeAcceptingNewPeer() {
    for (auto& listener : listeners_) {
      auto callback = listener.second;

      listener.first->runInEventBaseThread(
          [callback]() mutable { callback->onAcceptNewPeerResumed(); });
    }
  }

  void setEventCallback(EventRecvmsgCallback* cb) {
    eventCb_ = cb;
    applyEventCallback();
  }

  void setRecvmsgMultishotCallback(EventRecvmsgMultishotCallback* cb) {
    multishotCb_ = cb;
    applyEventCallback();
  }

  bool setTimestamping(int val) { return socket_->setTimestamping(val); }

 private:
  // AsyncUDPSocket::ReadCallback
  void getReadBuffer(void** buf, size_t* len) noexcept override {
    std::tie(*buf, *len) = buf_.preallocate(packetSize_, packetSize_);
  }

  void onDataAvailable(
      const folly::SocketAddress& clientAddress,
      size_t len,
      bool truncated,
      OnDataAvailableParams params) noexcept override {
    buf_.postallocate(len);
    auto data = buf_.split(len);

    if (listeners_.empty()) {
      LOG(WARNING) << "UDP server socket dropping packet, "
                   << "no listener registered";
      return;
    }

    uint32_t listenerId = 0;
    uint64_t client_hash_lo = 0;
    switch (dispatchMechanism_) {
      case DispatchMechanism::ClientAddressHash:
        // Hash base on clientAddress.
        // 1. This logic is samilar to: clientAddress.hash() % listeners_.size()
        //    But runs faster as it use multiply and shift instead of division.
        // 2. Only use the lower 32 bit from the address hash result for faster
        //    computation.
        client_hash_lo = static_cast<uint32_t>(clientAddress.hash());
        listenerId = (client_hash_lo * listeners_.size()) >> 32;
        break;
      case DispatchMechanism::RoundRobin: // round robin is default.
      default:
        if (nextListener_ >= listeners_.size()) {
          nextListener_ = 0;
        }
        listenerId = nextListener_;
        ++nextListener_;
        break;
    }

    auto callback = listeners_[listenerId].second;

    // Schedule it in the listener's eventbase
    // XXX: Speed this up
    auto f = [socket = socket_,
              client = clientAddress,
              callback,
              data_2 = std::move(data),
              truncated,
              params]() mutable {
      callback->onDataAvailable(
          socket, client, std::move(data_2), truncated, params);
    };

    listeners_[listenerId].first->runInEventBaseThread(std::move(f));
  }

  void onReadError(const AsyncSocketException& ex) noexcept override {
    LOG(ERROR) << ex.what();

    // Lets register to continue listening for packets
    socket_->resumeRead(this);
  }

  void onReadClosed() noexcept override {
    for (auto& listener : listeners_) {
      auto callback = listener.second;

      listener.first->runInEventBaseThread(
          [callback]() mutable { callback->onListenStopped(); });
    }
  }

  void applyEventCallback() {
    if (socket_) {
      if (eventCb_) {
        socket_->setEventCallback(eventCb_);
      } else if (multishotCb_) {
        socket_->setRecvmsgMultishotCallback(multishotCb_);
      } else {
        socket_->resetEventCallback();
      }
    }
  }

  EventBase* const evb_;
  const size_t packetSize_;

  std::shared_ptr<AsyncUDPSocket> socket_;

  // List of listener to distribute packets among
  typedef std::pair<EventBase*, Callback*> Listener;
  std::vector<Listener> listeners_;

  DispatchMechanism dispatchMechanism_;

  // Next listener to send packet to
  uint32_t nextListener_;

  // Temporary buffer for data
  folly::IOBufQueue buf_;

  bool reusePort_{false};
  bool reuseAddr_{false};
  bool recvTos_{false};

  EventRecvmsgCallback* eventCb_{nullptr};
  EventRecvmsgMultishotCallback* multishotCb_{nullptr};
};

} // namespace folly