chromium/third_party/mediapipe/src/mediapipe/calculators/util/detections_to_render_data_calculator.cc

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "absl/log/absl_check.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "mediapipe/calculators/util/detections_to_render_data_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_options.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {

namespace {

constexpr char kDetectionTag[] = "DETECTION";
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kDetectionListTag[] = "DETECTION_LIST";
constexpr char kRenderDataTag[] = "RENDER_DATA";

constexpr char kSceneLabelLabel[] = "LABEL";
constexpr char kSceneFeatureLabel[] = "FEATURE";
constexpr char kSceneLocationLabel[] = "LOCATION";
constexpr char kKeypointLabel[] = "KEYPOINT";

// The ratio of detection label font height to the height of detection bounding
// box.
constexpr double kLabelToBoundingBoxRatio = 0.1;
// Perserve 2 decimal digits.
constexpr float kNumScoreDecimalDigitsMultipler = 100;

}  // namespace

// A calculator that converts Detection proto to RenderData proto for
// visualization.
//
// Detection is the format for encoding one or more detections in an image.
// The input can be std::vector<Detection> or DetectionList.
//
// Please note that only Location Data formats of BOUNDING_BOX and
// RELATIVE_BOUNDING_BOX are supported. Normalized coordinates for
// RELATIVE_BOUNDING_BOX must be between 0.0 and 1.0. Any incremental normalized
// coordinates calculation in this calculator is capped at 1.0.
//
// The text(s) for "label(_id),score" will be shown on top left
// corner of the bounding box. The text for "feature_tag" will be shown on
// bottom left corner of the bounding box.
//
// Example config:
// node {
//   calculator: "DetectionsToRenderDataCalculator"
//   input_stream: "DETECTION:detection"
//   input_stream: "DETECTIONS:detections"
//   input_stream: "DETECTION_LIST:detection_list"
//   output_stream: "RENDER_DATA:render_data"
//   options {
//     [DetectionsToRenderDataCalculatorOptions.ext] {
//       produce_empty_packet : false
//     }
//   }
// }
class DetectionsToRenderDataCalculator : public CalculatorBase {
 public:
  DetectionsToRenderDataCalculator() {}
  ~DetectionsToRenderDataCalculator() override {}
  DetectionsToRenderDataCalculator(const DetectionsToRenderDataCalculator&) =
      delete;
  DetectionsToRenderDataCalculator& operator=(
      const DetectionsToRenderDataCalculator&) = delete;

  static absl::Status GetContract(CalculatorContract* cc);

  absl::Status Open(CalculatorContext* cc) override;

  absl::Status Process(CalculatorContext* cc) override;

 private:
  // These utility methods are supposed to be used only by this class. No
  // external client should depend on them. Due to C++ style guide unnamed
  // namespace should not be used in header files. So, these has been defined
  // as private static methods.
  static void SetRenderAnnotationColorThickness(
      const DetectionsToRenderDataCalculatorOptions& options,
      RenderAnnotation* render_annotation);

  static void SetTextCoordinate(bool normalized, double left, double baseline,
                                RenderAnnotation::Text* text);

  static void SetRectCoordinate(bool normalized, double xmin, double ymin,
                                double width, double height,
                                RenderAnnotation::Rectangle* rect);

  static void AddLabels(const Detection& detection,
                        const DetectionsToRenderDataCalculatorOptions& options,
                        float text_line_height, RenderData* render_data);
  static void AddFeatureTag(
      const Detection& detection,
      const DetectionsToRenderDataCalculatorOptions& options,
      float text_line_height, RenderData* render_data);
  static void AddLocationData(
      const Detection& detection,
      const DetectionsToRenderDataCalculatorOptions& options,
      RenderData* render_data);
  static void AddDetectionToRenderData(
      const Detection& detection,
      const DetectionsToRenderDataCalculatorOptions& options,
      RenderData* render_data);
};
REGISTER_CALCULATOR(DetectionsToRenderDataCalculator);

absl::Status DetectionsToRenderDataCalculator::GetContract(
    CalculatorContract* cc) {
  RET_CHECK(cc->Inputs().HasTag(kDetectionListTag) ||
            cc->Inputs().HasTag(kDetectionsTag) ||
            cc->Inputs().HasTag(kDetectionTag))
      << "None of the input streams are provided.";

  if (cc->Inputs().HasTag(kDetectionTag)) {
    cc->Inputs().Tag(kDetectionTag).Set<Detection>();
  }
  if (cc->Inputs().HasTag(kDetectionListTag)) {
    cc->Inputs().Tag(kDetectionListTag).Set<DetectionList>();
  }
  if (cc->Inputs().HasTag(kDetectionsTag)) {
    cc->Inputs().Tag(kDetectionsTag).Set<std::vector<Detection>>();
  }
  cc->Outputs().Tag(kRenderDataTag).Set<RenderData>();
  return absl::OkStatus();
}

absl::Status DetectionsToRenderDataCalculator::Open(CalculatorContext* cc) {
  cc->SetOffset(TimestampDiff(0));

  return absl::OkStatus();
}

absl::Status DetectionsToRenderDataCalculator::Process(CalculatorContext* cc) {
  const auto& options = cc->Options<DetectionsToRenderDataCalculatorOptions>();
  const bool has_detection_from_list =
      cc->Inputs().HasTag(kDetectionListTag) && !cc->Inputs()
                                                     .Tag(kDetectionListTag)
                                                     .Get<DetectionList>()
                                                     .detection()
                                                     .empty();
  const bool has_detection_from_vector =
      cc->Inputs().HasTag(kDetectionsTag) &&
      !cc->Inputs().Tag(kDetectionsTag).Get<std::vector<Detection>>().empty();
  const bool has_single_detection = cc->Inputs().HasTag(kDetectionTag) &&
                                    !cc->Inputs().Tag(kDetectionTag).IsEmpty();
  if (!options.produce_empty_packet() && !has_detection_from_list &&
      !has_detection_from_vector && !has_single_detection) {
    return absl::OkStatus();
  }

  // TODO: Add score threshold to
  // DetectionsToRenderDataCalculatorOptions.
  auto render_data = absl::make_unique<RenderData>();
  render_data->set_scene_class(options.scene_class());
  if (has_detection_from_list) {
    for (const auto& detection :
         cc->Inputs().Tag(kDetectionListTag).Get<DetectionList>().detection()) {
      AddDetectionToRenderData(detection, options, render_data.get());
    }
  }
  if (has_detection_from_vector) {
    for (const auto& detection :
         cc->Inputs().Tag(kDetectionsTag).Get<std::vector<Detection>>()) {
      AddDetectionToRenderData(detection, options, render_data.get());
    }
  }
  if (has_single_detection) {
    AddDetectionToRenderData(cc->Inputs().Tag(kDetectionTag).Get<Detection>(),
                             options, render_data.get());
  }
  cc->Outputs()
      .Tag(kRenderDataTag)
      .Add(render_data.release(), cc->InputTimestamp());
  return absl::OkStatus();
}

void DetectionsToRenderDataCalculator::SetRenderAnnotationColorThickness(
    const DetectionsToRenderDataCalculatorOptions& options,
    RenderAnnotation* render_annotation) {
  render_annotation->mutable_color()->set_r(options.color().r());
  render_annotation->mutable_color()->set_g(options.color().g());
  render_annotation->mutable_color()->set_b(options.color().b());
  render_annotation->set_thickness(options.thickness());
}

void DetectionsToRenderDataCalculator::SetTextCoordinate(
    bool normalized, double left, double baseline,
    RenderAnnotation::Text* text) {
  text->set_normalized(normalized);
  text->set_left(normalized ? std::max(left, 0.0) : left);
  // Normalized coordinates must be between 0.0 and 1.0, if they are used.
  text->set_baseline(normalized ? std::min(baseline, 1.0) : baseline);
}

void DetectionsToRenderDataCalculator::SetRectCoordinate(
    bool normalized, double xmin, double ymin, double width, double height,
    RenderAnnotation::Rectangle* rect) {
  if (xmin + width < 0.0 || ymin + height < 0.0) return;
  if (normalized) {
    if (xmin > 1.0 || ymin > 1.0) return;
  }
  rect->set_normalized(normalized);
  rect->set_left(normalized ? std::max(xmin, 0.0) : xmin);
  rect->set_top(normalized ? std::max(ymin, 0.0) : ymin);
  // No "xmin + width -1" because the coordinates can be relative, i.e. [0,1],
  // and we don't know what 1 pixel means in term of double [0,1].
  // For consistency decided to not decrease by 1 also when it is not relative.
  // However, when the coordinate is normalized it has to be between 0.0 and
  // 1.0.
  rect->set_right(normalized ? std::min(xmin + width, 1.0) : xmin + width);
  rect->set_bottom(normalized ? std::min(ymin + height, 1.0) : ymin + height);
}

void DetectionsToRenderDataCalculator::AddLabels(
    const Detection& detection,
    const DetectionsToRenderDataCalculatorOptions& options,
    float text_line_height, RenderData* render_data) {
  ABSL_CHECK(detection.label().empty() || detection.label_id().empty() ||
             detection.label_size() == detection.label_id_size())
      << "String or integer labels should be of same size. Or only one of them "
         "is present.";
  const auto num_labels =
      std::max(detection.label_size(), detection.label_id_size());
  ABSL_CHECK_EQ(detection.score_size(), num_labels)
      << "Number of scores and labels should match for detection.";

  // Extracts all "label(_id),score" for the detection.
  std::vector<std::string> label_and_scores = {};
  for (int i = 0; i < num_labels; ++i) {
    std::string label_str = detection.label().empty()
                                ? absl::StrCat(detection.label_id(i))
                                : detection.label(i);
    const float rounded_score =
        std::round(detection.score(i) * kNumScoreDecimalDigitsMultipler) /
        kNumScoreDecimalDigitsMultipler;
    std::string label_and_score =
        absl::StrCat(label_str, options.text_delimiter(), rounded_score,
                     options.text_delimiter());
    label_and_scores.push_back(label_and_score);
  }
  std::vector<std::string> labels;
  if (options.render_detection_id()) {
    const std::string detection_id_str =
        absl::StrCat("Id: ", detection.detection_id());
    labels.push_back(detection_id_str);
  }
  if (options.one_label_per_line()) {
    labels.insert(labels.end(), label_and_scores.begin(),
                  label_and_scores.end());
  } else {
    labels.push_back(absl::StrJoin(label_and_scores, ""));
  }
  // Add the render annotations for "label(_id),score".
  for (int i = 0; i < labels.size(); ++i) {
    auto label = labels.at(i);
    auto* label_annotation = render_data->add_render_annotations();
    label_annotation->set_scene_tag(kSceneLabelLabel);
    SetRenderAnnotationColorThickness(options, label_annotation);
    auto* text = label_annotation->mutable_text();
    *text = options.text();
    text->set_display_text(label);
    if (detection.location_data().format() == LocationData::BOUNDING_BOX) {
      SetTextCoordinate(false, detection.location_data().bounding_box().xmin(),
                        detection.location_data().bounding_box().ymin() +
                            (i + 1) * text_line_height,
                        text);
    } else {
      text->set_font_height(text_line_height * 0.9);
      SetTextCoordinate(
          true, detection.location_data().relative_bounding_box().xmin(),
          detection.location_data().relative_bounding_box().ymin() +
              (i + 1) * text_line_height,
          text);
    }
  }
}

void DetectionsToRenderDataCalculator::AddFeatureTag(
    const Detection& detection,
    const DetectionsToRenderDataCalculatorOptions& options,
    float text_line_height, RenderData* render_data) {
  auto* feature_tag_annotation = render_data->add_render_annotations();
  feature_tag_annotation->set_scene_tag(kSceneFeatureLabel);
  SetRenderAnnotationColorThickness(options, feature_tag_annotation);
  auto* feature_tag_text = feature_tag_annotation->mutable_text();
  feature_tag_text->set_display_text(detection.feature_tag());
  if (detection.location_data().format() == LocationData::BOUNDING_BOX) {
    SetTextCoordinate(false, detection.location_data().bounding_box().xmin(),
                      detection.location_data().bounding_box().ymin() +
                          detection.location_data().bounding_box().height(),
                      feature_tag_text);
  } else {
    feature_tag_text->set_font_height(text_line_height * 0.9);
    SetTextCoordinate(
        true, detection.location_data().relative_bounding_box().xmin(),
        detection.location_data().relative_bounding_box().ymin() +
            detection.location_data().relative_bounding_box().height(),
        feature_tag_text);
  }
}

void DetectionsToRenderDataCalculator::AddLocationData(
    const Detection& detection,
    const DetectionsToRenderDataCalculatorOptions& options,
    RenderData* render_data) {
  auto* location_data_annotation = render_data->add_render_annotations();
  location_data_annotation->set_scene_tag(kSceneLocationLabel);
  SetRenderAnnotationColorThickness(options, location_data_annotation);
  auto* location_data_rect = location_data_annotation->mutable_rectangle();
  if (detection.location_data().format() == LocationData::BOUNDING_BOX) {
    SetRectCoordinate(false, detection.location_data().bounding_box().xmin(),
                      detection.location_data().bounding_box().ymin(),
                      detection.location_data().bounding_box().width(),
                      detection.location_data().bounding_box().height(),
                      location_data_rect);
  } else {
    SetRectCoordinate(
        true, detection.location_data().relative_bounding_box().xmin(),
        detection.location_data().relative_bounding_box().ymin(),
        detection.location_data().relative_bounding_box().width(),
        detection.location_data().relative_bounding_box().height(),
        location_data_rect);
    // Keypoints are only supported in normalized/relative coordinates.
    if (detection.location_data().relative_keypoints_size()) {
      for (int i = 0; i < detection.location_data().relative_keypoints_size();
           ++i) {
        auto* keypoint_data_annotation = render_data->add_render_annotations();
        keypoint_data_annotation->set_scene_tag(kKeypointLabel);
        SetRenderAnnotationColorThickness(options, keypoint_data_annotation);
        auto* keypoint_data = keypoint_data_annotation->mutable_point();
        keypoint_data->set_normalized(true);
        // See location_data.proto for detail.
        keypoint_data->set_x(
            detection.location_data().relative_keypoints(i).x());
        keypoint_data->set_y(
            detection.location_data().relative_keypoints(i).y());
      }
    }
  }
}

void DetectionsToRenderDataCalculator::AddDetectionToRenderData(
    const Detection& detection,
    const DetectionsToRenderDataCalculatorOptions& options,
    RenderData* render_data) {
  ABSL_CHECK(detection.location_data().format() == LocationData::BOUNDING_BOX ||
             detection.location_data().format() ==
                 LocationData::RELATIVE_BOUNDING_BOX)
      << "Only Detection with formats of BOUNDING_BOX or RELATIVE_BOUNDING_BOX "
         "are supported.";
  double text_line_height;
  if (detection.location_data().format() == LocationData::BOUNDING_BOX) {
    text_line_height = options.text().font_height();
  } else {
    // Determine the text line height based on the default label to bounding box
    // ratio and the number of labels.
    text_line_height =
        detection.location_data().relative_bounding_box().height() *
        std::min(kLabelToBoundingBoxRatio,
                 1 / (double)(std::max(detection.label_size(),
                                       detection.label_id_size()) +
                              1 /* for feature_tag */));
  }
  AddLabels(detection, options, text_line_height, render_data);
  AddFeatureTag(detection, options, text_line_height, render_data);
  AddLocationData(detection, options, render_data);
}
}  // namespace mediapipe