chromium/third_party/mediapipe/src/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/interpreter.h"

namespace mediapipe {

constexpr char kFloatsTag[] = "FLOATS";
constexpr char kFloatTag[] = "FLOAT";
constexpr char kTensorsTag[] = "TENSORS";

// A calculator for converting TFLite tensors to to a float or a float vector.
//
// Input:
//  TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32. Only the first
//            tensor will be used.
// Output:
//  FLOAT(optional) - Converted single float number.
//  FLOATS(optional) - Converted float vector.
//
// Notes: To output FLOAT stream, the input TFLite tensor must have size 1, e.g.
//        only 1 float number in the tensor.
//
// Usage example:
// node {
//   calculator: "TfLiteTensorsToFloatsCalculator"
//   input_stream: "TENSORS:tensors"
//   output_stream: "FLOATS:floats"
// }
class TfLiteTensorsToFloatsCalculator : public CalculatorBase {
 public:
  static absl::Status GetContract(CalculatorContract* cc);

  absl::Status Open(CalculatorContext* cc) override;

  absl::Status Process(CalculatorContext* cc) override;
};
REGISTER_CALCULATOR(TfLiteTensorsToFloatsCalculator);

absl::Status TfLiteTensorsToFloatsCalculator::GetContract(
    CalculatorContract* cc) {
  RET_CHECK(cc->Inputs().HasTag(kTensorsTag));
  RET_CHECK(cc->Outputs().HasTag(kFloatsTag) ||
            cc->Outputs().HasTag(kFloatTag));

  cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
  if (cc->Outputs().HasTag(kFloatsTag)) {
    cc->Outputs().Tag(kFloatsTag).Set<std::vector<float>>();
  }
  if (cc->Outputs().HasTag(kFloatTag)) {
    cc->Outputs().Tag(kFloatTag).Set<float>();
  }

  return absl::OkStatus();
}

absl::Status TfLiteTensorsToFloatsCalculator::Open(CalculatorContext* cc) {
  cc->SetOffset(TimestampDiff(0));

  return absl::OkStatus();
}

absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) {
  RET_CHECK(!cc->Inputs().Tag(kTensorsTag).IsEmpty());

  const auto& input_tensors =
      cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensor>>();
  // TODO: Add option to specify which tensor to take from.
  const TfLiteTensor* raw_tensor = &input_tensors[0];
  const float* raw_floats = raw_tensor->data.f;
  int num_values = 1;
  for (int i = 0; i < raw_tensor->dims->size; ++i) {
    RET_CHECK_GT(raw_tensor->dims->data[i], 0);
    num_values *= raw_tensor->dims->data[i];
  }

  if (cc->Outputs().HasTag(kFloatTag)) {
    // TODO: Could add an index in the option to specifiy returning one
    // value of a float array.
    RET_CHECK_EQ(num_values, 1);
    cc->Outputs().Tag(kFloatTag).AddPacket(
        MakePacket<float>(raw_floats[0]).At(cc->InputTimestamp()));
  }
  if (cc->Outputs().HasTag(kFloatsTag)) {
    auto output_floats = absl::make_unique<std::vector<float>>(
        raw_floats, raw_floats + num_values);
    cc->Outputs()
        .Tag(kFloatsTag)
        .Add(output_floats.release(), cc->InputTimestamp());
  }

  return absl::OkStatus();
}
}  // namespace mediapipe