chromium/chromeos/ash/components/enhanced_network_tts/enhanced_network_tts_impl_unittest.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/components/enhanced_network_tts/enhanced_network_tts_impl.h"

#include <map>
#include <vector>

#include "base/functional/bind.h"
#include "base/memory/raw_ptr.h"
#include "base/test/task_environment.h"
#include "chromeos/ash/components/enhanced_network_tts/enhanced_network_tts_constants.h"
#include "chromeos/ash/components/enhanced_network_tts/enhanced_network_tts_test_utils.h"
#include "google_apis/google_api_keys.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "net/base/net_errors.h"
#include "net/http/http_status_code.h"
#include "services/data_decoder/public/cpp/test_support/in_process_data_decoder.h"
#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
#include "services/network/public/mojom/url_loader.mojom-shared.h"
#include "services/network/test/test_url_loader_factory.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "url/gurl.h"

namespace ash::enhanced_network_tts {
namespace {

// A fake server that supports test URL loading.
class TestServerURLLoaderFactory {
 public:
  TestServerURLLoaderFactory()
      : shared_loader_factory_(loader_factory_.GetSafeWeakWrapper()) {}
  TestServerURLLoaderFactory(const TestServerURLLoaderFactory&) = delete;
  TestServerURLLoaderFactory& operator=(const TestServerURLLoaderFactory&) =
      delete;
  ~TestServerURLLoaderFactory() = default;

  const std::vector<network::TestURLLoaderFactory::PendingRequest>& requests() {
    return *loader_factory_.pending_requests();
  }

  // Expects that the earliest received request has the given URL, headers and
  // body, and replies with the given response.
  //
  // |expected_headers| is a map from header key string to either:
  //   a) a null optional, if the given header should not be present, or
  //   b) a non-null optional, if the given header should be present and match
  //      the optional value.
  //
  // Consumes the earliest received request (i.e. a subsequent call will apply
  // to the second-earliest received request and so on).
  void ExpectRequestAndSimulateResponse(
      const std::string& expected_url,
      const std::map<std::string, std::optional<std::string>>& expected_headers,
      const std::string& expected_body,
      const std::string& response,
      const net::HttpStatusCode response_code) {
    const std::vector<network::TestURLLoaderFactory::PendingRequest>&
        pending_requests = *loader_factory_.pending_requests();

    ASSERT_FALSE(pending_requests.empty());
    const network::ResourceRequest& request = pending_requests.front().request;

    // Assert that the earliest request is for the given URL.
    EXPECT_EQ(request.url, GURL(expected_url));

    // Expect that specified headers are accurate.
    for (const auto& kv : expected_headers) {
      EXPECT_EQ(request.headers.GetHeader(kv.first), kv.second);
    }

    // Extract request body.
    std::string actual_body;
    if (request.request_body) {
      const std::vector<network::DataElement>* const elements =
          request.request_body->elements();

      // We only support the simplest body structure.
      if (elements && elements->size() == 1u &&
          (*elements)[0].type() ==
              network::mojom::DataElementDataView::Tag::kBytes) {
        actual_body = std::string(
            (*elements)[0].As<network::DataElementBytes>().AsStringPiece());
      }
    }

    EXPECT_TRUE(AreRequestsEqual(actual_body, expected_body));

    // Guaranteed to match the first request based on URL.
    loader_factory_.SimulateResponseForPendingRequest(expected_url, response,
                                                      response_code);
  }

  scoped_refptr<network::SharedURLLoaderFactory> AsSharedURLLoaderFactory() {
    return shared_loader_factory_;
  }

