chromium/third_party/tflite/src/tensorflow/lite/kernels/stablehlo_scatter.cc

/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <utility>
#include <vector>

#include "Eigen/Core"  // from @eigen_archive
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/core/c/builtin_op_data.h"
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/kernels/internal/runtime_shape.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/tensor_slice_util.h"

namespace tflite {
namespace ops {
namespace builtin {
namespace stablehlo_scatter {
namespace {

constexpr int kInputsTensor =;
constexpr int kScatterIndicesTensor =;
constexpr int kUpdatesTensor =;
constexpr int kOutputTensor =;

// Indicates the type of the computation performed in the op region of the
// scatter kernel.
enum class ComputationType {};

// Contains the data that the operation sets in the Prepare phase and uses in
// the Eval phase.
struct OpData {};

// Contains a vector with each element being a dimension index
// example: [1, 4] means the second and fifth dimensions of another vector.
DimVector;

// Returns the update scatter dimension given the update window dimensions.
// Example:
// When updates_rank=5, update_window_dims=[2, 4]
// it returns [0, 1, 3]
static DimVector GetUpdateScatterDims(int64_t updates_rank,
                                      const int64_t* update_window_dims,
                                      int num_update_window_dims) {}

// Creates a new Index from a given one that contains only the asked dimensions.
// Example: If update_index is [i,j,k,l,m] and update_scatter_dims
// is [1, 3, 4], the result is [j, l, m]
template <typename IndexType>
static Index<IndexType> GatherIndex(const Index<IndexType>& index,
                                    const DimVector& dims) {}

// Checks if the given index is within the bounds of the provided shape.
template <typename IndexType>
static bool IsInBounds(Index<IndexType> index, RuntimeShape shape) {}

static ComputationType OpCodeToComputationType(int op_code) {}

// Inspects the scatter op region subgraph and extracts the right
// ComputationType from the nodes of the Subgraph.
static TfLiteStatus GetComputationType(const Subgraph* computation_subgraph,
                                       ComputationType* computation_type,
                                       TfLiteContext* context) {}

// Applies the provided computation to `input_value` and `update_value` and
// stores the result in `tensor[index]`.
template <typename DataType, typename IndexType>
static TfLiteStatus ApplyComputation(TfLiteTensor* tensor,
                                     Index<IndexType> index,
                                     DataType input_value,
                                     DataType update_value,
                                     ComputationType computation_type,
                                     TfLiteContext* context) {}

// Evaluates this node given the type of the elements in the scatter_indices
// and the type of the elements in the input/updates tensors.
template <typename IndexType, typename DataType>
TfLiteStatus EvalWithTypes(TfLiteContext* context, TfLiteNode* node) {}

// Evaluates this node given the type of the elements in the scatter_indices
// tensor.
template <typename IndexType>
TfLiteStatus EvalWithIndexType(TfLiteContext* context, TfLiteNode* node,
                               TfLiteType index_type, TfLiteType data_type) {}

void* Init(TfLiteContext* context, const char* buffer, size_t length) {}

void Free(TfLiteContext* context, void* buffer) {}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {}

}  // namespace

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {}

}  // namespace stablehlo_scatter

TfLiteRegistration* Register_STABLEHLO_SCATTER() {}

}  // namespace builtin
}  // namespace ops
}  // namespace tflite