folly/folly/detail/test/ThreadLocalBenchmark.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <thread>

#include <folly/Benchmark.h>
#include <folly/ThreadLocal.h>
#include <folly/synchronization/test/Barrier.h>

using namespace folly;

namespace folly {
namespace threadlocal_detail {

template <typename Tag>
struct ThreadLocalTestHelper {
  using Meta = StaticMeta<Tag, void>;
  using TLElem = ThreadLocal<int, Tag>;
  std::vector<TLElem> elements;
};

void measureAccessAllThreads(
    int iters, uint32_t totalThreads, uint32_t tlCount, uint32_t spreadPerTL) {
  struct Tag {};

  ThreadLocalTestHelper<Tag> helper;

  std::vector<std::thread> threads;
  std::vector<std::unique_ptr<test::Barrier>> threadBarriers;
  test::Barrier allThreadsBarriers{static_cast<size_t>(totalThreads + 1)};

  BENCHMARK_SUSPEND {
    helper.elements.reserve(tlCount);
    for (uint32_t i = 0; i < tlCount; ++i) {
      helper.elements.push_back({});
    }
    for (uint32_t i = 0; i < totalThreads; ++i) {
      threadBarriers.push_back(std::make_unique<test::Barrier>(2));
      threads.push_back(std::thread([&, index = i]() {
        for (uint32_t tlIdx = 0; tlIdx < tlCount; ++tlIdx) {
          uint32_t startIndex = tlIdx % totalThreads;
          uint32_t endIndex = (tlIdx + spreadPerTL) % totalThreads;
          if (startIndex < endIndex) {
            if (index >= startIndex && index <= endIndex) {
              *helper.elements[tlIdx] = tlIdx;
            }
          } else {
            if (!(index >= endIndex && index <= startIndex)) {
              *helper.elements[tlIdx] = tlIdx;
            }
          }
        }
        allThreadsBarriers.wait();
        threadBarriers[index]->wait();
      }));
    }

    // Wait for all threads to start.
    allThreadsBarriers.wait();
  }

  [[maybe_unused]] uint64_t countFound = 0;
  [[maybe_unused]] uint64_t sum = 0;
  for (int32_t k = 0; k < iters; k++) {
    for (uint32_t i = 0; i < tlCount; ++i) {
      for (auto& elem : helper.elements[i].accessAllThreads()) {
        sum += elem;
        countFound++;
      }
    }
  }

  BENCHMARK_SUSPEND {
    while (!threads.empty()) {
      threadBarriers.back()->wait();
      threads.back().join();
      threads.pop_back();
      threadBarriers.pop_back();
    }
  }
}

} // namespace threadlocal_detail
} // namespace folly

void accessAllThreads(
    int iters, uint32_t totalThreads, uint32_t tlCount, uint32_t spreadPerTL) {
  folly::threadlocal_detail::measureAccessAllThreads(
      iters, totalThreads, tlCount, spreadPerTL);
}

BENCHMARK_DRAW_LINE();
BENCHMARK_NAMED_PARAM(accessAllThreads, 100t_100c_1s, 100, 100, 1)
BENCHMARK_NAMED_PARAM(accessAllThreads, 100t_10000c_10s, 100, 10000, 10)
BENCHMARK_NAMED_PARAM(accessAllThreads, 1000t_1000c_25s, 1000, 1000, 25)
BENCHMARK_NAMED_PARAM(accessAllThreads, 2000t_20000c_50s, 2000, 20000, 50)
BENCHMARK_DRAW_LINE();

int main(int argc, char* argv[]) {
  gflags::ParseCommandLineFlags(&argc, &argv, true);
  folly::runBenchmarks();
  return 0;
}