 private:
  network::TestURLLoaderFactory loader_factory_;
  scoped_refptr<network::SharedURLLoaderFactory> shared_loader_factory_;
};

// Receives the result of a request and writes the result data into the given
// variables.
void UnpackResult(std::optional<mojom::TtsRequestError>* const error,
                  std::vector<uint8_t>* const audio_data,
                  std::vector<mojom::TimingInfo>* const timing_data,
                  mojom::TtsResponsePtr result) {
  if (result->which() == mojom::TtsResponse::Tag::kErrorCode) {
    *error = result->get_error_code();
  } else {
    // Copy audio data.
    for (const auto audio_data_pt : result->get_data()->audio)
      audio_data->push_back(audio_data_pt);

    // Copy timing data.
    for (const auto& timing_ptr : result->get_data()->time_info)
      timing_data->push_back(*timing_ptr);
  }
}

class TestAudioDataObserverImpl : public mojom::AudioDataObserver {
 public:
  TestAudioDataObserverImpl() = default;
  TestAudioDataObserverImpl(const TestAudioDataObserverImpl&) = delete;
  void operator=(const TestAudioDataObserverImpl&) = delete;
  ~TestAudioDataObserverImpl() override = default;

  // Binds a pending receiver.
  void BindReceiver(mojo::PendingReceiver<mojom::AudioDataObserver> receiver) {
    receiver_.reset();
    receiver_.Bind(std::move(receiver));
  }

  // mojom::AudioDataObserver:
  void OnAudioDataReceived(mojom::TtsResponsePtr response) override {
    received_responses_.push_back(std::move(response));
  }

  mojom::TtsResponsePtr GetNexResponse() {
    mojom::TtsResponsePtr next_response =
        std::move(received_responses_.front());
    received_responses_.pop_front();
    return next_response;
  }

 private:
  std::list<mojom::TtsResponsePtr> received_responses_;

  mojo::Receiver<mojom::AudioDataObserver> receiver_{this};
};

}  // namespace

class EnhancedNetworkTtsImplTest : public testing::Test {
 protected:
  void SetUp() override {
    in_process_data_decoder_ =
        std::make_unique<data_decoder::test::InProcessDataDecoder>();
    enhanced_network_tts_impl_ = new EnhancedNetworkTtsImpl();
    enhanced_network_tts_impl_->BindReceiverAndURLFactory(
        remote_.BindNewPipeAndPassReceiver(),
        test_url_factory_.AsSharedURLLoaderFactory());
  }

  EnhancedNetworkTtsImpl& GetTestingInstance() {
    return *enhanced_network_tts_impl_;
  }

  TestAudioDataObserverImpl* GetTestingObserverPtr() { return &observer_; }

