folly/folly/io/async/test/RequestContextTest.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <atomic>
#include <cstdint>
#include <thread>

#include <folly/Memory.h>
#include <folly/io/async/EventBase.h>
#include <folly/io/async/Request.h>
#include <folly/io/async/test/RequestContextHelper.h>
#include <folly/portability/GTest.h>
#include <folly/system/ThreadName.h>

#include <boost/thread/barrier.hpp>

using namespace folly;

RequestToken testtoken("test");

class RequestContextTest : public ::testing::Test {
 protected:
  void SetUp() override {
    // Make sure each test starts out using the default context, and not some
    // other context left over by a previous test.
    RequestContext::setContext(nullptr);

    // Make sure no data is set for the "test" key when we start.  There could
    // be left over data in the default context from a previous test.  If we
    // don't clear it out future calls to setContextData() won't actually work,
    // and will reset the data to null instead of properly setting the new
    // desired data.
    //
    // (All of the tests generally want the behavior of overwriteContextData()
    // rather than setContextData(), but that method is private.)
    //
    // We ideally want to clear out data for any keys that may be set, not just
    // the "test" key, but there also isn't a RequestContext API to do this.
    clearData();
  }

  RequestContext& getContext() {
    auto* ctx = RequestContext::get();
    EXPECT_TRUE(ctx != nullptr);
    return *ctx;
  }

  void setData(int data = 0, std::string key = "test") {
    getContext().setContextData(key, std::make_unique<TestData>(data));
  }

  bool hasData(std::string key = "test") {
    return getContext().hasContextData(key);
  }

  const TestData& getData(std::string key = "test") {
    auto* ptr = dynamic_cast<TestData*>(getContext().getContextData(key));
    EXPECT_TRUE(ptr != nullptr);
    return *ptr;
  }

  void clearData(std::string key = "test") {
    getContext().clearContextData(key);
  }

  std::vector<intptr_t> getRootIdsFromAllThreads() {
    auto rootids = RequestContext::getRootIdsFromAllThreads();
    std::vector<intptr_t> result;
    std::transform(
        rootids.begin(), rootids.end(), std::back_inserter(result), [](auto e) {
          return e.id;
        });
    return result;
  }
};

TEST_F(RequestContextTest, SimpleTest) {
  EventBase base;

  // There should always be a default context with get()
  EXPECT_TRUE(RequestContext::get() != nullptr);
  // but fallback context should have rootid set to 0
  EXPECT_EQ(RequestContext::get()->getRootId(), 0);

  // but not with saveContext()
  EXPECT_EQ(RequestContext::saveContext(), nullptr);
  RequestContext::create();
  EXPECT_NE(RequestContext::saveContext(), nullptr);
  auto rootids = getRootIdsFromAllThreads();
  EXPECT_EQ(1, rootids.size());
  EXPECT_EQ(RequestContext::get()->getRootId(), rootids[0]);
  EXPECT_EQ(reinterpret_cast<intptr_t>(RequestContext::get()), rootids[0]);
  EXPECT_NE(RequestContext::get()->getRootId(), 0);
  RequestContext::create();
  EXPECT_NE(RequestContext::saveContext(), nullptr);
  EXPECT_NE(RequestContext::get()->getRootId(), rootids[0]);

  EXPECT_EQ(nullptr, RequestContext::get()->getContextData("test"));

  RequestContext::get()->setContextData("test", std::make_unique<TestData>(10));
  base.runInEventBaseThread([&]() {
    EXPECT_TRUE(RequestContext::get() != nullptr);
    auto data = dynamic_cast<TestData*>(
                    RequestContext::get()->getContextData(testtoken))
                    ->data_;
    EXPECT_EQ(10, data);
    rootids = getRootIdsFromAllThreads();
    EXPECT_EQ(2, rootids.size());
    EXPECT_EQ(RequestContext::get()->getRootId(), rootids[0]);
    EXPECT_EQ(RequestContext::get()->getRootId(), rootids[1]);
    base.terminateLoopSoon();
  });
  auto th = std::thread([&]() { base.loopForever(); });
  th.join();
  EXPECT_TRUE(RequestContext::get() != nullptr);
  auto a =
      dynamic_cast<TestData*>(RequestContext::get()->getContextData("test"));
  auto data = a->data_;
  EXPECT_EQ(10, data);

  RequestContext::setContext(std::shared_ptr<RequestContext>());
  // There should always be a default context
  EXPECT_TRUE(nullptr != RequestContext::get());
}

