chromium/third_party/webrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc

/*
 *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"

#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "third_party/rnnoise/src/rnn_activations.h"
#include "third_party/rnnoise/src/rnn_vad_weights.h"

namespace webrtc {
namespace rnn_vad {
namespace {

constexpr int kNumGruGates =;  // Update, reset, output.

std::vector<float> PreprocessGruTensor(rtc::ArrayView<const int8_t> tensor_src,
                                       int output_size) {}

// Computes the output for the update or the reset gate.
// Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where
// - `g`: output gate vector
// - `W`: weights matrix
// - `i`: input vector
// - `R`: recurrent weights matrix
// - `s`: state gate vector
// - `b`: bias vector
void ComputeUpdateResetGate(int input_size,
                            int output_size,
                            const VectorMath& vector_math,
                            rtc::ArrayView<const float> input,
                            rtc::ArrayView<const float> state,
                            rtc::ArrayView<const float> bias,
                            rtc::ArrayView<const float> weights,
                            rtc::ArrayView<const float> recurrent_weights,
                            rtc::ArrayView<float> gate) {}

// Computes the output for the state gate.
// Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where
// - `s'`: output state gate vector
// - `s`: previous state gate vector
// - `u`: update gate vector
// - `W`: weights matrix
// - `i`: input vector
// - `R`: recurrent weights matrix
// - `r`: reset gate vector
// - `b`: bias vector
// - `.*` element-wise product
void ComputeStateGate(int input_size,
                      int output_size,
                      const VectorMath& vector_math,
                      rtc::ArrayView<const float> input,
                      rtc::ArrayView<const float> update,
                      rtc::ArrayView<const float> reset,
                      rtc::ArrayView<const float> bias,
                      rtc::ArrayView<const float> weights,
                      rtc::ArrayView<const float> recurrent_weights,
                      rtc::ArrayView<float> state) {}

}  // namespace

GatedRecurrentLayer::GatedRecurrentLayer(
    const int input_size,
    const int output_size,
    const rtc::ArrayView<const int8_t> bias,
    const rtc::ArrayView<const int8_t> weights,
    const rtc::ArrayView<const int8_t> recurrent_weights,
    const AvailableCpuFeatures& cpu_features,
    absl::string_view layer_name)
    :{}

GatedRecurrentLayer::~GatedRecurrentLayer() = default;

void GatedRecurrentLayer::Reset() {}

void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {}

}  // namespace rnn_vad
}  // namespace webrtc