chromium/chrome/browser/ash/input_method/suggestions_service_client_unittest.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/input_method/suggestions_service_client.h"

#include "base/run_loop.h"
#include "base/test/bind.h"
#include "base/test/metrics/histogram_tester.h"
#include "chrome/browser/ash/input_method/suggestion_enums.h"
#include "chromeos/services/machine_learning/public/cpp/fake_service_connection.h"
#include "chromeos/services/machine_learning/public/mojom/text_suggester.mojom.h"
#include "content/public/test/browser_task_environment.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace ash {
namespace input_method {
namespace {

namespace machine_learning = ::chromeos::machine_learning;

using ime::AssistiveSuggestion;
using ime::AssistiveSuggestionMode;
using ime::AssistiveSuggestionType;
using ime::DecoderCompletionCandidate;

machine_learning::mojom::TextSuggesterResultPtr NoCandidate() {
  auto result = machine_learning::mojom::TextSuggesterResult::New();
  result->status = machine_learning::mojom::TextSuggesterResult::Status::OK;
  return result;
}

machine_learning::mojom::TextSuggesterResultPtr SingleCandidate(
    const std::string& text,
    float score) {
  auto result = machine_learning::mojom::TextSuggesterResult::New();
  result->status = machine_learning::mojom::TextSuggesterResult::Status::OK;
  auto multi_word = machine_learning::mojom::MultiWordSuggestionCandidate::New(
      /*text=*/text, /*normalized_score=*/score);
  result->candidates.emplace_back(
      machine_learning::mojom::TextSuggestionCandidate::NewMultiWord(
          std::move(multi_word)));
  return result;
}

class SuggestionsServiceClientTest : public testing::Test {
 public:
  SuggestionsServiceClientTest() {
    machine_learning::ServiceConnection::UseFakeServiceConnectionForTesting(
        &fake_service_connection_);
    machine_learning::ServiceConnection::GetInstance()->Initialize();
    // After initializing a client, we need to wait for any pending tasks to
    // resolve (ie. the task to connect to the fake service connection).
    client_ = std::make_unique<SuggestionsServiceClient>();
    base::RunLoop().RunUntilIdle();
  }

 protected:
  void SetTextSuggesterResult(
      machine_learning::mojom::TextSuggesterResultPtr result) {
    fake_service_connection_.SetOutputTextSuggesterResult(std::move(result));
  }

  void WaitForResults() { base::RunLoop().RunUntilIdle(); }

  SuggestionsServiceClient* client() { return client_.get(); }

