chromium/third_party/mediapipe/src/mediapipe/calculators/util/non_max_suppression_calculator.cc

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/port/rectangle.h"
#include "mediapipe/framework/port/status.h"

namespace mediapipe {

typedef std::vector<Detection> Detections;
typedef std::vector<std::pair<int, float>> IndexedScores;

namespace {

constexpr char kImageTag[] = "IMAGE";

bool SortBySecond(const std::pair<int, float>& indexed_score_0,
                  const std::pair<int, float>& indexed_score_1) {
  return (indexed_score_0.second > indexed_score_1.second);
}

// Removes all but the max scoring label and its score from the detection.
// Returns true if the detection has at least one label.
bool RetainMaxScoringLabelOnly(Detection* detection) {
  if (detection->label_id_size() == 0 && detection->label_size() == 0) {
    return false;
  }
  ABSL_CHECK(detection->label_id_size() == detection->score_size() ||
             detection->label_size() == detection->score_size())
      << "Number of scores must be equal to number of detections.";

  std::vector<std::pair<int, float>> indexed_scores;
  indexed_scores.reserve(detection->score_size());
  for (int k = 0; k < detection->score_size(); ++k) {
    indexed_scores.push_back(std::make_pair(k, detection->score(k)));
  }
  std::sort(indexed_scores.begin(), indexed_scores.end(), SortBySecond);
  const int top_index = indexed_scores[0].first;
  detection->clear_score();
  detection->add_score(indexed_scores[0].second);
  if (detection->label_id_size() > top_index) {
    const int top_label_id = detection->label_id(top_index);
    detection->clear_label_id();
    detection->add_label_id(top_label_id);
  } else {
    const std::string top_label = detection->label(top_index);
    detection->clear_label();
    detection->add_label(top_label);
  }

  return true;
}

// Computes an overlap similarity between two rectangles. Similarity measure is
// defined by overlap_type parameter.
float OverlapSimilarity(
    const NonMaxSuppressionCalculatorOptions::OverlapType overlap_type,
    const Rectangle_f& rect1, const Rectangle_f& rect2) {
  if (!rect1.Intersects(rect2)) return 0.0f;
  const float intersection_area = Rectangle_f(rect1).Intersect(rect2).Area();
  float normalization;
  switch (overlap_type) {
    case NonMaxSuppressionCalculatorOptions::JACCARD:
      normalization = Rectangle_f(rect1).Union(rect2).Area();
      break;
    case NonMaxSuppressionCalculatorOptions::MODIFIED_JACCARD:
      normalization = rect2.Area();
      break;
    case NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION:
      normalization = rect1.Area() + rect2.Area() - intersection_area;
      break;
    default:
      ABSL_LOG(FATAL) << "Unrecognized overlap type: " << overlap_type;
  }
  return normalization > 0.0f ? intersection_area / normalization : 0.0f;
}

// Computes an overlap similarity between two locations by first extracting the
// relative box (dimension normalized by frame width/height) from the location.
float OverlapSimilarity(
    const int frame_width, const int frame_height,
    const NonMaxSuppressionCalculatorOptions::OverlapType overlap_type,
    const Location& location1, const Location& location2) {
  const auto rect1 = location1.ConvertToRelativeBBox(frame_width, frame_height);
  const auto rect2 = location2.ConvertToRelativeBBox(frame_width, frame_height);
  return OverlapSimilarity(overlap_type, rect1, rect2);
}

// Computes an overlap similarity between two locations by first extracting the
// relative box from the location. It assumes that a relative-box representation
// is already available in the location, and therefore frame width and height
// are not needed for further normalization.
float OverlapSimilarity(
    const NonMaxSuppressionCalculatorOptions::OverlapType overlap_type,
    const Location& location1, const Location& location2) {
  const auto rect1 = location1.GetRelativeBBox();
  const auto rect2 = location2.GetRelativeBBox();
  return OverlapSimilarity(overlap_type, rect1, rect2);
}

}  // namespace

// A calculator performing non-maximum suppression on a set of detections.
// Inputs:
//   1. IMAGE (optional): A stream of ImageFrame used to obtain the frame size.
//      No image data is used. Not needed if the detection bounding boxes are
//      already represented in normalized dimensions (0.0~1.0).
//   2. A variable number of input streams of type std::vector<Detection>. The
//      exact number of such streams should be set via num_detection_streams
//      field in the calculator options.
//
// Outputs: a single stream of type std::vector<Detection> containing a subset
//   of the input detections after non-maximum suppression.
//
// Example config:
// node {
//   calculator: "NonMaxSuppressionCalculator"
//   input_stream: "IMAGE:frames"
//   input_stream: "detections1"
//   input_stream: "detections2"
//   output_stream: "detections"
//   options {
//     [mediapipe.NonMaxSuppressionCalculatorOptions.ext] {
//       num_detection_streams: 2
//       max_num_detections: 10
//       min_suppression_threshold: 0.2
//       overlap_type: JACCARD
//     }
//   }
// }
class NonMaxSuppressionCalculator : public CalculatorBase {
 public:
  NonMaxSuppressionCalculator() = default;
  ~NonMaxSuppressionCalculator() override = default;

