/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/synchronization/test/Semaphore.h>
#include <array>
#include <numeric>
#include <thread>
#include <vector>
#include <glog/logging.h>
#include <folly/Traits.h>
#include <folly/portability/GTest.h>
#include <folly/portability/SysMman.h>
#include <folly/synchronization/Latch.h>
#include <folly/synchronization/test/Barrier.h>
using namespace folly::test;
namespace {
template <SemaphoreWakePolicy WakePolicy>
auto wake_policy(PolicySemaphore<WakePolicy> const&) {
return WakePolicy;
}
template <typename Sem>
void test_basic() {
Sem sem;
EXPECT_FALSE(sem.try_wait());
sem.post();
EXPECT_TRUE(sem.try_wait());
sem.post();
sem.wait();
}
template <typename Sem>
void test_handoff_destruction() {
// regression to check for race:
// * poster thread calls Sem::post()
// * waiter thread calls Sem::wait() and then dtor+free
// strategy: mprotect the sem page after dtor; racing post() will segv
// alternate strategy: overwrite the former sem object, but non-portable
constexpr auto const nthreads = 1ull << 4;
constexpr auto const rounds = 1ull << 6;
std::array<size_t, nthreads> waits{};
for (auto r = 0ull; r < rounds; ++r) {
std::array<void*, nthreads> sems{};
std::array<std::thread, nthreads> threads;
folly::Latch ready(nthreads);
for (auto thi = 0ull; thi < nthreads; ++thi) {
sems[thi] = mmap(
nullptr,
sizeof(Sem),
PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS,
-1,
0);
PCHECK((void*)-1 != sems[thi]);
auto sem_ = new (sems[thi]) Sem(0);
threads[thi] = std::thread([&, thi, sem_] {
auto& sem = *reinterpret_cast<Sem*>(sem_);
sem.wait([&] { ready.count_down(); }, [&, thi] { ++waits[thi]; });
sem.~Sem();
mprotect(sem_, sizeof(Sem), PROT_NONE);
});
}
ready.wait();
for (auto sem_ : sems) {
auto& sem = *reinterpret_cast<Sem*>(sem_);
sem.post();
}
for (auto thi = 0ull; thi < nthreads; ++thi) {
threads[thi].join();
munmap(sems[thi], sizeof(Sem));
}
}
auto const allwaits = std::accumulate(waits.begin(), waits.end(), size_t(0));
EXPECT_EQ(nthreads * rounds, allwaits);
}
template <typename Sem>
void test_wake_policy() {
constexpr auto const nthreads = 16ull;
constexpr auto const rounds = 1ull << 4;
Sem sem;
std::array<std::thread, nthreads> threads;
for (auto i = 0ull; i < rounds; ++i) {
std::vector<uint64_t> wait_seq;
std::vector<uint64_t> wake_seq;
folly::Latch ready(nthreads); // first nthreads waits, then nthreads posts
for (auto thi = 0ull; thi < nthreads; ++thi) {
threads[thi] = std::thread([&, thi] {
sem.wait(
[&, thi] { wait_seq.push_back(thi), ready.count_down(); },
[&, thi] { wake_seq.push_back(thi); });
});
}
ready.wait(); // first nthreads waits, then nthreads posts
for (auto thi = 0ull; thi < nthreads; ++thi) {
sem.post();
}
for (auto thi = 0ull; thi < nthreads; ++thi) {
threads[thi].join();
}
EXPECT_EQ(nthreads, wait_seq.size());
EXPECT_EQ(nthreads, wake_seq.size());
switch (wake_policy(sem)) {
case SemaphoreWakePolicy::Fifo:
break;
case SemaphoreWakePolicy::Lifo:
std::reverse(wake_seq.begin(), wake_seq.end());
break;
}
EXPECT_EQ(wait_seq, wake_seq);
}
}
template <typename Sem>
void test_multi_ping_pong() {
constexpr auto const nthreads = 4ull;
constexpr auto const iters = 1ull << 12;
Sem sem;
std::array<std::thread, nthreads> threads;
size_t waits_before = 0;
size_t waits_after = 0;
size_t posts = 0;
for (auto& th : threads) {
th = std::thread([&] {
for (auto i = 0ull; i < iters; ++i) {
sem.wait([&] { ++waits_before; }, [&] { ++waits_after; });
sem.post([&] { ++posts; });
}
});
}
sem.post(); // start the flood
for (auto& thr : threads) {
thr.join();
}
sem.wait();
EXPECT_FALSE(sem.try_wait());
EXPECT_EQ(iters * nthreads, waits_before);
EXPECT_EQ(iters * nthreads, waits_after);
EXPECT_EQ(iters * nthreads, posts);
}
template <typename Sem>
void test_concurrent_split_waiters_posters() {
constexpr auto const nthreads = 4ull;
constexpr auto const iters = 1ull << 12;
Sem sem;
Barrier barrier(nthreads * 2);
std::array<std::thread, nthreads> posters;
std::array<std::thread, nthreads> waiters;
for (auto& th : posters) {
th = std::thread([&] {
barrier.wait();
for (auto i = 0ull; i < iters; ++i) {
if (i % (iters >> 4) == 0) {
std::this_thread::yield();
}
sem.post();
}
});
}
for (auto& th : waiters) {
th = std::thread([&] {
barrier.wait();
for (auto i = 0ull; i < iters; ++i) {
sem.wait();
}
});
}
for (auto& th : posters) {
th.join();
}
for (auto& th : waiters) {
th.join();
}
EXPECT_FALSE(sem.try_wait());
}
} // namespace
class SemaphoreTest : public testing::Test {};
TEST_F(SemaphoreTest, basic) {
test_basic<Semaphore>();
}
TEST_F(SemaphoreTest, multi_ping_pong) {
test_multi_ping_pong<Semaphore>();
}
TEST_F(SemaphoreTest, concurrent_split_waiters_posters) {
test_concurrent_split_waiters_posters<Semaphore>();
}
TEST_F(SemaphoreTest, handoff) {
test_handoff_destruction<Semaphore>();
}
class FifoSemaphoreTest : public testing::Test {};
TEST_F(FifoSemaphoreTest, basic) {
test_basic<FifoSemaphore>();
}
TEST_F(FifoSemaphoreTest, wake_policy) {
test_wake_policy<FifoSemaphore>();
}
TEST_F(FifoSemaphoreTest, multi_ping_pong) {
test_multi_ping_pong<FifoSemaphore>();
}
TEST_F(FifoSemaphoreTest, concurrent_split_waiters_posters) {
test_concurrent_split_waiters_posters<FifoSemaphore>();
}
TEST_F(FifoSemaphoreTest, handoff) {
test_handoff_destruction<FifoSemaphore>();
}
class LifoSemaphoreTest : public testing::Test {};
TEST_F(LifoSemaphoreTest, basic) {
test_basic<LifoSemaphore>();
}
TEST_F(LifoSemaphoreTest, wake_policy) {
test_wake_policy<LifoSemaphore>();
}
TEST_F(LifoSemaphoreTest, multi_ping_pong) {
test_multi_ping_pong<LifoSemaphore>();
}
TEST_F(LifoSemaphoreTest, concurrent_split_waiters_posters) {
test_concurrent_split_waiters_posters<LifoSemaphore>();
}
TEST_F(LifoSemaphoreTest, handoff) {
test_handoff_destruction<LifoSemaphore>();
}