folly/folly/concurrency/memory/test/ReadMostlySharedPtrTest.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <atomic>
#include <condition_variable>
#include <mutex>
#include <thread>

#include <folly/Memory.h>
#include <folly/concurrency/memory/ReadMostlySharedPtr.h>
#include <folly/portability/GTest.h>
#include <folly/synchronization/Baton.h>

using folly::ReadMostlyMainPtr;
using folly::ReadMostlyMainPtrDeleter;
using folly::ReadMostlySharedPtr;
using folly::ReadMostlyWeakPtr;

// send SIGALRM to test process after this many seconds
const unsigned int TEST_TIMEOUT = 10;

class ReadMostlySharedPtrTest : public ::testing::Test {
 public:
  ReadMostlySharedPtrTest() { alarm(TEST_TIMEOUT); }
};

struct TestObject {
  int value;
  std::atomic<int>& counter;

  TestObject(int value_, std::atomic<int>& counter_)
      : value(value_), counter(counter_) {
    ++counter;
  }

  ~TestObject() {
    assert(counter.load() > 0);
    --counter;
  }
};

// One side calls requestAndWait(), the other side calls waitForRequest(),
// does something and calls completed().
class Coordinator {
 public:
  void requestAndWait() {
    requestBaton_.post();
    completeBaton_.wait();
  }

  void waitForRequest() { requestBaton_.wait(); }

  void completed() { completeBaton_.post(); }

 private:
  folly::Baton<> requestBaton_;
  folly::Baton<> completeBaton_;
};

TEST_F(ReadMostlySharedPtrTest, BasicStores) {
  ReadMostlyMainPtr<TestObject> ptr;

  // Store 1.
  std::atomic<int> cnt1{0};
  ptr.reset(std::make_unique<TestObject>(1, cnt1));
  EXPECT_EQ(1, cnt1.load());

  // Store 2, check that 1 is destroyed.
  std::atomic<int> cnt2{0};
  ptr.reset(std::make_unique<TestObject>(2, cnt2));
  EXPECT_EQ(1, cnt2.load());
  EXPECT_EQ(0, cnt1.load());

  // Store nullptr, check that 2 is destroyed.
  ptr.reset(nullptr);
  EXPECT_EQ(0, cnt2.load());
}

TEST_F(ReadMostlySharedPtrTest, BasicLoads) {
  std::atomic<int> cnt2{0};
  ReadMostlySharedPtr<TestObject> x;

  {
    ReadMostlyMainPtr<TestObject> ptr;

    // Check that ptr is initially nullptr.
    EXPECT_EQ(ptr.get(), nullptr);

    std::atomic<int> cnt1{0};
    ptr.reset(std::make_unique<TestObject>(1, cnt1));
    EXPECT_EQ(1, cnt1.load());

    x = ptr;
    EXPECT_EQ(1, x->value);

    ptr.reset(std::make_unique<TestObject>(2, cnt2));
    EXPECT_EQ(1, cnt2.load());
    EXPECT_EQ(1, cnt1.load());

    x = ptr;
    EXPECT_EQ(2, x->value);
    EXPECT_EQ(0, cnt1.load());

    ptr.reset(nullptr);
    EXPECT_EQ(1, cnt2.load());
  }

  EXPECT_EQ(1, cnt2.load());

  x.reset();
  EXPECT_EQ(0, cnt2.load());
}