  static absl::Status GetContract(CalculatorContract* cc) {
    const auto& options = cc->Options<NonMaxSuppressionCalculatorOptions>();
    if (cc->Inputs().HasTag(kImageTag)) {
      cc->Inputs().Tag(kImageTag).Set<ImageFrame>();
    }
    for (int k = 0; k < options.num_detection_streams(); ++k) {
      cc->Inputs().Index(k).Set<Detections>();
    }
    cc->Outputs().Index(0).Set<Detections>();
    return absl::OkStatus();
  }

  absl::Status Open(CalculatorContext* cc) override {
    cc->SetOffset(TimestampDiff(0));

    options_ = cc->Options<NonMaxSuppressionCalculatorOptions>();
    ABSL_CHECK_GT(options_.num_detection_streams(), 0)
        << "At least one detection stream need to be specified.";
    ABSL_CHECK_NE(options_.max_num_detections(), 0)
        << "max_num_detections=0 is not a valid value. Please choose a "
        << "positive number of you want to limit the number of output "
        << "detections, or set -1 if you do not want any limit.";
    return absl::OkStatus();
  }

  absl::Status Process(CalculatorContext* cc) override {
    // Add all input detections to the same vector.
    Detections input_detections;
    for (int i = 0; i < options_.num_detection_streams(); ++i) {
      const auto& detections_packet = cc->Inputs().Index(i).Value();
      // Check whether this stream has a packet for this timestamp.
      if (detections_packet.IsEmpty()) {
        continue;
      }
      const auto& detections = detections_packet.Get<Detections>();

      input_detections.insert(input_detections.end(), detections.begin(),
                              detections.end());
    }

    // Check if there are any detections at all.
    if (input_detections.empty()) {
      if (options_.return_empty_detections()) {
        cc->Outputs().Index(0).Add(new Detections(), cc->InputTimestamp());
      }
      return absl::OkStatus();
    }

    // Remove all but the maximum scoring label from each input detection. This
    // corresponds to non-maximum suppression among detections which have
    // identical locations.
    Detections pruned_detections;
    pruned_detections.reserve(input_detections.size());
    for (auto& detection : input_detections) {
      if (RetainMaxScoringLabelOnly(&detection)) {
        pruned_detections.push_back(detection);
      }
    }

    // Copy all the scores (there is a single score in each detection after
    // the above pruning) to an indexed vector for sorting. The first value is
    // the index of the detection in the original vector from which the score
    // stems, while the second is the actual score.
    IndexedScores indexed_scores;
    indexed_scores.reserve(pruned_detections.size());
    for (int index = 0; index < pruned_detections.size(); ++index) {
      indexed_scores.push_back(
          std::make_pair(index, pruned_detections[index].score(0)));
    }
    std::sort(indexed_scores.begin(), indexed_scores.end(), SortBySecond);

    const int max_num_detections =
        (options_.max_num_detections() > -1)
            ? options_.max_num_detections()
            : static_cast<int>(indexed_scores.size());
    // A set of detections and locations, wrapping the location data from each
    // detection, which are retained after the non-maximum suppression.
    auto* retained_detections = new Detections();
    retained_detections->reserve(max_num_detections);

    if (options_.algorithm() == NonMaxSuppressionCalculatorOptions::WEIGHTED) {
      WeightedNonMaxSuppression(indexed_scores, pruned_detections,
                                max_num_detections, cc, retained_detections);
    } else {
      NonMaxSuppression(indexed_scores, pruned_detections, max_num_detections,
                        cc, retained_detections);
    }

    cc->Outputs().Index(0).Add(retained_detections, cc->InputTimestamp());

    return absl::OkStatus();
  }

