#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace optimized_ops {
inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data,
const RuntimeShape& rhs_shape, const float* rhs_data,
const RuntimeShape& output_shape, float* output_data,
CpuBackendContext* context,
bool transpose_lhs = false) { … }
inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
const RuntimeShape& rhs_shape, const int8_t* rhs_data,
const float* scaling_factors,
const int32_t* input_offset, int32_t* row_sums,
const RuntimeShape& output_shape,
int32_t* accum_scratch, float* output_data,
bool* compute_row_sums, CpuBackendContext* context) { … }
inline void BatchMatMul(const FullyConnectedParams& params,
const RuntimeShape& lhs_shape, const int8_t* lhs_data,
const RuntimeShape& rhs_shape, const int8_t* rhs_data,
const RuntimeShape& output_shape, int8_t* output_data,
CpuBackendContext* context,
bool transpose_lhs = false) { … }
inline void BatchMatMul(const FullyConnectedParams& params,
const RuntimeShape& lhs_shape, const int8_t* lhs_data,
const RuntimeShape& rhs_shape, const int8_t* rhs_data,
const RuntimeShape& output_shape, int32_t* output_data,
CpuBackendContext* context,
bool transpose_lhs = false) { … }
}
}
#endif