chromium/third_party/mediapipe/src/mediapipe/calculators/util/detection_unique_id_calculator.cc

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/port/status.h"

namespace mediapipe {

namespace {

constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kDetectionListTag[] = "DETECTION_LIST";

// Each detection processed by DetectionUniqueIDCalculator will be assigned an
// unique id that starts from 1. If a detection already has an ID other than 0,
// the ID will be overwritten.
static int64_t detection_id = 0;

inline int GetNextDetectionId() { return ++detection_id; }

}  // namespace

// Assign a unique id to detections.
// Note that the calculator will consume the input vector of Detection or
// DetectionList. So the input stream can not be connected to other calculators.
//
// Example config:
// node {
//   calculator: "DetectionUniqueIdCalculator"
//   input_stream: "DETECTIONS:detections"
//   output_stream: "DETECTIONS:output_detections"
// }
class DetectionUniqueIdCalculator : public CalculatorBase {
 public:
  static absl::Status GetContract(CalculatorContract* cc) {
    RET_CHECK(cc->Inputs().HasTag(kDetectionListTag) ||
              cc->Inputs().HasTag(kDetectionsTag))
        << "None of the input streams are provided.";

    if (cc->Inputs().HasTag(kDetectionListTag)) {
      RET_CHECK(cc->Outputs().HasTag(kDetectionListTag));
      cc->Inputs().Tag(kDetectionListTag).Set<DetectionList>();
      cc->Outputs().Tag(kDetectionListTag).Set<DetectionList>();
    }
    if (cc->Inputs().HasTag(kDetectionsTag)) {
      RET_CHECK(cc->Outputs().HasTag(kDetectionsTag));
      cc->Inputs().Tag(kDetectionsTag).Set<std::vector<Detection>>();
      cc->Outputs().Tag(kDetectionsTag).Set<std::vector<Detection>>();
    }

    return absl::OkStatus();
  }

  absl::Status Open(CalculatorContext* cc) override {
    cc->SetOffset(mediapipe::TimestampDiff(0));
    return absl::OkStatus();
  }
  absl::Status Process(CalculatorContext* cc) override;
};
REGISTER_CALCULATOR(DetectionUniqueIdCalculator);

absl::Status DetectionUniqueIdCalculator::Process(CalculatorContext* cc) {
  if (cc->Inputs().HasTag(kDetectionListTag) &&
      !cc->Inputs().Tag(kDetectionListTag).IsEmpty()) {
    auto result =
        cc->Inputs().Tag(kDetectionListTag).Value().Consume<DetectionList>();
    if (result.ok()) {
      auto detection_list = std::move(result).value();
      for (Detection& detection : *detection_list->mutable_detection()) {
        detection.set_detection_id(GetNextDetectionId());
      }
      cc->Outputs()
          .Tag(kDetectionListTag)
          .Add(detection_list.release(), cc->InputTimestamp());
    }
  }

  if (cc->Inputs().HasTag(kDetectionsTag) &&
      !cc->Inputs().Tag(kDetectionsTag).IsEmpty()) {
    auto result = cc->Inputs()
                      .Tag(kDetectionsTag)
                      .Value()
                      .Consume<std::vector<Detection>>();
    if (result.ok()) {
      auto detections = std::move(result).value();
      for (Detection& detection : *detections) {
        detection.set_detection_id(GetNextDetectionId());
      }
      cc->Outputs()
          .Tag(kDetectionsTag)
          .Add(detections.release(), cc->InputTimestamp());
    }
  }
  return absl::OkStatus();
}

}  // namespace mediapipe