 private:
  void NonMaxSuppression(const IndexedScores& indexed_scores,
                         const Detections& detections, int max_num_detections,
                         CalculatorContext* cc, Detections* output_detections) {
    std::vector<Location> retained_locations;
    retained_locations.reserve(max_num_detections);
    // We traverse the detections by decreasing score.
    for (const auto& indexed_score : indexed_scores) {
      const auto& detection = detections[indexed_score.first];
      if (options_.min_score_threshold() > 0 &&
          detection.score(0) < options_.min_score_threshold()) {
        break;
      }
      const Location location(detection.location_data());
      bool suppressed = false;
      // The current detection is suppressed iff there exists a retained
      // detection, whose location overlaps more than the specified
      // threshold with the location of the current detection.
      for (const auto& retained_location : retained_locations) {
        float similarity;
        if (cc->Inputs().HasTag(kImageTag)) {
          const auto& frame = cc->Inputs().Tag(kImageTag).Get<ImageFrame>();
          similarity = OverlapSimilarity(frame.Width(), frame.Height(),
                                         options_.overlap_type(),
                                         retained_location, location);
        } else {
          similarity = OverlapSimilarity(options_.overlap_type(),
                                         retained_location, location);
        }
        if (similarity > options_.min_suppression_threshold()) {
          suppressed = true;
          break;
        }
      }
      if (!suppressed) {
        output_detections->push_back(detection);
        retained_locations.push_back(location);
      }
      if (output_detections->size() >= max_num_detections) {
        break;
      }
    }
  }

  void WeightedNonMaxSuppression(const IndexedScores& indexed_scores,
                                 const Detections& detections,
                                 int max_num_detections, CalculatorContext* cc,
                                 Detections* output_detections) {
    IndexedScores remained_indexed_scores;
    remained_indexed_scores.assign(indexed_scores.begin(),
                                   indexed_scores.end());

    IndexedScores remained;
    IndexedScores candidates;
    output_detections->clear();
    while (!remained_indexed_scores.empty()) {
      const int original_indexed_scores_size = remained_indexed_scores.size();
      const auto& detection = detections[remained_indexed_scores[0].first];
      if (options_.min_score_threshold() > 0 &&
          detection.score(0) < options_.min_score_threshold()) {
        break;
      }
      remained.clear();
      candidates.clear();
      const Location location(detection.location_data());
      // This includes the first box.
      for (const auto& indexed_score : remained_indexed_scores) {
        Location rest_location(detections[indexed_score.first].location_data());
        float similarity =
            OverlapSimilarity(options_.overlap_type(), rest_location, location);
        if (similarity > options_.min_suppression_threshold()) {
          candidates.push_back(indexed_score);
        } else {
          remained.push_back(indexed_score);
        }
      }
      auto weighted_detection = detection;
      if (!candidates.empty()) {
        const int num_keypoints =
            detection.location_data().relative_keypoints_size();
        std::vector<float> keypoints(num_keypoints * 2);
        float w_xmin = 0.0f;
        float w_ymin = 0.0f;
        float w_xmax = 0.0f;
        float w_ymax = 0.0f;
        float total_score = 0.0f;
        for (const auto& candidate : candidates) {
          total_score += candidate.second;
          const auto& location_data =
              detections[candidate.first].location_data();
          const auto& bbox = location_data.relative_bounding_box();
          w_xmin += bbox.xmin() * candidate.second;
          w_ymin += bbox.ymin() * candidate.second;
          w_xmax += (bbox.xmin() + bbox.width()) * candidate.second;
          w_ymax += (bbox.ymin() + bbox.height()) * candidate.second;

          for (int i = 0; i < num_keypoints; ++i) {
            keypoints[i * 2] +=
                location_data.relative_keypoints(i).x() * candidate.second;
            keypoints[i * 2 + 1] +=
                location_data.relative_keypoints(i).y() * candidate.second;
          }
        }
        auto* weighted_location = weighted_detection.mutable_location_data()
                                      ->mutable_relative_bounding_box();
        weighted_location->set_xmin(w_xmin / total_score);
        weighted_location->set_ymin(w_ymin / total_score);
        weighted_location->set_width((w_xmax / total_score) -
                                     weighted_location->xmin());
        weighted_location->set_height((w_ymax / total_score) -
                                      weighted_location->ymin());
        for (int i = 0; i < num_keypoints; ++i) {
          auto* keypoint = weighted_detection.mutable_location_data()
                               ->mutable_relative_keypoints(i);
          keypoint->set_x(keypoints[i * 2] / total_score);
          keypoint->set_y(keypoints[i * 2 + 1] / total_score);
        }
      }

      output_detections->push_back(weighted_detection);
      // Breaks the loop if the size of indexed scores doesn't change after an
      // iteration.
      if (original_indexed_scores_size == remained.size()) {
        break;
      } else {
        remained_indexed_scores = std::move(remained);
      }
    }
  }

  NonMaxSuppressionCalculatorOptions options_;
};
REGISTER_CALCULATOR(NonMaxSuppressionCalculator);

}  // namespace mediapipe