chromium/third_party/tflite/src/tensorflow/lite/kernels/batch_matmul.cc

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/lite/kernels/internal/reference/batch_matmul.h"

#include <stddef.h>

#include <algorithm>
#include <cstdint>
#include <limits>

#include "tensorflow/lite/core/c/builtin_op_data.h"
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/optimized/batch_matmul.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"

namespace tflite {
namespace ops {
namespace builtin {
namespace batch_matmul {

static const int kInputLHSTensor =;
static const int kInputRHSTensor =;
static const int kOutputTensor =;

static const int kNumTempTensorsForAdjoints =;
static const int kNumTempTensorsForHybrid =;

// This file has two implementations of Transpose.
enum KernelType {};

struct OpData {};

struct OpContext {};

void* Init(TfLiteContext* context, const char* buffer, size_t length) {}

void Free(TfLiteContext* context, void* buffer) {}

TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
                                const RuntimeShape& extended_lhs_shape,
                                const RuntimeShape& extended_rhs_shape,
                                bool adj_x, bool adj_y, int output_rank,
                                TfLiteTensor* output) {}

// Initializes temp tensors to store transposed operands.
TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
                                   OpContext* op_context) {}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {}

template <typename scalar>
void TransposeRowsColumnsImpl(const TfLiteTensor* tensor_in,
                              const scalar* input, TfLiteTensor* tensor_out,
                              scalar* output) {}

TfLiteStatus TransposeRowsColumns(TfLiteContext* context,
                                  const TfLiteTensor* tensor_in,
                                  TfLiteTensor* tensor_out) {}

RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) {}

template <KernelType kernel_type>
TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data,
                        const RuntimeShape& input_shape,
                        const TfLiteTensor* input,
                        const RuntimeShape& filter_shape,
                        const TfLiteTensor* filter,
                        TfLiteTensor* input_quantized,
                        TfLiteTensor* scaling_factors,
                        TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
                        TfLiteTensor* input_offsets, TfLiteTensor* output) {}

template <KernelType kernel_type>
TfLiteStatus EvalInt8Int8(TfLiteContext* context, const OpData* data,
                          const RuntimeShape& lhs_shape,
                          const TfLiteTensor* lhs,
                          const RuntimeShape& rhs_shape,
                          const TfLiteTensor* rhs,
                          const RuntimeShape& output_shape,
                          TfLiteTensor* output, bool transpose_lhs) {}

template <KernelType kernel_type>
TfLiteStatus EvalInt8Int32(TfLiteContext* context, const OpData* data,
                           const RuntimeShape& lhs_shape,
                           const TfLiteTensor* lhs,
                           const RuntimeShape& rhs_shape,
                           const TfLiteTensor* rhs,
                           const RuntimeShape& output_shape,
                           TfLiteTensor* output, bool transpose_lhs) {}

template <KernelType kernel_type>
TfLiteStatus EvalInt16(TfLiteContext* context, const OpData* data,
                       const RuntimeShape& lhs_shape, const TfLiteTensor* lhs,
                       const RuntimeShape& rhs_shape, const TfLiteTensor* rhs,
                       const RuntimeShape& output_shape, TfLiteTensor* output) {}

template <KernelType kernel_type>
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
                           OpData* data, const RuntimeShape& lhs_shape,
                           const TfLiteTensor* lhs,
                           const RuntimeShape& rhs_shape,
                           const TfLiteTensor* rhs, TfLiteTensor* output,
                           bool transpose_lhs) {}

TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node,
                         const TfLiteTensor* rhs) {}

TfLiteTensor* GetTempLhs(TfLiteContext* context, TfLiteNode* node,
                         const TfLiteTensor* lhs) {}

// Perform a batch matrix multiply on
// LHS <..., A, B>  X  RHS<..., B, C>
// where the leading dimensions of LHS and RHS obey broadcasting rules
// (this Op will apply broadcasting rules).
// We assume that LHS and RHS are both row oriented (adjacent values in memory
// are in the same row) and will output in the same memory layout. However,
// our fast GEMM libraries assume RCC layout (LHS row oriented,
// RHS column oriented, output column oriented). Therefore, we perform
// RHS <..., C, B> X LHS <..., B, A>
// where output is a C X A column-oriented, which is equivalent to
// A X C row-oriented.
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {}

}  // namespace batch_matmul

TfLiteRegistration* Register_BATCH_MATMUL_REF() {}

TfLiteRegistration* Register_BATCH_MATMUL_GENERIC_OPTIMIZED() {}

TfLiteRegistration* Register_BATCH_MATMUL() {}

}  // namespace builtin
}  // namespace ops
}  // namespace tflite