llvm/flang/include/flang/Common/reference-wrapper.h

//===-- include/flang/Common/reference-wrapper.h ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// clang-format off
//
// Implementation of std::reference_wrapper borrowed from libcu++
// https://github.com/NVIDIA/libcudacxx/blob/f7e6cd07ed5ba826aeac0b742feafddfedc1e400/include/cuda/std/detail/libcxx/include/__functional/reference_wrapper.h#L1
// with modifications.
//
// The original source code is distributed under the Apache License v2.0
// with LLVM Exceptions.
//
// TODO: using libcu++ is the best option for CUDA, but there is a couple
// of issues:
//   * The include paths need to be set up such that all STD header files
//     are taken from libcu++.
//   * cuda:: namespace need to be forced for all std:: references.
//
// clang-format on

#ifndef FORTRAN_COMMON_REFERENCE_WRAPPER_H
#define FORTRAN_COMMON_REFERENCE_WRAPPER_H

#include "flang/Common/api-attrs.h"
#include <functional>
#include <type_traits>

#if !defined(STD_REFERENCE_WRAPPER_UNSUPPORTED) && \
    (defined(__CUDACC__) || defined(__CUDA__)) && defined(__CUDA_ARCH__)
#define STD_REFERENCE_WRAPPER_UNSUPPORTED 1
#endif

namespace Fortran::common {

template <class _Tp>
using __remove_cvref_t = std::remove_cv_t<std::remove_reference_t<_Tp>>;
template <class _Tp, class _Up>
struct __is_same_uncvref
    : std::is_same<__remove_cvref_t<_Tp>, __remove_cvref_t<_Up>> {};

#if STD_REFERENCE_WRAPPER_UNSUPPORTED
template <class _Tp> class reference_wrapper {
public:
  // types
  typedef _Tp type;

private:
  type *__f_;

  static RT_API_ATTRS void __fun(_Tp &);
  static void __fun(_Tp &&) = delete;

public:
  template <class _Up,
      class =
          std::enable_if_t<!__is_same_uncvref<_Up, reference_wrapper>::value,
              decltype(__fun(std::declval<_Up>()))>>
  constexpr RT_API_ATTRS reference_wrapper(_Up &&__u) {
    type &__f = static_cast<_Up &&>(__u);
    __f_ = std::addressof(__f);
  }

  // access
  constexpr RT_API_ATTRS operator type &() const { return *__f_; }
  constexpr RT_API_ATTRS type &get() const { return *__f_; }

  // invoke
  template <class... _ArgTypes>
  constexpr RT_API_ATTRS typename std::invoke_result_t<type &, _ArgTypes...>
  operator()(_ArgTypes &&...__args) const {
    return std::invoke(get(), std::forward<_ArgTypes>(__args)...);
  }
};

template <class _Tp> reference_wrapper(_Tp &) -> reference_wrapper<_Tp>;

template <class _Tp>
inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(_Tp &__t) {
  return reference_wrapper<_Tp>(__t);
}

template <class _Tp>
inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(
    reference_wrapper<_Tp> __t) {
  return __t;
}

template <class _Tp>
inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
    const _Tp &__t) {
  return reference_wrapper<const _Tp>(__t);
}

template <class _Tp>
inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
    reference_wrapper<_Tp> __t) {
  return __t;
}

template <class _Tp> void ref(const _Tp &&) = delete;
template <class _Tp> void cref(const _Tp &&) = delete;
#else // !STD_REFERENCE_WRAPPER_UNSUPPORTED
using std::cref;
using std::ref;
using std::reference_wrapper;
#endif // !STD_REFERENCE_WRAPPER_UNSUPPORTED

} // namespace Fortran::common

#endif // FORTRAN_COMMON_REFERENCE_WRAPPER_H