TEST_F(ReadMostlySharedPtrTest, LoadsFromThreads) {
  std::atomic<int> cnt{0};

  {
    ReadMostlyMainPtr<TestObject> ptr;
    Coordinator loads[7];

    std::thread t1([&] {
      loads[0].waitForRequest();
      EXPECT_EQ(ptr.getShared(), nullptr);
      loads[0].completed();

      loads[3].waitForRequest();
      EXPECT_EQ(2, ptr.getShared()->value);
      loads[3].completed();

      loads[4].waitForRequest();
      EXPECT_EQ(4, ptr.getShared()->value);
      loads[4].completed();

      loads[5].waitForRequest();
      EXPECT_EQ(5, ptr.getShared()->value);
      loads[5].completed();
    });

    std::thread t2([&] {
      loads[1].waitForRequest();
      EXPECT_EQ(1, ptr.getShared()->value);
      loads[1].completed();

      loads[2].waitForRequest();
      EXPECT_EQ(2, ptr.getShared()->value);
      loads[2].completed();

      loads[6].waitForRequest();
      EXPECT_EQ(5, ptr.getShared()->value);
      loads[6].completed();
    });

    loads[0].requestAndWait();

    ptr.reset(std::make_unique<TestObject>(1, cnt));
    loads[1].requestAndWait();

    ptr.reset(std::make_unique<TestObject>(2, cnt));
    loads[2].requestAndWait();
    loads[3].requestAndWait();

    ptr.reset(std::make_unique<TestObject>(3, cnt));
    ptr.reset(std::make_unique<TestObject>(4, cnt));
    loads[4].requestAndWait();

    ptr.reset(std::make_unique<TestObject>(5, cnt));
    loads[5].requestAndWait();
    loads[6].requestAndWait();

    EXPECT_EQ(1, cnt.load());

    t1.join();
    t2.join();
  }

  EXPECT_EQ(0, cnt.load());
}

TEST_F(ReadMostlySharedPtrTest, Ctor) {
  std::atomic<int> cnt1{0};
  {
    ReadMostlyMainPtr<TestObject> ptr(std::make_unique<TestObject>(1, cnt1));

    EXPECT_EQ(1, ptr.getShared()->value);
  }

  EXPECT_EQ(0, cnt1.load());
}

TEST_F(ReadMostlySharedPtrTest, ClearingCache) {
  std::atomic<int> cnt1{0};
  std::atomic<int> cnt2{0};

  ReadMostlyMainPtr<TestObject> ptr;

  // Store 1.
  ptr.reset(std::make_unique<TestObject>(1, cnt1));

  Coordinator c;

  std::thread t([&] {
    // Cache the pointer for this thread.
    ptr.getShared();
    c.requestAndWait();
  });

  // Wait for the thread to cache pointer.
  c.waitForRequest();
  EXPECT_EQ(1, cnt1.load());

  // Store 2 and check that 1 is destroyed.
  ptr.reset(std::make_unique<TestObject>(2, cnt2));
  EXPECT_EQ(0, cnt1.load());

  // Unblock thread.
  c.completed();
  t.join();
}

size_t useGlobalCalls = 0;

class TestRefCount {
 public:
  ~TestRefCount() noexcept { DCHECK_EQ(count_.load(), 0); }

  int64_t operator++() noexcept {
    auto ret = ++count_;
    DCHECK_GT(ret, 0);
    return ret;
  }

  int64_t operator--() noexcept {
    auto ret = --count_;
    DCHECK_GE(ret, 0);
    return ret;
  }

  int64_t operator*() noexcept { return count_.load(); }

  void useGlobal() { ++useGlobalCalls; }

  template <typename Container>
  static void useGlobal(const Container&) {
    ++useGlobalCalls;
  }

 private:
  std::atomic<int64_t> count_{1};
};

TEST_F(ReadMostlySharedPtrTest, ReadMostlyMainPtrDeleter) {
  EXPECT_EQ(0, useGlobalCalls);
  {
    ReadMostlyMainPtr<int, TestRefCount> ptr1(std::make_shared<int>(42));
    ReadMostlyMainPtr<int, TestRefCount> ptr2(std::make_shared<int>(42));
  }

  EXPECT_EQ(4, useGlobalCalls);

  useGlobalCalls = 0;
  {
    ReadMostlyMainPtr<int, TestRefCount> ptr1(std::make_shared<int>(42));
    ReadMostlyMainPtr<int, TestRefCount> ptr2(std::make_shared<int>(42));

    ReadMostlyMainPtrDeleter<TestRefCount> deleter;
    deleter.add(std::move(ptr1));
    deleter.add(std::move(ptr2));
  }

  EXPECT_EQ(1, useGlobalCalls);
}