TEST_F(RequestContextTest, RequestContextScopeGuard) {
  RequestContextScopeGuard g0;
  setData(10);
  {
    RequestContextScopeGuard g1;
    EXPECT_FALSE(hasData());
    setData(20);
    EXPECT_EQ(20, getData().data_);
    EXPECT_EQ(1, getData().set_);
    EXPECT_EQ(0, getData().unset_);
  }
  EXPECT_EQ(10, getData().data_);
  EXPECT_EQ(2, getData().set_);
  EXPECT_EQ(1, getData().unset_);
}

TEST_F(RequestContextTest, defaultContext) {
  // Don't create a top level guard
  setData(10);
  {
    RequestContextScopeGuard g1;
    EXPECT_FALSE(hasData());
  }
  EXPECT_EQ(10, getData().data_);
  EXPECT_EQ(1, getData().set_);
  EXPECT_EQ(0, getData().unset_);
}

TEST_F(RequestContextTest, setIfAbsentTest) {
  EXPECT_TRUE(RequestContext::get() != nullptr);

  RequestContext::get()->setContextData("test", std::make_unique<TestData>(10));
  EXPECT_FALSE(RequestContext::get()->setContextDataIfAbsent(
      "test", std::make_unique<TestData>(20)));
  EXPECT_EQ(
      10,
      dynamic_cast<TestData*>(RequestContext::get()->getContextData(testtoken))
          ->data_);

  EXPECT_TRUE(RequestContext::get()->setContextDataIfAbsent(
      "test2", std::make_unique<TestData>(20)));
  EXPECT_EQ(
      20,
      dynamic_cast<TestData*>(RequestContext::get()->getContextData("test2"))
          ->data_);

  RequestContext::setContext(std::shared_ptr<RequestContext>());
  EXPECT_TRUE(nullptr != RequestContext::get());
}

TEST_F(RequestContextTest, testSetUnset) {
  RequestContext::create();
  auto ctx1 = RequestContext::saveContext();
  ctx1->setContextData("test", std::make_unique<TestData>(10));
  auto testData1 = dynamic_cast<TestData*>(ctx1->getContextData("test"));

  // onSet called in setContextData
  EXPECT_EQ(1, testData1->set_);
  EXPECT_EQ(ctx1.get(), testData1->onSetRctx);

  // Override RequestContext
  RequestContext::create();
  auto ctx2 = RequestContext::saveContext();
  ctx2->setContextData(testtoken, std::make_unique<TestData>(20));
  auto testData2 = dynamic_cast<TestData*>(ctx2->getContextData(testtoken));

  // onSet called in setContextData
  EXPECT_EQ(1, testData2->set_);
  EXPECT_EQ(ctx2.get(), testData2->onSetRctx);

  // Check ctx1->onUnset was called
  EXPECT_EQ(1, testData1->unset_);
  EXPECT_EQ(ctx1.get(), testData1->onUnSetRctx);

  RequestContext::setContext(ctx1);
  EXPECT_EQ(2, testData1->set_);
  EXPECT_EQ(1, testData1->unset_);
  EXPECT_EQ(1, testData2->unset_);
  EXPECT_EQ(ctx1.get(), testData1->onSetRctx);
  EXPECT_EQ(ctx1.get(), testData1->onUnSetRctx);
  EXPECT_EQ(ctx2.get(), testData2->onUnSetRctx);

  RequestContext::setContext(ctx2);
  EXPECT_EQ(2, testData1->set_);
  EXPECT_EQ(2, testData1->unset_);
  EXPECT_EQ(2, testData2->set_);
  EXPECT_EQ(1, testData2->unset_);
}

