folly/folly/container/test/IntrusiveHeapTest.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <algorithm>
#include <memory>
#include <ostream>
#include <random>
#include <type_traits>
#include <vector>

#include <folly/Random.h>
#include <folly/container/IntrusiveHeap.h>
#include <folly/portability/GFlags.h>
#include <folly/portability/GTest.h>

DEFINE_int32(fuzz_count, 10000, "Number of operations to fuzz");

namespace folly {

namespace {

struct A;
struct B;

struct TaskBase {
  explicit TaskBase(char id) : id(id) {}

  bool operator<(const TaskBase& other) const { return pri < other.pri; }

  friend std::ostream& operator<<(std::ostream& os, const TaskBase& t) {
    return os << t.pri << ": " << t.id;
  }

  char id;
  int pri = 0;
};

struct Task : public TaskBase,
              public IntrusiveHeapNode<A>,
              public IntrusiveHeapNode<B> {
  using TaskBase::TaskBase;
};

struct TrackedTask : public Task {
  using Task::Task;

  bool& in(IntrusiveHeap<TrackedTask, std::less<>, A>&) { return inA; }
  bool& in(IntrusiveHeap<TrackedTask, std::less<>, B>&) { return inB; }

  bool inA = false;
  bool inB = false;
};

class TaskWithComposition : public TaskBase {
 public:
  using TaskBase::TaskBase;

 private:
  IntrusiveHeapNode<void> node_;

 public:
  using NodeTraits =
      MemberNodeTraits<TaskWithComposition, void, &TaskWithComposition::node_>;
};

} // namespace

class IntrusiveHeapTest {
 public:
  template <class Heap>
  static void print(const Heap& heap, std::ostream& os) {
    if (heap.root_ == nullptr) {
      os << "EMPTY";
      return;
    }
    print<Heap>(0, heap.root_, os << '\n');
  }

  template <class Heap>
  static void check(const Heap& heap) {
    check(heap, heap.root_, nullptr);
  }

 private:
  template <class Heap>
  static void check(
      const Heap& heap,
      const typename Heap::Node* node,
      const typename Heap::Node* parent) {
    if (node == nullptr)
      return;
    CHECK_EQ(node->parent_, parent) << "on node " << node << " in " << heap;
    if (parent != nullptr) {
      CHECK(!Heap::compare(parent, node))
          << "on node " << node << " in " << heap;
    }
    check(heap, node->left_, node);
    check(heap, node->right_, node);
  }