 private:
  content::BrowserTaskEnvironment task_environment_;
  machine_learning::FakeServiceConnectionImpl fake_service_connection_;
  std::unique_ptr<SuggestionsServiceClient> client_;
};

TEST_F(SuggestionsServiceClientTest, ReturnsCompletionResultsFromMojoService) {
  SetTextSuggesterResult(SingleCandidate("hi there completion", 0.5f));

  std::vector<AssistiveSuggestion> returned_results;
  client()->RequestSuggestions(
      /*preceding_text=*/"this is some text",
      /*suggestion_mode=*/AssistiveSuggestionMode::kCompletion,
      /*completion_candidates=*/std::vector<DecoderCompletionCandidate>{},
      /*callback=*/
      base::BindLambdaForTesting(
          [&](const std::vector<AssistiveSuggestion>& results) {
            returned_results = results;
          }));
  WaitForResults();

  std::vector<AssistiveSuggestion> expected_results = {
      AssistiveSuggestion{.mode = AssistiveSuggestionMode::kCompletion,
                          .type = AssistiveSuggestionType::kMultiWord,
                          .text = "hi there completion"},
  };

  EXPECT_EQ(returned_results, expected_results);
}

TEST_F(SuggestionsServiceClientTest, ReturnsPredictionResultsFromMojoService) {
  SetTextSuggesterResult(SingleCandidate("hi there prediction", 0.5f));

  std::vector<AssistiveSuggestion> returned_results;
  client()->RequestSuggestions(
      /*preceding_text=*/"this is some text",
      /*suggestion_mode=*/AssistiveSuggestionMode::kPrediction,
      /*completion_candidates=*/std::vector<DecoderCompletionCandidate>{},
      /*callback=*/
      base::BindLambdaForTesting(
          [&](const std::vector<AssistiveSuggestion>& results) {
            returned_results = results;
          }));
  WaitForResults();

  std::vector<AssistiveSuggestion> expected_results = {
      AssistiveSuggestion{.mode = AssistiveSuggestionMode::kPrediction,
                          .type = AssistiveSuggestionType::kMultiWord,
                          .text = "hi there prediction"},
  };

  EXPECT_EQ(returned_results, expected_results);
}

TEST_F(SuggestionsServiceClientTest, RecordsCandidateGenerationTimePerRequest) {
  SetTextSuggesterResult(SingleCandidate("hi there prediction", 0.5f));

  base::HistogramTester histogram_tester;
  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.CandidateGenerationTime.MultiWord", 0);

  client()->RequestSuggestions(
      /*preceding_text=*/"this is some text",
      /*suggestion_mode=*/AssistiveSuggestionMode::kPrediction,
      /*completion_candidates=*/std::vector<DecoderCompletionCandidate>{},
      /*callback=*/
      base::BindLambdaForTesting(
          [&](const std::vector<AssistiveSuggestion>& results) {}));
  WaitForResults();

  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.CandidateGenerationTime.MultiWord", 1);
}

TEST_F(SuggestionsServiceClientTest, RecordsPrecedingTextLengthPerRequest) {
  SetTextSuggesterResult(SingleCandidate("hi there prediction", 0.5f));
  std::string preceding_text =
      "This is some text that is very long, so long in fact it should be "
      "greater then 100 chars which is the limit currently set when "
      "trimming text sent to the suggestion service.";

  base::HistogramTester histogram_tester;
  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.PrecedingTextLength", 0);

  client()->RequestSuggestions(
      /*preceding_text=*/preceding_text,
      /*suggestion_mode=*/AssistiveSuggestionMode::kPrediction,
      /*completion_candidates=*/std::vector<DecoderCompletionCandidate>{},
      /*callback=*/
      base::BindLambdaForTesting(
          [&](const std::vector<AssistiveSuggestion>& results) {}));
  WaitForResults();

  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.PrecedingTextLength", 1);
  histogram_tester.ExpectUniqueSample(
      "InputMethod.Assistive.MultiWord.PrecedingTextLength",
      /*sample=*/preceding_text.size(), /*expected_bucket_count=*/1);
}

TEST_F(SuggestionsServiceClientTest, RecordsRequestCandidatesForCompletion) {
  SetTextSuggesterResult(SingleCandidate("hi there completion", 0.5f));

  base::HistogramTester histogram_tester;
  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.RequestCandidates", 0);

  client()->RequestSuggestions(
      /*preceding_text=*/"hello",
      /*suggestion_mode=*/AssistiveSuggestionMode::kCompletion,
      /*completion_candidates=*/std::vector<DecoderCompletionCandidate>{},
      /*callback=*/
      base::BindLambdaForTesting(
          [&](const std::vector<AssistiveSuggestion>& results) {}));
  WaitForResults();

  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.RequestCandidates", 1);
  histogram_tester.ExpectUniqueSample(
      "InputMethod.Assistive.MultiWord.RequestCandidates",
      /*sample=*/MultiWordSuggestionType::kCompletion,
      /*expected_bucket_count=*/1);
}

TEST_F(SuggestionsServiceClientTest, RecordsRequestCandidatesForPrediction) {
  SetTextSuggesterResult(SingleCandidate("hi there prediction", 0.5f));

  base::HistogramTester histogram_tester;
  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.RequestCandidates", 0);

  client()->RequestSuggestions(
      /*preceding_text=*/"hello",
      /*suggestion_mode=*/AssistiveSuggestionMode::kPrediction,
      /*completion_candidates=*/std::vector<DecoderCompletionCandidate>{},
      /*callback=*/
      base::BindLambdaForTesting(
          [&](const std::vector<AssistiveSuggestion>& results) {}));
  WaitForResults();

  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.RequestCandidates", 1);
  histogram_tester.ExpectUniqueSample(
      "InputMethod.Assistive.MultiWord.RequestCandidates",
      /*sample=*/MultiWordSuggestionType::kPrediction,
      /*expected_bucket_count=*/1);
}

TEST_F(SuggestionsServiceClientTest,
       DoesNotRecordCandidatesGeneratedWhenNoneReturnedForPrediction) {
  SetTextSuggesterResult(NoCandidate());

  base::HistogramTester histogram_tester;
  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.CandidatesGenerated", 0);

  client()->RequestSuggestions(
      /*preceding_text=*/"hello",
      /*suggestion_mode=*/AssistiveSuggestionMode::kPrediction,
      /*completion_candidates=*/std::vector<DecoderCompletionCandidate>{},
      /*callback=*/
      base::BindLambdaForTesting(
          [&](const std::vector<AssistiveSuggestion>& results) {}));
  WaitForResults();

  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.CandidatesGenerated", 0);
}

TEST_F(SuggestionsServiceClientTest,
       DoesNotRecordCandidatesGeneratedWhenNoneReturnedForCompletion) {
  SetTextSuggesterResult(NoCandidate());

  base::HistogramTester histogram_tester;
  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.CandidatesGenerated", 0);

  client()->RequestSuggestions(
      /*preceding_text=*/"hello",
      /*suggestion_mode=*/AssistiveSuggestionMode::kCompletion,
      /*completion_candidates=*/std::vector<DecoderCompletionCandidate>{},
      /*callback=*/
      base::BindLambdaForTesting(
          [&](const std::vector<AssistiveSuggestion>& results) {}));
  WaitForResults();

  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.CandidatesGenerated", 0);
}

TEST_F(SuggestionsServiceClientTest,
       RecordsCandidatesGeneratedWhenCandidateReturnedForPrediction) {
  SetTextSuggesterResult(SingleCandidate("hi there prediction", 0.5f));

  base::HistogramTester histogram_tester;
  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.CandidatesGenerated", 0);

  client()->RequestSuggestions(
      /*preceding_text=*/"hello",
      /*suggestion_mode=*/AssistiveSuggestionMode::kPrediction,
      /*completion_candidates=*/std::vector<DecoderCompletionCandidate>{},
      /*callback=*/
      base::BindLambdaForTesting(
          [&](const std::vector<AssistiveSuggestion>& results) {}));
  WaitForResults();

  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.CandidatesGenerated", 1);
  histogram_tester.ExpectUniqueSample(
      "InputMethod.Assistive.MultiWord.CandidatesGenerated",
      /*sample=*/MultiWordSuggestionType::kPrediction,
      /*expected_bucket_count=*/1);
}

TEST_F(SuggestionsServiceClientTest,
       RecordsCandidatesGeneratedWhenCandidateReturnedForCompletion) {
  SetTextSuggesterResult(SingleCandidate("hi there completion", 0.5f));

  base::HistogramTester histogram_tester;
  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.CandidatesGenerated", 0);

  client()->RequestSuggestions(
      /*preceding_text=*/"hello",
      /*suggestion_mode=*/AssistiveSuggestionMode::kCompletion,
      /*completion_candidates=*/std::vector<DecoderCompletionCandidate>{},
      /*callback=*/
      base::BindLambdaForTesting(
          [&](const std::vector<AssistiveSuggestion>& results) {}));
  WaitForResults();

  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.CandidatesGenerated", 1);
  histogram_tester.ExpectUniqueSample(
      "InputMethod.Assistive.MultiWord.CandidatesGenerated",
      /*sample=*/MultiWordSuggestionType::kCompletion,
      /*expected_bucket_count=*/1);
}

TEST_F(SuggestionsServiceClientTest,
       RecordsEmptyCandidateTextWhenCandidateTextMissing) {
  SetTextSuggesterResult(SingleCandidate("hi there completion", 0.5f));

  base::HistogramTester histogram_tester;
  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.EmptyCandidate", 0);

  client()->RequestSuggestions(
      /*preceding_text=*/"hel",
      /*suggestion_mode=*/AssistiveSuggestionMode::kCompletion,
      /*completion_candidates=*/
      std::vector<DecoderCompletionCandidate>{
          DecoderCompletionCandidate{"hello", 0.1f},
          DecoderCompletionCandidate{"", 0.01f},
          DecoderCompletionCandidate{"", 0.001f},
      },
      /*callback=*/
      base::BindLambdaForTesting(
          [&](const std::vector<AssistiveSuggestion>& results) {}));
  WaitForResults();

  histogram_tester.ExpectTotalCount(
      "InputMethod.Assistive.MultiWord.EmptyCandidate", 2);
  histogram_tester.ExpectUniqueSample(
      "InputMethod.Assistive.MultiWord.EmptyCandidate",
      /*sample=*/MultiWordSuggestionType::kCompletion,
      /*expected_bucket_count=*/2);
}

}  // namespace
}  // namespace input_method
}  // namespace ash