/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <algorithm> #include <cstdint> #include <memory> #include <vector> #include "Eigen/Core" // from @eigen_archive #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/internal/runtime_shape.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/tensor_slice_util.h" namespace tflite { namespace ops { namespace builtin { namespace stablehlo_gather { namespace { constexpr int kOperandTensor = …; constexpr int kStartIndicesTensor = …; constexpr int kOutputTensor = …; TfLiteIntArrayUniquePtr; // Clips the starting indices given the operand_shape and slice_sizes. This // means the starting index in a dimension will be shifted back if necessary so // that the whole slice can fit in the operand. // Example: // starting_index = [i, j], operand_shape = [oi, oj], slice_sizes = [si, sj] // starting_index will be transformed to [min(i, oi - si), min(j, oj - sj)] template <typename IndexType> TfLiteStatus ClipStartingIndex(const RuntimeShape& operand_shape, const int64_t* slice_sizes, int num_slice_sizes, Index<IndexType>& starting_index) { … } // Returns a vector containing slice_sizes with all the entries with indices // that are present in collapsed_slice_dims removed. // Example: slice_sizes = {3, 5, 2, 7}, collapsed_slice_dims = {1, 3} // Result: {3, 2} static std::vector<int64_t> GetCollapsedSliceShape( const int64_t* slice_sizes, int num_slice_sizes, const int64_t* collapsed_slice_dims, int num_collapsed_slice_dims) { … } // Creates the result shape based on the rank of the result, options and // shape of the result_indices operand. // Refer to the spec for a full explanation: // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather static TfLiteIntArrayUniquePtr GetResultShape( int64_t result_rank, const TfLiteStablehloGatherParams* data, const RuntimeShape& start_indices_shape) { … } // Extracts the batch and offset indices out of a given result index. // Result index is the index of an element in the output(result) tensor. // The location of the offset dims is given in the offset_dims argument and // the rest are batch dimensions. template <typename IndexType> TfLiteStatus SetBatchAndOffsetIndices(const Index<IndexType>& result_index, const int64_t* offset_dims, int num_offset_dims, Index<IndexType>& batch_index, Index<IndexType>& offset_index) { … } // Evaluates this node given the type of the elements in the start_indices // and the type of the elements in the operand tensor. template <typename IndexType, typename DataType> TfLiteStatus EvalWithTypes(TfLiteContext* context, TfLiteNode* node) { … } // Evaluates this node given the type of the elements in the scatter_indices // tensor. template <typename IndexType> TfLiteStatus EvalWithIndexType(TfLiteContext* context, TfLiteNode* node, TfLiteType index_type, TfLiteType data_type) { … } } // namespace // This is the kernel for stablehlo.gather which receives `slice_sizes` as a // static attribute. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { … } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { … } } // namespace stablehlo_gather TfLiteRegistration* Register_STABLEHLO_GATHER() { … } } // namespace builtin } // namespace ops } // namespace tflite