chromium/chrome/browser/ash/input_method/grammar_service_client.cc

// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/input_method/grammar_service_client.h"

#include "base/strings/string_util.h"
#include "base/strings/utf_offset_string_conversions.h"
#include "base/strings/utf_string_conversions.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "components/prefs/pref_service.h"
#include "components/spellcheck/browser/pref_names.h"
#include "components/user_prefs/user_prefs.h"
#include "ui/gfx/range/range.h"

namespace ash {
namespace input_method {
namespace {

using ::chromeos::machine_learning::mojom::GrammarCheckerQuery;
using ::chromeos::machine_learning::mojom::GrammarCheckerQueryPtr;
using ::chromeos::machine_learning::mojom::GrammarCheckerResult;
using ::chromeos::machine_learning::mojom::GrammarCheckerResultPtr;
using ::chromeos::machine_learning::mojom::LoadModelResult;
using ::chromeos::machine_learning::mojom::TextLanguagePtr;

const uint32_t kMaxQueryLength = 200;
const uint32_t kMinQueryLength = 5;
const double kLanguageConfidenceThreshold = 0.9;
const char kEnglishLocale[] = "en";

}  // namespace

GrammarServiceClient::GrammarServiceClient() {
  weak_this_ = weak_factory_.GetWeakPtr();
}

GrammarServiceClient::~GrammarServiceClient() = default;

void GrammarServiceClient::OnLoadGrammarCheckerDone(
    GrammarCheckerQueryPtr query,
    const std::string& query_text,
    TextCheckCompleteCallback callback,
    LoadModelResult result) {
  grammar_checker_loaded_ = result == LoadModelResult::OK;
  if (!grammar_checker_loaded_) {
    std::move(callback).Run(false, {});
    return;
  }
  grammar_checker_->Check(
      std::move(query),
      base::BindOnce(&GrammarServiceClient::ParseGrammarCheckerResult,
                     weak_this_, query_text, std::move(callback)));
}

void GrammarServiceClient::OnLoadTextClassifierDone(
    const std::string& query_text,
    TextCheckCompleteCallback callback,
    LoadModelResult result) {
  text_classifier_loaded_ = result == LoadModelResult::OK;
  if (!text_classifier_loaded_) {
    std::move(callback).Run(false, {});
    return;
  }
  text_classifier_->FindLanguages(
      query_text, base::BindOnce(&GrammarServiceClient::OnLanguageDetectionDone,
                                 weak_this_, query_text, std::move(callback)));
}

bool GrammarServiceClient::RequestTextCheck(
    Profile* profile,
    const std::u16string& text,
    TextCheckCompleteCallback callback) {
  if (!profile || !IsAvailable(profile) || text.size() > kMaxQueryLength ||
      text.size() < kMinQueryLength) {
    std::move(callback).Run(false, {});
    return false;
  }

  if (text_classifier_loaded_) {
    text_classifier_->FindLanguages(
        base::UTF16ToUTF8(text),
        base::BindOnce(&GrammarServiceClient::OnLanguageDetectionDone,
                       weak_this_, base::UTF16ToUTF8(text),
                       std::move(callback)));
    return true;
  }

  if (!text_classifier_.is_bound()) {
    chromeos::machine_learning::ServiceConnection::GetInstance()
        ->GetMachineLearningService()
        .LoadTextClassifier(
            text_classifier_.BindNewPipeAndPassReceiver(),
            base::BindOnce(&GrammarServiceClient::OnLoadTextClassifierDone,
                           weak_this_, base::UTF16ToUTF8(text),
                           std::move(callback)));
    return true;
  }

  std::move(callback).Run(false, {});
  return false;
}

void GrammarServiceClient::OnLanguageDetectionDone(
    const std::string& query_text,
    TextCheckCompleteCallback callback,
    std::vector<TextLanguagePtr> languages) {
  if (languages.empty() ||
      languages[0]->confidence < kLanguageConfidenceThreshold ||
      languages[0]->locale != kEnglishLocale) {
    std::move(callback).Run(false, {});
    return;
  }

  auto query = GrammarCheckerQuery::New();
  query->text = query_text;
  query->language = languages[0]->locale;

  if (grammar_checker_loaded_) {
    grammar_checker_->Check(
        std::move(query),
        base::BindOnce(&GrammarServiceClient::ParseGrammarCheckerResult,
                       weak_this_, query_text, std::move(callback)));
    return;
  }

  if (!grammar_checker_.is_bound()) {
    chromeos::machine_learning::ServiceConnection::GetInstance()
        ->GetMachineLearningService()
        .LoadGrammarChecker(
            grammar_checker_.BindNewPipeAndPassReceiver(),
            base::BindOnce(&GrammarServiceClient::OnLoadGrammarCheckerDone,
                           weak_this_, std::move(query), query_text,
                           std::move(callback)));
    return;
  }

  std::move(callback).Run(false, {});
}

void GrammarServiceClient::ParseGrammarCheckerResult(
    const std::string& query_text,
    TextCheckCompleteCallback callback,
    GrammarCheckerResultPtr result) const {
  if (result->status == GrammarCheckerResult::Status::OK &&
      !result->candidates.empty()) {
    const auto& top_candidate = result->candidates.front();
    if (!top_candidate->text.empty() && !top_candidate->fragments.empty()) {
      std::vector<ui::GrammarFragment> grammar_results;
      for (const auto& fragment : top_candidate->fragments) {
        uint32_t end;
        if (!base::CheckAdd(fragment->offset, fragment->length)
                 .AssignIfValid(&end) ||
            end > query_text.size()) {
          DLOG(ERROR) << "Grammar checker returns invalid correction "
                         "fragment, offset: "
                      << fragment->offset << ", length: " << fragment->length
                      << ", but the text length is " << query_text.size();
        } else {
          // Compute the offsets in string16.
          std::vector<size_t> offsets = {fragment->offset, end};
          base::UTF8ToUTF16AndAdjustOffsets(query_text, &offsets);
          grammar_results.emplace_back(gfx::Range(offsets[0], offsets[1]),
                                       fragment->replacement);
        }
      }
      std::move(callback).Run(true, grammar_results);
      return;
    }
  }
  std::move(callback).Run(false, {});
}

bool GrammarServiceClient::IsAvailable(Profile* profile) const {
  const PrefService* pref = profile->GetPrefs();
  DCHECK(pref);
  // If prefs don't allow spell checking, if the profile is off the record, the
  // grammar service should be unavailable.
  return pref->GetBoolean(spellcheck::prefs::kSpellCheckEnable) &&
         !profile->IsOffTheRecord();
}

}  // namespace input_method
}  // namespace ash