/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/experimental/channels/detail/AtomicQueue.h>
#include <folly/portability/GTest.h>
#include <folly/synchronization/Baton.h>
namespace folly {
namespace channels {
namespace detail {
static int* getConsumerParam() {
return reinterpret_cast<int*>(1);
}
TEST(AtomicQueueTest, Basic) {
folly::Baton<> producerBaton;
folly::Baton<> consumerBaton;
struct Consumer {
void consume(int* consumerParam) {
EXPECT_EQ(consumerParam, getConsumerParam());
baton.post();
}
void canceled(int*) { ADD_FAILURE() << "canceled() shouldn't be called"; }
folly::Baton<> baton;
};
AtomicQueue<Consumer, int> atomicQueue;
Consumer consumer;
std::thread producerThread([&] {
producerBaton.wait();
producerBaton.reset();
atomicQueue.push(1, getConsumerParam());
producerBaton.wait();
producerBaton.reset();
atomicQueue.push(2, getConsumerParam());
atomicQueue.push(3, getConsumerParam());
consumerBaton.post();
});
EXPECT_TRUE(atomicQueue.wait(&consumer, getConsumerParam()));
producerBaton.post();
consumer.baton.wait();
consumer.baton.reset();
{
auto q = atomicQueue.getMessages(getConsumerParam());
EXPECT_FALSE(q.empty());
EXPECT_EQ(1, q.front());
q.pop();
EXPECT_TRUE(q.empty());
}
producerBaton.post();
consumerBaton.wait();
consumerBaton.reset();
EXPECT_FALSE(atomicQueue.wait(&consumer, getConsumerParam()));
{
auto q = atomicQueue.getMessages(getConsumerParam());
EXPECT_FALSE(q.empty());
EXPECT_EQ(2, q.front());
q.pop();
EXPECT_FALSE(q.empty());
EXPECT_EQ(3, q.front());
q.pop();
EXPECT_TRUE(q.empty());
}
EXPECT_TRUE(atomicQueue.wait(&consumer, getConsumerParam()));
EXPECT_EQ(atomicQueue.cancelCallback(), &consumer);
EXPECT_TRUE(atomicQueue.wait(&consumer, getConsumerParam()));
EXPECT_EQ(atomicQueue.cancelCallback(), &consumer);
EXPECT_EQ(atomicQueue.cancelCallback(), nullptr);
producerThread.join();
}
TEST(AtomicQueueTest, Canceled) {
struct Consumer {
void consume(int*) { ADD_FAILURE() << "consume() shouldn't be called"; }
void canceled(int* consumerParam) {
EXPECT_EQ(consumerParam, getConsumerParam());
canceledCalled = true;
}
bool canceledCalled{false};
};
AtomicQueue<Consumer, int> atomicQueue;
Consumer consumer;
EXPECT_TRUE(atomicQueue.wait(&consumer, getConsumerParam()));
atomicQueue.close(getConsumerParam());
EXPECT_TRUE(consumer.canceledCalled);
EXPECT_TRUE(atomicQueue.isClosed());
EXPECT_TRUE(atomicQueue.getMessages(getConsumerParam()).empty());
EXPECT_TRUE(atomicQueue.isClosed());
atomicQueue.push(42, getConsumerParam());
EXPECT_TRUE(atomicQueue.getMessages(getConsumerParam()).empty());
EXPECT_TRUE(atomicQueue.isClosed());
}
TEST(AtomicQueueTest, Stress) {
struct Consumer {
void consume(int* consumerParam) {
EXPECT_EQ(consumerParam, getConsumerParam());
baton.post();
}
void canceled(int*) { ADD_FAILURE() << "canceled() shouldn't be called"; }
folly::Baton<> baton;
};
AtomicQueue<Consumer, int> atomicQueue;
auto getNext = [&atomicQueue, queue = Queue<int>()]() mutable {
Consumer consumer;
if (queue.empty()) {
if (atomicQueue.wait(&consumer, getConsumerParam())) {
consumer.baton.wait();
}
queue = atomicQueue.getMessages(getConsumerParam());
EXPECT_FALSE(queue.empty());
}
auto next = queue.front();
queue.pop();
return next;
};
constexpr ssize_t kNumIters = 100000;
constexpr ssize_t kSynchronizeEvery = 1000;
std::atomic<ssize_t> producerIndex{0};
std::atomic<ssize_t> consumerIndex{0};
std::thread producerThread([&] {
for (producerIndex = 1; producerIndex <= kNumIters; ++producerIndex) {
atomicQueue.push(producerIndex, getConsumerParam());
if (producerIndex % kSynchronizeEvery == 0) {
while (producerIndex > consumerIndex.load(std::memory_order_relaxed)) {
std::this_thread::yield();
}
}
}
});
for (consumerIndex = 1; consumerIndex <= kNumIters; ++consumerIndex) {
EXPECT_EQ(consumerIndex, getNext());
if (consumerIndex % kSynchronizeEvery == 0) {
while (consumerIndex > producerIndex.load(std::memory_order_relaxed)) {
std::this_thread::yield();
}
}
}
producerThread.join();
}
} // namespace detail
} // namespace channels
} // namespace folly