chromium/chromeos/components/quick_answers/utils/language_detector_unittest.cc

// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/components/quick_answers/utils/language_detector.h"

#include <memory>
#include <string>

#include "base/functional/bind.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "chromeos/components/quick_answers/quick_answers_model.h"
#include "chromeos/components/quick_answers/utils/quick_answers_utils.h"
#include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/text_classifier.mojom.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace quick_answers {
namespace {

using ::chromeos::machine_learning::mojom::
    REMOVED_TextSuggestSelectionRequestPtr;
using ::chromeos::machine_learning::mojom::TextAnnotationRequestPtr;
using ::chromeos::machine_learning::mojom::TextClassifier;
using ::chromeos::machine_learning::mojom::TextLanguage;
using ::chromeos::machine_learning::mojom::TextLanguagePtr;

TextLanguagePtr DefaultLanguage() {
  return TextLanguage::New("en", /*confidence=*/1);
}

}  // namespace

class FakeTextClassifier
    : public chromeos::machine_learning::mojom::TextClassifier {
 public:
  FakeTextClassifier() = default;
  ~FakeTextClassifier() override = default;

  // chromeos::machine_learning::mojom::TextClassifier:
  void Annotate(TextAnnotationRequestPtr request,
                AnnotateCallback callback) override {}

  void FindLanguages(const std::string& text,
                     FindLanguagesCallback callback) override {
    std::vector<TextLanguagePtr> languages;

    const auto it = detection_results_.find(text);
    if (it == detection_results_.end()) {
      languages.push_back(DefaultLanguage());
      std::move(callback).Run(std::move(languages));
      return;
    }

    languages.push_back(it->second.Clone());
    std::move(callback).Run(std::move(languages));
  }

  void RegisterDetectionResult(std::string text, TextLanguagePtr language) {
    detection_results_[text] = std::move(language);
  }

  void REMOVED_1(REMOVED_TextSuggestSelectionRequestPtr request,
                 REMOVED_1Callback callback) override {}

 private:
  std::map<std::string, TextLanguagePtr> detection_results_;
};

class LanguageDetectorTest : public testing::Test {
 public:
  LanguageDetectorTest() : language_detector_(&text_classifier_) {}

  LanguageDetectorTest(const LanguageDetectorTest&) = delete;
  LanguageDetectorTest& operator=(const LanguageDetectorTest&) = delete;

  const std::optional<std::string>& DetectLanguage(
      const std::string& surrounding_text,
      const std::string& selected_text) {
    base::test::TestFuture<std::optional<std::string>> future;

    language_detector_.DetectLanguage(surrounding_text, selected_text,
                                      future.GetCallback());

    detected_locale_ = future.Take();

    return detected_locale_;
  }

  FakeTextClassifier* text_classifier() { return &text_classifier_; }

 private:
  base::test::TaskEnvironment task_environment_;

  std::optional<std::string> detected_locale_;

  FakeTextClassifier text_classifier_;
  LanguageDetector language_detector_;

  base::test::ScopedFeatureList scoped_feature_list_;
};

TEST_F(LanguageDetectorTest, DetectLanguage) {
  std::string selected_text = "entrada de prueba";
  std::string surrounding_text = "otros textos entrada de prueba otros textos";

  text_classifier()->RegisterDetectionResult(selected_text,
                                             TextLanguage::New("es", 0.95));
  text_classifier()->RegisterDetectionResult(surrounding_text,
                                             TextLanguage::New("es", 1));

  auto detected_locale = DetectLanguage(surrounding_text, selected_text);

  EXPECT_EQ(detected_locale, "es");
}

TEST_F(LanguageDetectorTest, DetectLanguageShouldPreferSelectedTexts) {
  std::string selected_text = "entrada de prueba";
  std::string surrounding_text = "周围的文字 entrada de prueba 周围的文字";

  text_classifier()->RegisterDetectionResult(selected_text,
                                             TextLanguage::New("es", 0.95));
  text_classifier()->RegisterDetectionResult(surrounding_text,
                                             TextLanguage::New("zh", 0.98));

  auto detected_locale = DetectLanguage(surrounding_text, selected_text);

  // Should prefer selected texts for the detected locale.
  EXPECT_EQ(detected_locale, "es");
}

TEST_F(LanguageDetectorTest, DetectLanguageFallbackToSurroundingTexts) {
  std::string selected_text = "immagini";
  std::string surrounding_text =
      "Cerca informazioni in tutto il mondo come pagine web, immagini, video, "
      "ecc.";

  text_classifier()->RegisterDetectionResult(selected_text,
                                             TextLanguage::New("la", 0.6));
  text_classifier()->RegisterDetectionResult(surrounding_text,
                                             TextLanguage::New("it", 0.98));

  auto detected_locale = DetectLanguage(surrounding_text, selected_text);

  // Should fallback to surrounding texts result since the confidence level for
  // the detected result of selected texts is too low.
  EXPECT_EQ(detected_locale, "it");
}

TEST_F(LanguageDetectorTest, DetectLanguageLowConfidence) {
  std::string selected_text = "video";
  std::string surrounding_text = "web, immagini, video, ecc";

  text_classifier()->RegisterDetectionResult(selected_text,
                                             TextLanguage::New("es", 0.4));
  text_classifier()->RegisterDetectionResult(surrounding_text,
                                             TextLanguage::New("it", 0.4));

  auto detected_locale = DetectLanguage(surrounding_text, selected_text);

  // Should return empty result since the confidence level is too low.
  EXPECT_FALSE(detected_locale.has_value());
}

}  // namespace quick_answers