chromium/third_party/mediapipe/src/mediapipe/calculators/tensor/inference_io_mapper.cc

// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/calculators/tensor/inference_io_mapper.h"

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_join.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensor_span.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/util/tflite/tflite_signature_reader.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model_builder.h"

namespace mediapipe {

namespace {

FlatBufferModel;
Interpreter;
InterpreterBuilder;
BuiltinOpResolverWithoutDefaultDelegates;

// Checks for duplicate indices in a TensorIndicesMap.
absl::StatusOr<std::vector<int>> GenerateAndValidateTensorList(
    const InferenceCalculatorOptions::InputOutputConfig::TensorIndicesMap&
        tensor_indices_list) {}

absl::StatusOr<absl::flat_hash_map<std::string, int>> CreateNameToIndexMap(
    const std::vector<std::string>& names) {}

template <typename T>
static bool ContainsDuplicates(const std::vector<T>& input) {}

static absl::StatusOr<std::vector<int>> MapTensorNamesToIndices(
    const std::vector<std::string>& signature_tensor_names,
    const InferenceCalculatorOptions::InputOutputConfig::TensorNamesMap&
        config_tensor_names) {
  std::vector<int> result;
  result.reserve(signature_tensor_names.size());
  MP_ASSIGN_OR_RETURN(const auto input_name_to_index_map,
                      CreateNameToIndexMap(signature_tensor_names));
  for (const auto& tensor_name : config_tensor_names.tensor_names()) {
    const auto it = input_name_to_index_map.find(tensor_name);
    RET_CHECK(it != input_name_to_index_map.end())
        << "Tensor name " << tensor_name
        << " not found in model signatures. Model tensor names: "
        << absl::StrJoin(signature_tensor_names, ", ");
    result.push_back(it->second);
  }
  RET_CHECK(!ContainsDuplicates(result))
      << "Duplicate tensor names found in TensorNamesMap: "
      << absl::StrJoin(config_tensor_names.tensor_names(), ", ");
  return result;
};

// Feedback tensors are excluded from the InferenceRunner input and output
// accordingly (since they are class-internally handled by the
// InferenceFeedbackManager). This means that the input and output Tensor orders
// of the InferenceRunner don't match the model I/O tensors anymore and
// therefore tensor I/O indices need to be adjusted accordingly.
absl::Status ExcludeFeedbackTensorsFromRemappingIndicesVector(
    const InferenceCalculatorOptions::InputOutputConfig& io_config,
    const std::vector<std::string>& model_tensor_names,
    std::vector<int>& remapping_tensor_indices) {}

}  // namespace

// static
absl::StatusOr<InputOutputTensorNames>
InferenceIoMapper::GetInputOutputTensorNamesFromInterpreter(
    const tflite::Interpreter& interpreter) {}

// static
absl::StatusOr<InputOutputTensorNames>
InferenceIoMapper::GetInputOutputTensorNamesFromModel(
    const tflite::FlatBufferModel& flatbuffer,
    const tflite::OpResolver& op_resolver) {}

absl::Status InferenceIoMapper::UpdateIoMap(
    const InferenceCalculatorOptions::InputOutputConfig& io_config,
    const InputOutputTensorNames& input_output_tensor_names) {}

absl::StatusOr<TensorSpan> InferenceIoMapper::RemapInputTensors(
    const TensorSpan& unmapped_tensors) {}

absl::StatusOr<std::vector<Tensor>> InferenceIoMapper::RemapOutputTensors(
    std::vector<Tensor>&& unmapped_tensors) {}
}  // namespace mediapipe