folly/folly/concurrency/test/AtomicSharedPtrCounted.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <atomic>
#include <cstdint>
#include <stdexcept>
#include <type_traits>

#include <glog/logging.h>

struct counted_shared_tag {};
template <template <typename> class Atom = std::atomic>
struct intrusive_shared_count {
  intrusive_shared_count() { counts.store(0); }
  void add_ref(uint64_t count = 1) { counts.fetch_add(count); }

  uint64_t release_ref(uint64_t count = 1) { return counts.fetch_sub(count); }
  Atom<uint64_t> counts;
};

template <template <typename> class Atom = std::atomic>
struct counted_ptr_base {
 protected:
  static intrusive_shared_count<Atom>* getRef(void* pt) {
    char* p = (char*)pt;
    p -= sizeof(intrusive_shared_count<Atom>);
    return (intrusive_shared_count<Atom>*)p;
  }
};

// basically shared_ptr, but only supports make_counted, and provides
// access to add_ref / release_ref with a count.  Alias not supported.
template <typename T, template <typename> class Atom = std::atomic>
class counted_ptr : public counted_ptr_base<Atom> {
 public:
  T* p_;
  counted_ptr() : p_(nullptr) {}
  counted_ptr(counted_shared_tag, T* p) : p_(p) {
    if (p_) {
      counted_ptr_base<Atom>::getRef(p_)->add_ref();
    }
  }

  counted_ptr(const counted_ptr& o) : p_(o.p_) {
    if (p_) {
      counted_ptr_base<Atom>::getRef(p_)->add_ref();
    }
  }
  counted_ptr& operator=(const counted_ptr& o) {
    if (p_ && counted_ptr_base<Atom>::getRef(p_)->release_ref() == 1) {
      p_->~T();
      free(counted_ptr_base<Atom>::getRef(p_));
    }
    p_ = o.p_;
    if (p_) {
      counted_ptr_base<Atom>::getRef(p_)->add_ref();
    }
    return *this;
  }
  explicit counted_ptr(T* p) : p_(p) { CHECK(!p); }
  ~counted_ptr() {
    if (p_ && counted_ptr_base<Atom>::getRef(p_)->release_ref() == 1) {
      p_->~T();
      free(counted_ptr_base<Atom>::getRef(p_));
    }
  }
  typename std::add_lvalue_reference<T>::type operator*() const { return *p_; }

  T* get() const { return p_; }
  T* operator->() const { return p_; }
  explicit operator bool() const { return p_ == nullptr ? false : true; }
  bool operator==(const counted_ptr<T, Atom>& p) const {
    return get() == p.get();
  }
};

template <
    template <typename> class Atom = std::atomic,
    typename T,
    typename... Args>
counted_ptr<T, Atom> make_counted(Args&&... args) {
  char* mem = (char*)malloc(sizeof(T) + sizeof(intrusive_shared_count<Atom>));
  if (!mem) {
    throw std::bad_alloc();
  }
  new (mem) intrusive_shared_count<Atom>();
  T* ptr = (T*)(mem + sizeof(intrusive_shared_count<Atom>));
  new (ptr) T(std::forward<Args>(args)...);
  return counted_ptr<T, Atom>(counted_shared_tag(), ptr);
}

template <template <typename> class Atom = std::atomic>
class counted_ptr_internals : public counted_ptr_base<Atom> {
 public:
  template <typename T, typename... Args>
  static counted_ptr<T, Atom> make_ptr(Args&&... args) {
    return make_counted<Atom, T>(std::forward<Args...>(args...));
  }
  template <typename T>
  using CountedPtr = counted_ptr<T, Atom>;
  typedef void counted_base;

  template <typename T>
  static counted_base* get_counted_base(const counted_ptr<T, Atom>& bar) {
    return bar.p_;
  }

  template <typename T>
  static T* get_shared_ptr(counted_base* base) {
    return (T*)base;
  }

  template <typename T>
  static T* release_ptr(counted_ptr<T, Atom>& p) {
    auto res = p.p_;
    p.p_ = nullptr;
    return res;
  }

  template <typename T>
  static counted_ptr<T, Atom> get_shared_ptr_from_counted_base(
      counted_base* base, bool inc = true) {
    auto res = counted_ptr<T, Atom>(counted_shared_tag(), (T*)(base));
    if (!inc) {
      release_shared<T>(base, 1);
    }
    return res;
  }

  static void inc_shared_count(counted_base* base, int64_t count) {
    counted_ptr_base<Atom>::getRef(base)->add_ref(count);
  }

  template <typename T>
  static void release_shared(counted_base* base, uint64_t count) {
    if (count == counted_ptr_base<Atom>::getRef(base)->release_ref(count)) {
      ((T*)base)->~T();
      free(counted_ptr_base<Atom>::getRef(base));
    }
  }
};