chromium/content/browser/handwriting/handwriting_recognition_service_impl_cros_unittest.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "content/browser/handwriting/handwriting_recognition_service_impl_cros.h"

#include <optional>
#include <utility>
#include <vector>

#include "base/command_line.h"
#include "base/run_loop.h"
#include "base/test/bind.h"
#include "chromeos/services/machine_learning/public/cpp/fake_service_connection.h"
#include "chromeos/services/machine_learning/public/cpp/ml_switches.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "content/browser/handwriting/handwriting_recognizer_impl_cros.h"
#include "content/public/test/test_renderer_host.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "third_party/blink/public/mojom/handwriting/handwriting.mojom.h"

namespace content {

class HandwritingRecognitionServiceImplCrOSTest
    : public RenderViewHostTestHarness {
 public:
  void SetUp() override {
    RenderViewHostTestHarness::SetUp();
    chromeos::machine_learning::ServiceConnection::
        UseFakeServiceConnectionForTesting(&fake_ml_service_connection_);
    chromeos::machine_learning::ServiceConnection::GetInstance()->Initialize();
    // We need to add the switch to "enable" HWR support.
    base::CommandLine::ForCurrentProcess()->AppendSwitchASCII(
        ::switches::kOndeviceHandwritingSwitch, "use_rootfs");
  }

  chromeos::machine_learning::FakeServiceConnectionImpl&
  GetMlServiceConnection() {
    return fake_ml_service_connection_;
  }

 private:
  chromeos::machine_learning::FakeServiceConnectionImpl
      fake_ml_service_connection_;
};

TEST_F(HandwritingRecognitionServiceImplCrOSTest, CreateHandwritingRecognizer) {
  mojo::Remote<handwriting::mojom::HandwritingRecognitionService>
      service_remote;
  CrOSHandwritingRecognitionServiceImpl::Create(
      service_remote.BindNewPipeAndPassReceiver());
  auto model_constraint = handwriting::mojom::HandwritingModelConstraint::New();
  model_constraint->languages.push_back("en");
  bool is_callback_called = false;
  base::RunLoop runloop;
  service_remote->CreateHandwritingRecognizer(
      std::move(model_constraint),
      base::BindLambdaForTesting(
          [&](handwriting::mojom::CreateHandwritingRecognizerResult result,
              mojo::PendingRemote<handwriting::mojom::HandwritingRecognizer>
                  remote) {
            EXPECT_EQ(
                result,
                handwriting::mojom::CreateHandwritingRecognizerResult::kOk);
            is_callback_called = true;
            runloop.Quit();
          }));
  runloop.Run();
  EXPECT_TRUE(is_callback_called);
}

// In this test we provide valid input/output to check the mojo calls and data
// copying code work correctly.
TEST_F(HandwritingRecognitionServiceImplCrOSTest,
       GetPredictionCorrectConversion) {
  mojo::Remote<handwriting::mojom::HandwritingRecognitionService>
      service_remote;
  CrOSHandwritingRecognitionServiceImpl::Create(
      service_remote.BindNewPipeAndPassReceiver());
  auto model_constraint = handwriting::mojom::HandwritingModelConstraint::New();
  model_constraint->languages.push_back("en");
  bool is_callback_called = false;
  mojo::Remote<handwriting::mojom::HandwritingRecognizer> recognizer_remote;
  base::RunLoop runloop_create_recognizer;
  service_remote->CreateHandwritingRecognizer(
      std::move(model_constraint),
      base::BindLambdaForTesting(
          [&](handwriting::mojom::CreateHandwritingRecognizerResult result,
              mojo::PendingRemote<handwriting::mojom::HandwritingRecognizer>
                  input_remote) {
            is_callback_called = true;
            ASSERT_EQ(
                result,
                handwriting::mojom::CreateHandwritingRecognizerResult::kOk);
            recognizer_remote.Bind(std::move(input_remote));
            runloop_create_recognizer.Quit();
          }));
  runloop_create_recognizer.Run();
  ASSERT_TRUE(is_callback_called);

  // Generate and set the fake recognition result.
  auto prediction = chromeos::machine_learning::web_platform::mojom::
      HandwritingPrediction::New();
  prediction->text = "text wrote";
  auto segment = chromeos::machine_learning::web_platform::mojom::
      HandwritingSegment::New();
  segment->grapheme = "seg";
  segment->begin_index = 0u;
  segment->end_index = 3u;
  segment->drawing_segments.push_back(
      chromeos::machine_learning::web_platform::mojom::
          HandwritingDrawingSegment::New(0u, 10u, 15u));
  segment->drawing_segments.push_back(
      chromeos::machine_learning::web_platform::mojom::
          HandwritingDrawingSegment::New(1u, 0u, 13u));
  prediction->segmentation_result.push_back(std::move(segment));

  std::vector<
      chromeos::machine_learning::web_platform::mojom::HandwritingPredictionPtr>
      predictions;
  predictions.push_back(std::move(prediction));
  GetMlServiceConnection().SetOutputWebPlatformHandwritingRecognizerResult(
      std::move(predictions));

  // Generate 3 input strokes.
  std::vector<handwriting::mojom::HandwritingStrokePtr> strokes;
  const std::vector<int> num_points = {15, 10, 21};
  for (int npts : num_points) {
    auto stroke = handwriting::mojom::HandwritingStroke::New();
    for (int i = 0; i < npts; ++i) {
      // The actual values of the points do not matter.
      stroke->points.emplace_back(handwriting::mojom::HandwritingPoint::New());
    }
    strokes.emplace_back(std::move(stroke));
  }

  is_callback_called = false;
  base::RunLoop runloop_prediction;
  recognizer_remote->GetPrediction(
      std::move(strokes), handwriting::mojom::HandwritingHints::New(),
      base::BindLambdaForTesting(
          [&](std::optional<std::vector<
                  handwriting::mojom::HandwritingPredictionPtr>> result) {
            is_callback_called = true;
            ASSERT_TRUE(result.has_value());
            ASSERT_EQ(result.value().size(), 1u);
            EXPECT_EQ(result.value()[0]->text, "text wrote");
            ASSERT_EQ(result.value()[0]->segmentation_result.size(), 1u);
            EXPECT_EQ(result.value()[0]->segmentation_result[0]->grapheme,
                      "seg");
            // It is 0 because it is the first segment.
            EXPECT_EQ(result.value()[0]->segmentation_result[0]->begin_index,
                      0u);
            // This is the Length of the grapheme "seg".
            EXPECT_EQ(result.value()[0]->segmentation_result[0]->end_index, 3u);
            // Equals `ink_range->end_stroke-ink_range->begin_stroke+1`.
            ASSERT_EQ(result.value()[0]
                          ->segmentation_result[0]
                          ->drawing_segments.size(),
                      2u);
            EXPECT_EQ(result.value()[0]
                          ->segmentation_result[0]
                          ->drawing_segments[0]
                          ->stroke_index,
                      0u);
            // Equals `ink_range->start_point`.
            EXPECT_EQ(result.value()[0]
                          ->segmentation_result[0]
                          ->drawing_segments[0]
                          ->begin_point_index,
                      10u);
            // Equals `num_points[0]`.
            EXPECT_EQ(result.value()[0]
                          ->segmentation_result[0]
                          ->drawing_segments[0]
                          ->end_point_index,
                      15u);
            EXPECT_EQ(result.value()[0]
                          ->segmentation_result[0]
                          ->drawing_segments[1]
                          ->stroke_index,
                      1u);
            EXPECT_EQ(result.value()[0]
                          ->segmentation_result[0]
                          ->drawing_segments[1]
                          ->begin_point_index,
                      0u);
            // Equals `ink_range->end_point+1`.
            EXPECT_EQ(result.value()[0]
                          ->segmentation_result[0]
                          ->drawing_segments[1]
                          ->end_point_index,
                      13u);
            runloop_prediction.Quit();
          }));
  runloop_prediction.Run();
  EXPECT_TRUE(is_callback_called);
}

}  // namespace content