chromium/chrome/services/speech/audio_source_fetcher_unittest.cc

// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/services/speech/audio_source_fetcher_impl.h"

#include <memory>
#include <utility>

#include "base/files/file_path.h"
#include "base/functional/callback_forward.h"
#include "base/memory/raw_ptr.h"
#include "base/path_service.h"
#include "base/run_loop.h"
#include "base/test/bind.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/task_environment.h"
#include "chrome/services/speech/cros_speech_recognition_recognizer_impl.h"
#include "chrome/services/speech/speech_recognition_service_impl.h"
#include "media/base/audio_glitch_info.h"
#include "media/base/audio_timestamp_helper.h"
#include "media/mojo/mojom/audio_data.mojom.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "services/audio/public/cpp/fake_stream_factory.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace speech {

namespace {

// The server based recognition audio buffer duration is 100ms:
constexpr int kServerBasedRecognitionAudioSampleRate = 16000;
constexpr int kServerBasedRecognitionAudioFramesPerBuffer = 1600;

// The original audio buffer duration is 200ms:
constexpr int kOriginalSampleRate = 48000;
constexpr int kOriginalFramesPerBuffer = 9600;

constexpr char kServerBasedRecognitionSessionLength[] =
    "Ash.SpeechRecognitionSessionLength.ServerBased";
constexpr char kOnDeviceRecognitionSessionLength[] =
    "Ash.SpeechRecognitionSessionLength.OnDevice";

}  // namespace

class MockStreamFactory : public audio::FakeStreamFactory {
 public:
  MockStreamFactory() = default;
  ~MockStreamFactory() override = default;

  void CreateInputStream(
      mojo::PendingReceiver<media::mojom::AudioInputStream> stream_receiver,
      mojo::PendingRemote<media::mojom::AudioInputStreamClient> client,
      mojo::PendingRemote<media::mojom::AudioInputStreamObserver> observer,
      mojo::PendingRemote<media::mojom::AudioLog> log,
      const std::string& device_id,
      const media::AudioParameters& params,
      uint32_t shared_memory_count,
      bool enable_agc,
      base::ReadOnlySharedMemoryRegion key_press_count_buffer,
      media::mojom::AudioProcessingConfigPtr processing_config,
      CreateInputStreamCallback created_callback) override {
    last_created_callback_ = std::move(created_callback);
  }

 private:
  // Keeps the `last_created_callback_` alive during test.
  CreateInputStreamCallback last_created_callback_;

  mojo::Receiver<media::mojom::AudioStreamFactory> receiver_{this};
};

using OnSendAudioToSpeechRecognitionCallback =
    base::OnceCallback<void(media::mojom::AudioDataS16Ptr buffer)>;

// A class to verify whether the  media::mojom::AudioDataS16 data captured by
// AudioSourceFetcherImpl.
class MockAudioSourceConsumer : public AudioSourceConsumer {
 public:
  MockAudioSourceConsumer() = default;
  ~MockAudioSourceConsumer() override = default;

  // AudioSourceConsumer:
  void AddAudio(media::mojom::AudioDataS16Ptr buffer) override {
    EXPECT_FALSE(is_audio_end_);
    std::move(on_send_audio_to_speech_recognition_callback_)
        .Run(std::move(buffer));
  }

  void OnAudioCaptureEnd() override { is_audio_end_ = true; }

  void OnAudioCaptureError() override {}

  void SetOnSendAudioToSpeechRecognitionCallback(
      OnSendAudioToSpeechRecognitionCallback callback) {
    on_send_audio_to_speech_recognition_callback_ = std::move(callback);
  }

 private:
  // Used to verify the media::mojom::AudioDataS16 content.
  OnSendAudioToSpeechRecognitionCallback
      on_send_audio_to_speech_recognition_callback_;
  bool is_audio_end_ = false;
};

