#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <utility>
#include <vector>
#include "Eigen/Core"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/core/c/builtin_op_data.h"
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/kernels/internal/runtime_shape.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/tensor_slice_util.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace stablehlo_scatter {
namespace {
constexpr int kInputsTensor = …;
constexpr int kScatterIndicesTensor = …;
constexpr int kUpdatesTensor = …;
constexpr int kOutputTensor = …;
enum class ComputationType { … };
struct OpData { … };
DimVector;
static DimVector GetUpdateScatterDims(int64_t updates_rank,
const int64_t* update_window_dims,
int num_update_window_dims) { … }
template <typename IndexType>
static Index<IndexType> GatherIndex(const Index<IndexType>& index,
const DimVector& dims) { … }
template <typename IndexType>
static bool IsInBounds(Index<IndexType> index, RuntimeShape shape) { … }
static ComputationType OpCodeToComputationType(int op_code) { … }
static TfLiteStatus GetComputationType(const Subgraph* computation_subgraph,
ComputationType* computation_type,
TfLiteContext* context) { … }
template <typename DataType, typename IndexType>
static TfLiteStatus ApplyComputation(TfLiteTensor* tensor,
Index<IndexType> index,
DataType input_value,
DataType update_value,
ComputationType computation_type,
TfLiteContext* context) { … }
template <typename IndexType, typename DataType>
TfLiteStatus EvalWithTypes(TfLiteContext* context, TfLiteNode* node) { … }
template <typename IndexType>
TfLiteStatus EvalWithIndexType(TfLiteContext* context, TfLiteNode* node,
TfLiteType index_type, TfLiteType data_type) { … }
void* Init(TfLiteContext* context, const char* buffer, size_t length) { … }
void Free(TfLiteContext* context, void* buffer) { … }
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { … }
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { … }
}
TfLiteRegistration* Register_STABLEHLO_SCATTER() { … }
}
}
}