folly/folly/coro/GmockHelpers.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <atomic>
#include <type_traits>

#include <folly/experimental/coro/Coroutine.h>
#include <folly/experimental/coro/GtestHelpers.h>
#include <folly/experimental/coro/Result.h>
#include <folly/experimental/coro/Task.h>
#include <folly/portability/GMock.h>

#if FOLLY_HAS_COROUTINES

namespace folly {
namespace coro {
namespace gmock_helpers {

// This helper function is intended for use in GMock implementations where the
// implementation of the method is a coroutine lambda.
//
// The GMock framework internally always takes a copy of an action/lambda
// before invoking it to prevent cases where invoking the method might end
// up destroying itself.
//
// However, this is problematic for coroutine-lambdas-with-captures as the
// return-value from invoking a coroutine lambda will typically capture a
// reference to the copy of the lambda which will immediately become a dangling
// reference as soon as the mocking framework returns that value to the caller.
//
// Use this action-factory instead of Invoke() when passing coroutine-lambdas
// to mock definitions to ensure that a copy of the lambda is kept alive until
// the coroutine completes. It does this by invoking the lambda using the
// folly::coro::co_invoke() helper instead of directly invoking the lambda.
//
//
// Example:
//   using namespace ::testing
//   using namespace folly::coro::gmock_helpers;
//
//   MockFoo mock;
//   int fooCallCount = 0;
//
//   EXPECT_CALL(mock, foo(_))
//     .WillRepeatedly(CoInvoke(
//         [&](int x) -> folly::coro::Task<int> {
//           ++fooCallCount;
//           co_return x + 1;
//         }));
//
template <typename F>
auto CoInvoke(F&& f) {
  return ::testing::Invoke([f = static_cast<F&&>(f)](auto&&... a) {
    return co_invoke(f, static_cast<decltype(a)>(a)...);
  });
}

// Member function overload
template <class Class, typename MethodPtr>
auto CoInvoke(Class* obj_ptr, MethodPtr method_ptr) {
  return ::testing::Invoke([=](auto&&... a) {
    return co_invoke(method_ptr, obj_ptr, static_cast<decltype(a)>(a)...);
  });
}

// CoInvoke variant that does not pass arguments to callback function.
//
// Example:
//   using namespace ::testing
//   using namespace folly::coro::gmock_helpers;
//
//   MockFoo mock;
//   int fooCallCount = 0;
//
//   EXPECT_CALL(mock, foo(_))
//     .WillRepeatedly(CoInvokeWithoutArgs(
//         [&]() -> folly::coro::Task<int> {
//           ++fooCallCount;
//           co_return 42;
//         }));
template <typename F>
auto CoInvokeWithoutArgs(F&& f) {
  return ::testing::InvokeWithoutArgs(
      [f = static_cast<F&&>(f)]() { return co_invoke(f); });
}

// Member function overload
template <class Class, typename MethodPtr>
auto CoInvokeWithoutArgs(Class* obj_ptr, MethodPtr method_ptr) {
  return ::testing::InvokeWithoutArgs(
      [=]() { return co_invoke(method_ptr, obj_ptr); });
}

namespace detail {
template <typename Fn>
auto makeCoAction(Fn&& fn) {
  static_assert(
      std::is_copy_constructible_v<remove_cvref_t<Fn>>,
      "Fn should be copyable to allow calling mocked call multiple times.");

  using Ret = std::invoke_result_t<remove_cvref_t<Fn>&&>;
  return ::testing::InvokeWithoutArgs(
      [fn = std::forward<Fn>(fn)]() mutable -> Ret { return co_invoke(fn); });
}

// Helper class to capture a ByMove return value for mocked coroutine function.
// Adds a test failure if it is moved twice like:
//    .WillRepeatedly(CoReturnByMove...)
template <typename R>
struct OnceForwarder {
  static_assert(std::is_reference_v<R>);
  using V = remove_cvref_t<R>;

  explicit OnceForwarder(R r) noexcept(std::is_nothrow_constructible_v<V>)
      : val_(static_cast<R>(r)) {}