TEST_F(RequestContextTest, deadlockTest) {
  class DeadlockTestData : public RequestData {
   public:
    explicit DeadlockTestData(const std::string& val) : val_(val) {}

    ~DeadlockTestData() override {
      RequestContext::get()->setContextData(
          val_, std::make_unique<TestData>(1));
    }

    bool hasCallback() override { return false; }

    std::string val_;
  };

  RequestContext::get()->setContextData(
      "test", std::make_unique<DeadlockTestData>("test1"));
  RequestContext::get()->clearContextData(testtoken);
}

// A common use case is to use set/unset to maintain a thread global
// Regression test to ensure that unset is always called before set
TEST_F(RequestContextTest, sharedGlobalTest) {
  static bool global = false;

  class GlobalTestData : public RequestData {
   public:
    void onSet() override {
      ASSERT_FALSE(global);
      global = true;
    }

    void onUnset() override {
      ASSERT_TRUE(global);
      global = false;
    }

    bool hasCallback() override { return true; }
  };

  intptr_t root = 0;
  {
    RequestContextScopeGuard g0;
    RequestContext::get()->setContextData(
        "test", std::make_unique<GlobalTestData>());
    auto root0 = RequestContext::saveContext().get()->getRootId();
    EXPECT_EQ(getRootIdsFromAllThreads()[0], root0);
    {
      RequestContextScopeGuard g1;
      RequestContext::get()->setContextData(
          "test", std::make_unique<GlobalTestData>());
      auto root1 = RequestContext::saveContext().get()->getRootId();
      EXPECT_EQ(getRootIdsFromAllThreads()[0], root1);
    }
    EXPECT_EQ(getRootIdsFromAllThreads()[0], root0);
  }
  EXPECT_EQ(getRootIdsFromAllThreads()[0], root);
}

TEST_F(RequestContextTest, ShallowCopyBasic) {
  ShallowCopyRequestContextScopeGuard g0;
  setData(123, "immutable");
  EXPECT_EQ(123, getData("immutable").data_);
  EXPECT_FALSE(hasData());
  EXPECT_EQ(0, getRootIdsFromAllThreads()[0]);

  {
    ShallowCopyRequestContextScopeGuard g1;
    EXPECT_EQ(123, getData("immutable").data_);
    setData(789);
    EXPECT_EQ(789, getData().data_);
    EXPECT_EQ(0, getRootIdsFromAllThreads()[0]);
  }

  EXPECT_FALSE(hasData());
  EXPECT_EQ(123, getData("immutable").data_);
  EXPECT_EQ(1, getData("immutable").set_);
  EXPECT_EQ(0, getData("immutable").unset_);
  EXPECT_EQ(0, getRootIdsFromAllThreads()[0]);
}

TEST_F(RequestContextTest, ShallowCopyOverwrite) {
  RequestContextScopeGuard g0;
  setData(123);
  EXPECT_EQ(123, getData().data_);
  auto rootid = RequestContext::get()->getRootId();
  EXPECT_EQ(rootid, getRootIdsFromAllThreads()[0]);
  {
    ShallowCopyRequestContextScopeGuard g1(
        "test", std::make_unique<TestData>(789));
    EXPECT_EQ(789, getData().data_);
    EXPECT_EQ(1, getData().set_);
    EXPECT_EQ(0, getData().unset_);
    // should have inherited parent's rootid
    EXPECT_EQ(rootid, getRootIdsFromAllThreads()[0]);

    {
      // rootId is preserved for shallow copies of shallow copies
      ShallowCopyRequestContextScopeGuard g2;
      EXPECT_EQ(rootid, getRootIdsFromAllThreads()[0]);
    }
    EXPECT_EQ(rootid, getRootIdsFromAllThreads()[0]);
  }
  EXPECT_EQ(123, getData().data_);
  EXPECT_EQ(2, getData().set_);
  EXPECT_EQ(1, getData().unset_);
  EXPECT_EQ(rootid, getRootIdsFromAllThreads()[0]);
}

TEST_F(RequestContextTest, ShallowCopyDefaultContext) {
  // Don't set global scope guard
  setData(123);
  EXPECT_EQ(123, getData().data_);
  {
    ShallowCopyRequestContextScopeGuard g1(
        "test", std::make_unique<TestData>(789));
    EXPECT_EQ(789, getData().data_);
  }
  EXPECT_EQ(123, getData().data_);
  EXPECT_EQ(1, getData().set_);
  EXPECT_EQ(0, getData().unset_);
}

