chromium/chromeos/ash/services/libassistant/speaker_id_enrollment_controller_unittest.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/services/libassistant/speaker_id_enrollment_controller.h"

#include "base/test/mock_callback.h"
#include "base/test/task_environment.h"
#include "chromeos/ash/services/libassistant/test_support/libassistant_service_tester.h"
#include "chromeos/assistant/internal/libassistant/shared_headers.h"
#include "chromeos/assistant/internal/proto/shared/proto/v2/delegate/event_handler_interface.pb.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace ash::libassistant {

namespace {
using assistant_client::SpeakerIdEnrollmentStatus;
using SpeakerIdEnrollmentState =
    ::assistant_client::SpeakerIdEnrollmentUpdate::State;
using GetSpeakerIdEnrollmentStatusCallback =
    SpeakerIdEnrollmentController::GetSpeakerIdEnrollmentStatusCallback;
using ::testing::NiceMock;
using ::testing::StrictMock;

class SpeakerIdEnrollmentClientMock : public mojom::SpeakerIdEnrollmentClient {
 public:
  SpeakerIdEnrollmentClientMock() = default;
  SpeakerIdEnrollmentClientMock(const SpeakerIdEnrollmentClientMock&) = delete;
  SpeakerIdEnrollmentClientMock& operator=(
      const SpeakerIdEnrollmentClientMock&) = delete;
  ~SpeakerIdEnrollmentClientMock() override = default;

  ::mojo::PendingRemote<SpeakerIdEnrollmentClient> BindNewPipeAndPassRemote() {
    return receiver_.BindNewPipeAndPassRemote();
  }

  // mojom::SpeakerIdEnrollmentClient implementation:
  MOCK_METHOD(void, OnListeningHotword, ());
  MOCK_METHOD(void, OnProcessingHotword, ());
  MOCK_METHOD(void, OnSpeakerIdEnrollmentDone, ());
  MOCK_METHOD(void, OnSpeakerIdEnrollmentFailure, ());

  void FlushForTesting() { receiver_.FlushForTesting(); }

 private:
  ::mojo::Receiver<mojom::SpeakerIdEnrollmentClient> receiver_{this};
};
}  // namespace

class AssistantSpeakerIdEnrollmentControllerTest : public ::testing::Test {
 public:
  AssistantSpeakerIdEnrollmentControllerTest() = default;
  AssistantSpeakerIdEnrollmentControllerTest(
      const AssistantSpeakerIdEnrollmentControllerTest&) = delete;
  AssistantSpeakerIdEnrollmentControllerTest& operator=(
      const AssistantSpeakerIdEnrollmentControllerTest&) = delete;
  ~AssistantSpeakerIdEnrollmentControllerTest() override = default;

  mojom::SpeakerIdEnrollmentController& controller() {
    return service_tester_.speaker_id_enrollment_controller();
  }

  void FlushForTesting() { service_tester_.FlushForTesting(); }

  LibassistantServiceTester& service_tester() { return service_tester_; }

  assistant_client::SpeakerIdEnrollmentUpdate CreateUpdate(
      SpeakerIdEnrollmentState state) {
    assistant_client::SpeakerIdEnrollmentUpdate result;
    result.state = state;
    return result;
  }

  bool IsSpeakerIdEnrollmentInProgress() {
    return service_tester_.service()
        .speaker_id_enrollment_controller_for_testing()
        .IsSpeakerIdEnrollmentInProgressForTesting();
  }

  void CallUpdateCallback(
      const ::assistant_client::SpeakerIdEnrollmentUpdate& update) {
    ::assistant::api::OnSpeakerIdEnrollmentEventRequest request;
    ::assistant::api::events::SpeakerIdEnrollmentEvent* event =
        request.mutable_event();
    switch (update.state) {
      case ::assistant_client::SpeakerIdEnrollmentUpdate::State::INIT: {
        event->mutable_init_state();
        break;
      }
      case ::assistant_client::SpeakerIdEnrollmentUpdate::State::CHECK: {
        event->mutable_check_state();
        break;
      }
      case ::assistant_client::SpeakerIdEnrollmentUpdate::State::LISTEN: {
        event->mutable_listen_state();
        break;
      }
      case ::assistant_client::SpeakerIdEnrollmentUpdate::State::PROCESS: {
        event->mutable_process_state();
        break;
      }
      case ::assistant_client::SpeakerIdEnrollmentUpdate::State::UPLOAD: {
        event->mutable_upload_state();
        break;
      }
      case ::assistant_client::SpeakerIdEnrollmentUpdate::State::FETCH: {
        event->mutable_fetch_state();
        break;
      }
      case ::assistant_client::SpeakerIdEnrollmentUpdate::State::DONE: {
        event->mutable_done_state();
        break;
      }
      case ::assistant_client::SpeakerIdEnrollmentUpdate::State::FAILURE: {
        event->mutable_failure_state();
        break;
      }
    }

    service_tester_.service()
        .speaker_id_enrollment_controller_for_testing()
        .OnGrpcMessageForTesting(std::move(request));
  }

  void CallGetStatusCallback(bool user_model_exists) {
    service_tester_.service()
        .speaker_id_enrollment_controller_for_testing()
        .SendGetStatusResponseForTesting(user_model_exists);
  }

 private:
  base::test::SingleThreadTaskEnvironment environment_;
  LibassistantServiceTester service_tester_;
};

