chromium/third_party/webrtc/modules/audio_processing/aec3/matched_filter.cc

/*
 *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */
#include "modules/audio_processing/aec3/matched_filter.h"

// Defines WEBRTC_ARCH_X86_FAMILY, used below.
#include "rtc_base/system/arch.h"

#if defined(WEBRTC_HAS_NEON)
#include <arm_neon.h>
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
#include <emmintrin.h>
#endif
#include <algorithm>
#include <cstddef>
#include <initializer_list>
#include <iterator>
#include <numeric>

#include "absl/types/optional.h"
#include "api/array_view.h"
#include "modules/audio_processing/aec3/downsampled_render_buffer.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/experiments/field_trial_parser.h"
#include "rtc_base/logging.h"
#include "system_wrappers/include/field_trial.h"

namespace {

// Subsample rate used for computing the accumulated error.
// The implementation of some core functions depends on this constant being
// equal to 4.
constexpr int kAccumulatedErrorSubSampleRate =;

void UpdateAccumulatedError(
    const rtc::ArrayView<const float> instantaneous_accumulated_error,
    const rtc::ArrayView<float> accumulated_error,
    float one_over_error_sum_anchor) {}

size_t ComputePreEchoLag(
    const rtc::ArrayView<const float> accumulated_error,
    size_t lag,
    size_t alignment_shift_winner) {}

}  // namespace

namespace webrtc {
namespace aec3 {

#if defined(WEBRTC_HAS_NEON)

inline float SumAllElements(float32x4_t elements) {
  float32x2_t sum = vpadd_f32(vget_low_f32(elements), vget_high_f32(elements));
  sum = vpadd_f32(sum, sum);
  return vget_lane_f32(sum, 0);
}

void MatchedFilterCoreWithAccumulatedError_NEON(
    size_t x_start_index,
    float x2_sum_threshold,
    float smoothing,
    rtc::ArrayView<const float> x,
    rtc::ArrayView<const float> y,
    rtc::ArrayView<float> h,
    bool* filters_updated,
    float* error_sum,
    rtc::ArrayView<float> accumulated_error,
    rtc::ArrayView<float> scratch_memory) {
  const int h_size = static_cast<int>(h.size());
  const int x_size = static_cast<int>(x.size());
  RTC_DCHECK_EQ(0, h_size % 4);
  std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
  // Process for all samples in the sub-block.
  for (size_t i = 0; i < y.size(); ++i) {
    // Apply the matched filter as filter * x, and compute x * x.
    RTC_DCHECK_GT(x_size, x_start_index);
    // Compute loop chunk sizes until, and after, the wraparound of the circular
    // buffer for x.
    const int chunk1 =
        std::min(h_size, static_cast<int>(x_size - x_start_index));
    if (chunk1 != h_size) {
      const int chunk2 = h_size - chunk1;
      std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
      std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
    }
    const float* x_p =
        chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
    const float* h_p = &h[0];
    float* accumulated_error_p = &accumulated_error[0];
    // Initialize values for the accumulation.
    float32x4_t x2_sum_128 = vdupq_n_f32(0);
    float x2_sum = 0.f;
    float s = 0;
    // Perform 128 bit vector operations.
    const int limit_by_4 = h_size >> 2;
    for (int k = limit_by_4; k > 0;
         --k, h_p += 4, x_p += 4, accumulated_error_p++) {
      // Load the data into 128 bit vectors.
      const float32x4_t x_k = vld1q_f32(x_p);
      const float32x4_t h_k = vld1q_f32(h_p);
      // Compute and accumulate x * x.
      x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
      // Compute x * h
      float32x4_t hk_xk_128 = vmulq_f32(h_k, x_k);
      s += SumAllElements(hk_xk_128);
      const float e = s - y[i];
      accumulated_error_p[0] += e * e;
    }
    // Combine the accumulated vector and scalar values.
    x2_sum += SumAllElements(x2_sum_128);
    // Compute the matched filter error.
    float e = y[i] - s;
    const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
    (*error_sum) += e * e;
    // Update the matched filter estimate in an NLMS manner.
    if (x2_sum > x2_sum_threshold && !saturation) {
      RTC_DCHECK_LT(0.f, x2_sum);
      const float alpha = smoothing * e / x2_sum;
      const float32x4_t alpha_128 = vmovq_n_f32(alpha);
      // filter = filter + smoothing * (y - filter * x) * x / x * x.
      float* h_p = &h[0];
      x_p = chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
      // Perform 128 bit vector operations.
      const int limit_by_4 = h_size >> 2;
      for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
        // Load the data into 128 bit vectors.
        float32x4_t h_k = vld1q_f32(h_p);
        const float32x4_t x_k = vld1q_f32(x_p);
        // Compute h = h + alpha * x.
        h_k = vmlaq_f32(h_k, alpha_128, x_k);
        // Store the result.
        vst1q_f32(h_p, h_k);
      }
      *filters_updated = true;
    }
    x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
  }
}

void MatchedFilterCore_NEON(size_t x_start_index,
                            float x2_sum_threshold,
                            float smoothing,
                            rtc::ArrayView<const float> x,
                            rtc::ArrayView<const float> y,
                            rtc::ArrayView<float> h,
                            bool* filters_updated,
                            float* error_sum,
                            bool compute_accumulated_error,
                            rtc::ArrayView<float> accumulated_error,
                            rtc::ArrayView<float> scratch_memory) {
  const int h_size = static_cast<int>(h.size());
  const int x_size = static_cast<int>(x.size());
  RTC_DCHECK_EQ(0, h_size % 4);

  if (compute_accumulated_error) {
    return MatchedFilterCoreWithAccumulatedError_NEON(
        x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
        error_sum, accumulated_error, scratch_memory);
  }

  // Process for all samples in the sub-block.
  for (size_t i = 0; i < y.size(); ++i) {
    // Apply the matched filter as filter * x, and compute x * x.

    RTC_DCHECK_GT(x_size, x_start_index);
    const float* x_p = &x[x_start_index];
    const float* h_p = &h[0];

    // Initialize values for the accumulation.
    float32x4_t s_128 = vdupq_n_f32(0);
    float32x4_t x2_sum_128 = vdupq_n_f32(0);
    float x2_sum = 0.f;
    float s = 0;

    // Compute loop chunk sizes until, and after, the wraparound of the circular
    // buffer for x.
    const int chunk1 =
        std::min(h_size, static_cast<int>(x_size - x_start_index));

    // Perform the loop in two chunks.
    const int chunk2 = h_size - chunk1;
    for (int limit : {chunk1, chunk2}) {
      // Perform 128 bit vector operations.
      const int limit_by_4 = limit >> 2;
      for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
        // Load the data into 128 bit vectors.
        const float32x4_t x_k = vld1q_f32(x_p);
        const float32x4_t h_k = vld1q_f32(h_p);
        // Compute and accumulate x * x and h * x.
        x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
        s_128 = vmlaq_f32(s_128, h_k, x_k);
      }

      // Perform non-vector operations for any remaining items.
      for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
        const float x_k = *x_p;
        x2_sum += x_k * x_k;
        s += *h_p * x_k;
      }

      x_p = &x[0];
    }