TEST_F(RequestContextTest, ShallowCopyClear) {
  RequestContextScopeGuard g0;
  setData(123);
  EXPECT_EQ(123, getData().data_);
  {
    ShallowCopyRequestContextScopeGuard g1;
    EXPECT_EQ(123, getData().data_);
    clearData();
    setData(789);
    EXPECT_EQ(789, getData().data_);
  }
  EXPECT_EQ(123, getData().data_);
  EXPECT_EQ(2, getData().set_);
  EXPECT_EQ(1, getData().unset_);
}

TEST_F(RequestContextTest, ShallowCopyMulti) {
  RequestContextScopeGuard g0;
  setData(1, "test1");
  setData(2, "test2");
  EXPECT_EQ(1, getData("test1").data_);
  EXPECT_EQ(2, getData("test2").data_);
  {
    ShallowCopyRequestContextScopeGuard g1(
        RequestDataItem{"test1", std::make_unique<TestData>(2)},
        RequestDataItem{"test2", std::make_unique<TestData>(4)});

    EXPECT_EQ(2, getData("test1").data_);
    EXPECT_EQ(4, getData("test2").data_);
    clearData("test1");
    clearData("test2");
    setData(4, "test1");
    setData(8, "test2");
    EXPECT_EQ(4, getData("test1").data_);
    EXPECT_EQ(8, getData("test2").data_);
  }
  EXPECT_EQ(1, getData("test1").data_);
  EXPECT_EQ(2, getData("test2").data_);
}

TEST_F(RequestContextTest, RootIdOnCopy) {
  auto ctxBase = std::make_shared<RequestContext>(0xab);
  EXPECT_EQ(0xab, ctxBase->getRootId());
  {
    auto ctx = RequestContext::copyAsRoot(*ctxBase, 0xabc);
    EXPECT_EQ(0xabc, ctx->getRootId());
  }
  {
    auto ctx = RequestContext::copyAsChild(*ctxBase);
    EXPECT_EQ(0xab, ctx->getRootId());
  }
}

TEST_F(RequestContextTest, ThreadId) {
  folly::setThreadName("DummyThread");
  RequestContextScopeGuard g;
  auto ctxBase = std::make_shared<RequestContext>();
  auto rootids = RequestContext::getRootIdsFromAllThreads();
  EXPECT_EQ(*folly::getThreadName(rootids[0].tid), "DummyThread");
  EXPECT_EQ(rootids[0].tidOS, folly::getOSThreadID());

  EventBase base;
  base.runInEventBaseThread([&]() {
    RequestContextScopeGuard g_;
    folly::setThreadName("DummyThread2");
    rootids = RequestContext::getRootIdsFromAllThreads();
    base.terminateLoopSoon();
  });

  auto th = std::thread([&]() { base.loopForever(); });
  th.join();

  std::sort(rootids.begin(), rootids.end(), [](const auto& a, const auto& b) {
    auto aname = folly::getThreadName(a.tid);
    auto bname = folly::getThreadName(b.tid);
    return (aname ? *aname : "zzz") < (bname ? *bname : "zzz");
  });

  EXPECT_EQ(*folly::getThreadName(rootids[0].tid), "DummyThread");
  EXPECT_FALSE(folly::getThreadName(rootids[1].tid));
}

TEST_F(RequestContextTest, Clear) {
  struct Foo : public RequestData {
    bool& cleared;
    bool& deleted;
    Foo(bool& c, bool& d) : cleared(c), deleted(d) {}
    ~Foo() override {
      EXPECT_TRUE(cleared);
      deleted = true;
    }
    bool hasCallback() override { return false; }
    void onClear() override {
      EXPECT_FALSE(cleared);
      cleared = true;
    }
  };

  std::string key = "clear";
  {
    bool cleared = false;
    bool deleted = false;
    {
      RequestContextScopeGuard g;
      RequestContext::get()->setContextData(
          key, std::make_unique<Foo>(cleared, deleted));
      EXPECT_FALSE(cleared);
      RequestContext::get()->clearContextData(key);
      EXPECT_TRUE(cleared);
    }
    EXPECT_TRUE(deleted);
  }
  {
    bool cleared = false;
    bool deleted = false;
    {
      RequestContextScopeGuard g;
      RequestContext::get()->setContextData(
          key, std::make_unique<Foo>(cleared, deleted));
      EXPECT_FALSE(cleared);
      EXPECT_FALSE(deleted);
    }
    EXPECT_TRUE(cleared);
    EXPECT_TRUE(deleted);
  }
}

