#include "modules/audio_processing/aec3/matched_filter.h"
#include "rtc_base/system/arch.h"
#if defined(WEBRTC_HAS_NEON)
#include <arm_neon.h>
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
#include <emmintrin.h>
#endif
#include <algorithm>
#include <cstddef>
#include <initializer_list>
#include <iterator>
#include <numeric>
#include "absl/types/optional.h"
#include "api/array_view.h"
#include "modules/audio_processing/aec3/downsampled_render_buffer.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/experiments/field_trial_parser.h"
#include "rtc_base/logging.h"
#include "system_wrappers/include/field_trial.h"
namespace {
constexpr int kAccumulatedErrorSubSampleRate = …;
void UpdateAccumulatedError(
const rtc::ArrayView<const float> instantaneous_accumulated_error,
const rtc::ArrayView<float> accumulated_error,
float one_over_error_sum_anchor) { … }
size_t ComputePreEchoLag(
const rtc::ArrayView<const float> accumulated_error,
size_t lag,
size_t alignment_shift_winner) { … }
}
namespace webrtc {
namespace aec3 {
#if defined(WEBRTC_HAS_NEON)
inline float SumAllElements(float32x4_t elements) {
float32x2_t sum = vpadd_f32(vget_low_f32(elements), vget_high_f32(elements));
sum = vpadd_f32(sum, sum);
return vget_lane_f32(sum, 0);
}
void MatchedFilterCoreWithAccumulatedError_NEON(
size_t x_start_index,
float x2_sum_threshold,
float smoothing,
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum,
rtc::ArrayView<float> accumulated_error,
rtc::ArrayView<float> scratch_memory) {
const int h_size = static_cast<int>(h.size());
const int x_size = static_cast<int>(x.size());
RTC_DCHECK_EQ(0, h_size % 4);
std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
for (size_t i = 0; i < y.size(); ++i) {
RTC_DCHECK_GT(x_size, x_start_index);
const int chunk1 =
std::min(h_size, static_cast<int>(x_size - x_start_index));
if (chunk1 != h_size) {
const int chunk2 = h_size - chunk1;
std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
}
const float* x_p =
chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
const float* h_p = &h[0];
float* accumulated_error_p = &accumulated_error[0];
float32x4_t x2_sum_128 = vdupq_n_f32(0);
float x2_sum = 0.f;
float s = 0;
const int limit_by_4 = h_size >> 2;
for (int k = limit_by_4; k > 0;
--k, h_p += 4, x_p += 4, accumulated_error_p++) {
const float32x4_t x_k = vld1q_f32(x_p);
const float32x4_t h_k = vld1q_f32(h_p);
x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
float32x4_t hk_xk_128 = vmulq_f32(h_k, x_k);
s += SumAllElements(hk_xk_128);
const float e = s - y[i];
accumulated_error_p[0] += e * e;
}
x2_sum += SumAllElements(x2_sum_128);
float e = y[i] - s;
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
(*error_sum) += e * e;
if (x2_sum > x2_sum_threshold && !saturation) {
RTC_DCHECK_LT(0.f, x2_sum);
const float alpha = smoothing * e / x2_sum;
const float32x4_t alpha_128 = vmovq_n_f32(alpha);
float* h_p = &h[0];
x_p = chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
const int limit_by_4 = h_size >> 2;
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
float32x4_t h_k = vld1q_f32(h_p);
const float32x4_t x_k = vld1q_f32(x_p);
h_k = vmlaq_f32(h_k, alpha_128, x_k);
vst1q_f32(h_p, h_k);
}
*filters_updated = true;
}
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
}
}
void MatchedFilterCore_NEON(size_t x_start_index,
float x2_sum_threshold,
float smoothing,
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum,
bool compute_accumulated_error,
rtc::ArrayView<float> accumulated_error,
rtc::ArrayView<float> scratch_memory) {
const int h_size = static_cast<int>(h.size());
const int x_size = static_cast<int>(x.size());
RTC_DCHECK_EQ(0, h_size % 4);
if (compute_accumulated_error) {
return MatchedFilterCoreWithAccumulatedError_NEON(
x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
error_sum, accumulated_error, scratch_memory);
}
for (size_t i = 0; i < y.size(); ++i) {
RTC_DCHECK_GT(x_size, x_start_index);
const float* x_p = &x[x_start_index];
const float* h_p = &h[0];
float32x4_t s_128 = vdupq_n_f32(0);
float32x4_t x2_sum_128 = vdupq_n_f32(0);
float x2_sum = 0.f;
float s = 0;
const int chunk1 =
std::min(h_size, static_cast<int>(x_size - x_start_index));
const int chunk2 = h_size - chunk1;
for (int limit : {chunk1, chunk2}) {
const int limit_by_4 = limit >> 2;
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
const float32x4_t x_k = vld1q_f32(x_p);
const float32x4_t h_k = vld1q_f32(h_p);
x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
s_128 = vmlaq_f32(s_128, h_k, x_k);
}
for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
const float x_k = *x_p;
x2_sum += x_k * x_k;
s += *h_p * x_k;
}
x_p = &x[0];
}
s += SumAllElements(s_128);
x2_sum += SumAllElements(x2_sum_128);
float e = y[i] - s;
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
(*error_sum) += e * e;
if (x2_sum > x2_sum_threshold && !saturation) {
RTC_DCHECK_LT(0.f, x2_sum);
const float alpha = smoothing * e / x2_sum;
const float32x4_t alpha_128 = vmovq_n_f32(alpha);
float* h_p = &h[0];
x_p = &x[x_start_index];
for (int limit : {chunk1, chunk2}) {
const int limit_by_4 = limit >> 2;
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
float32x4_t h_k = vld1q_f32(h_p);
const float32x4_t x_k = vld1q_f32(x_p);
h_k = vmlaq_f32(h_k, alpha_128, x_k);
vst1q_f32(h_p, h_k);
}
for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
*h_p += alpha * *x_p;
}
x_p = &x[0];
}
*filters_updated = true;
}
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
}
}
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
void MatchedFilterCore_AccumulatedError_SSE2(
size_t x_start_index,
float x2_sum_threshold,
float smoothing,
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum,
rtc::ArrayView<float> accumulated_error,
rtc::ArrayView<float> scratch_memory) { … }
void MatchedFilterCore_SSE2(size_t x_start_index,
float x2_sum_threshold,
float smoothing,
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum,
bool compute_accumulated_error,
rtc::ArrayView<float> accumulated_error,
rtc::ArrayView<float> scratch_memory) { … }
#endif
void MatchedFilterCore(size_t x_start_index,
float x2_sum_threshold,
float smoothing,
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum,
bool compute_accumulated_error,
rtc::ArrayView<float> accumulated_error) { … }
size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h) { … }
}
MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
Aec3Optimization optimization,
size_t sub_block_size,
size_t window_size_sub_blocks,
int num_matched_filters,
size_t alignment_shift_sub_blocks,
float excitation_limit,
float smoothing_fast,
float smoothing_slow,
float matching_filter_threshold,
bool detect_pre_echo)
: … { … }
MatchedFilter::~MatchedFilter() = default;
void MatchedFilter::Reset(bool full_reset) { … }
void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
rtc::ArrayView<const float> capture,
bool use_slow_smoothing) { … }
void MatchedFilter::LogFilterProperties(int sample_rate_hz,
size_t shift,
size_t downsampling_factor) const { … }
void MatchedFilter::Dump() { … }
}