class AudioSourceFetcherImplTest
    : public testing::TestWithParam<bool>,
      public media::mojom::SpeechRecognitionRecognizerClient {
 public:
  AudioSourceFetcherImplTest() { is_server_based_ = GetParam(); }
  ~AudioSourceFetcherImplTest() override = default;

  void SetUp() override {
    std::unique_ptr<MockAudioSourceConsumer> speech_recognition_recognizer =
        std::make_unique<MockAudioSourceConsumer>();
    speech_recognition_recognizer_ = speech_recognition_recognizer.get();
    audio_source_fetcher_ = std::make_unique<AudioSourceFetcherImpl>(
        std::move(speech_recognition_recognizer),
        /*is_multi_channel_supported=*/true,
        /*is_server_based=*/is_server_based_);
  }

 protected:
  bool is_server_based() const { return is_server_based_; }

  AudioSourceFetcherImpl* audio_source_fetcher() {
    return audio_source_fetcher_.get();
  }

  void SetOnSendAudioToSpeechRecognitionCallback(
      OnSendAudioToSpeechRecognitionCallback callback) {
    speech_recognition_recognizer_->SetOnSendAudioToSpeechRecognitionCallback(
        std::move(callback));
  }

  void VerifyAudioBuffer(int sample_rate, int frame_count, bool stop = false) {
    base::RunLoop run_loop;
    SetOnSendAudioToSpeechRecognitionCallback(
        base::BindLambdaForTesting([&](media::mojom::AudioDataS16Ptr buffer) {
          EXPECT_EQ(sample_rate, buffer->sample_rate);
          EXPECT_EQ(frame_count, buffer->frame_count);

          run_loop.Quit();
        }));
    if (stop) {
      audio_source_fetcher()->Stop();
    }
    run_loop.Run();
  }

  // media::mojom::SpeechRecognitionRecognizerClient:
  // Left these function do nothing since we are not interested in this test.
  void OnSpeechRecognitionRecognitionEvent(
      const media::SpeechRecognitionResult& result,
      OnSpeechRecognitionRecognitionEventCallback reply) override {}
  void OnSpeechRecognitionStopped() override {}
  void OnSpeechRecognitionError() override {}
  void OnLanguageIdentificationEvent(
      media::mojom::LanguageIdentificationEventPtr event) override {}

  base::test::TaskEnvironment task_environment_;
  std::unique_ptr<AudioSourceFetcherImpl> audio_source_fetcher_;
  base::HistogramTester histogram_tester_;

 private:
  raw_ptr<MockAudioSourceConsumer, DanglingUntriaged>
      speech_recognition_recognizer_;
  bool is_server_based_;
};

TEST_P(AudioSourceFetcherImplTest, Resample) {
  MockStreamFactory fake_stream_factory;
  media::AudioParameters params =
      media::AudioParameters(media::AudioParameters::AUDIO_PCM_LOW_LATENCY,
                             media::ChannelLayoutConfig::Stereo(),
                             /*sample_rate=*/kOriginalSampleRate,
                             /*frames_per_buffer=*/kOriginalFramesPerBuffer);
  audio_source_fetcher()->Start(fake_stream_factory.MakeRemote(), "device_id",
                                params);

  std::unique_ptr<::media::AudioBus> audio_bus =
      media::AudioBus::Create(params);

  // Initialize channel data in `audio_bus`.
  audio_bus->Zero();
  audio_source_fetcher()->Capture(audio_bus.get(),
                                  /*audio_capture_time=*/base::TimeTicks::Now(),
                                  /*glitch_info=*/{},
                                  /*volume=*/1.0,
                                  /*key_pressed=*/true);
  if (is_server_based()) {
    VerifyAudioBuffer(kServerBasedRecognitionAudioSampleRate,
                      kServerBasedRecognitionAudioFramesPerBuffer);
  } else {
    VerifyAudioBuffer(kOriginalSampleRate, kOriginalFramesPerBuffer);
  }

  fake_stream_factory.ResetReceiver();
  audio_source_fetcher()->Stop();

  // There are remaining frames which are flushed to convert when calling
  // `audio_source_fetcher()->Stop()`.
  if (is_server_based()) {
    VerifyAudioBuffer(kServerBasedRecognitionAudioSampleRate,
                      kServerBasedRecognitionAudioFramesPerBuffer);
  }

  // Let's destroy the audio source fetcher and ensure that the metric
  // has been recorded.
  audio_source_fetcher_.reset();
  base::TimeDelta length = media::AudioTimestampHelper::FramesToTime(
      audio_bus->frames(), kOriginalSampleRate);

  const auto* histogram_name = is_server_based()
                                   ? kServerBasedRecognitionSessionLength
                                   : kOnDeviceRecognitionSessionLength;
  histogram_tester_.ExpectTimeBucketCount(histogram_name, length,
                                          /*count=*/1);
}

TEST_P(AudioSourceFetcherImplTest, StopDuringResample) {
  MockStreamFactory fake_stream_factory;
  media::AudioParameters params =
      media::AudioParameters(media::AudioParameters::AUDIO_PCM_LOW_LATENCY,
                             media::ChannelLayoutConfig::Stereo(),
                             /*sample_rate=*/kOriginalSampleRate,
                             /*frames_per_buffer=*/kOriginalFramesPerBuffer);
  audio_source_fetcher()->Start(fake_stream_factory.MakeRemote(), "device_id",
                                params);

  auto audio_bus = media::AudioBus::Create(params);

  // Initialize channel data in `audio_bus`.
  audio_bus->Zero();
  audio_source_fetcher()->Capture(audio_bus.get(),
                                  /*audio_capture_time=*/base::TimeTicks::Now(),
                                  /*glitch_info=*/{},
                                  /*volume=*/1.0,
                                  /*key_pressed=*/true);
  if (is_server_based()) {
    // Stop will prevent the pending resample call from running, so no audio
    // will be available to verify.
    audio_source_fetcher_->Stop();
    task_environment_.RunUntilIdle();
  } else {
    VerifyAudioBuffer(kOriginalSampleRate, kOriginalFramesPerBuffer,
                      /*stop=*/true);
  }
  audio_source_fetcher_.reset();
  fake_stream_factory.ResetReceiver();
}

INSTANTIATE_TEST_SUITE_P(All, AudioSourceFetcherImplTest, ::testing::Bool());

}  // namespace speech