chromium/third_party/mediapipe/src/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc

// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h"

#include <cmath>
#include <memory>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/tensor/image_to_tensor_converter.h"
#include "mediapipe/calculators/tensor/image_to_tensor_utils.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/formats/image_opencv.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/statusor.h"

namespace mediapipe {

namespace {

class ImageToTensorOpenCvConverter : public ImageToTensorConverter {
 public:
  ImageToTensorOpenCvConverter(BorderMode border_mode,
                               Tensor::ElementType tensor_type,
                               cv::InterpolationFlags flags)
      : tensor_type_(tensor_type), flags_(flags) {
    switch (border_mode) {
      case BorderMode::kReplicate:
        border_mode_ = cv::BORDER_REPLICATE;
        break;
      case BorderMode::kZero:
        border_mode_ = cv::BORDER_CONSTANT;
        break;
    }
    switch (tensor_type_) {
      case Tensor::ElementType::kInt8:
        mat_type_ = CV_8SC3;
        mat_gray_type_ = CV_8SC1;
        break;
      case Tensor::ElementType::kFloat32:
        mat_type_ = CV_32FC3;
        mat_gray_type_ = CV_32FC1;
        break;
      case Tensor::ElementType::kUInt8:
        mat_type_ = CV_8UC3;
        mat_gray_type_ = CV_8UC1;
        break;
      default:
        mat_type_ = -1;
        mat_gray_type_ = -1;
    }
  }

  absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
                       float range_min, float range_max,
                       int tensor_buffer_offset,
                       Tensor& output_tensor) override {
    const bool is_supported_format =
        input.image_format() == mediapipe::ImageFormat::SRGB ||
        input.image_format() == mediapipe::ImageFormat::SRGBA ||
        input.image_format() == mediapipe::ImageFormat::GRAY8;
    if (!is_supported_format) {
      return absl::InvalidArgumentError(absl::StrCat(
          "Unsupported format: ", static_cast<uint32_t>(input.image_format())));
    }

    RET_CHECK_GE(tensor_buffer_offset, 0)
        << "The input tensor_buffer_offset needs to be non-negative.";
    const auto& output_shape = output_tensor.shape();
    MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));

    const int output_height = output_shape.dims[1];
    const int output_width = output_shape.dims[2];
    const int output_channels = output_shape.dims[3];
    const int num_elements_per_img =
        output_height * output_width * output_channels;
    auto buffer_view = output_tensor.GetCpuWriteView();
    cv::Mat dst;
    const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_;
    switch (tensor_type_) {
      case Tensor::ElementType::kInt8:
        RET_CHECK_GE(
            output_shape.num_elements(),
            tensor_buffer_offset / sizeof(int8_t) + num_elements_per_img)
            << "The buffer offset + the input image size is larger than the "
               "allocated tensor buffer.";
        dst = cv::Mat(output_height, output_width, dst_data_type,
                      buffer_view.buffer<int8_t>() +
                          tensor_buffer_offset / sizeof(int8_t));
        break;
      case Tensor::ElementType::kFloat32:
        RET_CHECK_GE(
            output_shape.num_elements(),
            tensor_buffer_offset / sizeof(float) + num_elements_per_img)
            << "The buffer offset + the input image size is larger than the "
               "allocated tensor buffer.";
        dst = cv::Mat(
            output_height, output_width, dst_data_type,
            buffer_view.buffer<float>() + tensor_buffer_offset / sizeof(float));
        break;
      case Tensor::ElementType::kUInt8:
        RET_CHECK_GE(
            output_shape.num_elements(),
            tensor_buffer_offset / sizeof(uint8_t) + num_elements_per_img)
            << "The buffer offset + the input image size is larger than the "
               "allocated tensor buffer.";
        dst = cv::Mat(output_height, output_width, dst_data_type,
                      buffer_view.buffer<uint8_t>() +
                          tensor_buffer_offset / sizeof(uint8_t));
        break;
      default:
        return absl::InvalidArgumentError(
            absl::StrCat("Unsupported tensor type: ", tensor_type_));
    }

    const cv::RotatedRect rotated_rect(cv::Point2f(roi.center_x, roi.center_y),
                                       cv::Size2f(roi.width, roi.height),
                                       roi.rotation * 180.f / M_PI);
    cv::Mat src_points;
    cv::boxPoints(rotated_rect, src_points);

    const float dst_width = output_width;
    const float dst_height = output_height;
    /* clang-format off */
    float dst_corners[8] = {0.0f,      dst_height,
                            0.0f,      0.0f,
                            dst_width, 0.0f,
                            dst_width, dst_height};
    /* clang-format on */

    auto src = mediapipe::formats::MatView(&input);
    cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners);
    cv::Mat projection_matrix =
        cv::getPerspectiveTransform(src_points, dst_points);
    cv::Mat transformed;
    cv::warpPerspective(*src, transformed, projection_matrix,
                        cv::Size(dst_width, dst_height),
                        /*flags=*/flags_,
                        /*borderMode=*/border_mode_);

    if (transformed.channels() > output_channels) {
      cv::Mat proper_channels_mat;
      cv::cvtColor(transformed, proper_channels_mat, cv::COLOR_RGBA2RGB);
      transformed = proper_channels_mat;
    }

    constexpr float kInputImageRangeMin = 0.0f;
    constexpr float kInputImageRangeMax = 255.0f;
    MP_ASSIGN_OR_RETURN(
        auto transform,
        GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax,
                                    range_min, range_max));
    transformed.convertTo(dst, dst_data_type, transform.scale,
                          transform.offset);
    return absl::OkStatus();
  }

 private:
  absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
    RET_CHECK_EQ(output_shape.dims.size(), 4)
        << "Wrong output dims size: " << output_shape.dims.size();
    RET_CHECK_GE(output_shape.dims[0], 1)
        << "The batch dimension needs to be equal or larger than 1.";
    RET_CHECK(output_shape.dims[3] == 3 || output_shape.dims[3] == 1)
        << "Wrong output channel: " << output_shape.dims[3];
    return absl::OkStatus();
  }

  enum cv::BorderTypes border_mode_;
  Tensor::ElementType tensor_type_;
  cv::InterpolationFlags flags_;
  int mat_type_;
  int mat_gray_type_;
};

}  // namespace

absl::StatusOr<std::unique_ptr<ImageToTensorConverter>> CreateOpenCvConverter(
    CalculatorContext* cc, BorderMode border_mode,
    Tensor::ElementType tensor_type, cv::InterpolationFlags flags) {
  if (tensor_type != Tensor::ElementType::kInt8 &&
      tensor_type != Tensor::ElementType::kFloat32 &&
      tensor_type != Tensor::ElementType::kUInt8) {
    return absl::InvalidArgumentError(
        absl::StrCat("Tensor type is currently not supported by "
                     "ImageToTensorOpenCvConverter, type: ",
                     tensor_type));
  }
  return std::make_unique<ImageToTensorOpenCvConverter>(border_mode,
                                                        tensor_type, flags);
}

}  // namespace mediapipe