chromium/third_party/mediapipe/src/mediapipe/calculators/tensor/tensors_dequantization_calculator.cc

// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/memory_manager.h"
#include "mediapipe/framework/memory_manager_service.h"
#include "mediapipe/framework/port/ret_check.h"

namespace mediapipe {
namespace api2 {
namespace {

template <typename T>
void Dequantize(const Tensor& input, Tensor* output) {
  auto input_view = input.GetCpuReadView();
  auto input_buffer = input_view.buffer<T>();
  auto output_view = output->GetCpuWriteView();
  auto output_buffer = output_view.buffer<float>();
  for (int i = 0; i < input.shape().num_elements(); ++i) {
    output_buffer[i] = input.quantization_parameters().scale *
                       (static_cast<int>(input_buffer[i]) -
                        input.quantization_parameters().zero_point);
  }
}

}  // namespace

// Performs dequantization using the quantization parameters from the input
// UInt8 or Int8 tensors. Each element of the input tensors is converted using:
//
//   output = quantization_parameters.scale *
//     (input - quantization_parameters.zero_point)
//
// Input:
//  TENSORS - Vector of quantized Tensors of type kUint8 or kInt8.
// Output:
//  TENSORS - Vector of dequantized Tensors of type kFloat32.
//
// Usage example:
// node {
//   calculator: "TensorsDequantizationCalculator"
//   input_stream: "TENSORS:quantized_tensors"
//   output_stream: "TENSORS:dequantized_tensors"
// }
class TensorsDequantizationCalculator : public Node {
 public:
  static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
  static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"};
  MEDIAPIPE_NODE_CONTRACT(kInTensors, kOutTensors);

  absl::Status Open(CalculatorContext* cc) override;
  absl::Status Process(CalculatorContext* cc) override;

  static absl::Status UpdateContract(CalculatorContract* cc);

 private:
  // Enable pooling of AHWBs in Tensor instances.
  MemoryManager* memory_manager_ = nullptr;
};

absl::Status TensorsDequantizationCalculator::Open(CalculatorContext* cc) {
  if (cc->Service(kMemoryManagerService).IsAvailable()) {
    memory_manager_ = &cc->Service(kMemoryManagerService).GetObject();
  }
  return absl::OkStatus();
}

absl::Status TensorsDequantizationCalculator::Process(CalculatorContext* cc) {
  if (kInTensors(cc).IsEmpty()) {
    return absl::OkStatus();
  }
  const auto& input_tensors = *kInTensors(cc);
  RET_CHECK(!input_tensors.empty());
  auto output_tensors = std::make_unique<std::vector<Tensor>>();
  output_tensors->reserve(input_tensors.size());
  for (const auto& input_tensor : input_tensors) {
    output_tensors->emplace_back(Tensor::ElementType::kFloat32,
                                 input_tensor.shape(), memory_manager_);
    switch (input_tensor.element_type()) {
      case Tensor::ElementType::kUInt8:
        Dequantize<uint8_t>(input_tensor, &output_tensors->back());
        break;
      case Tensor::ElementType::kInt8:
        Dequantize<int8_t>(input_tensor, &output_tensors->back());
        break;
      case Tensor::ElementType::kBool:
        Dequantize<bool>(input_tensor, &output_tensors->back());
        break;
      default:
        return absl::InvalidArgumentError(absl::StrCat(
            "Unsupported input tensor type: ", input_tensor.element_type()));
    }
  }
  kOutTensors(cc).Send(std::move(output_tensors));
  return absl::OkStatus();
}

// static
absl::Status TensorsDequantizationCalculator::UpdateContract(
    CalculatorContract* cc) {
  cc->UseService(kMemoryManagerService).Optional();
  return absl::OkStatus();
}

MEDIAPIPE_REGISTER_NODE(TensorsDequantizationCalculator);

}  // namespace api2
}  // namespace mediapipe