// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "content/browser/handwriting/handwriting_recognition_service_impl_cros.h"
#include <optional>
#include <utility>
#include <vector>
#include "base/command_line.h"
#include "base/run_loop.h"
#include "base/test/bind.h"
#include "chromeos/services/machine_learning/public/cpp/fake_service_connection.h"
#include "chromeos/services/machine_learning/public/cpp/ml_switches.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "content/browser/handwriting/handwriting_recognizer_impl_cros.h"
#include "content/public/test/test_renderer_host.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "third_party/blink/public/mojom/handwriting/handwriting.mojom.h"
namespace content {
class HandwritingRecognitionServiceImplCrOSTest
: public RenderViewHostTestHarness {
public:
void SetUp() override {
RenderViewHostTestHarness::SetUp();
chromeos::machine_learning::ServiceConnection::
UseFakeServiceConnectionForTesting(&fake_ml_service_connection_);
chromeos::machine_learning::ServiceConnection::GetInstance()->Initialize();
// We need to add the switch to "enable" HWR support.
base::CommandLine::ForCurrentProcess()->AppendSwitchASCII(
::switches::kOndeviceHandwritingSwitch, "use_rootfs");
}
chromeos::machine_learning::FakeServiceConnectionImpl&
GetMlServiceConnection() {
return fake_ml_service_connection_;
}
private:
chromeos::machine_learning::FakeServiceConnectionImpl
fake_ml_service_connection_;
};
TEST_F(HandwritingRecognitionServiceImplCrOSTest, CreateHandwritingRecognizer) {
mojo::Remote<handwriting::mojom::HandwritingRecognitionService>
service_remote;
CrOSHandwritingRecognitionServiceImpl::Create(
service_remote.BindNewPipeAndPassReceiver());
auto model_constraint = handwriting::mojom::HandwritingModelConstraint::New();
model_constraint->languages.push_back("en");
bool is_callback_called = false;
base::RunLoop runloop;
service_remote->CreateHandwritingRecognizer(
std::move(model_constraint),
base::BindLambdaForTesting(
[&](handwriting::mojom::CreateHandwritingRecognizerResult result,
mojo::PendingRemote<handwriting::mojom::HandwritingRecognizer>
remote) {
EXPECT_EQ(
result,
handwriting::mojom::CreateHandwritingRecognizerResult::kOk);
is_callback_called = true;
runloop.Quit();
}));
runloop.Run();
EXPECT_TRUE(is_callback_called);
}
// In this test we provide valid input/output to check the mojo calls and data
// copying code work correctly.
TEST_F(HandwritingRecognitionServiceImplCrOSTest,
GetPredictionCorrectConversion) {
mojo::Remote<handwriting::mojom::HandwritingRecognitionService>
service_remote;
CrOSHandwritingRecognitionServiceImpl::Create(
service_remote.BindNewPipeAndPassReceiver());
auto model_constraint = handwriting::mojom::HandwritingModelConstraint::New();
model_constraint->languages.push_back("en");
bool is_callback_called = false;
mojo::Remote<handwriting::mojom::HandwritingRecognizer> recognizer_remote;
base::RunLoop runloop_create_recognizer;
service_remote->CreateHandwritingRecognizer(
std::move(model_constraint),
base::BindLambdaForTesting(
[&](handwriting::mojom::CreateHandwritingRecognizerResult result,
mojo::PendingRemote<handwriting::mojom::HandwritingRecognizer>
input_remote) {
is_callback_called = true;
ASSERT_EQ(
result,
handwriting::mojom::CreateHandwritingRecognizerResult::kOk);
recognizer_remote.Bind(std::move(input_remote));
runloop_create_recognizer.Quit();
}));
runloop_create_recognizer.Run();
ASSERT_TRUE(is_callback_called);
// Generate and set the fake recognition result.
auto prediction = chromeos::machine_learning::web_platform::mojom::
HandwritingPrediction::New();
prediction->text = "text wrote";
auto segment = chromeos::machine_learning::web_platform::mojom::
HandwritingSegment::New();
segment->grapheme = "seg";
segment->begin_index = 0u;
segment->end_index = 3u;
segment->drawing_segments.push_back(
chromeos::machine_learning::web_platform::mojom::
HandwritingDrawingSegment::New(0u, 10u, 15u));
segment->drawing_segments.push_back(
chromeos::machine_learning::web_platform::mojom::
HandwritingDrawingSegment::New(1u, 0u, 13u));
prediction->segmentation_result.push_back(std::move(segment));
std::vector<
chromeos::machine_learning::web_platform::mojom::HandwritingPredictionPtr>
predictions;
predictions.push_back(std::move(prediction));
GetMlServiceConnection().SetOutputWebPlatformHandwritingRecognizerResult(
std::move(predictions));
// Generate 3 input strokes.
std::vector<handwriting::mojom::HandwritingStrokePtr> strokes;
const std::vector<int> num_points = {15, 10, 21};
for (int npts : num_points) {
auto stroke = handwriting::mojom::HandwritingStroke::New();
for (int i = 0; i < npts; ++i) {
// The actual values of the points do not matter.
stroke->points.emplace_back(handwriting::mojom::HandwritingPoint::New());
}
strokes.emplace_back(std::move(stroke));
}
is_callback_called = false;
base::RunLoop runloop_prediction;
recognizer_remote->GetPrediction(
std::move(strokes), handwriting::mojom::HandwritingHints::New(),
base::BindLambdaForTesting(
[&](std::optional<std::vector<
handwriting::mojom::HandwritingPredictionPtr>> result) {
is_callback_called = true;
ASSERT_TRUE(result.has_value());
ASSERT_EQ(result.value().size(), 1u);
EXPECT_EQ(result.value()[0]->text, "text wrote");
ASSERT_EQ(result.value()[0]->segmentation_result.size(), 1u);
EXPECT_EQ(result.value()[0]->segmentation_result[0]->grapheme,
"seg");
// It is 0 because it is the first segment.
EXPECT_EQ(result.value()[0]->segmentation_result[0]->begin_index,
0u);
// This is the Length of the grapheme "seg".
EXPECT_EQ(result.value()[0]->segmentation_result[0]->end_index, 3u);
// Equals `ink_range->end_stroke-ink_range->begin_stroke+1`.
ASSERT_EQ(result.value()[0]
->segmentation_result[0]
->drawing_segments.size(),
2u);
EXPECT_EQ(result.value()[0]
->segmentation_result[0]
->drawing_segments[0]
->stroke_index,
0u);
// Equals `ink_range->start_point`.
EXPECT_EQ(result.value()[0]
->segmentation_result[0]
->drawing_segments[0]
->begin_point_index,
10u);
// Equals `num_points[0]`.
EXPECT_EQ(result.value()[0]
->segmentation_result[0]
->drawing_segments[0]
->end_point_index,
15u);
EXPECT_EQ(result.value()[0]
->segmentation_result[0]
->drawing_segments[1]
->stroke_index,
1u);
EXPECT_EQ(result.value()[0]
->segmentation_result[0]
->drawing_segments[1]
->begin_point_index,
0u);
// Equals `ink_range->end_point+1`.
EXPECT_EQ(result.value()[0]
->segmentation_result[0]
->drawing_segments[1]
->end_point_index,
13u);
runloop_prediction.Quit();
}));
runloop_prediction.Run();
EXPECT_TRUE(is_callback_called);
}
} // namespace content