  template <class Heap>
  static void print(
      int indent, const typename Heap::Node* node, std::ostream& os) {
    if (node == nullptr)
      return;
    // Right first so it looks like a top-down tree, but with left := up.
    print<Heap>(indent + 2, node->right_, os);
    os << std::string(indent, ' ') << *Heap::asT(node) << " (" << node
       << "; parent: " << node->parent_ << "; left: " << node->left_
       << "; right: " << node->right_ << ")\n";
    print<Heap>(indent + 2, node->left_, os);
  }
};

template <class T, class Compare, class Tag, class Traits>
std::ostream& operator<<(
    std::ostream& os, const IntrusiveHeap<T, Compare, Tag, Traits>& heap) {
  IntrusiveHeapTest::print(heap, os);
  return os;
}

TEST(IntrusiveHeap, Static) {
  Task x('x');
  IntrusiveHeapNode<A>* na = &x;
  IntrusiveHeapNode<B>* nb = &x;
  EXPECT_EQ(static_cast<Task*>(na), &x);
  EXPECT_EQ(static_cast<Task*>(nb), &x);
  EXPECT_NE(static_cast<void*>(na), static_cast<void*>(nb));
  static_assert(!std::is_copy_assignable_v<Task>);
  static_assert(!std::is_move_assignable_v<Task>);
  static_assert(!std::is_copy_constructible_v<Task>);
  static_assert(!std::is_move_constructible_v<Task>);
}

template <class T, class Heap>
void testBasic() {
  T a('a'), b('b'), c('c'), d('d');
  Heap heap;
  const auto isLinked = [](T& x) {
    return Heap::NodeTraits::asNode(&x)->isLinked();
  };

  EXPECT_FALSE(isLinked(a));
  heap.push(&a);
  EXPECT_TRUE(isLinked(a));
  heap.push(&b);
  heap.push(&c);
  heap.push(&d);
  b.pri = 3;
  heap.update(&b);
  EXPECT_TRUE(isLinked(b));
  EXPECT_EQ(heap.top(), &b) << heap;
  d.pri = 4;
  heap.update(&d);
  EXPECT_EQ(heap.top(), &d) << heap;
  d.pri = 2;
  heap.update(&d);
  EXPECT_EQ(heap.pop(), &b) << heap;
  EXPECT_FALSE(isLinked(b));
  a.pri = 1;
  heap.update(&a);
  EXPECT_EQ(heap.pop(), &d) << heap;
  EXPECT_FALSE(isLinked(d));
  EXPECT_EQ(heap.pop(), &a) << heap;
  EXPECT_FALSE(isLinked(a));
  EXPECT_EQ(heap.pop(), &c) << heap;
  EXPECT_FALSE(isLinked(c));
  EXPECT_EQ(heap.top(), nullptr);
}

TEST(IntrusiveHeap, BasicDerived) {
  testBasic<Task, IntrusiveHeap<Task, std::less<>, A>>();
}

TEST(IntrusiveHeap, BasicComposition) {
  testBasic<
      TaskWithComposition,
      IntrusiveHeap<
          TaskWithComposition,
          std::less<>,
          void,
          TaskWithComposition::NodeTraits>>();
}

TEST(IntrusiveHeap, Fuzz) {
  std::default_random_engine rng(1729); // Deterministic seed.
  std::vector<std::unique_ptr<TrackedTask>> tasks;

  IntrusiveHeap<TrackedTask, std::less<>, A> aHeap;
  IntrusiveHeap<TrackedTask, std::less<>, B> bHeap;
  for (char id = 'a'; id <= 'z'; ++id) {
    tasks.push_back(std::make_unique<TrackedTask>(id));
  }
  for (auto i = FLAGS_fuzz_count; i-- > 0;) {
    auto item = tasks[folly::Random::rand32(0, tasks.size(), rng)].get();
    const auto fuzzHeap = [&](auto& heap) {
      using Heap = std::decay_t<decltype(heap)>;
      if (folly::Random::oneIn(2, rng)) {
        if (bool& contained = item->in(heap)) {
          VLOG(1) << "removing " << *item << " from " << heap;
          EXPECT_TRUE(Heap::NodeTraits::asNode(item)->isLinked());
          heap.erase(item);
          EXPECT_FALSE(Heap::NodeTraits::asNode(item)->isLinked());
          contained = false;
        } else {
          VLOG(1) << "pushing " << *item << " to " << heap;
          heap.push(item);
          contained = true;
        }
      }
      IntrusiveHeapTest::check(heap);
      VLOG(1) << heap;
      if (folly::Random::oneIn(50, rng)) {
        const auto compare = [](auto& a, auto& b) { return *a < *b; };
        std::sort(tasks.begin(), tasks.end(), compare);
        if (TrackedTask* top = heap.top()) {
          auto lb = std::lower_bound(tasks.begin(), tasks.end(), top, compare);
          ASSERT_TRUE(lb != tasks.end());
          for (auto it = lb; it != tasks.end(); ++it) {
            auto t = it->get();
            if (t->in(heap)) {
              ASSERT_EQ(top->pri, t->pri);
            }
          }
        }
      }
    };
    if (folly::Random::oneIn(2, rng)) {
      fuzzHeap(aHeap);
    } else {
      fuzzHeap(bHeap);
    }
    const auto dp = static_cast<int>(folly::Random::rand32(0, 7, rng)) - 3;
    VLOG(1) << "adjusting " << *item << " by " << dp;
    item->pri += dp;
    if (item->in(aHeap)) {
      aHeap.update(item);
      IntrusiveHeapTest::check(aHeap);
      VLOG(1) << aHeap;
    }
    if (item->in(bHeap)) {
      bHeap.update(item);
      IntrusiveHeapTest::check(bHeap);
      VLOG(1) << bHeap;
    }
  }
}

} // namespace folly