chromium/third_party/mediapipe/src/mediapipe/calculators/util/detections_to_timed_box_list_calculator.cc

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/tracking/box_tracker.h"

namespace mediapipe {

namespace {

constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kDetectionListTag[] = "DETECTION_LIST";
constexpr char kBoxesTag[] = "BOXES";

}  // namespace

// A calculator that converts Detection proto to TimedBoxList proto for
// tracking.
//
// Please note that only Location Data formats of RELATIVE_BOUNDING_BOX are
// supported.
//
// Example config:
// node {
//   calculator: "DetectionsToTimedBoxListCalculator"
//   input_stream: "DETECTIONS:detections"
//   output_stream: "BOXES:boxes"
// }
class DetectionsToTimedBoxListCalculator : public CalculatorBase {
 public:
  static absl::Status GetContract(CalculatorContract* cc) {
    RET_CHECK(cc->Inputs().HasTag(kDetectionListTag) ||
              cc->Inputs().HasTag(kDetectionsTag))
        << "None of the input streams are provided.";
    if (cc->Inputs().HasTag(kDetectionListTag)) {
      cc->Inputs().Tag(kDetectionListTag).Set<DetectionList>();
    }
    if (cc->Inputs().HasTag(kDetectionsTag)) {
      cc->Inputs().Tag(kDetectionsTag).Set<std::vector<Detection>>();
    }
    cc->Outputs().Tag(kBoxesTag).Set<TimedBoxProtoList>();
    return absl::OkStatus();
  }

  absl::Status Open(CalculatorContext* cc) override {
    cc->SetOffset(TimestampDiff(0));
    return absl::OkStatus();
  }
  absl::Status Process(CalculatorContext* cc) override;

 private:
  void ConvertDetectionToTimedBox(const Detection& detection,
                                  TimedBoxProto* box, CalculatorContext* cc);
};
REGISTER_CALCULATOR(DetectionsToTimedBoxListCalculator);

absl::Status DetectionsToTimedBoxListCalculator::Process(
    CalculatorContext* cc) {
  auto output_timed_box_list = absl::make_unique<TimedBoxProtoList>();

  if (cc->Inputs().HasTag(kDetectionListTag)) {
    const auto& detection_list =
        cc->Inputs().Tag(kDetectionListTag).Get<DetectionList>();
    for (const auto& detection : detection_list.detection()) {
      TimedBoxProto* box = output_timed_box_list->add_box();
      ConvertDetectionToTimedBox(detection, box, cc);
    }
  }
  if (cc->Inputs().HasTag(kDetectionsTag)) {
    const auto& detections =
        cc->Inputs().Tag(kDetectionsTag).Get<std::vector<Detection>>();
    for (const auto& detection : detections) {
      TimedBoxProto* box = output_timed_box_list->add_box();
      ConvertDetectionToTimedBox(detection, box, cc);
    }
  }

  cc->Outputs().Tag(kBoxesTag).Add(output_timed_box_list.release(),
                                   cc->InputTimestamp());
  return absl::OkStatus();
}

void DetectionsToTimedBoxListCalculator::ConvertDetectionToTimedBox(
    const Detection& detection, TimedBoxProto* box, CalculatorContext* cc) {
  const auto& relative_bounding_box =
      detection.location_data().relative_bounding_box();
  box->set_left(relative_bounding_box.xmin());
  box->set_right(relative_bounding_box.xmin() + relative_bounding_box.width());
  box->set_top(relative_bounding_box.ymin());
  box->set_bottom(relative_bounding_box.ymin() +
                  relative_bounding_box.height());
  box->set_id(detection.detection_id());
  box->set_time_msec(cc->InputTimestamp().Microseconds() / 1000);
}

}  // namespace mediapipe