#include "tensorflow/lite/kernels/internal/reference/batch_matmul.h"
#include <stddef.h>
#include <algorithm>
#include <cstdint>
#include <limits>
#include "tensorflow/lite/core/c/builtin_op_data.h"
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/optimized/batch_matmul.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace batch_matmul {
static const int kInputLHSTensor = …;
static const int kInputRHSTensor = …;
static const int kOutputTensor = …;
static const int kNumTempTensorsForAdjoints = …;
static const int kNumTempTensorsForHybrid = …;
enum KernelType { … };
struct OpData { … };
struct OpContext { … };
void* Init(TfLiteContext* context, const char* buffer, size_t length) { … }
void Free(TfLiteContext* context, void* buffer) { … }
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
const RuntimeShape& extended_lhs_shape,
const RuntimeShape& extended_rhs_shape,
bool adj_x, bool adj_y, int output_rank,
TfLiteTensor* output) { … }
TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
OpContext* op_context) { … }
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { … }
template <typename scalar>
void TransposeRowsColumnsImpl(const TfLiteTensor* tensor_in,
const scalar* input, TfLiteTensor* tensor_out,
scalar* output) { … }
TfLiteStatus TransposeRowsColumns(TfLiteContext* context,
const TfLiteTensor* tensor_in,
TfLiteTensor* tensor_out) { … }
RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) { … }
template <KernelType kernel_type>
TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data,
const RuntimeShape& input_shape,
const TfLiteTensor* input,
const RuntimeShape& filter_shape,
const TfLiteTensor* filter,
TfLiteTensor* input_quantized,
TfLiteTensor* scaling_factors,
TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
TfLiteTensor* input_offsets, TfLiteTensor* output) { … }
template <KernelType kernel_type>
TfLiteStatus EvalInt8Int8(TfLiteContext* context, const OpData* data,
const RuntimeShape& lhs_shape,
const TfLiteTensor* lhs,
const RuntimeShape& rhs_shape,
const TfLiteTensor* rhs,
const RuntimeShape& output_shape,
TfLiteTensor* output, bool transpose_lhs) { … }
template <KernelType kernel_type>
TfLiteStatus EvalInt8Int32(TfLiteContext* context, const OpData* data,
const RuntimeShape& lhs_shape,
const TfLiteTensor* lhs,
const RuntimeShape& rhs_shape,
const TfLiteTensor* rhs,
const RuntimeShape& output_shape,
TfLiteTensor* output, bool transpose_lhs) { … }
template <KernelType kernel_type>
TfLiteStatus EvalInt16(TfLiteContext* context, const OpData* data,
const RuntimeShape& lhs_shape, const TfLiteTensor* lhs,
const RuntimeShape& rhs_shape, const TfLiteTensor* rhs,
const RuntimeShape& output_shape, TfLiteTensor* output) { … }
template <KernelType kernel_type>
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
OpData* data, const RuntimeShape& lhs_shape,
const TfLiteTensor* lhs,
const RuntimeShape& rhs_shape,
const TfLiteTensor* rhs, TfLiteTensor* output,
bool transpose_lhs) { … }
TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* rhs) { … }
TfLiteTensor* GetTempLhs(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* lhs) { … }
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { … }
}
TfLiteRegistration* Register_BATCH_MATMUL_REF() { … }
TfLiteRegistration* Register_BATCH_MATMUL_GENERIC_OPTIMIZED() { … }
TfLiteRegistration* Register_BATCH_MATMUL() { … }
}
}
}