TEST_F(RequestContextTest, OverwriteNullData) {
  folly::ShallowCopyRequestContextScopeGuard g0("token", nullptr);
  {
    folly::ShallowCopyRequestContextScopeGuard g1(
        "token", std::make_unique<TestData>(0));
    EXPECT_NE(folly::RequestContext::get()->getContextData("token"), nullptr);
  }
}

TEST_F(RequestContextTest, ConcurrentDataRefRelease) {
  for (int i = 0; i < 100; ++i) {
    std::atomic<int> step{0};
    std::shared_ptr<folly::RequestContext> sp1;
    auto th1 = std::thread([&]() {
      folly::RequestContextScopeGuard g0; // Creates ctx0.
      setData(); // Creates data0 with one reference in ctx0.
      {
        folly::ShallowCopyRequestContextScopeGuard g1;
        // g1 created ctx1 with second reference to data0.
        EXPECT_NE(&getData(), nullptr);
        // Keep shared_ptr to ctx1 to pass to th2
        sp1 = folly::RequestContext::saveContext();
        step.store(1); // sp1 is ready.
        while (step.load() < 2)
          /* Wait for th2 to clear reference to data0. */;
      }
      // End of g2 released shared_ptr to ctx1, switched back to ctx0
      // At this point:
      // - One shared_ptr to ctx0, held by th1.
      // - One shared_ptr to ctx1, help by th2.
      // - data0 has one clear count (for reference from ctx0) and
      //   two delete counts (one each from ctx0 and ctx1).
      step.store(3);
      // End of g1 will destroy ctx0, release clear/delete counts for data0.
    });
    auto th2 = std::thread([&]() {
      while (step.load() < 1)
        /* Wait for th1 to set sp1. */;
      folly::RequestContextScopeGuard g2(std::move(sp1));
      // g2 set context to ctx1.
      EXPECT_EQ(sp1.get(), nullptr);
      EXPECT_NE(&getData(), nullptr);
      clearData();
      step.store(2); // th2 cleared reference to data0 in ctx1.
      while (step.load() < 3)
        /* Wait for th1 to release shared_ptr to ctx1. */;
      // End of g2 will destroy ctx1, release delete count for data0.
    });
    th1.join();
    th2.join();
  }
}

TEST_F(RequestContextTest, AccessAllThreadsDestructionGuard) {
  constexpr auto kNumThreads = 128;

  std::vector<std::thread> threads{kNumThreads};
  boost::barrier barrier{kNumThreads + 1};

  std::atomic<std::size_t> count{0};
  for (auto& thread : threads) {
    thread = std::thread([&] {
      // Force creation of thread local
      RequestContext::get();
      ++count;
      // Wait for all other threads to do the same
      barrier.wait();
      // Wait until signaled to die
      barrier.wait();
    });
  }

  barrier.wait();
  // Sanity check
  EXPECT_EQ(count.load(), kNumThreads);

  {
    auto accessor = RequestContext::accessAllThreads();
    // Allow threads to die (but they should not as long as we hold accessor!)
    barrier.wait();
    auto accessorsCount = std::distance(accessor.begin(), accessor.end());
    EXPECT_EQ(accessorsCount, kNumThreads + 1);
    for (RequestContext::StaticContext& staticContext : accessor) {
      EXPECT_EQ(staticContext.requestContext, nullptr);
    }
  }

  for (auto& thread : threads) {
    thread.join();
  }
}

namespace {

struct KeyATraits {
  static inline const RequestToken kToken{"keyA"};
};

struct KeyBTraits {
  static inline const RequestToken kToken{"keyB"};
};

} // namespace