TEST_F(ReadMostlySharedPtrTest, nullptr) {
  {
    ReadMostlyMainPtr<int, TestRefCount> nptr;
    EXPECT_TRUE(nptr == nullptr);
    EXPECT_TRUE(nullptr == nptr);
    EXPECT_EQ(nptr, nullptr);
    EXPECT_EQ(nullptr, nptr);
    EXPECT_FALSE(nptr);
    EXPECT_TRUE(!nptr);

    ReadMostlyMainPtr<int, TestRefCount> ptr(std::make_shared<int>(42));
    EXPECT_FALSE(ptr == nullptr);
    EXPECT_FALSE(nullptr == ptr);
    EXPECT_NE(ptr, nullptr);
    EXPECT_NE(nullptr, ptr);
    EXPECT_FALSE(!ptr);
    EXPECT_TRUE(ptr);
  }
  {
    ReadMostlySharedPtr<int, TestRefCount> nptr;
    EXPECT_TRUE(nptr == nullptr);
    EXPECT_TRUE(nullptr == nptr);
    EXPECT_EQ(nptr, nullptr);
    EXPECT_EQ(nullptr, nptr);
    EXPECT_FALSE(nptr);
    EXPECT_TRUE(!nptr);

    ReadMostlyMainPtr<int, TestRefCount> ptr(std::make_shared<int>(42));
    EXPECT_FALSE(ptr == nullptr);
    EXPECT_FALSE(nullptr == ptr);
    EXPECT_NE(ptr, nullptr);
    EXPECT_NE(nullptr, ptr);
    EXPECT_FALSE(!ptr);
    EXPECT_TRUE(ptr);
  }
}

TEST_F(ReadMostlySharedPtrTest, getStdShared) {
  const ReadMostlyMainPtr<int> rmmp1(std::make_shared<int>(42));

  ReadMostlyMainPtr<int> rmmp2;
  rmmp2.reset(rmmp1.getStdShared());

  const ReadMostlySharedPtr<int> rmsp1 = rmmp1.getShared();
  ReadMostlySharedPtr<int> rmsp2(rmsp1);

  // No conditions to check; we just wanted to ensure this compiles.
  SUCCEED();
}

struct Base {
  virtual ~Base() = default;

  virtual std::string getName() const { return "Base"; }
};

struct Derived : public Base {
  std::string getName() const override { return "Derived"; }
};

TEST_F(ReadMostlySharedPtrTest, casts) {
  ReadMostlyMainPtr<Derived> rmmp(std::make_shared<Derived>());
  ReadMostlySharedPtr<Derived> rmsp(rmmp);
  {
    ReadMostlySharedPtr<Base> rmspbase(rmmp);
    EXPECT_EQ("Derived", rmspbase->getName());
    EXPECT_EQ("Derived", rmspbase.getStdShared()->getName());
  }
  {
    ReadMostlySharedPtr<Base> rmspbase(rmsp);
    EXPECT_EQ("Derived", rmspbase->getName());
    EXPECT_EQ("Derived", rmspbase.getStdShared()->getName());
  }
  {
    ReadMostlySharedPtr<Base> rmspbase;
    rmspbase = rmsp;
    EXPECT_EQ("Derived", rmspbase->getName());
    EXPECT_EQ("Derived", rmspbase.getStdShared()->getName());
  }
  {
    auto rmspcopy = rmsp;
    ReadMostlySharedPtr<Base> rmspbase(std::move(rmspcopy));
    EXPECT_EQ("Derived", rmspbase->getName());
    EXPECT_EQ("Derived", rmspbase.getStdShared()->getName());
  }
  {
    auto rmspcopy = rmsp;
    ReadMostlySharedPtr<Base> rmspbase;
    rmspbase = std::move(rmspcopy);
    EXPECT_EQ("Derived", rmspbase->getName());
    EXPECT_EQ("Derived", rmspbase.getStdShared()->getName());
  }
}