#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
#include "ruy/matrix.h"
#include "ruy/mul_params.h"
#include "ruy/ruy.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
namespace tflite {
namespace cpu_backend_gemm {
namespace detail {
inline ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy) { … }
template <typename Scalar, typename DataPointer>
void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
ruy::Matrix<Scalar>* dst, bool use_caching = false) { … }
template <typename AccumScalar, typename DstScalar,
QuantizationFlavor quantization_flavor>
struct MakeRuyMulParamsImpl final { … };
MakeRuyMulParamsImpl<std::int32_t, DstScalar, quantization_flavor>;
MakeRuyMulParamsImpl<std::int32_t, std::int32_t, quantization_flavor>;
template <typename AccumScalar, typename DstScalar,
QuantizationFlavor quantization_flavor>
void MakeRuyMulParams(
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) { … }
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar, QuantizationFlavor quantization_flavor>
struct GemmImplUsingRuy { … };
}
}
}
#endif