// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/formats/motion/optical_flow_field.h"
#include <math.h>
#include <cmath>
#include <cstdint>
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/mathutil.h"
#include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/formats/location_opencv.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/point2.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/type_map.h"
namespace {
const float kHugeToIgnore = 1e9;
// File tags defined in Middlebury specifications to check little endian floats
const char kFloFileHeaderOnWrite[] = "PIEH";
const float kFloFileHeaderOnRead = 202021.25;
void CartesianToPolarCoordinates(const cv::Mat& cartesian, cv::Mat* magnitudes,
cv::Mat* angles) {
ABSL_CHECK(magnitudes != nullptr);
ABSL_CHECK(angles != nullptr);
cv::Mat cartesian_components[2];
cv::split(cartesian, cartesian_components);
cv::cartToPolar(cartesian_components[0], cartesian_components[1], *magnitudes,
*angles, true);
}
float MaxAbsoluteValueIgnoringHuge(const cv::Mat_<float>& values, float huge) {
float max_val = 0.0f;
for (int r = 0; r < values.rows; ++r) {
for (int c = 0; c < values.cols; ++c) {
const float abs_val = std::abs(values(r, c));
if (abs_val < huge && max_val < abs_val) {
max_val = abs_val;
}
}
}
return max_val;
}
cv::Mat MakeVisualizationHsv(const cv::Mat_<float>& angles,
const cv::Mat_<float>& magnitudes, float max_mag) {
cv::Mat hsv(angles.size(), CV_8UC3);
for (int r = 0; r < hsv.rows; ++r) {
for (int c = 0; c < hsv.cols; ++c) {
const uint8_t hue = static_cast<uint8_t>(255.0f * angles(r, c) / 360.0f);
uint8_t saturation = 255;
if (magnitudes(r, c) < max_mag) {
saturation = static_cast<uint8_t>(255.0f * magnitudes(r, c) / max_mag);
}
const uint8_t value = 255;
hsv.at<cv::Vec3b>(r, c) = cv::Vec3b(hue, saturation, value);
}
}
return hsv;
}
} // namespace
namespace mediapipe {
OpticalFlowField::OpticalFlowField(const cv::Mat_<cv::Point2f>& flow) {
flow.copyTo(flow_data_);
}
float OpticalFlowField::GetRobustMaximumMagnitude() const {
cv::Mat angles;
cv::Mat magnitudes;
CartesianToPolarCoordinates(flow_data_, &magnitudes, &angles);
return MaxAbsoluteValueIgnoringHuge(magnitudes, kHugeToIgnore);
}
cv::Mat OpticalFlowField::GetVisualizationInternal(
float max_magnitude, bool enforce_max_magnitude) const {
cv::Mat angles;
cv::Mat magnitudes;
CartesianToPolarCoordinates(flow_data_, &magnitudes, &angles);
if (!enforce_max_magnitude) {
// Guard against dividing by zero for the case of an all-zero flow field.
max_magnitude =
std::max(std::numeric_limits<float>::epsilon(),
MaxAbsoluteValueIgnoringHuge(magnitudes, kHugeToIgnore));
}
ABSL_CHECK_LT(0, max_magnitude);
cv::Mat hsv = MakeVisualizationHsv(angles, magnitudes, max_magnitude);
cv::Mat viz;
cv::cvtColor(hsv, viz, 71 /*cv::COLOR_HSV2RGB_FULL*/);
return viz;
}
cv::Mat OpticalFlowField::GetVisualization() const {
// Dummy value of 1.0 for max_magnitude will be replaced.
return GetVisualizationInternal(1.0f, false);
}
cv::Mat OpticalFlowField::GetVisualizationSaturatedAt(
float max_magnitude) const {
ABSL_CHECK_LT(0, max_magnitude)
<< "Specified saturation magnitude must be positive.";
return GetVisualizationInternal(max_magnitude, true);
}
void OpticalFlowField::Allocate(int width, int height) {
flow_data_.create(height, width);
}
void OpticalFlowField::Resize(int new_width, int new_height) {
if (new_width == flow_data_.cols && new_height == flow_data_.rows) {
return;
}
cv::Mat source_for_resize = flow_data_;
float width_scale = new_width / static_cast<float>(source_for_resize.cols);
float height_scale = new_height / static_cast<float>(source_for_resize.rows);
cv::resize(source_for_resize, flow_data_, cv::Size(new_width, new_height), 0,
0, cv::INTER_LINEAR);
for (int r = 0; r < new_height; ++r) {
for (int c = 0; c < new_width; ++c) {
cv::Point2f flow_vector = flow_data_.at<cv::Point2f>(r, c);
flow_data_.at<cv::Point2f>(r, c) = cv::Point2f(
flow_vector.x * width_scale, flow_vector.y * height_scale);
}
}
}
void OpticalFlowField::CopyFromTensor(const tensorflow::Tensor& tensor) {
ABSL_CHECK_EQ(tensorflow::DT_FLOAT, tensor.dtype());
ABSL_CHECK_EQ(3, tensor.dims()) << "Tensor must be height x width x 2.";
ABSL_CHECK_EQ(2, tensor.dim_size(2)) << "Tensor must be height x width x 2.";
const int height = tensor.dim_size(0);
const int width = tensor.dim_size(1);
Allocate(width, height);
typename tensorflow::TTypes<float, 3>::ConstTensor input_flow =
tensor.shaped<float, 3>({height, width, 2});
for (int r = 0; r < height; ++r) {
for (int c = 0; c < width; ++c) {
flow_data_(r, c) = cv::Point2f(input_flow(r, c, 0), input_flow(r, c, 1));
}
}
}
void OpticalFlowField::SetFromProto(const OpticalFlowFieldData& proto) {
ABSL_CHECK_EQ(proto.width() * proto.height(), proto.dx_size());
ABSL_CHECK_EQ(proto.width() * proto.height(), proto.dy_size());
flow_data_.create(proto.height(), proto.width());
int i = 0;
for (int r = 0; r < flow_data_.rows; ++r) {
for (int c = 0; c < flow_data_.cols; ++c, ++i) {
flow_data_(r, c) = cv::Point2f(proto.dx(i), // x component
proto.dy(i)); // y component
}
}
}
void OpticalFlowField::ConvertToProto(OpticalFlowFieldData* proto) const {
proto->set_width(width());
proto->set_height(height());
proto->clear_dx();
proto->clear_dy();
for (int r = 0; r < flow_data_.rows; ++r) {
for (int c = 0; c < flow_data_.cols; ++c) {
proto->add_dx(flow_data_(r, c).x);
proto->add_dy(flow_data_(r, c).y);
}
}
}
bool OpticalFlowField::FollowFlow(float x, float y, float* new_x,
float* new_y) const {
ABSL_CHECK(new_x);
ABSL_CHECK(new_y);
if (x < 0 || x > flow_data_.cols - 1 || // horizontal bounds
y < 0 || y > flow_data_.rows - 1) { // vertical bounds
return false;
}
const cv::Point2f flow_vector = InterpolatedFlowAt(x, y);
*new_x = x + flow_vector.x;
*new_y = y + flow_vector.y;
return true;
}
cv::Point2f OpticalFlowField::InterpolatedFlowAt(float x, float y) const {
// Sanity bounds checks.
ABSL_CHECK_GE(x, 0);
ABSL_CHECK_GE(y, 0);
ABSL_CHECK_LE(x, flow_data_.cols - 1);
ABSL_CHECK_LE(y, flow_data_.rows - 1);
const int x0 = static_cast<int>(std::floor(x));
const int y0 = static_cast<int>(std::floor(y));
// Make sure we don't try to access out of bounds pixels in the case where no
// interpolation is needed (e.g., because x == width - 1).
int x1 = x0 < flow_data_.cols - 1 ? x0 + 1 : x0;
int y1 = y0 < flow_data_.rows - 1 ? y0 + 1 : y0;
const cv::Point2f flow_top_left = flow_data_(y0, x0);
const cv::Point2f flow_top_right = flow_data_(y0, x1);
const cv::Point2f flow_bottom_left = flow_data_(y1, x0);
const cv::Point2f flow_bottom_right = flow_data_(y1, x1);
// Linearly interpolate horizontally first.
const cv::Point2f flow_top =
flow_top_left + (x - x0) * (flow_top_right - flow_top_left);
const cv::Point2f flow_bottom =
flow_bottom_left + (x - x0) * (flow_bottom_right - flow_bottom_left);
// Linear interpolation vertically.
return flow_top + (y - y0) * (flow_bottom - flow_top);
}
cv::Mat OpticalFlowField::ConvertToCorrespondences() const {
// Initialize with (dx, dy).
cv::Mat_<cv::Point2f> correspondences = flow_data_.clone();
// Add (x, y) to each location.
for (int y = 0; y < correspondences.rows; ++y) {
for (int x = 0; x < correspondences.cols; ++x) {
correspondences(y, x) += cv::Point2f(x, y);
}
}
return correspondences;
}
bool OpticalFlowField::AllWithinMargin(const OpticalFlowField& other,
float margin) const {
if (other.width() != width() || other.height() != height()) {
return false;
}
for (int r = 0; r < flow_data_.rows; ++r) {
for (int c = 0; c < flow_data_.cols; ++c) {
const cv::Point2f& this_motion = flow_data_.at<cv::Point2f>(r, c);
const cv::Point2f& other_motion = other.flow_data().at<cv::Point2f>(r, c);
if (!MathUtil::WithinMargin(this_motion.x, other_motion.x, margin) ||
!MathUtil::WithinMargin(this_motion.y, other_motion.y, margin)) {
ABSL_LOG(INFO) << "First failure at" << r << " " << c;
return false;
}
}
}
return true;
}
void OpticalFlowField::EstimateMotionConsistencyOcclusions(
const OpticalFlowField& forward, const OpticalFlowField& backward,
double spatial_distance_threshold, Location* occluded_mask,
Location* disoccluded_mask) {
ABSL_CHECK_EQ(forward.width(), backward.width())
<< "Flow fields have different widths.";
ABSL_CHECK_EQ(forward.height(), backward.height())
<< "Flow fields have different heights.";
if (occluded_mask != nullptr) {
*occluded_mask = FindMotionInconsistentPixels(forward, backward,
spatial_distance_threshold);
}
if (disoccluded_mask != nullptr) {
*disoccluded_mask = FindMotionInconsistentPixels(
backward, forward, spatial_distance_threshold);
}
}
Location OpticalFlowField::FindMotionInconsistentPixels(
const OpticalFlowField& forward, const OpticalFlowField& backward,
double spatial_distance_threshold) {
const uint8_t kOccludedPixelValue = 1;
const double threshold_sq =
spatial_distance_threshold * spatial_distance_threshold;
cv::Mat occluded = cv::Mat::zeros(forward.height(), forward.width(), CV_8UC1);
for (int x = 0; x < forward.width(); ++x) {
for (int y = 0; y < forward.height(); ++y) {
// Location of the point in the next frame.
float new_x;
float new_y;
// Location of the point in this frame after a round-trip to the next
// frame and back.
float round_trip_x;
float round_trip_y;
forward.FollowFlow(x, y, &new_x, &new_y);
bool in_bounds_in_next_frame =
backward.FollowFlow(new_x, new_y, &round_trip_x, &round_trip_y);
if (!in_bounds_in_next_frame ||
Point2_f(x - round_trip_x, y - round_trip_y).ToVector().Norm2() >
threshold_sq) {
occluded.at<uint8_t>(y, x) = kOccludedPixelValue;
}
}
}
return CreateCvMaskLocation<uint8_t>(occluded);
}
} // namespace mediapipe