  raw_ptr<EnhancedNetworkTtsImpl> enhanced_network_tts_impl_;
  std::unique_ptr<data_decoder::test::InProcessDataDecoder>
      in_process_data_decoder_;
  base::test::TaskEnvironment test_task_env_;
  TestServerURLLoaderFactory test_url_factory_;
  mojo::Remote<mojom::EnhancedNetworkTts> remote_;
  TestAudioDataObserverImpl observer_;
};

TEST_F(EnhancedNetworkTtsImplTest, GetAudioDataSucceeds) {
  const std::string input_text = "Hi.";
  const float rate = 1.0;
  GetTestingInstance().GetAudioData(
      mojom::TtsRequest::New(input_text, rate, std::nullopt, std::nullopt),
      base::BindOnce(
          [](TestAudioDataObserverImpl* observer,
             mojo::PendingReceiver<mojom::AudioDataObserver> pending_receiver) {
            observer->BindReceiver(std::move(pending_receiver));
          },
          GetTestingObserverPtr()));
  test_task_env_.RunUntilIdle();

  const std::map<std::string, std::optional<std::string>> expected_headers = {
      {kGoogApiKeyHeader, google_apis::GetReadAloudAPIKey()}};
  const std::string expected_body = CreateCorrectRequest(input_text, rate);
  // |expected_output| here is arbitrary, which is encoded into a fake response
  // sent by the fake server, |TestServerURLLoaderFactory|. In general, we
  // expect the real server sends the audio data back as a base64 encoded JSON
  // string.
  const std::vector<uint8_t> expected_output = {1, 2, 5};
  test_url_factory_.ExpectRequestAndSimulateResponse(
      kReadAloudServerUrl, expected_headers, expected_body,
      CreateServerResponse(expected_output), net::HTTP_OK);
  test_task_env_.RunUntilIdle();

  // We only get the data after the server's response. We simulate the response
  // in the code above.
  std::optional<mojom::TtsRequestError> error;
  std::vector<uint8_t> audio_data;
  std::vector<mojom::TimingInfo> timing_data;
  UnpackResult(&error, &audio_data, &timing_data,
               GetTestingObserverPtr()->GetNexResponse());
  EXPECT_EQ(audio_data, expected_output);
  // The timing data is hardcoded in |kTemplateResponse|.
  EXPECT_EQ(timing_data.size(), 2u);
  EXPECT_EQ(timing_data[0].text, "test1");
  EXPECT_EQ(timing_data[0].time_offset, "0.01s");
  EXPECT_EQ(timing_data[0].duration, "0.14s");
  EXPECT_EQ(timing_data[0].text_offset, 0u);
  EXPECT_EQ(timing_data[1].text, "test2");
  EXPECT_EQ(timing_data[1].time_offset, "0.16s");
  EXPECT_EQ(timing_data[1].duration, "0.17s");
  EXPECT_EQ(timing_data[1].text_offset, 6u);
}

TEST_F(EnhancedNetworkTtsImplTest, GetAudioDataIgnoresWhitespacesAtStart) {
  const std::string input_text = "    test1 test2";
  const std::string input_text_trimmed = "test1 test2";
  const float rate = 1.0;
  GetTestingInstance().GetAudioData(
      mojom::TtsRequest::New(input_text, rate, std::nullopt, std::nullopt),
      base::BindOnce(
          [](TestAudioDataObserverImpl* observer,
             mojo::PendingReceiver<mojom::AudioDataObserver> pending_receiver) {
            observer->BindReceiver(std::move(pending_receiver));
          },
          GetTestingObserverPtr()));
  test_task_env_.RunUntilIdle();

  const std::map<std::string, std::optional<std::string>> expected_headers = {
      {kGoogApiKeyHeader, google_apis::GetReadAloudAPIKey()}};
  const std::string expected_body =
      CreateCorrectRequest(input_text_trimmed, rate);
  // |expected_output| here is arbitrary, which is encoded into a fake response
  // sent by the fake server, |TestServerURLLoaderFactory|. In general, we
  // expect the real server sends the audio data back as a base64 encoded JSON
  // string.
  const std::vector<uint8_t> expected_output = {1, 2, 5};
  test_url_factory_.ExpectRequestAndSimulateResponse(
      kReadAloudServerUrl, expected_headers, expected_body,
      CreateServerResponse(expected_output), net::HTTP_OK);
  test_task_env_.RunUntilIdle();

  // We only get the data after the server's response. We simulate the response
  // in the code above.
  std::optional<mojom::TtsRequestError> error;
  std::vector<uint8_t> audio_data;
  std::vector<mojom::TimingInfo> timing_data;
  UnpackResult(&error, &audio_data, &timing_data,
               GetTestingObserverPtr()->GetNexResponse());
  // The text offset will be compensated with whitespaces.
  EXPECT_EQ(timing_data[0].text, "test1");
  EXPECT_EQ(timing_data[0].text_offset, 4u);
  EXPECT_EQ(timing_data[1].text, "test2");
  EXPECT_EQ(timing_data[1].text_offset, 10u);
}

TEST_F(EnhancedNetworkTtsImplTest, GetAudioDataSucceedsWithFasterRate) {
  const std::string input_text = "Rate will be capped to kMaxRate";
  const float rate = kMaxRate + 1.0f;
  GetTestingInstance().GetAudioData(
      mojom::TtsRequest::New(input_text, rate, std::nullopt, std::nullopt),
      base::BindOnce(
          [](TestAudioDataObserverImpl* observer,
             mojo::PendingReceiver<mojom::AudioDataObserver> pending_receiver) {
            observer->BindReceiver(std::move(pending_receiver));
          },
          GetTestingObserverPtr()));
  test_task_env_.RunUntilIdle();

  const std::map<std::string, std::optional<std::string>> expected_headers = {
      {kGoogApiKeyHeader, google_apis::GetReadAloudAPIKey()}};
  const std::string expected_body = CreateCorrectRequest(input_text, kMaxRate);
  // |expected_output| here is arbitrary, which is encoded into a fake response
  // sent by the fake server, |TestServerURLLoaderFactory|. In general, we
  // expect the real server sends the audio data back as a base64 encoded JSON
  // string.
  const std::vector<uint8_t> expected_output = {1, 2, 5};
  test_url_factory_.ExpectRequestAndSimulateResponse(
      kReadAloudServerUrl, expected_headers, expected_body,
      CreateServerResponse(expected_output), net::HTTP_OK);
  test_task_env_.RunUntilIdle();

  std::optional<mojom::TtsRequestError> error;
  std::vector<uint8_t> audio_data;
  std::vector<mojom::TimingInfo> timing_data;
  UnpackResult(&error, &audio_data, &timing_data,
               GetTestingObserverPtr()->GetNexResponse());
  // We only get the data after the server's response. We simulate the response
  // in the code above.
  EXPECT_EQ(audio_data, expected_output);
}

TEST_F(EnhancedNetworkTtsImplTest, GetAudioDataSucceedsWithSlowerRate) {
  const std::string input_text = "Rate will be floored to kMinRate";
  const float rate = kMinRate - 0.1f;
  GetTestingInstance().GetAudioData(
      mojom::TtsRequest::New(input_text, rate, std::nullopt, std::nullopt),
      base::BindOnce(
          [](TestAudioDataObserverImpl* observer,
             mojo::PendingReceiver<mojom::AudioDataObserver> pending_receiver) {
            observer->BindReceiver(std::move(pending_receiver));
          },
          GetTestingObserverPtr()));
  test_task_env_.RunUntilIdle();

  const std::map<std::string, std::optional<std::string>> expected_headers = {
      {kGoogApiKeyHeader, google_apis::GetReadAloudAPIKey()}};
  const std::string expected_body = CreateCorrectRequest(input_text, kMinRate);
  // |expected_output| here is arbitrary, which is encoded into a fake response
  // sent by the fake server, |TestServerURLLoaderFactory|. In general, we
  // expect the real server sends the audio data back as a base64 encoded JSON
  // string.
  const std::vector<uint8_t> expected_output = {1, 2, 5};
  test_url_factory_.ExpectRequestAndSimulateResponse(
      kReadAloudServerUrl, expected_headers, expected_body,
      CreateServerResponse(expected_output), net::HTTP_OK);
  test_task_env_.RunUntilIdle();

  // We only get the data after the server's response. We simulate the response
  // in the code above.
  std::optional<mojom::TtsRequestError> error;
  std::vector<uint8_t> audio_data;
  std::vector<mojom::TimingInfo> timing_data;
  UnpackResult(&error, &audio_data, &timing_data,
               GetTestingObserverPtr()->GetNexResponse());
  EXPECT_EQ(audio_data, expected_output);
}

TEST_F(EnhancedNetworkTtsImplTest, GetAudioDataWithLongUtterance) {
  const std::string input_text = "Sent 1. Hello world!";
  const float rate = 1.0;
  // Sets the limit to cover the first sentence and every words in the second
  // sentence.
  GetTestingInstance().SetCharLimitPerRequestForTesting(8);
  GetTestingInstance().GetAudioData(
      mojom::TtsRequest::New(input_text, rate, std::nullopt, std::nullopt),
      base::BindOnce(
          [](TestAudioDataObserverImpl* observer,
             mojo::PendingReceiver<mojom::AudioDataObserver> pending_receiver) {
            observer->BindReceiver(std::move(pending_receiver));
          },
          GetTestingObserverPtr()));
  test_task_env_.RunUntilIdle();

  const std::map<std::string, std::optional<std::string>> expected_headers = {
      {kGoogApiKeyHeader, google_apis::GetReadAloudAPIKey()}};
  // |expected_output| here is arbitrary, which is encoded into a fake response
  // sent by the fake server, |TestServerURLLoaderFactory|. In general, we
  // expect the real server sends the audio data back as a base64 encoded JSON
  // string.
  const std::vector<uint8_t> expected_output = {1, 2, 5};

  // The first request contains the first sentence.
  const std::string first_expected_body =
      CreateCorrectRequest("Sent 1. ", rate);
  test_url_factory_.ExpectRequestAndSimulateResponse(
      kReadAloudServerUrl, expected_headers, first_expected_body,
      CreateServerResponse(expected_output), net::HTTP_OK);
  test_task_env_.RunUntilIdle();

  // The second request contains the first word in the second sentence.
  const std::string second_expected_body = CreateCorrectRequest("Hello", rate);
  test_url_factory_.ExpectRequestAndSimulateResponse(
      kReadAloudServerUrl, expected_headers, second_expected_body,
      CreateServerResponse(expected_output), net::HTTP_OK);
  test_task_env_.RunUntilIdle();

  // The third request contains the second word in the second sentence.
  const std::string third_expected_body = CreateCorrectRequest(" world!", rate);
  test_url_factory_.ExpectRequestAndSimulateResponse(
      kReadAloudServerUrl, expected_headers, third_expected_body,
      CreateServerResponse(expected_output), net::HTTP_OK);
  test_task_env_.RunUntilIdle();
}

TEST_F(EnhancedNetworkTtsImplTest, EmptyUtteranceError) {
  const std::string input_text("");
  const float rate = 1.0;
  GetTestingInstance().GetAudioData(
      mojom::TtsRequest::New(input_text, rate, std::nullopt, std::nullopt),
      base::BindOnce(
          [](TestAudioDataObserverImpl* observer,
             mojo::PendingReceiver<mojom::AudioDataObserver> pending_receiver) {
            observer->BindReceiver(std::move(pending_receiver));
          },
          GetTestingObserverPtr()));
  test_task_env_.RunUntilIdle();

  // Over length request will be terminated before sending to server.
  std::optional<mojom::TtsRequestError> error;
  std::vector<uint8_t> audio_data;
  std::vector<mojom::TimingInfo> timing_data;
  UnpackResult(&error, &audio_data, &timing_data,
               GetTestingObserverPtr()->GetNexResponse());
  EXPECT_EQ(error, mojom::TtsRequestError::kEmptyUtterance);
}

TEST_F(EnhancedNetworkTtsImplTest, OverrideRequest) {
  const std::string input_text("request");
  const float rate = 1.0;
  GetTestingInstance().GetAudioData(
      mojom::TtsRequest::New(input_text, rate, std::nullopt, std::nullopt),
      base::BindOnce(
          [](TestAudioDataObserverImpl* observer,
             mojo::PendingReceiver<mojom::AudioDataObserver> pending_receiver) {
            observer->BindReceiver(std::move(pending_receiver));
          },
          GetTestingObserverPtr()));
  test_task_env_.RunUntilIdle();
  // The second request, which has a new observer, comes in before the server
  // replies to the first one.
  TestAudioDataObserverImpl second_observer;
  GetTestingInstance().GetAudioData(
      mojom::TtsRequest::New(input_text, rate, std::nullopt, std::nullopt),
      base::BindOnce(
          [](TestAudioDataObserverImpl* observer,
             mojo::PendingReceiver<mojom::AudioDataObserver> pending_receiver) {
            observer->BindReceiver(std::move(pending_receiver));
          },
          &second_observer));
  test_task_env_.RunUntilIdle();

  // Assume the server replies to the requests in sequence.
  const std::map<std::string, std::optional<std::string>> expected_headers = {
      {kGoogApiKeyHeader, google_apis::GetReadAloudAPIKey()}};
  std::string expected_body = CreateCorrectRequest(input_text, rate);
  const std::vector<uint8_t> expected_output = {1, 2, 5};
  test_url_factory_.ExpectRequestAndSimulateResponse(
      kReadAloudServerUrl, expected_headers, expected_body,
      CreateServerResponse(expected_output), net::HTTP_OK);
  test_task_env_.RunUntilIdle();
  // Assume the server replies same message to both requests.
  test_url_factory_.ExpectRequestAndSimulateResponse(
      kReadAloudServerUrl, expected_headers, expected_body,
      CreateServerResponse(expected_output), net::HTTP_OK);
  test_task_env_.RunUntilIdle();

  // The first request gets an error message.
  std::optional<mojom::TtsRequestError> error_first_request;
  std::vector<uint8_t> audio_data_first_request;
  std::vector<mojom::TimingInfo> timing_data_first_request;
  UnpackResult(&error_first_request, &audio_data_first_request,
               &timing_data_first_request,
               GetTestingObserverPtr()->GetNexResponse());
  EXPECT_EQ(error_first_request, mojom::TtsRequestError::kRequestOverride);
  EXPECT_EQ(timing_data_first_request.size(), 0u);
  EXPECT_EQ(audio_data_first_request.size(), 0u);

  // The second request gets the data.
  std::optional<mojom::TtsRequestError> error_second_request;
  std::vector<uint8_t> audio_data_second_request;
  std::vector<mojom::TimingInfo> timing_data_second_request;
  UnpackResult(&error_second_request, &audio_data_second_request,
               &timing_data_second_request, second_observer.GetNexResponse());
  EXPECT_EQ(audio_data_second_request, expected_output);
}

TEST_F(EnhancedNetworkTtsImplTest, ServerError) {
  const std::string input_text = "Hi.";
  const float rate = 1.0;
  GetTestingInstance().GetAudioData(
      mojom::TtsRequest::New(input_text, rate, std::nullopt, std::nullopt),
      base::BindOnce(
          [](TestAudioDataObserverImpl* observer,
             mojo::PendingReceiver<mojom::AudioDataObserver> pending_receiver) {
            observer->BindReceiver(std::move(pending_receiver));
          },
          GetTestingObserverPtr()));
  test_task_env_.RunUntilIdle();

  const std::map<std::string, std::optional<std::string>> expected_headers = {
      {kGoogApiKeyHeader, google_apis::GetReadAloudAPIKey()}};
  const std::string expected_body = CreateCorrectRequest(input_text, rate);
  test_url_factory_.ExpectRequestAndSimulateResponse(
      kReadAloudServerUrl, expected_headers, expected_body, "" /* response= */,
      net::HTTP_INTERNAL_SERVER_ERROR);
  test_task_env_.RunUntilIdle();

  // We only get the data after the server's response. We simulate the response
  // in the code above.
  std::optional<mojom::TtsRequestError> error;
  std::vector<uint8_t> audio_data;
  std::vector<mojom::TimingInfo> timing_data;
  UnpackResult(&error, &audio_data, &timing_data,
               GetTestingObserverPtr()->GetNexResponse());
  EXPECT_EQ(error, mojom::TtsRequestError::kServerError);
}

TEST_F(EnhancedNetworkTtsImplTest, JsonDecodingError) {
  const std::string input_text = "Hi.";
  const float rate = 1.0;
  GetTestingInstance().GetAudioData(
      mojom::TtsRequest::New(input_text, rate, std::nullopt, std::nullopt),
      base::BindOnce(
          [](TestAudioDataObserverImpl* observer,
             mojo::PendingReceiver<mojom::AudioDataObserver> pending_receiver) {
            observer->BindReceiver(std::move(pending_receiver));
          },
          GetTestingObserverPtr()));
  test_task_env_.RunUntilIdle();

  const std::map<std::string, std::optional<std::string>> expected_headers = {
      {kGoogApiKeyHeader, google_apis::GetReadAloudAPIKey()}};
  const std::string expected_body = CreateCorrectRequest(input_text, rate);
  const char response[] = R"([{some wired response)";
  test_url_factory_.ExpectRequestAndSimulateResponse(
      kReadAloudServerUrl, expected_headers, expected_body, response,
      net::HTTP_OK);
  test_task_env_.RunUntilIdle();

  // We only get the data after the server's response. We simulate the response
  // in the code above.
  std::optional<mojom::TtsRequestError> error;
  std::vector<uint8_t> audio_data;
  std::vector<mojom::TimingInfo> timing_data;
  UnpackResult(&error, &audio_data, &timing_data,
               GetTestingObserverPtr()->GetNexResponse());
  EXPECT_EQ(error, mojom::TtsRequestError::kReceivedUnexpectedData);
}

}  // namespace ash::enhanced_network_tts