// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/source_location.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_builder.h"
#include "mediapipe/util/tensor_to_detection.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace mediapipe {
class CalculatorOptions;
} // namespace mediapipe
namespace mediapipe {
namespace tf = ::tensorflow;
namespace {
const char kNumDetections[] = "NUM_DETECTIONS";
const char kBoxes[] = "BOXES";
const char kScores[] = "SCORES";
const char kClasses[] = "CLASSES";
const char kDetections[] = "DETECTIONS";
const char kKeypoints[] = "KEYPOINTS";
const char kMasks[] = "MASKS";
const char kLabelMap[] = "LABELMAP";
const int kNumCoordsPerBox = 4;
} // namespace
// Takes object detection results and converts them into MediaPipe Detections.
//
// Inputs are assumed to be tensors of the form:
// `num_detections` : float32 scalar tensor indicating the number of valid
// detections.
// `detection_boxes` : float32 tensor of the form [num_boxes, 4]. Format for
// coordinates is {ymin, xmin, ymax, xmax}.
// `detection_scores` : float32 tensor of the form [num_boxes].
// `detection_classes` : float32 tensor of the form [num_boxes].
// `detection_keypoints`: float32 tensor of the form
// [num_boxes, num_keypoints, 2].
// `detection_masks` : float32 tensor of the form
// [num_boxes, height, width].
//
// These are generated according to the Vale object detector model exporter,
// which may be found in
// image/understanding/object_detection/export_inference_graph.py
//
// By default, the output Detections store label ids (integers) for each
// detection. Optionally, a label map (of the form std::map<int, string>
// mapping label ids to label names as strings) can be made available as an
// input side packet, in which case the output Detections store
// labels as their associated string provided by the label map.
//
// Usage example:
// node {
// calculator: "ObjectDetectionTensorsToDetectionsCalculator"
// input_stream: "BOXES:detection_boxes_tensor"
// input_stream: "SCORES:detection_scores_tensor"
// input_stream: "CLASSES:detection_classes_tensor"
// input_stream: "NUM_DETECTIONS:num_detections_tensor"
// output_stream: "DETECTIONS:detections"
// options: {
// [mediapipe.ObjectDetectionsTensorToDetectionsCalculatorOptions.ext]: {
// tensor_dim_to_squeeze: 0
// }
// }
// }
class ObjectDetectionTensorsToDetectionsCalculator : public CalculatorBase {
public:
ObjectDetectionTensorsToDetectionsCalculator() = default;
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag(kBoxes).Set<tf::Tensor>();
cc->Inputs().Tag(kScores).Set<tf::Tensor>();
if (cc->Inputs().HasTag(kNumDetections)) {
cc->Inputs().Tag(kNumDetections).Set<tf::Tensor>();
}
if (cc->Inputs().HasTag(kClasses)) {
cc->Inputs().Tag(kClasses).Set<tf::Tensor>();
}
if (cc->Inputs().HasTag(kKeypoints)) {
cc->Inputs().Tag(kKeypoints).Set<tf::Tensor>();
}
if (cc->Inputs().HasTag(kMasks)) {
cc->Inputs().Tag(kMasks).Set<tf::Tensor>();
const auto& calculator_options =
cc->Options<ObjectDetectionsTensorToDetectionsCalculatorOptions>();
float mask_threshold = calculator_options.mask_threshold();
if (!(mask_threshold >= 0.0 && mask_threshold <= 1.0)) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "mask_threshold must be in range [0.0, 1.0]";
}
}
cc->Outputs().Tag(kDetections).Set<std::vector<Detection>>();
if (cc->InputSidePackets().HasTag(kLabelMap)) {
cc->InputSidePackets()
.Tag(kLabelMap)
.Set<std::unique_ptr<std::map<int, std::string>>>();
}
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
if (cc->InputSidePackets().HasTag(kLabelMap)) {
label_map_ = GetFromUniquePtr<std::map<int, std::string>>(
cc->InputSidePackets().Tag(kLabelMap));
}
const auto& tensor_dim_to_squeeze_field =
cc->Options<ObjectDetectionsTensorToDetectionsCalculatorOptions>()
.tensor_dim_to_squeeze();
tensor_dims_to_squeeze_ = std::vector<int32_t>(
tensor_dim_to_squeeze_field.begin(), tensor_dim_to_squeeze_field.end());
std::sort(tensor_dims_to_squeeze_.rbegin(), tensor_dims_to_squeeze_.rend());
cc->SetOffset(0);
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
const auto& options =
cc->Options<ObjectDetectionsTensorToDetectionsCalculatorOptions>();
tf::Tensor input_num_detections_tensor =
tf::Tensor(tf::DT_FLOAT, tf::TensorShape({0}));
if (cc->Inputs().HasTag(kClasses)) {
MP_ASSIGN_OR_RETURN(
input_num_detections_tensor,
MaybeSqueezeDims(kNumDetections,
cc->Inputs().Tag(kNumDetections).Get<tf::Tensor>()));
}
if (input_num_detections_tensor.dtype() != tf::DT_INT32) {
RET_CHECK_EQ(input_num_detections_tensor.dtype(), tf::DT_FLOAT);
}
MP_ASSIGN_OR_RETURN(
auto input_boxes_tensor,
MaybeSqueezeDims(kBoxes, cc->Inputs().Tag(kBoxes).Get<tf::Tensor>()));
RET_CHECK_EQ(input_boxes_tensor.dtype(), tf::DT_FLOAT);
MP_ASSIGN_OR_RETURN(
auto input_scores_tensor,
MaybeSqueezeDims(kScores, cc->Inputs().Tag(kScores).Get<tf::Tensor>()));
RET_CHECK_EQ(input_scores_tensor.dtype(), tf::DT_FLOAT);
tf::Tensor input_classes_tensor =
tf::Tensor(tf::DT_FLOAT, tf::TensorShape({0}));
if (cc->Inputs().HasTag(kClasses)) {
MP_ASSIGN_OR_RETURN(
input_classes_tensor,
MaybeSqueezeDims(kClasses,
cc->Inputs().Tag(kClasses).Get<tf::Tensor>()));
}
RET_CHECK_EQ(input_classes_tensor.dtype(), tf::DT_FLOAT);
auto output_detections = absl::make_unique<std::vector<Detection>>();
const tf::Tensor& input_keypoints_tensor =
cc->Inputs().HasTag(kKeypoints)
? cc->Inputs().Tag(kKeypoints).Get<tf::Tensor>()
: tf::Tensor(tf::DT_FLOAT, tf::TensorShape({0, 0, 0}));
const tf::Tensor& input_masks_tensor =
cc->Inputs().HasTag(kMasks)
? cc->Inputs().Tag(kMasks).Get<tf::Tensor>()
: tf::Tensor(tf::DT_FLOAT, tf::TensorShape({0, 0, 0}));
RET_CHECK_EQ(input_masks_tensor.dtype(), tf::DT_FLOAT);
const std::map<int, std::string> label_map =
(label_map_ == nullptr) ? std::map<int, std::string>{} : *label_map_;
RET_CHECK_OK(TensorsToDetections(
input_num_detections_tensor, input_boxes_tensor, input_scores_tensor,
input_classes_tensor, input_keypoints_tensor, input_masks_tensor,
options.mask_threshold(), label_map, output_detections.get()));
cc->Outputs()
.Tag(kDetections)
.Add(output_detections.release(), cc->InputTimestamp());
return absl::OkStatus();
}
private:
std::map<int, std::string>* label_map_;
std::vector<int32_t> tensor_dims_to_squeeze_;
absl::StatusOr<tf::Tensor> MaybeSqueezeDims(const std::string& tensor_tag,
const tf::Tensor& input_tensor) {
if (tensor_dims_to_squeeze_.empty()) {
return input_tensor;
}
tf::TensorShape tensor_shape = input_tensor.shape();
for (const int dim : tensor_dims_to_squeeze_) {
RET_CHECK_GT(tensor_shape.dims(), dim)
<< "Dimension " << dim
<< " does not exist in input tensor with num dimensions "
<< input_tensor.dims() << " dims";
RET_CHECK_EQ(tensor_shape.dim_size(dim), 1)
<< "Cannot remove dimension " << dim << " with size "
<< tensor_shape.dim_size(dim);
tensor_shape.RemoveDim(dim);
}
tf::Tensor output_tensor;
RET_CHECK(output_tensor.CopyFrom(input_tensor, tensor_shape));
return std::move(output_tensor);
}
};
REGISTER_CALCULATOR(ObjectDetectionTensorsToDetectionsCalculator);
} // namespace mediapipe