chromium/third_party/mediapipe/src/mediapipe/framework/formats/location_opencv.cc

// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/framework/formats/location_opencv.h"

#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/memory/memory.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/formats/annotation/rasterization.pb.h"
#include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/statusor.h"

namespace mediapipe {

namespace {
Rectangle_i MaskToRectangle(const LocationData& location_data) {
  ABSL_CHECK(location_data.mask().has_rasterization());
  const auto& rasterization = location_data.mask().rasterization();
  if (rasterization.interval_size() == 0) {
    return Rectangle_i(0, 0, 0, 0);
  }
  int xmin = std::numeric_limits<int>::max();
  int xmax = std::numeric_limits<int>::lowest();
  int ymin = std::numeric_limits<int>::max();
  int ymax = std::numeric_limits<int>::lowest();
  for (const auto& interval : rasterization.interval()) {
    xmin = std::min(xmin, interval.left_x());
    xmax = std::max(xmax, interval.right_x());
    ymin = std::min(ymin, interval.y());
    ymax = std::max(ymax, interval.y());
  }
  return Rectangle_i(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1);
}

std::unique_ptr<cv::Mat> MaskToMat(const LocationData::BinaryMask& mask) {
  auto image = absl::make_unique<cv::Mat>();
  *image = cv::Mat::zeros(cv::Size(mask.width(), mask.height()), CV_32FC1);
  for (const auto& interval : mask.rasterization().interval()) {
    for (int x = interval.left_x(); x <= interval.right_x(); ++x) {
      image->at<float>(interval.y(), x) = 1.0f;
    }
  }
  return image;
}
absl::StatusOr<std::unique_ptr<cv::Mat>> RectangleToMat(
    int image_width, int image_height, const Rectangle_i& rect) {
  // These checks prevent undefined behavior caused when setting memory for
  // rectangles whose edges lie outside image edges.
  if (rect.ymin() < 0 || rect.xmin() < 0 || rect.xmax() > image_width ||
      rect.ymax() > image_height) {
    return absl::InvalidArgumentError(absl::Substitute(
        "Rectangle must be bounded by image boundaries.\nImage Width: "
        "$0\nImage Height: $1\nRectangle: [($2, $3), ($4, $5)]",
        image_width, image_height, rect.xmin(), rect.ymin(), rect.xmax(),
        rect.ymax()));
  }
  // Allocate image and set pixels of foreground mask.
  auto image = absl::make_unique<cv::Mat>();
  *image = cv::Mat::zeros(cv::Size(image_width, image_height), CV_32FC1);
  for (int y = rect.ymin(); y < rect.ymax(); ++y) {
    for (int x = rect.xmin(); x < rect.xmax(); ++x) {
      image->at<float>(y, x) = 1.0f;
    }
  }

  return std::move(image);
}
}  // namespace

Location CreateBBoxLocation(const cv::Rect& rect) {
  return Location::CreateBBoxLocation(rect.x, rect.y, rect.width, rect.height);
}

std::unique_ptr<cv::Mat> GetCvMask(const Location& location) {
  const auto location_data = location.ConvertToProto();
  ABSL_CHECK_EQ(LocationData::MASK, location_data.format());
  const auto& mask = location_data.mask();
  std::unique_ptr<cv::Mat> mat(
      new cv::Mat(mask.height(), mask.width(), CV_8UC1, cv::Scalar(0)));
  for (const auto& interval : location_data.mask().rasterization().interval()) {
    for (int x = interval.left_x(); x <= interval.right_x(); ++x) {
      mat->at<uint8_t>(interval.y(), x) = 255;
    }
  }
  return mat;
}

std::unique_ptr<cv::Mat> ConvertToCvMask(const Location& location,
                                         int image_width, int image_height) {
  const auto location_data = location.ConvertToProto();
  switch (location_data.format()) {
    case LocationData::GLOBAL:
    case LocationData::BOUNDING_BOX:
    case LocationData::RELATIVE_BOUNDING_BOX: {
      auto status_or_mat = RectangleToMat(
          image_width, image_height,
          location.ConvertToBBox<Rectangle_i>(image_width, image_height));
      if (!status_or_mat.ok()) {
        ABSL_LOG(ERROR) << status_or_mat.status().message();
        return nullptr;
      }
      return std::move(status_or_mat).value();
    }
    case LocationData::MASK: {
      return MaskToMat(location_data.mask());
    }
  }
// This should never happen; a new LocationData::Format enum was introduced
// without updating this function's switch(...) to support it.
#if !defined(MEDIAPIPE_MOBILE) && !defined(MEDIAPIPE_LITE)
  ABSL_LOG(ERROR) << "Location's LocationData has format not supported by "
                     "Location::ConvertToMask: "
                  << location_data.DebugString();
#endif
  return nullptr;
}

void EnlargeLocation(Location& location, const float factor) {
  ABSL_CHECK_GT(factor, 0.0f);
  if (factor == 1.0f) return;
  auto location_data = location.ConvertToProto();
  switch (location_data.format()) {
    case LocationData::GLOBAL: {
      // Do nothing.
      break;
    }
    case LocationData::BOUNDING_BOX: {
      auto* box = location_data.mutable_bounding_box();
      const int enlarged_int_width =
          static_cast<int>(std::round(factor * box->width()));
      const int enlarged_int_height =
          static_cast<int>(std::round(factor * box->height()));
      box->set_xmin(
          std::max(box->xmin() + box->width() / 2 - enlarged_int_width / 2, 0));
      box->set_ymin(std::max(
          box->ymin() + box->height() / 2 - enlarged_int_height / 2, 0));
      box->set_width(enlarged_int_width);
      box->set_height(enlarged_int_height);
      break;
    }
    case LocationData::RELATIVE_BOUNDING_BOX: {
      auto* box = location_data.mutable_relative_bounding_box();
      box->set_xmin(box->xmin() - ((factor - 1.0) * box->width()) / 2.0);
      box->set_ymin(box->ymin() - ((factor - 1.0) * box->height()) / 2.0);
      box->set_width(factor * box->width());
      box->set_height(factor * box->height());
      break;
    }
    case LocationData::MASK: {
      auto mask_bounding_box = MaskToRectangle(location_data);
      const float scaler = std::fabs(factor - 1.0f);
      const int dilation_width =
          static_cast<int>(std::round(scaler * mask_bounding_box.Width()));
      const int dilation_height =
          static_cast<int>(std::round(scaler * mask_bounding_box.Height()));
      if (dilation_width == 0 || dilation_height == 0) break;
      cv::Mat morph_element(dilation_height, dilation_width, CV_8U,
                            cv::Scalar(1));
      auto mask = GetCvMask(location);
      if (factor > 1.0f) {
        cv::dilate(*mask, *mask, morph_element);
      } else {
        cv::erode(*mask, *mask, morph_element);
      }
      CreateCvMaskLocation<uint8_t>(*mask).ConvertToProto(&location_data);
      break;
    }
  }
  location.SetFromProto(location_data);
}

template <typename T>
Location CreateCvMaskLocation(const cv::Mat_<T>& mask) {
  ABSL_CHECK_EQ(1, mask.channels())
      << "The specified cv::Mat mask should be single-channel.";

  LocationData location_data;
  location_data.set_format(LocationData::MASK);
  location_data.mutable_mask()->set_width(mask.cols);
  location_data.mutable_mask()->set_height(mask.rows);
  auto* rasterization = location_data.mutable_mask()->mutable_rasterization();
  const auto kForegroundThreshold = static_cast<T>(0);
  for (int y = 0; y < mask.rows; y++) {
    Rasterization::Interval* interval;
    bool traversing = false;
    for (int x = 0; x < mask.cols; x++) {
      const bool is_foreground =
          mask.template at<T>(y, x) > kForegroundThreshold;
      if (is_foreground) {
        if (!traversing) {
          interval = rasterization->add_interval();
          interval->set_y(y);
          interval->set_left_x(x);
          traversing = true;
        }
        interval->set_right_x(x);
      } else {
        traversing = false;
      }
    }
  }
  return Location(location_data);
}

template Location CreateCvMaskLocation(const cv::Mat_<uint8_t>& mask);
template Location CreateCvMaskLocation(const cv::Mat_<float>& mask);

}  // namespace mediapipe