TEST_F(RequestContextTest, GetThreadCachedContextData) {
  auto makeData = [](int value) {
    return std::make_unique<ImmutableRequestData<int>>(value);
  };

  auto getData = [](auto traits) {
    auto* data = RequestContext::try_get()
                     ->getThreadCachedContextData<decltype(traits)>();
    CHECK(data != nullptr);
    auto* idata = dynamic_cast<ImmutableRequestData<int>*>(data);
    CHECK(idata != nullptr);
    return idata;
  };

  RequestContextScopeGuard guard;

  RequestContext::try_get()->setContextData(KeyATraits::kToken, makeData(1));
  RequestContext::try_get()->setContextData(KeyBTraits::kToken, makeData(2));

  EXPECT_EQ(getData(KeyATraits{})->value(), 1);
  EXPECT_EQ(getData(KeyBTraits{})->value(), 2);

  RequestContext::try_get()->overwriteContextData(
      KeyATraits::kToken, makeData(3));
  EXPECT_EQ(getData(KeyATraits{})->value(), 3);
  EXPECT_EQ(getData(KeyBTraits{})->value(), 2);

  RequestContext::try_get()->clearContextData(KeyATraits::kToken);
  EXPECT_TRUE(
      RequestContext::try_get()->getThreadCachedContextData<KeyATraits>() ==
      nullptr);
  EXPECT_EQ(getData(KeyBTraits{})->value(), 2);

  // Invalidations are delivered from other threads too.
  std::thread([&, ctx = RequestContext::saveContext()] {
    RequestContextScopeGuard guard2(ctx);
    RequestContext::try_get()->setContextData(KeyATraits::kToken, makeData(4));
  }).join();
  EXPECT_EQ(getData(KeyATraits{})->value(), 4);
  EXPECT_EQ(getData(KeyBTraits{})->value(), 2);

  // Caches are not leaked when switching request context.
  {
    RequestContextScopeGuard guard3;
    EXPECT_TRUE(
        RequestContext::try_get()->getThreadCachedContextData<KeyATraits>() ==
        nullptr);
    EXPECT_TRUE(
        RequestContext::try_get()->getThreadCachedContextData<KeyBTraits>() ==
        nullptr);
  }
}

TEST(RequestContextTryGetTest, TryGetTest) {
  // try_get() should not create a default RequestContext object if none exists.
  EXPECT_EQ(RequestContext::try_get(), nullptr);
  // Explicitly create a new instance so that subsequent calls to try_get()
  // return it.
  RequestContext::create();
  EXPECT_NE(RequestContext::saveContext(), nullptr);
  EXPECT_NE(RequestContext::try_get(), nullptr);
  // Make sure that the pointers returned by both get() and try_get() point to
  // the same underlying instance.
  EXPECT_EQ(RequestContext::try_get(), RequestContext::get());
  // Set some context data and read it out via try_get() accessor.
  RequestContext::get()->setContextData("test", std::make_unique<TestData>(10));
  auto rc = RequestContext::try_get();
  EXPECT_TRUE(rc->hasContextData("test"));
  auto* dataPtr = dynamic_cast<TestData*>(rc->getContextData("test"));
  EXPECT_EQ(dataPtr->data_, 10);

  auto thread = std::thread([&] {
    auto accessor = RequestContext::accessAllThreads();
    // test there is no deadlock with try_get()
    RequestContext::try_get();
  });
  thread.join();

  thread = std::thread([&] {
    RequestContext::get();
    auto accessor = RequestContext::accessAllThreads();
    // test there is no deadlock with get()
    RequestContext::get();
  });
  thread.join();
}

TEST(ImmutableRequestTest, simple) {
  ImmutableRequestData<int> ird(4);
  EXPECT_EQ(ird.value(), 4);
}

TEST(ImmutableRequestTest, typeTraits) {
  using IRDI = ImmutableRequestData<int>;

  auto c1 = std::is_constructible<IRDI, int>::value;
  EXPECT_TRUE(c1);
  auto n1 = std::is_nothrow_constructible<IRDI, int>::value;
  EXPECT_TRUE(n1);

  auto c2 = std::is_constructible<IRDI, int, int>::value;
  EXPECT_FALSE(c2);
}