folly/folly/coro/test/GeneratorTest.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <folly/Portability.h>

#include <algorithm>

#include <folly/ScopeGuard.h>
#include <folly/experimental/coro/Generator.h>
#include <folly/portability/GTest.h>

#if FOLLY_HAS_COROUTINES

namespace folly {
namespace coro {

class GeneratorTest : public testing::Test {};

TEST_F(GeneratorTest, DefaultConstructed_EmptySequence) {
  Generator<std::uint32_t> ints;
  EXPECT_EQ(ints.begin(), ints.end());
}

TEST_F(GeneratorTest, NonRecursiveUse) {
  auto f = []() -> Generator<float> {
    co_yield 1.0f;
    co_yield 2.0f;
  };

  auto gen = f();
  auto iter = gen.begin();
  EXPECT_EQ(*iter, 1.0f);
  ++iter;
  EXPECT_EQ(*iter, 2.0f);
  ++iter;
  EXPECT_EQ(iter, gen.end());
}

TEST_F(GeneratorTest, ThrowsBeforeYieldingFirstElement_RethrowsFromBegin) {
  class MyException : public std::exception {};

  auto f = []() -> Generator<std::uint32_t> {
    throw MyException{};
    co_return;
  };

  auto gen = f();
  EXPECT_THROW(gen.begin(), MyException);
}

TEST_F(GeneratorTest, ThrowsAfterYieldingFirstElement_RethrowsFromIncrement) {
  class MyException : public std::exception {};

  auto f = []() -> Generator<std::uint32_t> {
    co_yield 1;
    throw MyException{};
  };

  auto gen = f();
  auto iter = gen.begin();
  EXPECT_EQ(*iter, 1u);
  EXPECT_THROW(++iter, MyException);
}

TEST_F(GeneratorTest, NotStartedUntilCalled) {
  bool reachedA = false;
  bool reachedB = false;
  bool reachedC = false;
  auto f = [&]() -> Generator<std::uint32_t> {
    reachedA = true;
    co_yield 1;
    reachedB = true;
    co_yield 2;
    reachedC = true;
  };

  auto gen = f();
  EXPECT_FALSE(reachedA);
  auto iter = gen.begin();
  EXPECT_TRUE(reachedA);
  EXPECT_FALSE(reachedB);
  EXPECT_EQ(*iter, 1u);
  ++iter;
  EXPECT_TRUE(reachedB);
  EXPECT_FALSE(reachedC);
  EXPECT_EQ(*iter, 2u);
  ++iter;
  EXPECT_TRUE(reachedC);
  EXPECT_EQ(iter, gen.end());
}

TEST_F(GeneratorTest, DestroyedBeforeCompletion_DestructsObjectsOnStack) {
  bool destructed = false;
  bool completed = false;
  auto f = [&]() -> Generator<std::uint32_t> {
    SCOPE_EXIT {
      destructed = true;
    };

    co_yield 1;
    co_yield 2;
    completed = true;
  };

  {
    auto g = f();
    auto it = g.begin();
    auto itEnd = g.end();
    EXPECT_NE(it, itEnd);
    EXPECT_EQ(*it, 1u);
    EXPECT_FALSE(destructed);
  }

  EXPECT_FALSE(completed);
  EXPECT_TRUE(destructed);
}

TEST_F(GeneratorTest, SimpleRecursiveYield) {
  auto f = [](int n, auto& f_) -> Generator<const std::uint32_t> {
    co_yield n;
    if (n > 0) {
      co_yield f_(n - 1, f_);
      co_yield n;
    }
  };

  auto f2 = [&f](int n) { return f(n, f); };

  {
    auto gen = f2(1);
    auto iter = gen.begin();
    EXPECT_EQ(*iter, 1u);
    ++iter;
    EXPECT_EQ(*iter, 0u);
    ++iter;
    EXPECT_EQ(*iter, 1u);
    ++iter;
    EXPECT_EQ(iter, gen.end());
  }

  {
    auto gen = f2(2);
    auto iter = gen.begin();
    EXPECT_EQ(*iter, 2u);
    ++iter;
    EXPECT_EQ(*iter, 1u);
    ++iter;
    EXPECT_EQ(*iter, 0u);
    ++iter;
    EXPECT_EQ(*iter, 1u);
    ++iter;
    EXPECT_EQ(*iter, 2u);
    ++iter;
    EXPECT_EQ(iter, gen.end());
  }
}

TEST_F(GeneratorTest, NestedEmptyYield) {
  auto f = []() -> Generator<std::uint32_t> { co_return; };

  auto g = [&f]() -> Generator<std::uint32_t> {
    co_yield 1;
    co_yield f();
    co_yield 2;
  };

  auto gen = g();
  auto iter = gen.begin();
  EXPECT_EQ(*iter, 1u);
  ++iter;
  EXPECT_EQ(*iter, 2u);
  ++iter;
  EXPECT_EQ(iter, gen.end());
}

TEST_F(GeneratorTest, ExceptionThrownFromRecursiveCall_CanBeCaughtByCaller) {
  class SomeException : public std::exception {};
  bool caught = false;

  auto f = [&](std::uint32_t depth, auto&& f_) -> Generator<std::uint32_t> {
    if (depth == 1u) {
      throw SomeException{};
    }

    co_yield 1;

    try {
      co_yield f_(1, f_);
    } catch (const SomeException&) {
      caught = true;
    }

    co_yield 2;
  };

  auto gen = f(0, f);
  auto iter = gen.begin();
  EXPECT_EQ(*iter, 1u);
  EXPECT_FALSE(caught);
  ++iter;
  EXPECT_TRUE(caught);
  EXPECT_EQ(*iter, 2u);
  ++iter;
  EXPECT_EQ(iter, gen.end());
}

TEST_F(GeneratorTest, ExceptionThrownFromNestedCall_CanBeCaughtByCaller) {
  class SomeException : public std::exception {};

  auto f = [](std::uint32_t depth, auto&& f_) -> Generator<std::uint32_t> {
    if (depth == 4u) {
      throw SomeException{};
    } else if (depth == 3u) {
      co_yield 3;

      bool caught = false;
      try {
        co_yield f_(4, f_);
      } catch (const SomeException&) {
        caught = true;
      }

      co_yield caught ? 33 : 1337;

      throw SomeException{};
    } else if (depth == 2u) {
      bool caught = false;
      try {
        co_yield f_(3, f_);
      } catch (const SomeException&) {
        caught = true;
      }

      if (caught) {
        co_yield 2;
      }
    } else {
      co_yield 1;
      co_yield f_(2, f_);
      co_yield f_(3, f_);
    }
  };

  auto gen = f(1, f);
  auto iter = gen.begin();
  EXPECT_EQ(*iter, 1u);
  ++iter;
  EXPECT_EQ(*iter, 3u);
  ++iter;
  EXPECT_EQ(*iter, 33u);
  ++iter;
  EXPECT_EQ(*iter, 2u);
  ++iter;
  EXPECT_EQ(*iter, 3u);
  ++iter;
  EXPECT_EQ(*iter, 33u);
  EXPECT_THROW(++iter, SomeException);

  EXPECT_EQ(iter, gen.end());
}

namespace {
Generator<std::uint32_t> iterate_range(std::uint32_t begin, std::uint32_t end) {
  if ((end - begin) <= 10u) {
    for (std::uint32_t i = begin; i < end; ++i) {
      co_yield i;
    }
  } else {
    std::uint32_t mid = begin + (end - begin) / 2;
    co_yield iterate_range(begin, mid);
    co_yield iterate_range(mid, end);
  }
}
} // namespace

TEST_F(GeneratorTest, UsageInStandardAlgorithms) {
  {
    auto a = iterate_range(5, 30);
    auto b = iterate_range(5, 30);
    EXPECT_TRUE(std::equal(a.begin(), a.end(), b.begin(), b.end()));
  }

  {
    auto a = iterate_range(5, 30);
    auto b = iterate_range(5, 300);
    EXPECT_FALSE(std::equal(a.begin(), a.end(), b.begin(), b.end()));
  }
}

TEST_F(GeneratorTest, InvokeLambda) {
  auto ptr = std::make_unique<int>(123);
  auto gen = folly::coro::co_invoke(
      [p = std::move(
           ptr)]() mutable -> folly::coro::Generator<std::unique_ptr<int>&&> {
        co_yield std::move(p);
      });

  auto it = gen.begin();
  auto result = std::move(*it);
  EXPECT_NE(result, nullptr);
  EXPECT_EQ(*result, 123);
}
} // namespace coro
} // namespace folly

#endif