folly/folly/coro/test/SharedMutexTest.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <folly/Portability.h>

#include <folly/executors/CPUThreadPoolExecutor.h>
#include <folly/executors/ManualExecutor.h>
#include <folly/experimental/coro/Baton.h>
#include <folly/experimental/coro/BlockingWait.h>
#include <folly/experimental/coro/SharedMutex.h>
#include <folly/experimental/coro/Task.h>
#include <folly/portability/GTest.h>

#include <mutex>

#if FOLLY_HAS_COROUTINES

using namespace folly;

class SharedMutexTest : public testing::Test {};

TEST_F(SharedMutexTest, TryLock) {
  coro::SharedMutex m;

  CHECK(m.try_lock());
  CHECK(!m.try_lock());
  CHECK(!m.try_lock_shared());
  m.unlock();

  CHECK(m.try_lock_shared());
  CHECK(!m.try_lock());
  CHECK(m.try_lock_shared());
  CHECK(!m.try_lock());
  m.unlock_shared();
  CHECK(!m.try_lock());
  CHECK(m.try_lock_shared());
  m.unlock_shared();
  m.unlock_shared();

  CHECK(m.try_lock());
  m.unlock();
}

TEST_F(SharedMutexTest, ManualLockAsync) {
  coro::SharedMutex mutex;
  int value = 0;

  auto makeReaderTask = [&](coro::Baton& b) -> coro::Task<int> {
    co_await mutex.co_lock_shared();
    int valueCopy = value;
    co_await b;
    mutex.unlock_shared();
    co_return valueCopy;
  };

  auto makeWriterTask = [&](coro::Baton& b) -> coro::Task<void> {
    co_await mutex.co_lock();
    co_await b;
    value += 1;
    mutex.unlock();
  };

  ManualExecutor executor;

  {
    coro::Baton b1;
    coro::Baton b2;
    coro::Baton b3;
    coro::Baton b4;
    coro::Baton b5;

    auto r1 = makeReaderTask(b1).scheduleOn(&executor).start();
    auto r2 = makeReaderTask(b2).scheduleOn(&executor).start();
    auto w1 = makeWriterTask(b3).scheduleOn(&executor).start();
    auto w2 = makeWriterTask(b4).scheduleOn(&executor).start();
    auto r3 = makeReaderTask(b5).scheduleOn(&executor).start();
    executor.drain();

    b1.post();
    executor.drain();
    CHECK_EQ(0, std::move(r1).get());

    b2.post();
    executor.drain();
    CHECK_EQ(0, std::move(r2).get());

    b3.post();
    executor.drain();
    CHECK_EQ(1, value);

    b4.post();
    executor.drain();
    CHECK_EQ(2, value);

    // This reader should have had to wait for the prior two write locks
    // to complete before it acquired the read-lock.
    b5.post();
    executor.drain();
    CHECK_EQ(2, std::move(r3).get());
  }
}

TEST_F(SharedMutexTest, ScopedLockAsync) {
  coro::SharedMutex mutex;
  int value = 0;

  auto makeReaderTask = [&](coro::Baton& b) -> coro::Task<int> {
    auto lock = co_await mutex.co_scoped_lock_shared();
    co_await b;
    co_return value;
  };

  auto makeWriterTask = [&](coro::Baton& b) -> coro::Task<void> {
    auto lock = co_await mutex.co_scoped_lock();
    co_await b;
    value += 1;
  };

  ManualExecutor executor;

  {
    coro::Baton b1;
    coro::Baton b2;
    coro::Baton b3;
    coro::Baton b4;
    coro::Baton b5;

    auto r1 = makeReaderTask(b1).scheduleOn(&executor).start();
    auto r2 = makeReaderTask(b2).scheduleOn(&executor).start();
    auto w1 = makeWriterTask(b3).scheduleOn(&executor).start();
    auto w2 = makeWriterTask(b4).scheduleOn(&executor).start();
    auto r3 = makeReaderTask(b5).scheduleOn(&executor).start();

    b1.post();
    executor.drain();
    CHECK_EQ(0, std::move(r1).get());

    b2.post();
    executor.drain();
    CHECK_EQ(0, std::move(r2).get());

    b3.post();
    executor.drain();
    CHECK_EQ(1, value);

    b4.post();
    executor.drain();
    CHECK_EQ(2, value);

    // This reader should have had to wait for the prior two write locks
    // to complete before it acquired the read-lock.
    b5.post();
    executor.drain();
    CHECK_EQ(2, std::move(r3).get());
  }
}

TEST_F(SharedMutexTest, ThreadSafety) {
  // Spin up a thread-pool with 3 threads and 6 coroutines
  // (2 writers, 4 readers) that are constantly spinning in a loop reading
  // and modifying some shared state.

  CPUThreadPoolExecutor threadPool{
      3, std::make_shared<NamedThreadFactory>("TestThreadPool")};

  static constexpr int iterationCount = 100'000;

  coro::SharedMutex mutex;
  int value1 = 0;
  int value2 = 0;

  auto makeWriterTask = [&]() -> coro::Task<void> {
    for (int i = 0; i < iterationCount; ++i) {
      auto lock = co_await mutex.co_scoped_lock();
      ++value1;
      ++value2;
    }
  };

  auto makeReaderTask = [&]() -> coro::Task<void> {
    for (int i = 0; i < iterationCount; ++i) {
      auto lock = co_await mutex.co_scoped_lock_shared();
      CHECK_EQ(value1, value2);
    }
  };

  auto w1 = makeWriterTask().scheduleOn(&threadPool).start();
  auto w2 = makeWriterTask().scheduleOn(&threadPool).start();
  auto r1 = makeReaderTask().scheduleOn(&threadPool).start();
  auto r2 = makeReaderTask().scheduleOn(&threadPool).start();
  auto r3 = makeReaderTask().scheduleOn(&threadPool).start();
  auto r4 = makeReaderTask().scheduleOn(&threadPool).start();

  std::move(w1).get();
  std::move(w2).get();
  std::move(r1).get();
  std::move(r2).get();
  std::move(r3).get();
  std::move(r4).get();

  CHECK_EQ(value1, 2 * iterationCount);
  CHECK_EQ(value2, 2 * iterationCount);
}

#endif