// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "absl/container/node_hash_map.h"
#include "mediapipe/calculators/video/tracked_detection_manager_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/tracking/box_tracker.h"
#include "mediapipe/util/tracking/tracked_detection.h"
#include "mediapipe/util/tracking/tracked_detection_manager.h"
#include "mediapipe/util/tracking/tracking.h"
namespace mediapipe {
namespace {
using ::mediapipe::NormalizedRect;
constexpr int kDetectionUpdateTimeOutMS = 5000;
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kDetectionBoxesTag[] = "DETECTION_BOXES";
constexpr char kDetectionListTag[] = "DETECTION_LIST";
constexpr char kTrackingBoxesTag[] = "TRACKING_BOXES";
constexpr char kCancelObjectIdTag[] = "CANCEL_OBJECT_ID";
// Move |src| to the back of |dst|.
void MoveIds(std::vector<int>* dst, std::vector<int> src) {
dst->insert(dst->end(), std::make_move_iterator(src.begin()),
std::make_move_iterator(src.end()));
}
int64_t GetInputTimestampMs(::mediapipe::CalculatorContext* cc) {
return cc->InputTimestamp().Microseconds() / 1000; // 1 ms = 1000 us.
}
// Converts a Mediapipe Detection Proto to a TrackedDetection class.
std::unique_ptr<TrackedDetection> GetTrackedDetectionFromDetection(
const Detection& detection, int64_t timestamp) {
std::unique_ptr<TrackedDetection> tracked_detection =
absl::make_unique<TrackedDetection>(detection.detection_id(), timestamp);
const float top = detection.location_data().relative_bounding_box().ymin();
const float bottom =
detection.location_data().relative_bounding_box().ymin() +
detection.location_data().relative_bounding_box().height();
const float left = detection.location_data().relative_bounding_box().xmin();
const float right = detection.location_data().relative_bounding_box().xmin() +
detection.location_data().relative_bounding_box().width();
NormalizedRect bounding_box;
bounding_box.set_x_center((left + right) / 2.f);
bounding_box.set_y_center((top + bottom) / 2.f);
bounding_box.set_height(bottom - top);
bounding_box.set_width(right - left);
tracked_detection->set_bounding_box(bounding_box);
for (int i = 0; i < detection.label_size(); ++i) {
tracked_detection->AddLabel(detection.label(i), detection.score(i));
}
return tracked_detection;
}
// Converts a TrackedDetection class to a Mediapipe Detection Proto.
Detection GetAxisAlignedDetectionFromTrackedDetection(
const TrackedDetection& tracked_detection) {
Detection detection;
LocationData* location_data = detection.mutable_location_data();
auto corners = tracked_detection.GetCorners();
float x_min = std::numeric_limits<float>::max();
float x_max = std::numeric_limits<float>::min();
float y_min = std::numeric_limits<float>::max();
float y_max = std::numeric_limits<float>::min();
for (int i = 0; i < 4; ++i) {
x_min = std::min(x_min, corners[i].x());
x_max = std::max(x_max, corners[i].x());
y_min = std::min(y_min, corners[i].y());
y_max = std::max(y_max, corners[i].y());
}
location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX);
LocationData::RelativeBoundingBox* relative_bbox =
location_data->mutable_relative_bounding_box();
relative_bbox->set_xmin(x_min);
relative_bbox->set_ymin(y_min);
relative_bbox->set_width(x_max - x_min);
relative_bbox->set_height(y_max - y_min);
// Use previous id which is the id the object when it's first detected.
if (tracked_detection.previous_id() > 0) {
detection.set_detection_id(tracked_detection.previous_id());
} else {
detection.set_detection_id(tracked_detection.unique_id());
}
// Sort the labels by descending scores.
std::vector<std::pair<std::string, float>> labels_and_scores;
for (const auto& label_and_score : tracked_detection.label_to_score_map()) {
labels_and_scores.push_back(label_and_score);
}
std::sort(labels_and_scores.begin(), labels_and_scores.end(),
[](const auto& a, const auto& b) { return a.second > b.second; });
for (const auto& label_and_score : labels_and_scores) {
detection.add_label(label_and_score.first);
detection.add_score(label_and_score.second);
}
return detection;
}
} // namespace
// TrackedDetectionManagerCalculator accepts detections and tracking results at
// different frame rate for real time tracking of targets.
// Input:
// DETECTIONS: A vector<Detection> of newly detected targets.
// TRACKING_BOXES: A TimedBoxProtoList which contains a list of tracked boxes
// from previous detections.
//
// Output:
// CANCEL_OBJECT_ID: Ids of targets that are missing/lost such that it should
// be removed from tracking.
// DETECTIONS: List of detections that are being tracked.
// DETECTION_BOXES: List of bounding boxes of detections that are being
// tracked.
//
// Usage example:
// node {
// calculator: "TrackedDetectionManagerCalculator"
// input_stream: "DETECTIONS:detections"
// input_stream: "TRACKING_BOXES:boxes"
// output_stream: "CANCEL_OBJECT_ID:cancel_object_id"
// output_stream: "DETECTIONS:output_detections"
// }
class TrackedDetectionManagerCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
// Adds new list of detections to |waiting_for_update_detections_|.
void AddDetectionList(const DetectionList& detection_list,
CalculatorContext* cc);
void AddDetections(const std::vector<Detection>& detections,
CalculatorContext* cc);
// Manages existing and new detections.
TrackedDetectionManager tracked_detection_manager_;
// Set of detections that are not up to date yet. These detections will be
// added to the detection manager until they got updated from the box tracker.
absl::node_hash_map<int, std::unique_ptr<TrackedDetection>>
waiting_for_update_detections_;
};
REGISTER_CALCULATOR(TrackedDetectionManagerCalculator);
absl::Status TrackedDetectionManagerCalculator::GetContract(
CalculatorContract* cc) {
if (cc->Inputs().HasTag(kDetectionsTag)) {
cc->Inputs().Tag(kDetectionsTag).Set<std::vector<Detection>>();
}
if (cc->Inputs().HasTag(kDetectionListTag)) {
cc->Inputs().Tag(kDetectionListTag).Set<DetectionList>();
}
if (cc->Inputs().HasTag(kTrackingBoxesTag)) {
cc->Inputs().Tag(kTrackingBoxesTag).Set<TimedBoxProtoList>();
}
if (cc->Outputs().HasTag(kCancelObjectIdTag)) {
cc->Outputs().Tag(kCancelObjectIdTag).Set<int>();
}
if (cc->Outputs().HasTag(kDetectionsTag)) {
cc->Outputs().Tag(kDetectionsTag).Set<std::vector<Detection>>();
}
if (cc->Outputs().HasTag(kDetectionBoxesTag)) {
cc->Outputs().Tag(kDetectionBoxesTag).Set<std::vector<NormalizedRect>>();
}
return absl::OkStatus();
}
absl::Status TrackedDetectionManagerCalculator::Open(CalculatorContext* cc) {
mediapipe::TrackedDetectionManagerCalculatorOptions options =
cc->Options<mediapipe::TrackedDetectionManagerCalculatorOptions>();
tracked_detection_manager_.SetConfig(
options.tracked_detection_manager_options());
return absl::OkStatus();
}
absl::Status TrackedDetectionManagerCalculator::Process(CalculatorContext* cc) {
if (cc->Inputs().HasTag(kTrackingBoxesTag) &&
!cc->Inputs().Tag(kTrackingBoxesTag).IsEmpty()) {
const TimedBoxProtoList& tracked_boxes =
cc->Inputs().Tag(kTrackingBoxesTag).Get<TimedBoxProtoList>();
// Collect all detections that are removed.
auto removed_detection_ids = absl::make_unique<std::vector<int>>();
for (const TimedBoxProto& tracked_box : tracked_boxes.box()) {
NormalizedRect bounding_box;
bounding_box.set_x_center((tracked_box.left() + tracked_box.right()) /
2.f);
bounding_box.set_y_center((tracked_box.bottom() + tracked_box.top()) /
2.f);
bounding_box.set_height(tracked_box.bottom() - tracked_box.top());
bounding_box.set_width(tracked_box.right() - tracked_box.left());
bounding_box.set_rotation(tracked_box.rotation());
// First check if this box updates a detection that's waiting for
// update from the tracker.
auto waiting_for_update_detectoin_ptr =
waiting_for_update_detections_.find(tracked_box.id());
if (waiting_for_update_detectoin_ptr !=
waiting_for_update_detections_.end()) {
// Add the detection and remove duplicated detections.
auto removed_ids = tracked_detection_manager_.AddDetection(
std::move(waiting_for_update_detectoin_ptr->second));
MoveIds(removed_detection_ids.get(), std::move(removed_ids));
waiting_for_update_detections_.erase(waiting_for_update_detectoin_ptr);
}
auto removed_ids = tracked_detection_manager_.UpdateDetectionLocation(
tracked_box.id(), bounding_box, tracked_box.time_msec());
MoveIds(removed_detection_ids.get(), std::move(removed_ids));
}
// TODO: Should be handled automatically in detection manager.
auto removed_ids = tracked_detection_manager_.RemoveObsoleteDetections(
GetInputTimestampMs(cc) - kDetectionUpdateTimeOutMS);
MoveIds(removed_detection_ids.get(), std::move(removed_ids));
// TODO: Should be handled automatically in detection manager.
removed_ids = tracked_detection_manager_.RemoveOutOfViewDetections();
MoveIds(removed_detection_ids.get(), std::move(removed_ids));
if (!removed_detection_ids->empty() &&
cc->Outputs().HasTag(kCancelObjectIdTag)) {
auto timestamp = cc->InputTimestamp();
for (int box_id : *removed_detection_ids) {
// The timestamp is incremented (by 1 us) because currently the box
// tracker calculator only accepts one cancel object ID for any given
// timestamp.
cc->Outputs()
.Tag(kCancelObjectIdTag)
.AddPacket(mediapipe::MakePacket<int>(box_id).At(timestamp++));
}
}
// Output detections and corresponding bounding boxes.
const auto& all_detections =
tracked_detection_manager_.GetAllTrackedDetections();
auto output_detections = absl::make_unique<std::vector<Detection>>();
auto output_boxes = absl::make_unique<std::vector<NormalizedRect>>();
for (const auto& detection_ptr : all_detections) {
const auto& detection = *detection_ptr.second;
// Only output detections that are synced.
if (detection.last_updated_timestamp() <
cc->InputTimestamp().Microseconds() / 1000) {
continue;
}
output_detections->emplace_back(
GetAxisAlignedDetectionFromTrackedDetection(detection));
output_boxes->emplace_back(detection.bounding_box());
}
if (cc->Outputs().HasTag(kDetectionsTag)) {
cc->Outputs()
.Tag(kDetectionsTag)
.Add(output_detections.release(), cc->InputTimestamp());
}
if (cc->Outputs().HasTag(kDetectionBoxesTag)) {
cc->Outputs()
.Tag(kDetectionBoxesTag)
.Add(output_boxes.release(), cc->InputTimestamp());
}
}
if (cc->Inputs().HasTag(kDetectionsTag) &&
!cc->Inputs().Tag(kDetectionsTag).IsEmpty()) {
const auto detections =
cc->Inputs().Tag(kDetectionsTag).Get<std::vector<Detection>>();
AddDetections(detections, cc);
}
if (cc->Inputs().HasTag(kDetectionListTag) &&
!cc->Inputs().Tag(kDetectionListTag).IsEmpty()) {
const auto detection_list =
cc->Inputs().Tag(kDetectionListTag).Get<DetectionList>();
AddDetectionList(detection_list, cc);
}
return absl::OkStatus();
}
void TrackedDetectionManagerCalculator::AddDetectionList(
const DetectionList& detection_list, CalculatorContext* cc) {
for (const auto& detection : detection_list.detection()) {
// Convert from microseconds to milliseconds.
std::unique_ptr<TrackedDetection> new_detection =
GetTrackedDetectionFromDetection(
detection, cc->InputTimestamp().Microseconds() / 1000);
const int id = new_detection->unique_id();
waiting_for_update_detections_[id] = std::move(new_detection);
}
}
void TrackedDetectionManagerCalculator::AddDetections(
const std::vector<Detection>& detections, CalculatorContext* cc) {
for (const auto& detection : detections) {
// Convert from microseconds to milliseconds.
std::unique_ptr<TrackedDetection> new_detection =
GetTrackedDetectionFromDetection(
detection, cc->InputTimestamp().Microseconds() / 1000);
const int id = new_detection->unique_id();
waiting_for_update_detections_[id] = std::move(new_detection);
}
}
} // namespace mediapipe