chromium/third_party/mediapipe/src/mediapipe/calculators/util/detection_classifications_merger_calculator.cc

// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "absl/strings/substitute.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h"

namespace mediapipe {
namespace api2 {

namespace {}  // namespace

// Replaces the classification labels and scores from the input `Detection` with
// the ones provided into the input `ClassificationList`. Namely:
// * `label_id[i]` becomes `classification[i].index`
// * `score[i]` becomes `classification[i].score`
// * `label[i]` becomes `classification[i].label` (if present)
//
// In case the input `ClassificationList` contains no results (i.e.
// `classification` is empty, which may happen if the classifier uses a score
// threshold and no confident enough result were returned), the input
// `Detection` is returned unchanged.
//
// This is specifically designed for two-stage detection cascades where the
// detections returned by a standalone detector (typically a class-agnostic
// localizer) are fed e.g. into a `TfLiteTaskImageClassifierCalculator` through
// the optional "RECT" or "NORM_RECT" input, e.g:
//
// node {
//   calculator: "DetectionsToRectsCalculator"
//   # Output of an upstream object detector.
//   input_stream: "DETECTION:detection"
//   output_stream: "NORM_RECT:norm_rect"
// }
// node {
//   calculator: "TfLiteTaskImageClassifierCalculator"
//   input_stream: "IMAGE:image"
//   input_stream: "NORM_RECT:norm_rect"
//   output_stream: "CLASSIFICATION_RESULT:classification_result"
// }
// node {
//   calculator: "TfLiteTaskClassificationResultToClassificationsCalculator"
//   input_stream: "CLASSIFICATION_RESULT:classification_result"
//   output_stream: "CLASSIFICATION_LIST:classification_list"
// }
// node {
//   calculator: "DetectionClassificationsMergerCalculator"
//   input_stream: "INPUT_DETECTION:detection"
//   input_stream: "CLASSIFICATION_LIST:classification_list"
//   # Final output.
//   output_stream: "OUTPUT_DETECTION:classified_detection"
// }
//
// Inputs:
// INPUT_DETECTION: `Detection` proto.
// CLASSIFICATION_LIST: `ClassificationList` proto.
//
// Output:
// OUTPUT_DETECTION: modified `Detection` proto.
class DetectionClassificationsMergerCalculator : public Node {
 public:
  static constexpr Input<Detection> kInputDetection{"INPUT_DETECTION"};
  static constexpr Input<ClassificationList> kClassificationList{
      "CLASSIFICATION_LIST"};
  static constexpr Output<Detection> kOutputDetection{"OUTPUT_DETECTION"};

  MEDIAPIPE_NODE_CONTRACT(kInputDetection, kClassificationList,
                          kOutputDetection);

  absl::Status Process(CalculatorContext* cc) override;
};
MEDIAPIPE_REGISTER_NODE(DetectionClassificationsMergerCalculator);

absl::Status DetectionClassificationsMergerCalculator::Process(
    CalculatorContext* cc) {
  if (kInputDetection(cc).IsEmpty() && kClassificationList(cc).IsEmpty()) {
    return absl::OkStatus();
  }
  RET_CHECK(!kInputDetection(cc).IsEmpty());
  RET_CHECK(!kClassificationList(cc).IsEmpty());

  Detection detection = *kInputDetection(cc);
  const ClassificationList& classification_list = *kClassificationList(cc);

  // Update input detection only if classification did return results.
  if (classification_list.classification_size() != 0) {
    detection.clear_label_id();
    detection.clear_score();
    detection.clear_label();
    detection.clear_display_name();
    for (const auto& classification : classification_list.classification()) {
      if (!classification.has_index()) {
        return absl::InvalidArgumentError(
            "Missing required 'index' field in Classification proto.");
      }
      detection.add_label_id(classification.index());
      if (!classification.has_score()) {
        return absl::InvalidArgumentError(
            "Missing required 'score' field in Classification proto.");
      }
      detection.add_score(classification.score());
      if (classification.has_label()) {
        detection.add_label(classification.label());
      }
      if (classification.has_display_name()) {
        detection.add_display_name(classification.display_name());
      }
    }
    // Post-conversion sanity checks.
    if (detection.label_size() != 0 &&
        detection.label_size() != detection.label_id_size()) {
      return absl::InvalidArgumentError(absl::Substitute(
          "Each input Classification is expected to either always or never "
          "provide a 'label' field. Found $0 'label' fields for $1 "
          "'Classification' objects.",
          /*$0=*/detection.label_size(), /*$1=*/detection.label_id_size()));
    }
    if (detection.display_name_size() != 0 &&
        detection.display_name_size() != detection.label_id_size()) {
      return absl::InvalidArgumentError(absl::Substitute(
          "Each input Classification is expected to either always or never "
          "provide a 'display_name' field. Found $0 'display_name' fields for "
          "$1 'Classification' objects.",
          /*$0=*/detection.display_name_size(),
          /*$1=*/detection.label_id_size()));
    }
  }
  kOutputDetection(cc).Send(detection);
  return absl::OkStatus();
}

}  // namespace api2
}  // namespace mediapipe