TEST_F(AssistantSpeakerIdEnrollmentControllerTest,
       StartShouldBeANoopIfLibassistantIsNotStarted) {
  StrictMock<SpeakerIdEnrollmentClientMock> client;

  controller().StartSpeakerIdEnrollment("user-id",
                                        /*skip_cloud_enrollment=*/false,
                                        client.BindNewPipeAndPassRemote());

  FlushForTesting();

  EXPECT_FALSE(IsSpeakerIdEnrollmentInProgress());
}

TEST_F(AssistantSpeakerIdEnrollmentControllerTest,
       StopShouldBeANoopIfLibassistantIsNotStarted) {
  controller().StopSpeakerIdEnrollment();

  FlushForTesting();
}

TEST_F(AssistantSpeakerIdEnrollmentControllerTest,
       ShouldBeAbleToStartEnrollment) {
  service_tester().Start();
  StrictMock<SpeakerIdEnrollmentClientMock> client;

  controller().StartSpeakerIdEnrollment("user-id",
                                        /*skip_cloud_enrollment=*/false,
                                        client.BindNewPipeAndPassRemote());

  FlushForTesting();

  EXPECT_TRUE(IsSpeakerIdEnrollmentInProgress());
}

TEST_F(AssistantSpeakerIdEnrollmentControllerTest,
       ShouldBeAbleToStopEnrollment) {
  service_tester().Start();
  StrictMock<SpeakerIdEnrollmentClientMock> client;

  controller().StartSpeakerIdEnrollment("user-id",
                                        /*skip_cloud_enrollment=*/false,
                                        client.BindNewPipeAndPassRemote());
  controller().StopSpeakerIdEnrollment();

  FlushForTesting();

  EXPECT_FALSE(IsSpeakerIdEnrollmentInProgress());
}

TEST_F(AssistantSpeakerIdEnrollmentControllerTest,
       ShouldPassEnrollmentUpdatesToClient) {
  service_tester().Start();
  StrictMock<SpeakerIdEnrollmentClientMock> client;

  controller().StartSpeakerIdEnrollment("user-id",
                                        /*skip_cloud_enrollment=*/false,
                                        client.BindNewPipeAndPassRemote());
  FlushForTesting();
  EXPECT_TRUE(IsSpeakerIdEnrollmentInProgress());

  // These are ignored
  CallUpdateCallback(CreateUpdate(SpeakerIdEnrollmentState::INIT));
  CallUpdateCallback(CreateUpdate(SpeakerIdEnrollmentState::CHECK));
  CallUpdateCallback(CreateUpdate(SpeakerIdEnrollmentState::UPLOAD));
  CallUpdateCallback(CreateUpdate(SpeakerIdEnrollmentState::FETCH));
  client.FlushForTesting();

  EXPECT_CALL(client, OnListeningHotword);
  CallUpdateCallback(CreateUpdate(SpeakerIdEnrollmentState::LISTEN));
  client.FlushForTesting();

  EXPECT_CALL(client, OnProcessingHotword);
  CallUpdateCallback(CreateUpdate(SpeakerIdEnrollmentState::PROCESS));
  client.FlushForTesting();

  EXPECT_CALL(client, OnSpeakerIdEnrollmentDone);
  CallUpdateCallback(CreateUpdate(SpeakerIdEnrollmentState::DONE));
  client.FlushForTesting();

  EXPECT_CALL(client, OnSpeakerIdEnrollmentFailure);
  CallUpdateCallback(CreateUpdate(SpeakerIdEnrollmentState::FAILURE));
  client.FlushForTesting();
}

TEST_F(AssistantSpeakerIdEnrollmentControllerTest,
       ShouldStartSecondEnrollmentEvenIfFirstIsOngoing) {
  service_tester().Start();

  StrictMock<SpeakerIdEnrollmentClientMock> first_client;
  StrictMock<SpeakerIdEnrollmentClientMock> second_client;

  controller().StartSpeakerIdEnrollment(
      "first user",
      /*skip_cloud_enrollment=*/false, first_client.BindNewPipeAndPassRemote());

  controller().StartSpeakerIdEnrollment(
      "second user",
      /*skip_cloud_enrollment=*/false,
      second_client.BindNewPipeAndPassRemote());

  FlushForTesting();

  EXPECT_CALL(second_client, OnProcessingHotword);
  CallUpdateCallback(CreateUpdate(SpeakerIdEnrollmentState::PROCESS));

  first_client.FlushForTesting();
  second_client.FlushForTesting();
}

TEST_F(
    AssistantSpeakerIdEnrollmentControllerTest,
    GetSpeakerIdEnrollmentStatusShouldReturnFalseIfLibassistantIsNotStarted) {
  base::MockCallback<GetSpeakerIdEnrollmentStatusCallback> callback;

  EXPECT_CALL(callback, Run).WillOnce([](auto response) {
    EXPECT_FALSE(response->user_model_exists);
  });

  controller().GetSpeakerIdEnrollmentStatus("user-id", callback.Get());

  FlushForTesting();
}

TEST_F(AssistantSpeakerIdEnrollmentControllerTest,
       ShouldReturnSpeakerIdEnrollmentStatus) {
  service_tester().Start();

  base::MockCallback<GetSpeakerIdEnrollmentStatusCallback> callback;
  EXPECT_CALL(callback, Run).WillOnce([](auto response) {
    EXPECT_TRUE(response->user_model_exists);
  });

  controller().GetSpeakerIdEnrollmentStatus("user-id", callback.Get());
  FlushForTesting();

  CallGetStatusCallback(/*user_model_exists=*/true);
  FlushForTesting();
}

}  // namespace ash::libassistant