#ifndef EIGEN_BFLOAT16_H
#define EIGEN_BFLOAT16_H
#include "../../InternalHeaderCheck.h"
#if defined(EIGEN_HAS_HIP_BF16)
#pragma push_macro("EIGEN_CONSTEXPR")
#undef EIGEN_CONSTEXPR
#define EIGEN_CONSTEXPR
#endif
#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) …
#if defined(EIGEN_HAS_HIP_BF16) && defined(EIGEN_GPU_COMPILE_PHASE)
#define EIGEN_USE_HIP_BF16
#endif
namespace Eigen {
struct bfloat16;
namespace numext {
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src);
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src);
}
namespace bfloat16_impl {
#if defined(EIGEN_USE_HIP_BF16)
struct __bfloat16_raw : public hip_bfloat16 {
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() {}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(hip_bfloat16 hb) : hip_bfloat16(hb) {}
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : hip_bfloat16(raw) {}
};
#else
struct __bfloat16_raw { … };
#endif
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
template <bool AssumeArgumentIsNormalOrInfinityOrZero>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff);
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff);
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
struct bfloat16_base : public __bfloat16_raw { … };
}
struct bfloat16 : public bfloat16_impl::bfloat16_base { … };
namespace bfloat16_impl {
template <typename = void>
struct numeric_limits_bfloat16_impl { … };
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_specialized;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_signed;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_integer;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_exact;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_infinity;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_quiet_NaN;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_signaling_NaN;
#if __cplusplus >= 202302L
EIGEN_DIAGNOSTICS(push)
EIGEN_DISABLE_DEPRECATED_WARNING
#endif
template <typename T>
EIGEN_CONSTEXPR const std::float_denorm_style numeric_limits_bfloat16_impl<T>::has_denorm;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::has_denorm_loss;
#if __cplusplus >= 202302L
EIGEN_DIAGNOSTICS(pop)
#endif
template <typename T>
EIGEN_CONSTEXPR const std::float_round_style numeric_limits_bfloat16_impl<T>::round_style;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_iec559;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_bounded;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::is_modulo;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::digits;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::digits10;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::max_digits10;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::radix;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::min_exponent;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::min_exponent10;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::max_exponent;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_bfloat16_impl<T>::max_exponent10;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::traps;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_bfloat16_impl<T>::tinyness_before;
}
}
namespace std {
template <>
class numeric_limits<Eigen::bfloat16> : public Eigen::bfloat16_impl::numeric_limits_bfloat16_impl<> { … };
template <>
class numeric_limits<const Eigen::bfloat16> : public numeric_limits<Eigen::bfloat16> { … };
template <>
class numeric_limits<volatile Eigen::bfloat16> : public numeric_limits<Eigen::bfloat16> { … };
template <>
class numeric_limits<const volatile Eigen::bfloat16> : public numeric_limits<Eigen::bfloat16> { … };
}
namespace Eigen {
namespace bfloat16_impl {
#if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC)
#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
#pragma push_macro("EIGEN_DEVICE_FUNC")
#undef EIGEN_DEVICE_FUNC
#if (defined(EIGEN_HAS_GPU_BF16) && defined(EIGEN_HAS_NATIVE_BF16))
#define EIGEN_DEVICE_FUNC …
#else
#define EIGEN_DEVICE_FUNC …
#endif
#endif
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(const bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(const bfloat16& a, const int& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(const int& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator*(const bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator-(const bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator/(const bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator-(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator+=(bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator*=(bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator-=(bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator/=(bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const bfloat16& a, const bfloat16& b) { … }
#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
#pragma pop_macro("EIGEN_DEVICE_FUNC")
#endif
#endif
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator/(const bfloat16& a, Index b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(
const __bfloat16_raw& bf) { … }
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff) { … }
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isinf)(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isnan)(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isfinite)(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan2(const bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 trunc(const bfloat16& a) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16(min)(const bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16(max)(const bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16& a, const bfloat16& b) { … }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfloat16& b) { … }
#ifndef EIGEN_NO_IO
EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, const bfloat16& v) { … }
#endif
}
namespace internal {
template <>
struct is_arithmetic<bfloat16> { … };
template <>
struct random_impl<bfloat16> { … };
}
template <>
struct NumTraits<Eigen::bfloat16> : GenericNumTraits<Eigen::bfloat16> { … };
}
#if defined(EIGEN_HAS_HIP_BF16)
#pragma pop_macro("EIGEN_CONSTEXPR")
#endif
namespace Eigen {
namespace numext {
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)(const Eigen::bfloat16& h) { … }
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)(const Eigen::bfloat16& h) { … }
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(const Eigen::bfloat16& h) { … }
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src) { … }
template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src) { … }
}
}
#if EIGEN_HAS_STD_HASH
namespace std {
template <>
struct hash<Eigen::bfloat16> { … };
}
#endif
#if defined(EIGEN_HIPCC)
#if defined(EIGEN_HAS_HIP_BF16)
__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl(Eigen::bfloat16 var, int srcLane, int width = warpSize) {
const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
return Eigen::numext::bit_cast<Eigen::bfloat16>(static_cast<Eigen::numext::uint16_t>(__shfl(ivar, srcLane, width)));
}
__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_up(Eigen::bfloat16 var, unsigned int delta,
int width = warpSize) {
const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
return Eigen::numext::bit_cast<Eigen::bfloat16>(static_cast<Eigen::numext::uint16_t>(__shfl_up(ivar, delta, width)));
}
__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_down(Eigen::bfloat16 var, unsigned int delta,
int width = warpSize) {
const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
return Eigen::numext::bit_cast<Eigen::bfloat16>(
static_cast<Eigen::numext::uint16_t>(__shfl_down(ivar, delta, width)));
}
__device__ EIGEN_STRONG_INLINE Eigen::bfloat16 __shfl_xor(Eigen::bfloat16 var, int laneMask, int width = warpSize) {
const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
return Eigen::numext::bit_cast<Eigen::bfloat16>(
static_cast<Eigen::numext::uint16_t>(__shfl_xor(ivar, laneMask, width)));
}
#endif
#endif
#if defined(EIGEN_HIPCC)
EIGEN_STRONG_INLINE __device__ Eigen::bfloat16 __ldg(const Eigen::bfloat16* ptr) {
return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(
__ldg(Eigen::numext::bit_cast<const Eigen::numext::uint16_t*>(ptr)));
}
#endif
#endif