chromium/third_party/mediapipe/src/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc

// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <math.h>

#include <cstdlib>

#include "absl/strings/str_format.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/rectangle.h"
#include "mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.pb.h"

namespace mediapipe {

namespace {

constexpr char kPrevLandmarksTag[] = "PREV_LANDMARKS";
constexpr char kPrevLandmarksRectTag[] = "PREV_LANDMARKS_RECT";
constexpr char kRecropRectTag[] = "RECROP_RECT";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kTrackingRectTag[] = "TRACKING_RECT";

using ::mediapipe::NormalizedRect;

// TODO: Use rect rotation.
// Verifies that Intersection over Union of previous frame rect and current
// frame re-crop rect is less than threshold.
bool IouRequirementsSatisfied(const NormalizedRect& prev_rect,
                              const NormalizedRect& recrop_rect,
                              const std::pair<int, int>& image_size,
                              const float min_iou) {
  auto r1 = Rectangle_f(prev_rect.x_center() * image_size.first,
                        prev_rect.y_center() * image_size.second,
                        prev_rect.width() * image_size.first,
                        prev_rect.height() * image_size.second);
  auto r2 = Rectangle_f(recrop_rect.x_center() * image_size.first,
                        recrop_rect.y_center() * image_size.second,
                        recrop_rect.width() * image_size.first,
                        recrop_rect.height() * image_size.second);

  const float intersection_area = r1.Intersect(r2).Area();
  const float union_area = r1.Area() + r2.Area() - intersection_area;

  const float intersection_threshold = union_area * min_iou;
  if (intersection_area < intersection_threshold) {
    VLOG(1) << absl::StrFormat("Lost tracking: IoU intersection %f < %f",
                               intersection_area, intersection_threshold);
    return false;
  }

  return true;
}

// Verifies that current frame re-crop rect rotation/translation/scale didn't
// change much comparing to the previous frame rect. Translation and scale are
// normalized by current frame re-crop rect.
bool RectRequirementsSatisfied(const NormalizedRect& prev_rect,
                               const NormalizedRect& recrop_rect,
                               const std::pair<int, int> image_size,
                               const float rotation_degrees,
                               const float translation, const float scale) {
  // Rotate both rects so that re-crop rect edges are parallel to XY axes. That
  // will allow to compute x/y translation of the previous frame rect along axes
  // of the current frame re-crop rect.
  const float rotation = -recrop_rect.rotation();

  const float cosa = cos(rotation);
  const float sina = sin(rotation);

  // Rotate previous frame rect and get its parameters.
  const float prev_rect_x = prev_rect.x_center() * image_size.first * cosa -
                            prev_rect.y_center() * image_size.second * sina;
  const float prev_rect_y = prev_rect.x_center() * image_size.first * sina +
                            prev_rect.y_center() * image_size.second * cosa;
  const float prev_rect_width = prev_rect.width() * image_size.first;
  const float prev_rect_height = prev_rect.height() * image_size.second;
  const float prev_rect_rotation = prev_rect.rotation() / M_PI * 180.f;

  // Rotate current frame re-crop rect and get its parameters.
  const float recrop_rect_x = recrop_rect.x_center() * image_size.first * cosa -
                              recrop_rect.y_center() * image_size.second * sina;
  const float recrop_rect_y = recrop_rect.x_center() * image_size.first * sina +
                              recrop_rect.y_center() * image_size.second * cosa;
  const float recrop_rect_width = recrop_rect.width() * image_size.first;
  const float recrop_rect_height = recrop_rect.height() * image_size.second;
  const float recrop_rect_rotation = recrop_rect.rotation() / M_PI * 180.f;

  // Rect requirements are satisfied unless one of the checks below fails.
  bool satisfied = true;

  // Ensure that rotation diff is in [0, 180] range.
  float rotation_diff = prev_rect_rotation - recrop_rect_rotation;
  if (rotation_diff > 180.f) {
    rotation_diff -= 360.f;
  }
  if (rotation_diff < -180.f) {
    rotation_diff += 360.f;
  }
  rotation_diff = abs(rotation_diff);
  if (rotation_diff > rotation_degrees) {
    satisfied = false;
    VLOG(1) << absl::StrFormat("Lost tracking: rect rotation %f > %f",
                               rotation_diff, rotation_degrees);
  }

  const float x_diff = abs(prev_rect_x - recrop_rect_x);
  const float x_threshold = recrop_rect_width * translation;
  if (x_diff > x_threshold) {
    satisfied = false;
    VLOG(1) << absl::StrFormat("Lost tracking: rect x translation %f > %f",
                               x_diff, x_threshold);
  }

  const float y_diff = abs(prev_rect_y - recrop_rect_y);
  const float y_threshold = recrop_rect_height * translation;
  if (y_diff > y_threshold) {
    satisfied = false;
    VLOG(1) << absl::StrFormat("Lost tracking: rect y translation %f > %f",
                               y_diff, y_threshold);
  }

  const float width_diff = abs(prev_rect_width - recrop_rect_width);
  const float width_threshold = recrop_rect_width * scale;
  if (width_diff > width_threshold) {
    satisfied = false;
    VLOG(1) << absl::StrFormat("Lost tracking: rect width %f > %f", width_diff,
                               width_threshold);
  }

  const float height_diff = abs(prev_rect_height - recrop_rect_height);
  const float height_threshold = recrop_rect_height * scale;
  if (height_diff > height_threshold) {
    satisfied = false;
    VLOG(1) << absl::StrFormat("Lost tracking: rect height %f > %f",
                               height_diff, height_threshold);
  }

  return satisfied;
}

// Verifies that landmarks from the previous frame are within re-crop rectangle
// bounds on the current frame.
bool LandmarksRequirementsSatisfied(const NormalizedLandmarkList& landmarks,
                                    const NormalizedRect& recrop_rect,
                                    const std::pair<int, int> image_size,
                                    const float recrop_rect_margin) {
  // Rotate both re-crop rectangle and landmarks so that re-crop rectangle edges
  // are parallel to XY axes. It will allow to easily check if landmarks are
  // within re-crop rect bounds along re-crop rect axes.
  //
  // Rect rotation is specified clockwise. To apply cos/sin functions we
  // transform it into counterclockwise.
  const float rotation = -recrop_rect.rotation();

  const float cosa = cos(rotation);
  const float sina = sin(rotation);

  // Rotate rect.
  const float rect_x = recrop_rect.x_center() * image_size.first * cosa -
                       recrop_rect.y_center() * image_size.second * sina;
  const float rect_y = recrop_rect.x_center() * image_size.first * sina +
                       recrop_rect.y_center() * image_size.second * cosa;
  const float rect_width =
      recrop_rect.width() * image_size.first * (1.f + recrop_rect_margin);
  const float rect_height =
      recrop_rect.height() * image_size.second * (1.f + recrop_rect_margin);

  // Get rect bounds.
  const float rect_left = rect_x - rect_width * 0.5f;
  const float rect_right = rect_x + rect_width * 0.5f;
  const float rect_top = rect_y - rect_height * 0.5f;
  const float rect_bottom = rect_y + rect_height * 0.5f;

  for (int i = 0; i < landmarks.landmark_size(); ++i) {
    const auto& landmark = landmarks.landmark(i);
    const float x = landmark.x() * image_size.first * cosa -
                    landmark.y() * image_size.second * sina;
    const float y = landmark.x() * image_size.first * sina +
                    landmark.y() * image_size.second * cosa;

    if (!(rect_left < x && x < rect_right && rect_top < y && y < rect_bottom)) {
      VLOG(1) << "Lost tracking: landmarks out of re-crop rect";
      return false;
    }
  }

  return true;
}

}  // namespace

// A calculator to track object rectangle between frames.
//
// Calculator checks that all requirements for tracking are satisfied and uses
// rectangle from the previous frame in this case, otherwise - uses current
// frame re-crop rectangle.
//
// There are several types of tracking requirements that can be configured via
// options:
//   IoU: Verifies that IoU of the previous frame rectangle and current frame
//     re-crop rectangle is less than a given threshold.
//   Rect parameters: Verifies that rotation/translation/scale of the re-crop
//     rectangle on the current frame is close to the rectangle from the
//     previous frame within given thresholds.
//   Landmarks: Verifies that landmarks from the previous frame are within
//     the re-crop rectangle on the current frame.
//
// Inputs:
//   PREV_LANDMARKS: Object landmarks from the previous frame.
//   PREV_LANDMARKS_RECT: Object rectangle based on the landmarks from the
//     previous frame.
//   RECROP_RECT: Object re-crop rectangle from the current frame.
//   IMAGE_SIZE: Image size to transform normalized coordinates to absolute.
//
// Outputs:
//   TRACKING_RECT: Rectangle to use for object prediction on the current frame.
//     It will be either object rectangle from the previous frame (if all
//     tracking requirements are satisfied) or re-crop rectangle from the
//     current frame (if tracking lost the object).
//
// Example config:
//   node {
//     calculator: "RoiTrackingCalculator"
//     input_stream: "PREV_LANDMARKS:prev_hand_landmarks"
//     input_stream: "PREV_LANDMARKS_RECT:prev_hand_landmarks_rect"
//     input_stream: "RECROP_RECT:hand_recrop_rect"
//     input_stream: "IMAGE_SIZE:image_size"
//     output_stream: "TRACKING_RECT:hand_tracking_rect"
//     options: {
//       [mediapipe.RoiTrackingCalculatorOptions.ext] {
//         rect_requirements: {
//           rotation_degrees: 40.0
//           translation: 0.2
//           scale: 0.4
//         }
//         landmarks_requirements: {
//           recrop_rect_margin: -0.1
//         }
//       }
//     }
//   }
class RoiTrackingCalculator : public CalculatorBase {
 public:
  static absl::Status GetContract(CalculatorContract* cc);
  absl::Status Open(CalculatorContext* cc) override;
  absl::Status Process(CalculatorContext* cc) override;

