chromium/chromeos/components/quick_answers/understanding/intent_generator.cc

// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/components/quick_answers/understanding/intent_generator.h"

#include <map>

#include "base/i18n/break_iterator.h"
#include "base/i18n/case_conversion.h"
#include "base/no_destructor.h"
#include "base/strings/string_split.h"
#include "base/strings/utf_string_conversions.h"
#include "chromeos/components/quick_answers/public/cpp/quick_answers_state.h"
#include "chromeos/components/quick_answers/quick_answers_model.h"
#include "chromeos/components/quick_answers/utils/quick_answers_metrics.h"
#include "chromeos/components/quick_answers/utils/quick_answers_utils.h"
#include "chromeos/components/quick_answers/utils/spell_checker.h"
#include "chromeos/components/quick_answers/utils/translation_v2_utils.h"
#include "chromeos/constants/chromeos_features.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
#include "components/translate/core/browser/translate_download_manager.h"
#include "third_party/abseil-cpp/absl/strings/ascii.h"
#include "ui/base/l10n/l10n_util.h"

namespace quick_answers {
namespace {

using ::chromeos::machine_learning::mojom::LoadModelResult;
using ::chromeos::machine_learning::mojom::TextAnnotationPtr;
using ::chromeos::machine_learning::mojom::TextAnnotationRequest;
using ::chromeos::machine_learning::mojom::TextAnnotationRequestPtr;
using ::chromeos::machine_learning::mojom::TextClassifier;

// TODO(llin): Finalize on the threshold based on user feedback.
constexpr int kUnitConversionIntentAndSelectionLengthDiffThreshold = 5;
constexpr int kTranslationTextLengthThreshold = 100;
constexpr int kRichAnswersTranslationTextLengthThreshold = 250;
constexpr int kDefinitionIntentAndSelectionLengthDiffThreshold = 2;

// TODO(b/169370175): Remove the temporary invalid set after we ramp up to v2
// model.
// Set of invalid characters for definition annonations.
constexpr char kInvalidCharactersSet[] = "()[]{}<>_&|!";

constexpr char kEnglishLanguage[] = "en";

const std::map<std::string, IntentType>& GetIntentTypeMap() {
  static base::NoDestructor<std::map<std::string, IntentType>> kIntentTypeMap(
      {{"unit", IntentType::kUnit}, {"dictionary", IntentType::kDictionary}});
  return *kIntentTypeMap;
}

bool ExtractEntity(const std::string& selected_text,
                   const std::vector<TextAnnotationPtr>& annotations,
                   std::string* entity_str,
                   std::string* type) {
  for (auto& annotation : annotations) {
    // The offset in annotation result is by chars instead of by bytes. Converts
    // to string16 to support extracting substring from string with UTF-16
    // characters.
    *entity_str = base::UTF16ToUTF8(
        base::UTF8ToUTF16(selected_text)
            .substr(annotation->start_offset,
                    annotation->end_offset - annotation->start_offset));

    // Use the first entity type.
    auto intent_type_map = GetIntentTypeMap();
    for (const auto& entity : annotation->entities) {
      if (intent_type_map.find(entity->name) != intent_type_map.end()) {
        *type = entity->name;
        return true;
      }
    }
  }

  return false;
}

IntentType RewriteIntent(const std::string& selected_text,
                         const std::string& entity_str,
                         const IntentType intent) {
  int intent_and_selection_length_diff =
      base::UTF8ToUTF16(selected_text).length() -
      base::UTF8ToUTF16(entity_str).length();
  if ((intent == IntentType::kUnit &&
       intent_and_selection_length_diff >
           kUnitConversionIntentAndSelectionLengthDiffThreshold) ||
      (intent == IntentType::kDictionary &&
       intent_and_selection_length_diff >
           kDefinitionIntentAndSelectionLengthDiffThreshold)) {
    // Override intent type to |kUnknown| if length diff between intent
    // text and selection text is above the threshold.
    return IntentType::kUnknown;
  }

  return intent;
}

bool IsPreferredLanguage(const std::string& detected_language) {
  auto preferred_languages_list =
      base::SplitString(QuickAnswersState::Get()->preferred_languages(), ",",
                        base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);

  for (const std::string& locale : preferred_languages_list) {
    if (l10n_util::GetLanguage(locale) == detected_language)
      return true;
  }
  return false;
}

// TODO(b/169370175): There is an issue with text classifier that
// concatenated words are annotated as definitions. Before we switch to v2
// model, skip such kind of queries for definition annotation for now.
bool ShouldSkipDefinition(const std::string& text) {
  // Skip definition annotations if English is not device language or user
  // preferred language (Currently the text classifier only works with English
  // words).
  auto device_language =
      l10n_util::GetLanguage(QuickAnswersState::Get()->application_locale());
  if (device_language != kEnglishLanguage &&
      !IsPreferredLanguage(kEnglishLanguage))
    return true;

  DCHECK(text.length());
  // Skip the query for definition annotation if the selected text contains
  // capitalized characters in the middle and not all capitalized.
  const auto& text_utf16 = base::UTF8ToUTF16(text);
  bool has_capitalized_middle_characters =
      text_utf16.substr(1) != base::i18n::ToLower(text_utf16.substr(1));
  bool are_all_characters_capitalized =
      text_utf16 == base::i18n::ToUpper(text_utf16);
  if (has_capitalized_middle_characters && !are_all_characters_capitalized)
    return true;
  // Skip the query for definition annotation if the selected text contains
  // invalid characters.
  if (text.find_first_of(kInvalidCharactersSet) != std::string::npos)
    return true;

  return false;
}

// Check that both the source and target languages are supported by the
// translation v2 API.
bool AreTranslationLanguagesSupported(const std::string& source_language,
                                      const std::string& target_language) {
  return TranslationV2Utils::IsSupported(source_language) &&
         TranslationV2Utils::IsSupported(target_language);
}

bool HasDigits(const std::string& word) {
  for (char c : word) {
    if (absl::ascii_isdigit(static_cast<unsigned char>(c))) {
      return true;
    }
  }
  return false;
}

}  // namespace

IntentGenerator::IntentGenerator(base::WeakPtr<SpellChecker> spell_checker,
                                 IntentGeneratorCallback complete_callback)
    : spell_checker_(std::move(spell_checker)),
      complete_callback_(std::move(complete_callback)) {}

IntentGenerator::~IntentGenerator() {
  if (complete_callback_)
    std::move(complete_callback_)
        .Run(IntentInfo(std::string(), IntentType::kUnknown));
}

void IntentGenerator::GenerateIntent(const QuickAnswersRequest& request) {
  const std::u16string& u16_text = base::UTF8ToUTF16(request.selected_text);
  base::i18n::BreakIterator iter(u16_text,
                                 base::i18n::BreakIterator::BREAK_WORD);
  if (!iter.Init() || !iter.Advance()) {
    NOTREACHED_IN_MIGRATION() << "Failed to load BreakIterator.";

    std::move(complete_callback_)
        .Run(IntentInfo(request.selected_text, IntentType::kUnknown));
    return;
  }

  DCHECK(spell_checker_.get()) << "spell_checker_ should exist when the "
                                  "always trigger feature is enabled";
  // Check spelling if the selected text is a valid single word.
  if (iter.IsWord() && iter.prev() == 0 && iter.pos() == u16_text.length()) {
    // Search server do not provide useful information for proper nouns and
    // abbreviations (such as "Amy" and "ASAP"). Check spelling of the word in
    // lower case to filter out such cases.
    auto text = base::UTF16ToUTF8(
        base::i18n::ToLower(base::UTF8ToUTF16(request.selected_text)));
    spell_checker_->CheckSpelling(
        text, base::BindOnce(&IntentGenerator::CheckSpellingCallback,
                             weak_factory_.GetWeakPtr(), request));
    return;
  }

  // Fallback to text classifier.
  MaybeLoadTextClassifier(request);
}

void IntentGenerator::FlushForTesting() {
  text_classifier_.FlushForTesting();
}

void IntentGenerator::MaybeLoadTextClassifier(
    const QuickAnswersRequest& request) {
  if (QuickAnswersState::Get()->ShouldUseQuickAnswersTextAnnotator()) {
    // Load text classifier.
    chromeos::machine_learning::ServiceConnection::GetInstance()
        ->GetMachineLearningService()
        .LoadTextClassifier(
            text_classifier_.BindNewPipeAndPassReceiver(),
            base::BindOnce(&IntentGenerator::LoadModelCallback,
                           weak_factory_.GetWeakPtr(), request));
    return;
  }

  std::move(complete_callback_)
      .Run(IntentInfo(request.selected_text, IntentType::kUnknown));
}

void IntentGenerator::CheckSpellingCallback(const QuickAnswersRequest& request,
                                            bool correctness,
                                            const std::string& language) {
  // Generate dictionary intent if the selected word passed spell check.
  // The dictionaries treat digits as valid words, while we will not be able to
  // grab any useful information from the Search server for words like that.
  // Thus we filter out the words containing digits. We still fallback to the
  // text classifier for unit conversion intent.
  if (correctness && !HasDigits(request.selected_text)) {
    std::move(complete_callback_)
        .Run(IntentInfo(request.selected_text, IntentType::kDictionary,
                        QuickAnswersState::Get()->application_locale(),
                        language));

    // Record intent source type and language for dictionary intent.
    RecordDictionaryIntentSource(DictionaryIntentSource::kHunspell);
    RecordDictionaryIntentLanguage(language);
    return;
  }

  // If the selected word did not pass spell check, fallback to the text
  // classifier. We may generate other intent type as well as definition intent
  // if the word is not covered in the dictionary but in the model.
  MaybeLoadTextClassifier(request);
}

void IntentGenerator::LoadModelCallback(const QuickAnswersRequest& request,
                                        LoadModelResult result) {
  if (result != LoadModelResult::OK) {
    LOG(ERROR) << "Failed to load TextClassifier.";
    std::move(complete_callback_)
        .Run(IntentInfo(request.selected_text, IntentType::kUnknown));
    return;
  }

  if (text_classifier_) {
    TextAnnotationRequestPtr text_annotation_request =
        TextAnnotationRequest::New();

    text_annotation_request->text = request.selected_text;
    text_annotation_request->default_locales =
        QuickAnswersState::Get()->application_locale();
    text_annotation_request->trigger_dictionary_on_beginner_words = true;

    text_classifier_->Annotate(
        std::move(text_annotation_request),
        base::BindOnce(&IntentGenerator::AnnotationCallback,
                       weak_factory_.GetWeakPtr(), request));
  }
}

void IntentGenerator::AnnotationCallback(
    const QuickAnswersRequest& request,
    std::vector<TextAnnotationPtr> annotations) {
  std::string entity_str;
  std::string type;

  if (ExtractEntity(request.selected_text, annotations, &entity_str, &type)) {
    auto intent_type_map = GetIntentTypeMap();
    auto it = intent_type_map.find(type);
    if (it != intent_type_map.end()) {
      // Skip the entity if the corresponding intent type is ineligible.
      bool definition_ineligible =
          !QuickAnswersState::IsIntentEligible(Intent::kDefinition);
      bool unit_conversion_ineligible =
          !QuickAnswersState::IsIntentEligible(Intent::kUnitConversion);
      if ((it->second == IntentType::kDictionary && definition_ineligible) ||
          (it->second == IntentType::kUnit && unit_conversion_ineligible)) {
        // Fallback to language detection for generating translation intent.
        MaybeGenerateTranslationIntent(request);
        return;
      }
      // Skip the entity for definition annonation.
      if (it->second == IntentType::kDictionary &&
          ShouldSkipDefinition(request.selected_text)) {
        // Fallback to language detection for generating translation intent.
        MaybeGenerateTranslationIntent(request);
        return;
      }
      std::move(complete_callback_)
          .Run(IntentInfo(
              entity_str,
              RewriteIntent(request.selected_text, entity_str, it->second),
              QuickAnswersState::Get()->application_locale()));

      // Record intent source type and language for dictionary intent.
      if (it->second == IntentType::kDictionary) {
        RecordDictionaryIntentSource(DictionaryIntentSource::kTextClassifier);
        // Record the English language since currently the text classifier only
        // works with English words.
        RecordDictionaryIntentLanguage(kEnglishLanguage);
      }
      return;
    }
  }
  // Fallback to language detection for generating translation intent.
  MaybeGenerateTranslationIntent(request);
}

void IntentGenerator::MaybeGenerateTranslationIntent(
    const QuickAnswersRequest& request) {
  DCHECK(complete_callback_);

  if (!QuickAnswersState::IsIntentEligible(Intent::kTranslation) ||
      chromeos::features::IsQuickAnswersV2TranslationDisabled()) {
    std::move(complete_callback_)
        .Run(IntentInfo(request.selected_text, IntentType::kUnknown));
    return;
  }

  size_t translation_text_length_threshold =
      chromeos::features::IsQuickAnswersRichCardEnabled()
          ? kRichAnswersTranslationTextLengthThreshold
          : kTranslationTextLengthThreshold;
  // Don't generate translation intent if no device language is provided or the
  // length of selected text is above the threshold. Returns unknown intent
  // type.
  if (QuickAnswersState::Get()->application_locale().empty() ||
      request.selected_text.length() > translation_text_length_threshold) {
    std::move(complete_callback_)
        .Run(IntentInfo(request.selected_text, IntentType::kUnknown));
    return;
  }

  language_detector_ =
      std::make_unique<LanguageDetector>(text_classifier_.get());
  language_detector_->DetectLanguage(
      request.context.surrounding_text, request.selected_text,
      base::BindOnce(&IntentGenerator::LanguageDetectorCallback,
                     weak_factory_.GetWeakPtr(), request));
}

void IntentGenerator::LanguageDetectorCallback(
    const QuickAnswersRequest& request,
    std::optional<std::string> detected_locale) {
  language_detector_.reset();

  auto device_language =
      l10n_util::GetLanguage(QuickAnswersState::Get()->application_locale());
  auto detected_language = detected_locale.has_value()
                               ? l10n_util::GetLanguage(detected_locale.value())
                               : std::string();

  // Generate translation intent if the detected language is different to the
  // system language and is not one of the preferred languages.
  // Skip translation if the source or target languages are not supported.
  if (!detected_language.empty() && detected_language != device_language &&
      !IsPreferredLanguage(detected_language) &&
      AreTranslationLanguagesSupported(detected_language, device_language)) {
    std::move(complete_callback_)
        .Run(IntentInfo(request.selected_text, IntentType::kTranslation,
                        device_language, detected_language));
    return;
  }

  std::move(complete_callback_)
      .Run(IntentInfo(request.selected_text, IntentType::kUnknown));
}

}  // namespace quick_answers