chromium/third_party/mediapipe/src/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc

// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <string>
#include <vector>

#include "absl/container/node_hash_set.h"
#include "absl/log/absl_log.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/port/map_util.h"
#include "mediapipe/framework/port/re2.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/modules/objectron/calculators/filter_detection_calculator.pb.h"

namespace mediapipe {

namespace {

constexpr char kDetectionTag[] = "DETECTION";
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kLabelsTag[] = "LABELS";
constexpr char kLabelsCsvTag[] = "LABELS_CSV";
constexpr char kLabelMapTag[] = "LABEL_MAP";

using mediapipe::RE2;
using Detections = std::vector<Detection>;
using Strings = std::vector<std::string>;

struct FirstGreaterComparator {
  bool operator()(const std::pair<float, int>& a,
                  const std::pair<float, int>& b) const {
    return a.first > b.first;
  }
};

absl::Status SortLabelsByDecreasingScore(const Detection& detection,
                                         Detection* sorted_detection) {
  RET_CHECK(sorted_detection);
  RET_CHECK_EQ(detection.score_size(), detection.label_size());
  if (!detection.label_id().empty()) {
    RET_CHECK_EQ(detection.score_size(), detection.label_id_size());
  }
  // Copies input to keep all fields unchanged, and to reserve space for
  // repeated fields. Repeated fields (score, label, and label_id) will be
  // overwritten.
  *sorted_detection = detection;

  std::vector<std::pair<float, int>> scores_and_indices(detection.score_size());
  for (int i = 0; i < detection.score_size(); ++i) {
    scores_and_indices[i].first = detection.score(i);
    scores_and_indices[i].second = i;
  }

  std::sort(scores_and_indices.begin(), scores_and_indices.end(),
            FirstGreaterComparator());

  for (int i = 0; i < detection.score_size(); ++i) {
    const int index = scores_and_indices[i].second;
    sorted_detection->set_score(i, detection.score(index));
    sorted_detection->set_label(i, detection.label(index));
  }

  if (!detection.label_id().empty()) {
    for (int i = 0; i < detection.score_size(); ++i) {
      const int index = scores_and_indices[i].second;
      sorted_detection->set_label_id(i, detection.label_id(index));
    }
  }
  return absl::OkStatus();
}

}  // namespace

// Filters the entries in a Detection to only those with valid scores
// for the specified allowed labels. Allowed labels are provided as a
// std::vector<std::string> in an optional input side packet. Allowed labels can
// contain simple strings or regular expressions. The valid score range
// can be set in the options.The allowed labels can be provided as
// std::vector<std::string> (LABELS) or CSV string (LABELS_CSV) containing class
// names of allowed labels. Note: Providing an empty vector in the input side
// packet Packet causes this calculator to act as a sink if
// empty_allowed_labels_means_allow_everything is set to false (default value).
// To allow all labels, use the calculator with no input side packet stream, or
// set empty_allowed_labels_means_allow_everything to true.
//
// Example config:
// node {
//   calculator: "FilterDetectionCalculator"
//   input_stream: "DETECTIONS:detections"
//   output_stream: "DETECTIONS:filtered_detections"
//   input_side_packet: "LABELS:allowed_labels"
//   options: {
//     [mediapipe.FilterDetectionCalculatorOptions.ext]: {
//       min_score: 0.5
//     }
//   }
// }

class FilterDetectionCalculator : public CalculatorBase {
 public:
  static absl::Status GetContract(CalculatorContract* cc);
  absl::Status Open(CalculatorContext* cc) override;
  absl::Status Process(CalculatorContext* cc) override;

 private:
  bool IsValidLabel(const std::string& label);
  bool IsValidScore(float score);
  // Stores numeric limits for filtering on the score.
  FilterDetectionCalculatorOptions options_;
  // We use the next two fields to possibly filter to a limited set of
  // classes.  The hash_set will be empty in two cases: 1) if no input
  // side packet stream is provided (not filtering on labels), or 2)
  // if the input side packet contains an empty vector (no labels are
  // allowed). We use limit_labels_ to distinguish between the two cases.
  bool limit_labels_ = true;
  absl::node_hash_set<std::string> allowed_labels_;
};
REGISTER_CALCULATOR(FilterDetectionCalculator);

absl::Status FilterDetectionCalculator::GetContract(CalculatorContract* cc) {
  RET_CHECK(!cc->Inputs().GetTags().empty());
  RET_CHECK(!cc->Outputs().GetTags().empty());

  if (cc->Inputs().HasTag(kDetectionTag)) {
    cc->Inputs().Tag(kDetectionTag).Set<Detection>();
    cc->Outputs().Tag(kDetectionTag).Set<Detection>();
  }
  if (cc->Inputs().HasTag(kDetectionsTag)) {
    cc->Inputs().Tag(kDetectionsTag).Set<Detections>();
    cc->Outputs().Tag(kDetectionsTag).Set<Detections>();
  }
  if (cc->InputSidePackets().HasTag(kLabelsTag)) {
    cc->InputSidePackets().Tag(kLabelsTag).Set<Strings>();
  }
  if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
    cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>();
  }
  if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
    cc->InputSidePackets()
        .Tag(kLabelMapTag)
        .Set<std::unique_ptr<std::map<int, std::string>>>();
  }
  return absl::OkStatus();
}

absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
  cc->SetOffset(TimestampDiff(0));
  options_ = cc->Options<FilterDetectionCalculatorOptions>();
  limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) ||
                  cc->InputSidePackets().HasTag(kLabelsCsvTag) ||
                  cc->InputSidePackets().HasTag(kLabelMapTag);
  if (limit_labels_) {
    Strings allowlist_labels;
    if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
      allowlist_labels = absl::StrSplit(
          cc->InputSidePackets().Tag(kLabelsCsvTag).Get<std::string>(), ',',
          absl::SkipWhitespace());
      for (auto& e : allowlist_labels) {
        absl::StripAsciiWhitespace(&e);
      }
    } else if (cc->InputSidePackets().HasTag(kLabelsTag)) {
      allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>();
    } else if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
      auto label_map = cc->InputSidePackets()
                           .Tag(kLabelMapTag)
                           .Get<std::unique_ptr<std::map<int, std::string>>>()
                           .get();
      for (const auto& [_, v] : *label_map) {
        allowlist_labels.push_back(v);
      }
    }
    allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end());
  }
  if (limit_labels_ && allowed_labels_.empty()) {
    if (options_.fail_on_empty_labels()) {
      cc->GetCounter("VideosWithEmptyLabelsAllowlist")->Increment();
      return tool::StatusFail(
          "FilterDetectionCalculator received empty allowlist with "
          "fail_on_empty_labels = true.");
    }
    if (options_.empty_allowed_labels_means_allow_everything()) {
      // Continue as if side_input was not provided, i.e. pass all labels.
      limit_labels_ = false;
    }
  }
  return absl::OkStatus();
}

absl::Status FilterDetectionCalculator::Process(CalculatorContext* cc) {
  if (limit_labels_ && allowed_labels_.empty()) {
    return absl::OkStatus();
  }
  Detections detections;
  if (cc->Inputs().HasTag(kDetectionsTag)) {
    detections = cc->Inputs().Tag(kDetectionsTag).Get<Detections>();
  } else if (cc->Inputs().HasTag(kDetectionTag)) {
    detections.emplace_back(cc->Inputs().Tag(kDetectionTag).Get<Detection>());
  }
  std::unique_ptr<Detections> outputs(new Detections);
  for (const auto& input : detections) {
    Detection output;
    for (int i = 0; i < input.label_size(); ++i) {
      const std::string& label = input.label(i);
      const float score = input.score(i);
      if (IsValidLabel(label) && IsValidScore(score)) {
        output.add_label(label);
        output.add_score(score);
      }
    }
    if (output.label_size() > 0) {
      if (input.has_location_data()) {
        *output.mutable_location_data() = input.location_data();
      }
      Detection output_sorted;
      if (!SortLabelsByDecreasingScore(output, &output_sorted).ok()) {
        // Uses the orginal output if fails to sort.
        cc->GetCounter("FailedToSortLabelsInDetection")->Increment();
        output_sorted = output;
      }
      outputs->emplace_back(output_sorted);
    }
  }

  if (cc->Outputs().HasTag(kDetectionsTag)) {
    cc->Outputs()
        .Tag(kDetectionsTag)
        .Add(outputs.release(), cc->InputTimestamp());
  } else if (!outputs->empty()) {
    cc->Outputs()
        .Tag(kDetectionTag)
        .Add(new Detection((*outputs)[0]), cc->InputTimestamp());
  }
  return absl::OkStatus();
}

bool FilterDetectionCalculator::IsValidLabel(const std::string& label) {
  bool match = !limit_labels_ || allowed_labels_.contains(label);
  if (!match) {
    // If no exact match is found, check for regular expression
    // comparions in the allowed_labels.
    for (const auto& label_regexp : allowed_labels_) {
      match = match || RE2::FullMatch(label, RE2(label_regexp));
    }
  }
  return match;
}

bool FilterDetectionCalculator::IsValidScore(float score) {
  if (options_.has_min_score() && score < options_.min_score()) {
    ABSL_LOG(ERROR) << "Filter out detection with low score " << score;
    return false;
  }
  if (options_.has_max_score() && score > options_.max_score()) {
    ABSL_LOG(ERROR) << "Filter out detection with high score " << score;
    return false;
  }
  return true;
}

}  // namespace mediapipe