 private:
  RoiTrackingCalculatorOptions options_;
};
REGISTER_CALCULATOR(RoiTrackingCalculator);

absl::Status RoiTrackingCalculator::GetContract(CalculatorContract* cc) {
  cc->Inputs().Tag(kPrevLandmarksTag).Set<NormalizedLandmarkList>();
  cc->Inputs().Tag(kPrevLandmarksRectTag).Set<NormalizedRect>();
  cc->Inputs().Tag(kRecropRectTag).Set<NormalizedRect>();
  cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
  cc->Outputs().Tag(kTrackingRectTag).Set<NormalizedRect>();

  return absl::OkStatus();
}

absl::Status RoiTrackingCalculator::Open(CalculatorContext* cc) {
  cc->SetOffset(TimestampDiff(0));
  options_ = cc->Options<RoiTrackingCalculatorOptions>();
  return absl::OkStatus();
}

absl::Status RoiTrackingCalculator::Process(CalculatorContext* cc) {
  // If there is no current frame re-crop rect (i.e. object is not present on
  // the current frame) - return empty packet.
  if (cc->Inputs().Tag(kRecropRectTag).IsEmpty()) {
    return absl::OkStatus();
  }

  // If there is no previous rect, but there is current re-crop rect - return
  // current re-crop rect as is.
  if (cc->Inputs().Tag(kPrevLandmarksRectTag).IsEmpty()) {
    cc->Outputs()
        .Tag(kTrackingRectTag)
        .AddPacket(cc->Inputs().Tag(kRecropRectTag).Value());
    return absl::OkStatus();
  }

  // At this point we have both previous rect (which also means we have previous
  // landmarks) and currrent re-crop rect.
  const auto& prev_landmarks =
      cc->Inputs().Tag(kPrevLandmarksTag).Get<NormalizedLandmarkList>();
  const auto& prev_rect =
      cc->Inputs().Tag(kPrevLandmarksRectTag).Get<NormalizedRect>();
  const auto& recrop_rect =
      cc->Inputs().Tag(kRecropRectTag).Get<NormalizedRect>();
  const auto& image_size =
      cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();

  // Keep tracking unless one of the requirements below is not satisfied.
  bool keep_tracking = true;

  // If IoU of the previous rect and current re-crop rect is lower than allowed
  // threshold - use current re-crop rect.
  if (options_.has_iou_requirements() &&
      !IouRequirementsSatisfied(prev_rect, recrop_rect, image_size,
                                options_.iou_requirements().min_iou())) {
    keep_tracking = false;
  }

  // If previous rect and current re-crop rect differ more than it is allowed by
  // the augmentations (used during the model training) - use current re-crop
  // rect.
  if (options_.has_rect_requirements() &&
      !RectRequirementsSatisfied(
          prev_rect, recrop_rect, image_size,
          options_.rect_requirements().rotation_degrees(),
          options_.rect_requirements().translation(),
          options_.rect_requirements().scale())) {
    keep_tracking = false;
  }

  // If landmarks from the previous frame are not in the current re-crop rect
  // (i.e. object moved too fast and using previous frame rect won't cover
  // landmarks on the current frame) - use current re-crop rect.
  if (options_.has_landmarks_requirements() &&
      !LandmarksRequirementsSatisfied(
          prev_landmarks, recrop_rect, image_size,
          options_.landmarks_requirements().recrop_rect_margin())) {
    keep_tracking = false;
  }

  // If object didn't move a lot comparing to the previous frame - we'll keep
  // tracking it and will return rect from the previous frame, otherwise -
  // return re-crop rect from the current frame.
  if (keep_tracking) {
    cc->Outputs()
        .Tag(kTrackingRectTag)
        .AddPacket(cc->Inputs().Tag(kPrevLandmarksRectTag).Value());
  } else {
    cc->Outputs()
        .Tag(kTrackingRectTag)
        .AddPacket(cc->Inputs().Tag(kRecropRectTag).Value());
    VLOG(1) << "Lost tracking: check messages above for details";
  }

  return absl::OkStatus();
}

}  // namespace mediapipe