#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
#include <algorithm>
#include <cstdint>
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
namespace batch_matmul {
inline int broadcast_dim(int lhs_dim, int rhs_dim) { … }
inline int extent(const RuntimeShape& shape, int x) { … }
}
template <typename Ta, typename Tb, typename Tout>
inline void BatchMatMul(const RuntimeShape& lhs_shape, const Ta* lhs_data,
const RuntimeShape& rhs_shape, const Tb* rhs_data,
const RuntimeShape& output_shape, Tout* output_data) { … }
inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
const RuntimeShape& rhs_shape, const int8_t* rhs_data,
const float* scaling_factors,
const int32_t* input_offset, int32_t* row_sums,
const RuntimeShape& output_shape, float* output_data,
bool* compute_row_sums) { … }
template <typename T, typename AccumT>
inline void BatchMatMul(const FullyConnectedParams& params,
const RuntimeShape& lhs_shape, const T* lhs_data,
const RuntimeShape& rhs_shape, const T* rhs_data,
const RuntimeShape& output_shape, T* output_data) { … }
}
}
#endif