chromium/third_party/mediapipe/src/mediapipe/util/tracking/tone_estimation.cc

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/util/tracking/tone_estimation.h"

#include <math.h>

#include <algorithm>
#include <cmath>
#include <numeric>
#include <vector>

#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "mediapipe/util/tracking/motion_models.pb.h"
#include "mediapipe/util/tracking/tone_models.pb.h"

namespace mediapipe {

ToneEstimation::ToneEstimation(const ToneEstimationOptions& options,
                               int frame_width, int frame_height)
    : options_(options),
      frame_width_(frame_width),
      frame_height_(frame_height),
      original_width_(frame_width),
      original_height_(frame_height) {
  switch (options_.downsample_mode()) {
    case ToneEstimationOptions::DOWNSAMPLE_NONE:
      break;
    case ToneEstimationOptions::DOWNSAMPLE_TO_MAX_SIZE: {
      const float max_size = std::max(frame_width_, frame_height_);
      if (max_size > 1.03f * options_.downsampling_size()) {
        downsample_scale_ = max_size / options_.downsampling_size();
        frame_height_ /= downsample_scale_;
        frame_width_ /= downsample_scale_;
        use_downsampling_ = true;
      }
      break;
    }
    case ToneEstimationOptions::DOWNSAMPLE_TO_MIN_SIZE: {
      const float min_size = std::min(frame_width_, frame_height_);
      if (min_size > 1.03f * options_.downsampling_size()) {
        downsample_scale_ = min_size / options_.downsampling_size();
        frame_height_ /= downsample_scale_;
        frame_width_ /= downsample_scale_;
        use_downsampling_ = true;
      }

      break;
    }
    case ToneEstimationOptions::DOWNSAMPLE_BY_FACTOR: {
      ABSL_CHECK_GE(options_.downsample_factor(), 1);
      frame_width_ /= options_.downsample_factor();
      frame_height_ /= options_.downsample_factor();
      downsample_scale_ = options_.downsample_factor();
      use_downsampling_ = true;
      break;
    }
  }

  if (use_downsampling_) {
    resized_input_.reset(new cv::Mat(frame_height_, frame_width_, CV_8UC3));
    prev_resized_input_.reset(
        new cv::Mat(frame_height_, frame_width_, CV_8UC3));
  }
}

ToneEstimation::~ToneEstimation() {}

void ToneEstimation::EstimateToneChange(
    const RegionFlowFeatureList& feature_list_input,
    const cv::Mat& curr_frame_input, const cv::Mat* prev_frame_input,
    ToneChange* tone_change, cv::Mat* debug_output) {
  ABSL_CHECK_EQ(original_height_, curr_frame_input.rows);
  ABSL_CHECK_EQ(original_width_, curr_frame_input.cols);
  ABSL_CHECK(tone_change != nullptr);

  const cv::Mat& curr_frame =
      use_downsampling_ ? *resized_input_ : curr_frame_input;
  const cv::Mat* prev_frame = (use_downsampling_ && prev_frame_input)
                                  ? prev_resized_input_.get()
                                  : prev_frame_input;

  RegionFlowFeatureList scaled_feature_list;
  const RegionFlowFeatureList& feature_list =
      use_downsampling_ ? scaled_feature_list : feature_list_input;

  if (use_downsampling_) {
    cv::resize(curr_frame_input, *resized_input_, resized_input_->size());
    if (prev_frame_input) {
      cv::resize(*prev_frame_input, *prev_resized_input_,
                 prev_resized_input_->size());
    }
    LinearSimilarityModel scale_transform;
    scale_transform.set_a(1.0f / downsample_scale_);
    scaled_feature_list = feature_list_input;
    TransformRegionFlowFeatureList(scale_transform, &scaled_feature_list);
  }

  ABSL_CHECK_EQ(frame_height_, curr_frame.rows);
  ABSL_CHECK_EQ(frame_width_, curr_frame.cols);

  ClipMask<3> curr_clip;
  ComputeClipMask<3>(options_.clip_mask_options(), curr_frame, &curr_clip);

  // Compute tone statistics.
  tone_change->set_frac_clipped(cv::sum(curr_clip.mask)[0] /
                                (frame_height_ * frame_width_));

  IntensityPercentiles(curr_frame, curr_clip.mask,
                       options_.tone_match_options().log_domain(), tone_change);

  ColorToneMatches color_tone_matches;
  // TODO: Buffer clip mask.
  if (prev_frame) {
    ClipMask<3> prev_clip;
    ComputeClipMask<3>(options_.clip_mask_options(), *prev_frame, &prev_clip);
    ComputeToneMatches<3>(options_.tone_match_options(), feature_list,
                          curr_frame, *prev_frame, curr_clip, prev_clip,
                          &color_tone_matches, debug_output);

    EstimateGainBiasModel(options_.irls_iterations(), &color_tone_matches,
                          tone_change->mutable_gain_bias());

    if (!IsStableGainBiasModel(options_.stable_gain_bias_bounds(),
                               tone_change->gain_bias(), color_tone_matches,
                               tone_change->mutable_stability_stats())) {
      VLOG(1) << "Warning: Estimated gain-bias is unstable.";
      // Reset to identity.
      tone_change->mutable_gain_bias()->CopyFrom(GainBiasModel());
      tone_change->set_type(ToneChange::INVALID);
    }

    // TODO: EstimateMixtureGainBiasModel();
  }
}

void ToneEstimation::IntensityPercentiles(const cv::Mat& frame,
                                          const cv::Mat& clip_mask,
                                          bool log_domain,
                                          ToneChange* tone_change) const {
  cv::Mat intensity(frame.rows, frame.cols, CV_8UC1);
  cv::cvtColor(frame, intensity, cv::COLOR_RGB2GRAY);

  std::vector<float> histogram(256, 0.0f);

  for (int i = 0; i < intensity.rows; ++i) {
    const uint8_t* intensity_ptr = intensity.ptr<uint8_t>(i);
    const uint8_t* clip_ptr = clip_mask.ptr<uint8_t>(i);

    for (int j = 0; j < intensity.cols; ++j) {
      if (!clip_ptr[j]) {
        ++histogram[intensity_ptr[j]];
      }
    }
  }

  // Construct cumulative histogram.
  std::partial_sum(histogram.begin(), histogram.end(), histogram.begin());

  // Normalize histogram.
  const float histogram_sum = histogram.back();
  if (histogram_sum == 0) {
    // Frame is of solid color. Use default values.
    return;
  }

  const float denom = 1.0f / histogram_sum;
  for (auto& entry : histogram) {
    entry *= denom;
  }

  std::vector<float> percentiles;
  percentiles.push_back(options_.stats_low_percentile());
  percentiles.push_back(options_.stats_low_mid_percentile());
  percentiles.push_back(options_.stats_mid_percentile());
  percentiles.push_back(options_.stats_high_mid_percentile());
  percentiles.push_back(options_.stats_high_percentile());

  std::vector<float> percentile_values(percentiles.size());

  const float log_denom = 1.0f / LogDomainLUT().MaxLogDomainValue();
  for (int k = 0; k < percentile_values.size(); ++k) {
    const int percentile_bin =
        std::lower_bound(histogram.begin(), histogram.end(), percentiles[k]) -
        histogram.begin();
    percentile_values[k] = percentile_bin;
    if (log_domain) {
      percentile_values[k] =
          LogDomainLUT().Map(percentile_values[k]) * log_denom;
    } else {
      percentile_values[k] *= (1.0f / 255.0f);
    }
  }

  tone_change->set_low_percentile(percentile_values[0]);
  tone_change->set_low_mid_percentile(percentile_values[1]);
  tone_change->set_mid_percentile(percentile_values[2]);
  tone_change->set_high_mid_percentile(percentile_values[3]);
  tone_change->set_high_percentile(percentile_values[4]);
}

void ToneEstimation::EstimateGainBiasModel(int irls_iterations,
                                           ColorToneMatches* color_tone_matches,
                                           GainBiasModel* gain_bias_model) {
  ABSL_CHECK(color_tone_matches != nullptr);
  ABSL_CHECK(gain_bias_model != nullptr);

  // Effectively estimate each model independently.
  float solution_ptr[6] = {1.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f};

  const int num_channels = color_tone_matches->size();
  ABSL_CHECK_GT(num_channels, 0);
  ABSL_CHECK_LE(num_channels, 3);

  // TODO: One IRLS weight per color match.
  for (int c = 0; c < num_channels; ++c) {
    std::deque<PatchToneMatch>& patch_tone_matches = (*color_tone_matches)[c];
    // Reset irls weight.
    int num_matches = 0;
    for (auto& patch_tone_match : patch_tone_matches) {
      patch_tone_match.set_irls_weight(1.0);
      num_matches += patch_tone_match.tone_match_size();
    }

    // Do not attempt solution if not matches have been found.
    if (num_matches < 3) {
      continue;
    }

    cv::Mat model_mat(num_matches, 2, CV_32F);
    cv::Mat rhs(num_matches, 1, CV_32F);
    cv::Mat solution(2, 1, CV_32F);

    for (int iteration = 0; iteration < irls_iterations; ++iteration) {
      // Setup matrix.
      int row = 0;
      for (const auto& patch_tone_match : patch_tone_matches) {
        const float irls_weight = patch_tone_match.irls_weight();
        for (const auto& tone_match : patch_tone_match.tone_match()) {
          float* row_ptr = model_mat.ptr<float>(row);
          float* rhs_ptr = rhs.ptr<float>(row);
          row_ptr[0] = tone_match.curr_val() * irls_weight;
          row_ptr[1] = irls_weight;
          rhs_ptr[0] = tone_match.prev_val() * irls_weight;
          ++row;
        }
      }

      // Solve.
      if (!cv::solve(model_mat, rhs, solution, cv::DECOMP_QR)) {
        // Fallback to identity.
        solution_ptr[2 * c] = 1;
        solution_ptr[2 * c + 1] = 0;
        break;  // Break to next color channel.
      }

      float a = solution.at<float>(0, 0);
      float b = solution.at<float>(1, 0);

      // Copy to solution.
      solution_ptr[2 * c] = a;
      solution_ptr[2 * c + 1] = b;

      // Evaluate error.
      for (auto& patch_tone_match : patch_tone_matches) {
        const int num_tone_matches = patch_tone_match.tone_match_size();

        if (num_tone_matches == 0) {
          continue;
        }

        float summed_error = 0.0f;
        for (const auto& tone_match : patch_tone_match.tone_match()) {
          // Express tone registration error in 0 .. 100.
          const float error =
              100.0f * (tone_match.curr_val() * a + b - tone_match.prev_val());
          summed_error += error * error;
        }

        // Compute RMSE.
        const float patch_error =
            std::sqrt(static_cast<double>(summed_error / num_tone_matches));
        // TODO: L1 instead of L0?
        patch_tone_match.set_irls_weight(1.0f / (patch_error + 1e-6f));
      }
    }
  }

  gain_bias_model->CopyFrom(
      GainBiasModelAdapter::FromPointer<float>(solution_ptr, false));

  // Test invertability, reset if failed.
  const float det = gain_bias_model->gain_c1() * gain_bias_model->gain_c2() *
                    gain_bias_model->gain_c3();
  if (fabs(det) < 1e-6f) {
    ABSL_LOG(WARNING) << "Estimated gain bias model is not invertible. "
                      << "Falling back to identity model.";
    gain_bias_model->CopyFrom(GainBiasModel());
  }
}

bool ToneEstimation::IsStableGainBiasModel(
    const ToneEstimationOptions::GainBiasBounds& bounds,
    const GainBiasModel& model, const ColorToneMatches& color_tone_matches,
    ToneChange::StabilityStats* stats) {
  if (stats != nullptr) {
    stats->Clear();
  }

  // Test each channel for stability.
  if (model.gain_c1() < bounds.lower_gain() ||
      model.gain_c1() > bounds.upper_gain() ||
      model.bias_c1() < bounds.lower_bias() ||
      model.bias_c1() > bounds.upper_bias()) {
    return false;
  }

  if (model.gain_c2() < bounds.lower_gain() ||
      model.gain_c2() > bounds.upper_gain() ||
      model.bias_c2() < bounds.lower_bias() ||
      model.bias_c2() > bounds.upper_bias()) {
    return false;
  }

  if (model.gain_c3() < bounds.lower_gain() ||
      model.gain_c3() > bounds.upper_gain() ||
      model.bias_c3() < bounds.lower_bias() ||
      model.bias_c3() > bounds.upper_bias()) {
    return false;
  }

  // Test each channel independently.
  int total_inliers = 0;
  int total_tone_matches = 0;
  double total_inlier_weight = 0.0;
  for (const auto& patch_tone_matches : color_tone_matches) {
    int num_inliers = 0;
    for (const auto& patch_tone_match : patch_tone_matches) {
      if (patch_tone_match.irls_weight() > bounds.min_inlier_weight()) {
        ++num_inliers;
        // Clamp the weight to a registration error of 1 intensity value
        // difference (out of 255). Since weight are inversely proportional to
        // registration errors in the range 0..100, this corresponds to a max
        // weight of 2.55.
        total_inlier_weight += std::min(2.55f, patch_tone_match.irls_weight());
      }
    }

    if (num_inliers <
        bounds.min_inlier_fraction() * patch_tone_matches.size()) {
      return false;
    }

    total_inliers += num_inliers;
    total_tone_matches += patch_tone_matches.size();
  }

  if (stats != nullptr && total_tone_matches > 0) {
    stats->set_num_inliers(total_inliers);
    stats->set_inlier_fraction(total_inliers * 1.0f / total_tone_matches);
    stats->set_inlier_weight(total_inlier_weight);
  }

  return true;
}

}  // namespace mediapipe