folly/folly/synchronization/test/LifoSemTests.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <folly/synchronization/LifoSem.h>

#include <thread>

#include <folly/Random.h>
#include <folly/portability/Asm.h>
#include <folly/portability/GFlags.h>
#include <folly/portability/GTest.h>
#include <folly/synchronization/NativeSemaphore.h>
#include <folly/test/DeterministicSchedule.h>

using namespace folly;
using namespace folly::test;

typedef LifoSemImpl<DeterministicAtomic> DLifoSem;
typedef DeterministicSchedule DSched;

class LifoSemTest : public testing::Test {
 private:
  // pre-init the pool to avoid deadlock when using DeterministicAtomic
  using Node = detail::LifoSemRawNode<DeterministicAtomic>;
  Node::Pool& pool_{Node::pool()};
};

TEST(LifoSem, basic) {
  LifoSem sem;
  EXPECT_FALSE(sem.tryPost());
  EXPECT_FALSE(sem.tryWait());
  sem.post();
  EXPECT_TRUE(sem.tryWait());
  sem.post();
  sem.wait();
}

TEST(LifoSem, multi) {
  LifoSem sem;

  const int opsPerThread = 10000;
  std::thread threads[10];
  std::atomic<int> blocks(0);

  for (auto& thr : threads) {
    thr = std::thread([&] {
      int b = 0;
      for (int i = 0; i < opsPerThread; ++i) {
        if (!sem.tryWait()) {
          sem.wait();
          ++b;
        }
        sem.post();
      }
      blocks += b;
    });
  }

  // start the flood
  sem.post();

  for (auto& thr : threads) {
    thr.join();
  }

  LOG(INFO) << opsPerThread * sizeof(threads) / sizeof(threads[0])
            << " post/wait pairs, " << blocks << " blocked";
}

TEST_F(LifoSemTest, pingpong) {
  DSched sched(DSched::uniform(0));

  const int iters = 100;

  for (int pass = 0; pass < 10; ++pass) {
    DLifoSem a;
    DLifoSem b;

    auto thr = DSched::thread([&] {
      for (int i = 0; i < iters; ++i) {
        a.wait();
        // main thread can't be running here
        EXPECT_EQ(a.valueGuess(), 0);
        EXPECT_EQ(b.valueGuess(), 0);
        b.post();
      }
    });
    for (int i = 0; i < iters; ++i) {
      a.post();
      b.wait();
      // child thread can't be running here
      EXPECT_EQ(a.valueGuess(), 0);
      EXPECT_EQ(b.valueGuess(), 0);
    }
    DSched::join(thr);
  }
}

TEST_F(LifoSemTest, mutex) {
  DSched sched(DSched::uniform(0));

  const int iters = 100;

  for (int pass = 0; pass < 10; ++pass) {
    DLifoSem a;

    auto thr = DSched::thread([&] {
      for (int i = 0; i < iters; ++i) {
        a.wait();
        a.post();
      }
    });
    for (int i = 0; i < iters; ++i) {
      a.post();
      a.wait();
    }
    a.post();
    DSched::join(thr);
    a.wait();
  }
}

TEST_F(LifoSemTest, no_blocking) {
  long seed = folly::randomNumberSeed() % 10000;
  LOG(INFO) << "seed=" << seed;
  DSched sched(DSched::uniform(seed));

  const int iters = 100;
  const int numThreads = 2;
  const int width = 10;

  for (int pass = 0; pass < 10; ++pass) {
    DLifoSem a;

    std::vector<std::thread> threads;
    while (threads.size() < numThreads) {
      threads.emplace_back(DSched::thread([&] {
        for (int i = 0; i < iters; ++i) {
          a.post(width);
          for (int w = 0; w < width; ++w) {
            a.wait();
          }
        }
      }));
    }
    for (auto& thr : threads) {
      DSched::join(thr);
    }
  }
}

TEST_F(LifoSemTest, one_way) {
  long seed = folly::randomNumberSeed() % 10000;
  LOG(INFO) << "seed=" << seed;
  DSched sched(DSched::uniformSubset(seed, 1, 6));

  const int iters = 1000;

  for (int pass = 0; pass < 10; ++pass) {
    DLifoSem a;

    auto thr = DSched::thread([&] {
      for (int i = 0; i < iters; ++i) {
        a.wait();
      }
    });
    for (int i = 0; i < iters; ++i) {
      a.post();
    }
    DSched::join(thr);
  }
}

TEST_F(LifoSemTest, shutdown_wait_order) {
  DLifoSem a;
  a.shutdown();
  a.post();
  a.wait();
  EXPECT_THROW(a.wait(), ShutdownSemError);
  EXPECT_TRUE(a.isShutdown());
}

