// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "content/browser/handwriting/handwriting_recognizer_impl_cros.h"
#include <memory>
#include <optional>
#include <string_view>
#include <utility>
#include <vector>
#include "base/functional/bind.h"
#include "base/memory/ptr_util.h"
#include "base/notreached.h"
#include "base/strings/string_util.h"
#include "mojo/public/cpp/bindings/self_owned_receiver.h"
#include "third_party/blink/public/mojom/handwriting/handwriting.mojom.h"
namespace {
// Supported language tags. At the moment, CrOS only ships two models.
static constexpr char kLanguageTagEnglish[] = "en";
static constexpr char kLanguageTagGesture[] = "zxx-x-Gesture";
// Supported model identifiers. This is passed to mlservice.
static constexpr char kModelEn[] = "en";
static constexpr char kModelGesture[] = "gesture_in_context";
// Model descriptors.
handwriting::mojom::QueryHandwritingRecognizerResultPtr
CreateEnglishModelDescriptor() {
auto desc = handwriting::mojom::QueryHandwritingRecognizerResult::New();
desc->text_alternatives = true;
desc->text_segmentation = true;
desc->hints = handwriting::mojom::HandwritingHintsQueryResult::New();
desc->hints->alternatives = true;
desc->hints->text_context = true;
desc->hints->recognition_type = {
handwriting::mojom::HandwritingRecognitionType::kText};
desc->hints->input_type = {
handwriting::mojom::HandwritingInputType::kMouse,
handwriting::mojom::HandwritingInputType::kStylus,
handwriting::mojom::HandwritingInputType::kTouch,
};
return desc;
}
handwriting::mojom::QueryHandwritingRecognizerResultPtr
CreateGestureModelDescriptor() {
auto desc = handwriting::mojom::QueryHandwritingRecognizerResult::New();
desc->text_alternatives = true;
desc->text_segmentation = false;
desc->hints = handwriting::mojom::HandwritingHintsQueryResult::New();
desc->hints->alternatives = true;
desc->hints->text_context = true;
desc->hints->recognition_type = {
handwriting::mojom::HandwritingRecognitionType::kText};
desc->hints->input_type = {
handwriting::mojom::HandwritingInputType::kMouse,
handwriting::mojom::HandwritingInputType::kStylus,
handwriting::mojom::HandwritingInputType::kTouch,
};
return desc;
}
// Returns whether the two language tags are semantically the same.
// TODO(crbug.com/40742391): We may need a better language tag matching
// method (e.g. libicu's LocaleMatcher).
bool LanguageTagsAreMatching(std::string_view a, std::string_view b) {
// Per BCP 47, language tag comparisons are case-insensitive.
return base::EqualsCaseInsensitiveASCII(a, b);
}
// Returns the model identifier (language in HandwritingRecognizerSpec) for
// ml_service backend. Returns std::nullopt if language_tag isn't supported.
std::optional<std::string> GetModelIdentifier(std::string_view language_tag) {
if (LanguageTagsAreMatching(language_tag, kLanguageTagEnglish))
return kModelEn;
if (LanguageTagsAreMatching(language_tag, kLanguageTagGesture))
return kModelGesture;
return std::nullopt;
}
} // namespace
namespace content {
namespace {
using chromeos::machine_learning::mojom::LoadHandwritingModelResult;
// The callback for `mojom::MachineLearningService::LoadHandwritingModel`
// (CrOS).
void OnModelBinding(
mojo::PendingRemote<handwriting::mojom::HandwritingRecognizer> remote,
handwriting::mojom::HandwritingRecognitionService::
CreateHandwritingRecognizerCallback callback,
LoadHandwritingModelResult result) {
if (result == LoadHandwritingModelResult::OK) {
std::move(callback).Run(
handwriting::mojom::CreateHandwritingRecognizerResult::kOk,
std::move(remote));
return;
}
switch (result) {
case LoadHandwritingModelResult::OK:
// Handled above.
NOTREACHED_IN_MIGRATION();
break;
case LoadHandwritingModelResult::FEATURE_NOT_SUPPORTED_ERROR:
case LoadHandwritingModelResult::LANGUAGE_NOT_SUPPORTED_ERROR:
case LoadHandwritingModelResult::FEATURE_DISABLED_BY_USER:
case LoadHandwritingModelResult::DLC_DOES_NOT_EXIST:
// Report as NotSupported if MLService indicates the model isn't
// available, or user doesn't want to use handwriting recognition.
std::move(callback).Run(
handwriting::mojom::CreateHandwritingRecognizerResult::kNotSupported,
mojo::NullRemote());
return;
case LoadHandwritingModelResult::DLC_GET_PATH_ERROR:
case LoadHandwritingModelResult::DLC_INSTALL_ERROR:
case LoadHandwritingModelResult::LOAD_NATIVE_LIB_ERROR:
case LoadHandwritingModelResult::LOAD_FUNC_PTR_ERROR:
case LoadHandwritingModelResult::LOAD_MODEL_FILES_ERROR:
case LoadHandwritingModelResult::LOAD_MODEL_ERROR:
case LoadHandwritingModelResult::DEPRECATED_MODEL_SPEC_ERROR:
// Report as error otherwise.
std::move(callback).Run(
handwriting::mojom::CreateHandwritingRecognizerResult::kError,
mojo::NullRemote());
return;
}
}
// The callback for `mojom::HandwritingRecognizer::Recognize` (CrOS).
void OnRecognitionResult(
CrOSHandwritingRecognizerImpl::GetPredictionCallback callback,
std::optional<std::vector<chromeos::machine_learning::web_platform::mojom::
HandwritingPredictionPtr>>
result_from_mlservice) {
if (!result_from_mlservice.has_value()) {
std::move(callback).Run(std::nullopt);
return;
}
std::vector<handwriting::mojom::HandwritingPredictionPtr> result_to_blink;
for (const auto& prediction_ml : result_from_mlservice.value()) {
auto prediction_blink = handwriting::mojom::HandwritingPrediction::New();
prediction_blink->text = prediction_ml->text;
for (const auto& segment_ml : prediction_ml->segmentation_result) {
auto segment_blink = handwriting::mojom::HandwritingSegment::New();
segment_blink->grapheme = segment_ml->grapheme;
segment_blink->begin_index = segment_ml->begin_index;
segment_blink->end_index = segment_ml->end_index;
for (const auto& drawing_segment_ml : segment_ml->drawing_segments) {
auto drawing_segment_blink =
handwriting::mojom::HandwritingDrawingSegment::New();
drawing_segment_blink->stroke_index = drawing_segment_ml->stroke_index;
drawing_segment_blink->begin_point_index =
drawing_segment_ml->begin_point_index;
drawing_segment_blink->end_point_index =
drawing_segment_ml->end_point_index;
segment_blink->drawing_segments.push_back(
std::move(drawing_segment_blink));
}
prediction_blink->segmentation_result.push_back(std::move(segment_blink));
}
result_to_blink.push_back(std::move(prediction_blink));
}
std::move(callback).Run(std::move(result_to_blink));
}
} // namespace
// static
void CrOSHandwritingRecognizerImpl::Create(
handwriting::mojom::HandwritingModelConstraintPtr constraint_blink,
handwriting::mojom::HandwritingRecognitionService::
CreateHandwritingRecognizerCallback callback) {
// On CrOS, only one language is supported.
if (constraint_blink->languages.size() != 1) {
std::move(callback).Run(
handwriting::mojom::CreateHandwritingRecognizerResult::kNotSupported,
mojo::NullRemote());
return;
}
std::optional<std::string> model_spec_language =
GetModelIdentifier(constraint_blink->languages[0]);
if (!model_spec_language) {
std::move(callback).Run(
handwriting::mojom::CreateHandwritingRecognizerResult::kNotSupported,
mojo::NullRemote());
return;
}
mojo::PendingRemote<
chromeos::machine_learning::web_platform::mojom::HandwritingRecognizer>
cros_remote;
auto cros_receiver = cros_remote.InitWithNewPipeAndPassReceiver();
auto impl = base::WrapUnique(
new CrOSHandwritingRecognizerImpl(std::move(cros_remote)));
mojo::PendingRemote<handwriting::mojom::HandwritingRecognizer>
renderer_remote;
mojo::MakeSelfOwnedReceiver<handwriting::mojom::HandwritingRecognizer>(
std::move(impl), renderer_remote.InitWithNewPipeAndPassReceiver());
auto constraint_ml = chromeos::machine_learning::web_platform::mojom::
HandwritingModelConstraint::New();
constraint_ml->languages.push_back(model_spec_language.value());
chromeos::machine_learning::ServiceConnection::GetInstance()
->GetMachineLearningService()
.LoadWebPlatformHandwritingModel(
std::move(constraint_ml), std::move(cros_receiver),
base::BindOnce(&OnModelBinding, std::move(renderer_remote),
std::move(callback)));
}
// static
bool CrOSHandwritingRecognizerImpl::SupportsLanguageTag(
std::string_view language_tag) {
return GetModelIdentifier(language_tag).has_value();
}
CrOSHandwritingRecognizerImpl::CrOSHandwritingRecognizerImpl(
mojo::PendingRemote<
chromeos::machine_learning::web_platform::mojom::HandwritingRecognizer>
pending_remote)
: remote_cros_(std::move(pending_remote)) {}
CrOSHandwritingRecognizerImpl::~CrOSHandwritingRecognizerImpl() = default;
void CrOSHandwritingRecognizerImpl::GetPrediction(
std::vector<handwriting::mojom::HandwritingStrokePtr> strokes_blink,
handwriting::mojom::HandwritingHintsPtr hints_blink,
GetPredictionCallback callback) {
std::vector<
chromeos::machine_learning::web_platform::mojom::HandwritingStrokePtr>
strokes_ml;
for (const auto& stroke_blink : strokes_blink) {
auto stroke_ml = chromeos::machine_learning::web_platform::mojom::
HandwritingStroke::New();
for (const auto& point_blink : stroke_blink->points) {
auto point_ml = chromeos::machine_learning::web_platform::mojom::
HandwritingPoint::New();
point_ml->location.set_x(point_blink->location.x());
point_ml->location.set_y(point_blink->location.y());
point_ml->t = point_blink->t;
stroke_ml->points.push_back(std::move(point_ml));
}
strokes_ml.push_back(std::move(stroke_ml));
}
auto hints_ml =
chromeos::machine_learning::web_platform::mojom::HandwritingHints::New();
hints_ml->recognition_type = hints_blink->recognition_type;
hints_ml->input_type = hints_blink->input_type;
hints_ml->text_context = hints_blink->text_context;
hints_ml->alternatives = hints_blink->alternatives;
remote_cros_->GetPrediction(
std::move(strokes_ml), std::move(hints_ml),
base::BindOnce(&OnRecognitionResult, std::move(callback)));
}
// static
handwriting::mojom::QueryHandwritingRecognizerResultPtr
CrOSHandwritingRecognizerImpl::GetModelDescriptor(
handwriting::mojom::HandwritingModelConstraintPtr constraint) {
if (!constraint) {
// CrOS doesn't provide a default recognizer.
return nullptr;
}
if (constraint->languages.size() != 1) {
// CrOS only supports single language recognizers.
return nullptr;
}
// TODO(https://crbug.com/1231900): Integrate with language packs instead of
// returning hard-coded values.
const auto& model_identifier = GetModelIdentifier(constraint->languages[0]);
if (model_identifier == kModelEn) {
return CreateEnglishModelDescriptor();
} else if (model_identifier == kModelGesture) {
return CreateGestureModelDescriptor();
} else {
return nullptr;
}
}
} // namespace content