chromium/chromeos/ash/components/boca/babelorca/transcript_sender_unittest.cc

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/components/boca/babelorca/transcript_sender.h"

#include <string>

#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "base/types/expected.h"
#include "chromeos/ash/components/boca/babelorca/fakes/fake_tachyon_authed_client.h"
#include "chromeos/ash/components/boca/babelorca/fakes/fake_tachyon_client.h"
#include "chromeos/ash/components/boca/babelorca/fakes/fake_tachyon_request_data_provider.h"
#include "chromeos/ash/components/boca/babelorca/proto/babel_orca_message.pb.h"
#include "chromeos/ash/components/boca/babelorca/proto/tachyon.pb.h"
#include "chromeos/ash/components/boca/babelorca/proto/tachyon_enums.pb.h"
#include "chromeos/ash/components/boca/babelorca/response_callback_wrapper.h"
#include "chromeos/ash/components/boca/babelorca/response_callback_wrapper_impl.h"
#include "chromeos/ash/components/boca/babelorca/tachyon_constants.h"
#include "media/mojo/mojom/speech_recognition_result.h"
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace ash::babelorca {
namespace {

constexpr char kSenderEmail[] = "[email protected]";
constexpr char kLanguage[] = "en-US";

struct TranscriptSenderTestCase {
  std::string test_name;
  std::string transcript1;
  std::string transcript2;
  size_t max_allowed_char;
  std::string expected_sent_transcript;
  int expected_index;
};

using TranscriptSenderTest = testing::TestWithParam<TranscriptSenderTestCase>;

void VerifyTranscriptPartProto(const TranscriptPart& transcript_part,
                               int transcript_id,
                               const std::string& transcript_text,
                               int index,
                               bool is_final,
                               const std::string& language) {
  EXPECT_EQ(transcript_part.transcript_id(), transcript_id);
  EXPECT_EQ(transcript_part.text(), transcript_text);
  EXPECT_EQ(transcript_part.text_index(), index);
  EXPECT_EQ(transcript_part.is_final(), is_final);
  EXPECT_EQ(transcript_part.language(), language);
}

TEST(TranscriptSenderTest, SendOneMessageLongerThanMaxAllowed) {
  base::test::TaskEnvironment task_environment;
  const int kMaxAllowedChar = 5;
  const std::string kTranscriptText = "hello transcription";
  FakeTachyonAuthedClient authed_client;
  FakeTachyonRequestDataProvider request_data_provider;
  base::test::TestFuture<void> failure_future;
  TranscriptSender sender(&authed_client, &request_data_provider, kSenderEmail,
                          TRAFFIC_ANNOTATION_FOR_TESTS,
                          {.max_allowed_char = kMaxAllowedChar},
                          failure_future.GetCallback());

  media::SpeechRecognitionResult transcript(kTranscriptText,
                                            /*is_final=*/false);
  EXPECT_TRUE(sender.SendTranscriptionUpdate(transcript, kLanguage));
  authed_client.WaitForRequest();
  authed_client.ExecuteResponseCallback(
      InboxSendResponse().SerializeAsString());

  EXPECT_FALSE(failure_future.IsReady());
  InboxSendRequest sent_request;
  ASSERT_TRUE(sent_request.ParseFromString(authed_client.GetRequestString()));
  EXPECT_EQ(sent_request.header().auth_token_payload(),
            request_data_provider.tachyon_token());

  // dest_id
  EXPECT_EQ(sent_request.dest_id().id(), request_data_provider.group_id());
  EXPECT_EQ(sent_request.dest_id().type(), IdType::GROUP_ID);
  EXPECT_THAT(sent_request.dest_id().app(), testing::StrEq(kTachyonAppName));

  EXPECT_EQ(sent_request.fanout_sender(), MessageFanout::OTHER_SENDER_DEVICES);

  // sender_id
  EXPECT_EQ(sent_request.message().sender_id().id(), kSenderEmail);
  EXPECT_EQ(sent_request.message().sender_id().type(), IdType::EMAIL);
  EXPECT_THAT(sent_request.message().sender_id().app(),
              testing::StrEq(kTachyonAppName));

  // receiver_id
  EXPECT_EQ(sent_request.message().receiver_id().id(),
            request_data_provider.group_id());
  EXPECT_EQ(sent_request.message().receiver_id().type(), IdType::GROUP_ID);
  EXPECT_THAT(sent_request.message().receiver_id().app(),
              testing::StrEq(kTachyonAppName));

  EXPECT_EQ(sent_request.message().message_type(), InboxMessage::GROUP);
  EXPECT_EQ(sent_request.message().message_class(), InboxMessage::USER);

  BabelOrcaMessage message;
  ASSERT_TRUE(message.ParseFromString(sent_request.message().message()));
  EXPECT_EQ(message.order(), 0);
  EXPECT_FALSE(message.has_previous_transcript());
  ASSERT_TRUE(message.has_current_transcript());
  VerifyTranscriptPartProto(message.current_transcript(), /*transcript_id=*/0,
                            kTranscriptText, /*index=*/0, /*is_final=*/false,
                            kLanguage);
}

TEST(TranscriptSenderTest, SendNewTranscript) {
  base::test::TaskEnvironment task_environment;
  const int kMaxAllowedChar = 26;
  const std::string kTranscriptText = "hello1 hello2 hello3";
  FakeTachyonAuthedClient authed_client;
  std::string request_string1;
  std::string request_string2;
  FakeTachyonRequestDataProvider request_data_provider;
  base::test::TestFuture<void> failure_future;
  TranscriptSender sender(&authed_client, &request_data_provider, kSenderEmail,
                          TRAFFIC_ANNOTATION_FOR_TESTS,
                          {.max_allowed_char = kMaxAllowedChar},
                          failure_future.GetCallback());

  media::SpeechRecognitionResult transcript1(kTranscriptText,
                                             /*is_final=*/true);
  EXPECT_TRUE(sender.SendTranscriptionUpdate(transcript1, kLanguage));
  authed_client.WaitForRequest();
  request_string1 = authed_client.GetRequestString();
  authed_client.ExecuteResponseCallback(
      InboxSendResponse().SerializeAsString());

  media::SpeechRecognitionResult transcript2(kTranscriptText,
                                             /*is_final=*/false);
  EXPECT_TRUE(sender.SendTranscriptionUpdate(transcript2, kLanguage));
  authed_client.WaitForRequest();
  request_string2 = authed_client.GetRequestString();
  authed_client.ExecuteResponseCallback(
      InboxSendResponse().SerializeAsString());

  EXPECT_FALSE(failure_future.IsReady());
  InboxSendRequest sent_request1;
  BabelOrcaMessage message1;
  ASSERT_TRUE(sent_request1.ParseFromString(request_string1));
  ASSERT_TRUE(message1.ParseFromString(sent_request1.message().message()));
  EXPECT_EQ(message1.order(), 0);
  EXPECT_FALSE(message1.has_previous_transcript());
  ASSERT_TRUE(message1.has_current_transcript());
  VerifyTranscriptPartProto(message1.current_transcript(), /*transcript_id=*/0,
                            kTranscriptText, /*index=*/0, /*is_final=*/true,
                            kLanguage);

  InboxSendRequest sent_request2;
  BabelOrcaMessage message2;
  ASSERT_TRUE(sent_request2.ParseFromString(request_string2));
  ASSERT_TRUE(message2.ParseFromString(sent_request2.message().message()));
  EXPECT_EQ(message2.order(), 1);
  ASSERT_TRUE(message2.has_current_transcript());
  VerifyTranscriptPartProto(message2.current_transcript(), /*transcript_id=*/1,
                            kTranscriptText, /*index=*/0, /*is_final=*/false,
                            kLanguage);
  ASSERT_TRUE(message2.has_previous_transcript());
  VerifyTranscriptPartProto(message2.previous_transcript(), /*transcript_id=*/0,
                            "hello3", /*index=*/14, /*is_final=*/true,
                            kLanguage);
  EXPECT_EQ(message1.sender_uuid(), message2.sender_uuid());
}

TEST(TranscriptSenderTest, RejectSendingAndReplyOnMaxErrorsReached) {
  base::test::TaskEnvironment task_environment;
  const int kMaxAllowedChar = 26;
  const std::string kTranscriptText = "hello1 hello2 hello3";
  FakeTachyonAuthedClient authed_client;
  FakeTachyonRequestDataProvider request_data_provider;
  base::test::TestFuture<void> failure_future;
  TranscriptSender sender(
      &authed_client, &request_data_provider, kSenderEmail,
      TRAFFIC_ANNOTATION_FOR_TESTS,
      {.max_allowed_char = kMaxAllowedChar, .max_errors_num = 2},
      failure_future.GetCallback());

  media::SpeechRecognitionResult transcript1(kTranscriptText,
                                             /*is_final=*/true);
  EXPECT_TRUE(sender.SendTranscriptionUpdate(transcript1, kLanguage));
  authed_client.WaitForRequest();
  authed_client.ExecuteResponseCallback(base::unexpected(
      ResponseCallbackWrapper::TachyonRequestError::kHttpError));

  media::SpeechRecognitionResult transcript2(kTranscriptText,
                                             /*is_final=*/false);
  EXPECT_TRUE(sender.SendTranscriptionUpdate(transcript2, kLanguage));
  authed_client.WaitForRequest();
  authed_client.ExecuteResponseCallback(base::unexpected(
      ResponseCallbackWrapper::TachyonRequestError::kHttpError));

  media::SpeechRecognitionResult transcript3(kTranscriptText,
                                             /*is_final=*/false);
  EXPECT_FALSE(sender.SendTranscriptionUpdate(transcript3, kLanguage));

  EXPECT_TRUE(failure_future.IsReady());
}

TEST(TranscriptSenderTest, ResetErrorCountOnSuccess) {
  base::test::TaskEnvironment task_environment;
  const int kMaxAllowedChar = 26;
  const std::string kTranscriptText = "hello1 hello2 hello3";
  FakeTachyonAuthedClient authed_client;
  FakeTachyonRequestDataProvider request_data_provider;
  base::test::TestFuture<void> failure_future;
  TranscriptSender sender(
      &authed_client, &request_data_provider, kSenderEmail,
      TRAFFIC_ANNOTATION_FOR_TESTS,
      {.max_allowed_char = kMaxAllowedChar, .max_errors_num = 2},
      failure_future.GetCallback());

  // Failed request
  media::SpeechRecognitionResult transcript1(kTranscriptText,
                                             /*is_final=*/true);
  EXPECT_TRUE(sender.SendTranscriptionUpdate(transcript1, kLanguage));
  authed_client.WaitForRequest();
  authed_client.ExecuteResponseCallback(base::unexpected(
      ResponseCallbackWrapper::TachyonRequestError::kHttpError));

  // Successful request, should reset error count.
  media::SpeechRecognitionResult transcript2(kTranscriptText,
                                             /*is_final=*/false);
  EXPECT_TRUE(sender.SendTranscriptionUpdate(transcript2, kLanguage));
  authed_client.WaitForRequest();
  authed_client.ExecuteResponseCallback(
      InboxSendResponse().SerializeAsString());

  // Failed request, should not trigger failure callback since the error count
  // was reset.
  media::SpeechRecognitionResult transcript3(kTranscriptText,
                                             /*is_final=*/false);
  EXPECT_TRUE(sender.SendTranscriptionUpdate(transcript3, kLanguage));
  authed_client.WaitForRequest();
  authed_client.ExecuteResponseCallback(base::unexpected(
      ResponseCallbackWrapper::TachyonRequestError::kHttpError));

  EXPECT_FALSE(failure_future.IsReady());
}

TEST(TranscriptSenderTest, InflightRequestsAreHandledOnFailure) {
  base::test::TaskEnvironment task_environment;
  const int kMaxAllowedChar = 26;
  const std::string kTranscriptText = "hello1 hello2 hello3";
  FakeTachyonAuthedClient authed_client;
  FakeTachyonRequestDataProvider request_data_provider;
  base::test::TestFuture<void> failure_future;
  TranscriptSender sender(
      &authed_client, &request_data_provider, kSenderEmail,
      TRAFFIC_ANNOTATION_FOR_TESTS,
      {.max_allowed_char = kMaxAllowedChar, .max_errors_num = 2},
      failure_future.GetCallback());

  media::SpeechRecognitionResult transcript1(kTranscriptText,
                                             /*is_final=*/true);
  EXPECT_TRUE(sender.SendTranscriptionUpdate(transcript1, kLanguage));
  authed_client.WaitForRequest();
  std::unique_ptr<ResponseCallbackWrapper> response_cb1 =
      authed_client.TakeResponseCallback();

  media::SpeechRecognitionResult transcript2(kTranscriptText,
                                             /*is_final=*/false);
  EXPECT_TRUE(sender.SendTranscriptionUpdate(transcript2, kLanguage));
  authed_client.WaitForRequest();
  std::unique_ptr<ResponseCallbackWrapper> response_cb2 =
      authed_client.TakeResponseCallback();

  media::SpeechRecognitionResult transcript3(kTranscriptText,
                                             /*is_final=*/false);
  EXPECT_TRUE(sender.SendTranscriptionUpdate(transcript3, kLanguage));
  authed_client.WaitForRequest();
  std::unique_ptr<ResponseCallbackWrapper> response_cb3 =
      authed_client.TakeResponseCallback();

  response_cb1->Run(base::unexpected(
      ResponseCallbackWrapper::TachyonRequestError::kHttpError));
  response_cb2->Run(base::unexpected(
      ResponseCallbackWrapper::TachyonRequestError::kHttpError));

  EXPECT_TRUE(failure_future.IsReady());

  response_cb3->Run(base::unexpected(
      ResponseCallbackWrapper::TachyonRequestError::kHttpError));

  media::SpeechRecognitionResult transcript4(kTranscriptText,
                                             /*is_final=*/false);
  EXPECT_FALSE(sender.SendTranscriptionUpdate(transcript3, kLanguage));
}

TEST_P(TranscriptSenderTest, SendTwoMessages) {
  base::test::TaskEnvironment task_environment;
  FakeTachyonAuthedClient authed_client;
  InboxSendRequest sent_request2;
  FakeTachyonRequestDataProvider request_data_provider;
  TranscriptSender sender(&authed_client, &request_data_provider, kSenderEmail,
                          TRAFFIC_ANNOTATION_FOR_TESTS,
                          {.max_allowed_char = GetParam().max_allowed_char},
                          base::DoNothing());

  media::SpeechRecognitionResult transcript1(GetParam().transcript1,
                                             /*is_final=*/false);
  EXPECT_TRUE(sender.SendTranscriptionUpdate(transcript1, kLanguage));
  authed_client.WaitForRequest();
  media::SpeechRecognitionResult transcript2(GetParam().transcript2,
                                             /*is_final=*/true);
  EXPECT_TRUE(sender.SendTranscriptionUpdate(transcript2, kLanguage));
  authed_client.WaitForRequest();

  BabelOrcaMessage message2;
  ASSERT_TRUE(sent_request2.ParseFromString(authed_client.GetRequestString()));
  ASSERT_TRUE(message2.ParseFromString(sent_request2.message().message()));
  VerifyTranscriptPartProto(message2.current_transcript(), /*transcript_id=*/0,
                            GetParam().expected_sent_transcript,
                            /*index=*/GetParam().expected_index,
                            /*is_final=*/true, kLanguage);
  EXPECT_FALSE(message2.has_previous_transcript());
}

INSTANTIATE_TEST_SUITE_P(
    TranscriptSenderTestSuiteInstantiation,
    TranscriptSenderTest,
    testing::ValuesIn<TranscriptSenderTestCase>({
        {"DiffShorterThanMaxAllowed", "hello", "hello world", 7, "o world", 4},
        {"DiffLongerThanMaxAllowed", "hello", "hello world", 1, " world", 5},
    }),
    [](const testing::TestParamInfo<TranscriptSenderTest::ParamType>& info) {
      return info.param.test_name;
    });

}  // namespace
}  // namespace ash::babelorca