TEST_F(LifoSemTest, shutdown_multi) {
  DSched sched(DSched::uniform(0));

  for (int pass = 0; pass < 10; ++pass) {
    DLifoSem a;
    std::vector<std::thread> threads;
    while (threads.size() < 20) {
      threads.push_back(DSched::thread([&] {
        try {
          a.wait();
          ADD_FAILURE();
        } catch (ShutdownSemError&) {
          // expected
          EXPECT_TRUE(a.isShutdown());
        }
      }));
    }
    a.shutdown();
    for (auto& thr : threads) {
      DSched::join(thr);
    }
  }
}

TEST(LifoSem, multiTryWaitSimple) {
  LifoSem sem;
  sem.post(5);
  auto n = sem.tryWait(10); // this used to trigger an assert
  ASSERT_EQ(5, n);
}

TEST_F(LifoSemTest, multi_try_wait) {
  long seed = folly::randomNumberSeed() % 10000;
  LOG(INFO) << "seed=" << seed;
  DSched sched(DSched::uniform(seed));
  DLifoSem sem;

  const int NPOSTS = 1000;

  auto producer = [&] {
    for (int i = 0; i < NPOSTS; ++i) {
      sem.post();
    }
  };

  DeterministicAtomic<bool> consumer_stop(false);
  int consumed = 0;

  auto consumer = [&] {
    bool stop;
    do {
      stop = consumer_stop.load();
      int n;
      do {
        n = sem.tryWait(10);
        consumed += n;
      } while (n > 0);
    } while (!stop);
  };

  std::thread producer_thread(DSched::thread(producer));
  std::thread consumer_thread(DSched::thread(consumer));
  DSched::join(producer_thread);
  consumer_stop.store(true);
  DSched::join(consumer_thread);

  ASSERT_EQ(NPOSTS, consumed);
}

TEST_F(LifoSemTest, timeout) {
  long seed = folly::randomNumberSeed() % 10000;
  LOG(INFO) << "seed=" << seed;
  DSched sched(DSched::uniform(seed));
  DeterministicAtomic<uint32_t> handoffs{0};

  for (int pass = 0; pass < 10; ++pass) {
    DLifoSem a;
    std::vector<std::thread> threads;
    while (threads.size() < 20) {
      threads.push_back(DSched::thread([&] {
        for (int i = 0; i < 10; i++) {
          try {
            if (a.try_wait_for(std::chrono::milliseconds(1))) {
              handoffs--;
            }
          } catch (ShutdownSemError&) {
            // expected
            EXPECT_TRUE(a.isShutdown());
          }
        }
      }));
    }
    std::vector<std::thread> threads2;
    while (threads2.size() < 20) {
      threads2.push_back(DSched::thread([&] {
        for (int i = 0; i < 10; i++) {
          a.post();
          handoffs++;
        }
      }));
    }
    if (pass > 5) {
      a.shutdown();
    }
    for (auto& thr : threads) {
      DSched::join(thr);
    }
    for (auto& thr : threads2) {
      DSched::join(thr);
    }
    // At least one timeout must occur.
    EXPECT_GT(handoffs.load(), 0);
  }
}

TEST_F(LifoSemTest, shutdown_try_wait_for) {
  long seed = folly::randomNumberSeed() % 1000000;
  LOG(INFO) << "seed=" << seed;
  DSched sched(DSched::uniform(seed));

  DLifoSem stopped;
  std::thread worker1 = DSched::thread([&stopped] {
    while (!stopped.isShutdown()) {
      // i.e. poll for messages with timeout
      LOG(INFO) << "thread polled";
    }
  });
  std::thread worker2 = DSched::thread([&stopped] {
    while (!stopped.isShutdown()) {
      // Do some work every 1 second

      try {
        // this is normally 1 second in prod use case.
        stopped.try_wait_for(std::chrono::milliseconds(1));
      } catch (folly::ShutdownSemError&) {
        LOG(INFO) << "try_wait_for shutdown";
      }
    }
  });

  std::thread shutdown = DSched::thread([&stopped] {
    LOG(INFO) << "LifoSem shutdown";
    stopped.shutdown();
    LOG(INFO) << "LifoSem shutdown done";
  });

  DSched::join(shutdown);
  DSched::join(worker1);
  DSched::join(worker2);
  LOG(INFO) << "Threads joined";
}

int main(int argc, char** argv) {
  testing::InitGoogleTest(&argc, argv);
  gflags::ParseCommandLineFlags(&argc, &argv, true);
  int rv = RUN_ALL_TESTS();
  return rv;
}