#include "mediapipe/calculators/tensor/inference_io_mapper.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_join.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensor_span.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/util/tflite/tflite_signature_reader.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model_builder.h"
namespace mediapipe {
namespace {
FlatBufferModel;
Interpreter;
InterpreterBuilder;
BuiltinOpResolverWithoutDefaultDelegates;
absl::StatusOr<std::vector<int>> GenerateAndValidateTensorList(
const InferenceCalculatorOptions::InputOutputConfig::TensorIndicesMap&
tensor_indices_list) { … }
absl::StatusOr<absl::flat_hash_map<std::string, int>> CreateNameToIndexMap(
const std::vector<std::string>& names) { … }
template <typename T>
static bool ContainsDuplicates(const std::vector<T>& input) { … }
static absl::StatusOr<std::vector<int>> MapTensorNamesToIndices(
const std::vector<std::string>& signature_tensor_names,
const InferenceCalculatorOptions::InputOutputConfig::TensorNamesMap&
config_tensor_names) {
std::vector<int> result;
result.reserve(signature_tensor_names.size());
MP_ASSIGN_OR_RETURN(const auto input_name_to_index_map,
CreateNameToIndexMap(signature_tensor_names));
for (const auto& tensor_name : config_tensor_names.tensor_names()) {
const auto it = input_name_to_index_map.find(tensor_name);
RET_CHECK(it != input_name_to_index_map.end())
<< "Tensor name " << tensor_name
<< " not found in model signatures. Model tensor names: "
<< absl::StrJoin(signature_tensor_names, ", ");
result.push_back(it->second);
}
RET_CHECK(!ContainsDuplicates(result))
<< "Duplicate tensor names found in TensorNamesMap: "
<< absl::StrJoin(config_tensor_names.tensor_names(), ", ");
return result;
};
absl::Status ExcludeFeedbackTensorsFromRemappingIndicesVector(
const InferenceCalculatorOptions::InputOutputConfig& io_config,
const std::vector<std::string>& model_tensor_names,
std::vector<int>& remapping_tensor_indices) { … }
}
absl::StatusOr<InputOutputTensorNames>
InferenceIoMapper::GetInputOutputTensorNamesFromInterpreter(
const tflite::Interpreter& interpreter) { … }
absl::StatusOr<InputOutputTensorNames>
InferenceIoMapper::GetInputOutputTensorNamesFromModel(
const tflite::FlatBufferModel& flatbuffer,
const tflite::OpResolver& op_resolver) { … }
absl::Status InferenceIoMapper::UpdateIoMap(
const InferenceCalculatorOptions::InputOutputConfig& io_config,
const InputOutputTensorNames& input_output_tensor_names) { … }
absl::StatusOr<TensorSpan> InferenceIoMapper::RemapInputTensors(
const TensorSpan& unmapped_tensors) { … }
absl::StatusOr<std::vector<Tensor>> InferenceIoMapper::RemapOutputTensors(
std::vector<Tensor>&& unmapped_tensors) { … }
}