    // Combine the accumulated vector and scalar values.
    s += SumAllElements(s_128);
    x2_sum += SumAllElements(x2_sum_128);

    // Compute the matched filter error.
    float e = y[i] - s;
    const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
    (*error_sum) += e * e;

    // Update the matched filter estimate in an NLMS manner.
    if (x2_sum > x2_sum_threshold && !saturation) {
      RTC_DCHECK_LT(0.f, x2_sum);
      const float alpha = smoothing * e / x2_sum;
      const float32x4_t alpha_128 = vmovq_n_f32(alpha);

      // filter = filter + smoothing * (y - filter * x) * x / x * x.
      float* h_p = &h[0];
      x_p = &x[x_start_index];

      // Perform the loop in two chunks.
      for (int limit : {chunk1, chunk2}) {
        // Perform 128 bit vector operations.
        const int limit_by_4 = limit >> 2;
        for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
          // Load the data into 128 bit vectors.
          float32x4_t h_k = vld1q_f32(h_p);
          const float32x4_t x_k = vld1q_f32(x_p);
          // Compute h = h + alpha * x.
          h_k = vmlaq_f32(h_k, alpha_128, x_k);

          // Store the result.
          vst1q_f32(h_p, h_k);
        }

        // Perform non-vector operations for any remaining items.
        for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
          *h_p += alpha * *x_p;
        }

        x_p = &x[0];
      }

      *filters_updated = true;
    }

    x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
  }
}

#endif

#if defined(WEBRTC_ARCH_X86_FAMILY)

void MatchedFilterCore_AccumulatedError_SSE2(
    size_t x_start_index,
    float x2_sum_threshold,
    float smoothing,
    rtc::ArrayView<const float> x,
    rtc::ArrayView<const float> y,
    rtc::ArrayView<float> h,
    bool* filters_updated,
    float* error_sum,
    rtc::ArrayView<float> accumulated_error,
    rtc::ArrayView<float> scratch_memory) {}

void MatchedFilterCore_SSE2(size_t x_start_index,
                            float x2_sum_threshold,
                            float smoothing,
                            rtc::ArrayView<const float> x,
                            rtc::ArrayView<const float> y,
                            rtc::ArrayView<float> h,
                            bool* filters_updated,
                            float* error_sum,
                            bool compute_accumulated_error,
                            rtc::ArrayView<float> accumulated_error,
                            rtc::ArrayView<float> scratch_memory) {}
#endif

void MatchedFilterCore(size_t x_start_index,
                       float x2_sum_threshold,
                       float smoothing,
                       rtc::ArrayView<const float> x,
                       rtc::ArrayView<const float> y,
                       rtc::ArrayView<float> h,
                       bool* filters_updated,
                       float* error_sum,
                       bool compute_accumulated_error,
                       rtc::ArrayView<float> accumulated_error) {}

size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h) {}

}  // namespace aec3

MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
                             Aec3Optimization optimization,
                             size_t sub_block_size,
                             size_t window_size_sub_blocks,
                             int num_matched_filters,
                             size_t alignment_shift_sub_blocks,
                             float excitation_limit,
                             float smoothing_fast,
                             float smoothing_slow,
                             float matching_filter_threshold,
                             bool detect_pre_echo)
    :{}

MatchedFilter::~MatchedFilter() = default;

void MatchedFilter::Reset(bool full_reset) {}

void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
                           rtc::ArrayView<const float> capture,
                           bool use_slow_smoothing) {}

void MatchedFilter::LogFilterProperties(int sample_rate_hz,
                                        size_t shift,
                                        size_t downsampling_factor) const {}

void MatchedFilter::Dump() {}

}  // namespace webrtc