  R operator()() noexcept {
    auto performedPreviously =
        performed_.exchange(true, std::memory_order_relaxed);
    if (performedPreviously) {
      terminate_with<std::runtime_error>(
          "a CoReturnByMove action must be performed only once");
    }
    return static_cast<R>(val_);
  }

 private:
  V val_;
  std::atomic<bool> performed_ = false;
};

// Allow to return a value by providing a convertible value.
// This works similarly to Return(x):
// MOCK_METHOD1(Method, T(U));
// EXPECT_CALL(mock, Method(_)).WillOnce(Return(F()));
// should work as long as F is convertible to T.
template <typename T>
class CoReturnImpl {
 public:
  explicit CoReturnImpl(T&& value) : value_(std::move(value)) {}

  template <typename Result, typename ArgumentTuple>
  Result Perform(const ArgumentTuple& /* unused */) const {
    return [](T value) -> Result { co_return value; }(T(value_));
  }

 private:
  T value_;
};

template <typename T>
class CoReturnByMoveImpl {
 public:
  explicit CoReturnByMoveImpl(std::shared_ptr<OnceForwarder<T&&>> forwarder)
      : forwarder_(std::move(forwarder)) {}

  template <typename Result, typename ArgumentTuple>
  Result Perform(const ArgumentTuple& /* unused */) const {
    return [](std::shared_ptr<OnceForwarder<T&&>> forwarder) -> Result {
      co_return (*forwarder)();
    }(forwarder_);
  }

 private:
  std::shared_ptr<OnceForwarder<T&&>> forwarder_;
};

} // namespace detail

// Helper functions to adapt CoRoutines enabled functions to be mocked using
// gMock. CoReturn and CoThrows are gMock Action types that mirror the Return
// and Throws Action types used in EXPECT_CALL|ON_CALL invocations.
//
// Example:
//   using namespace ::testing
//   using namespace folly::coro::gmock_helpers;
//
//   MockFoo mock;
//   std::string result = "abc";
//
//   EXPECT_CALL(mock, co_foo(_))
//     .WillRepeatedly(CoReturn(result));
//
//   // For Task<void> return types.
//   EXPECT_CALL(mock, co_bar(_))
//     .WillRepeatedly(CoReturn());
//
//   // For returning by move.
//   EXPECT_CALL(mock, co_bar(_))
//     .WillRepeatedly(CoReturnByMove(std::move(result)));
//
//   // For returning by move.
//   EXPECT_CALL(mock, co_bar(_))
//     .WillRepeatedly(CoReturnByMove(std::make_unique(result)));
//
//
//  EXPECT_CALL(mock, co_foo(_))
//     .WillRepeatedly(CoThrow<std::string>(std::runtime_error("error")));
template <typename T>
auto CoReturn(T ret) {
  return ::testing::MakePolymorphicAction(
      detail::CoReturnImpl<T>(std::move(ret)));
}

inline auto CoReturn() {
  return ::testing::InvokeWithoutArgs([]() -> Task<> { co_return; });
}

template <typename T>
auto CoReturnByMove(T&& ret) {
  static_assert(
      !std::is_lvalue_reference_v<decltype(ret)>,
      "the argument must be passed as non-const rvalue-ref");
  static_assert(
      !std::is_const_v<T>,
      "the argument must be passed as non-const rvalue-ref");

  auto ptr = std::make_shared<detail::OnceForwarder<T&&>>(std::move(ret));

  return ::testing::MakePolymorphicAction(
      detail::CoReturnByMoveImpl<T>(std::move(ptr)));
}

template <typename T, typename Ex>
auto CoThrow(Ex&& e) {
  return detail::makeCoAction(
      [ex = std::forward<Ex>(e)]() -> Task<T> { co_yield co_error(ex); });
}

} // namespace gmock_helpers
} // namespace coro
} // namespace folly

#define CO_ASSERT_THAT(value, matcher) \
  CO_ASSERT_PRED_FORMAT1(              \
      ::testing::internal::MakePredicateFormatterFromMatcher(matcher), value)

#endif // FOLLY_HAS_COROUTINES