chromium/chrome/browser/ash/trusted_vault/trusted_vault_backend_ash_unittest.cc

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/trusted_vault/trusted_vault_backend_ash.h"

#include <memory>

#include "base/test/mock_callback.h"
#include "base/test/task_environment.h"
#include "chromeos/crosapi/mojom/account_manager.mojom.h"
#include "components/signin/public/identity_manager/account_info.h"
#include "components/signin/public/identity_manager/identity_test_environment.h"
#include "components/trusted_vault/test/fake_trusted_vault_client.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace ash {

using testing::Eq;
using testing::IsEmpty;
using testing::SizeIs;

class TestTrustedVaultBackendObserver
    : public crosapi::mojom::TrustedVaultBackendObserver {
 public:
  TestTrustedVaultBackendObserver(
      mojo::Remote<crosapi::mojom::TrustedVaultBackend>* backend) {
    backend->get()->AddObserver(receiver_.BindNewPipeAndPassRemote());
    backend->FlushForTesting();
  }

  ~TestTrustedVaultBackendObserver() override = default;

  int num_on_trusted_vault_keys_changed_calls() const {
    return num_on_trusted_vault_keys_changed_calls_;
  }

  int num_on_trusted_vault_recoverability_changed_calls() const {
    return num_on_trusted_vault_recoverability_changed_calls_;
  }

  void OnTrustedVaultKeysChanged() override {
    num_on_trusted_vault_keys_changed_calls_++;
  }

  void OnTrustedVaultRecoverabilityChanged() override {
    num_on_trusted_vault_recoverability_changed_calls_++;
  }

 private:
  mojo::Receiver<crosapi::mojom::TrustedVaultBackendObserver> receiver_{this};

  int num_on_trusted_vault_keys_changed_calls_ = 0;
  int num_on_trusted_vault_recoverability_changed_calls_ = 0;
};

class TrustedVaultBackendAshTest : public testing::Test {
 public:
  TrustedVaultBackendAshTest() {
    primary_account_info_ = identity_test_env_.MakePrimaryAccountAvailable(
        "[email protected]", signin::ConsentLevel::kSignin);
    backend_ = std::make_unique<TrustedVaultBackendAsh>(
        identity_test_env_.identity_manager(), &trusted_vault_client_ash_);
    backend_->BindReceiver(backend_remote_.BindNewPipeAndPassReceiver());
  }

  ~TrustedVaultBackendAshTest() override = default;

  AccountInfo* primary_account_info() { return &primary_account_info_; }

  crosapi::mojom::AccountKeyPtr GetPrimaryAccountKey() const {
    crosapi::mojom::AccountKeyPtr account_key =
        crosapi::mojom::AccountKey::New();
    account_key->id = primary_account_info_.gaia;
    account_key->account_type = crosapi::mojom::AccountType::kGaia;
    return account_key;
  }

  crosapi::mojom::AccountKeyPtr GetNonPrimaryAccountKey() const {
    crosapi::mojom::AccountKeyPtr account_key = GetPrimaryAccountKey();
    account_key->id += "a";
    return account_key;
  }

  trusted_vault::FakeTrustedVaultClient* client_ash() {
    return &trusted_vault_client_ash_;
  }

  TrustedVaultBackendAsh* backend() { return backend_.get(); }

  void DeleteBackend() { backend_.reset(); }

  mojo::Remote<crosapi::mojom::TrustedVaultBackend>& backend_remote() {
    return backend_remote_;
  }

 private:
  base::test::SingleThreadTaskEnvironment task_environment_;

  signin::IdentityTestEnvironment identity_test_env_;
  AccountInfo primary_account_info_;

  trusted_vault::FakeTrustedVaultClient trusted_vault_client_ash_;
  std::unique_ptr<TrustedVaultBackendAsh> backend_;
  mojo::Remote<crosapi::mojom::TrustedVaultBackend> backend_remote_;
};

TEST_F(TrustedVaultBackendAshTest, ShouldFetchKeys) {
  const std::vector<std::vector<uint8_t>> keys = {{1, 2, 3}};
  client_ash()->StoreKeys(primary_account_info()->gaia, keys,
                          /*last_key_version=*/1);

  base::MockCallback<TrustedVaultBackendAsh::FetchKeysCallback> on_keys_fetched;
  EXPECT_CALL(on_keys_fetched, Run(keys));
  backend_remote()->FetchKeys(GetPrimaryAccountKey(), on_keys_fetched.Get());
  // Fetching keys is quite asynchronous in this setup:
  // 1. Ensure mojo propagates remote FetchKeys() call.
  backend_remote().FlushForTesting();
  // 2. Mimics asynchronous fetch completion on client_ash() side.
  EXPECT_TRUE(client_ash()->CompleteAllPendingRequests());
  // 3. Ensure mojo propagates callback call.
  backend_remote().FlushForTesting();
}

TEST_F(TrustedVaultBackendAshTest, ShouldValidateAccountKeyOnFetchKeys) {
  const std::vector<std::vector<uint8_t>> keys = {{1, 2, 3}};
  client_ash()->StoreKeys(primary_account_info()->gaia, keys,
                          /*last_key_version=*/1);

  base::MockCallback<TrustedVaultBackendAsh::FetchKeysCallback> on_keys_fetched;
  EXPECT_CALL(on_keys_fetched, Run(IsEmpty()));
  backend_remote()->FetchKeys(GetNonPrimaryAccountKey(), on_keys_fetched.Get());
  backend_remote().FlushForTesting();
}

TEST(TrustedVaultBackendAshNoFixtureTest, ShouldHandleAbsenseOfPrimaryAccount) {
  base::test::SingleThreadTaskEnvironment task_environment;
  signin::IdentityTestEnvironment identity_test_env;

  trusted_vault::FakeTrustedVaultClient trusted_vault_client_ash;
  auto backend = std::make_unique<TrustedVaultBackendAsh>(
      identity_test_env.identity_manager(), &trusted_vault_client_ash);

  mojo::Remote<crosapi::mojom::TrustedVaultBackend> backend_remote;
  backend->BindReceiver(backend_remote.BindNewPipeAndPassReceiver());

  ASSERT_FALSE(identity_test_env.identity_manager()->HasPrimaryAccount(
      signin::ConsentLevel::kSignin));

  // Mimic that some data is stored for `gaia`. This shouldn't be possible when
  // there is no primary account, but this test does this to make meaningful
  // expectations.
  const std::string gaia = "example";
  const std::vector<std::vector<uint8_t>> keys = {{1, 2, 3}};
  trusted_vault_client_ash.StoreKeys(gaia, keys, /*last_key_version=*/1);

  crosapi::mojom::AccountKeyPtr account_key = crosapi::mojom::AccountKey::New();
  account_key->id = gaia;
  account_key->account_type = crosapi::mojom::AccountType::kGaia;

  base::MockCallback<TrustedVaultBackendAsh::FetchKeysCallback> on_keys_fetched;
  EXPECT_CALL(on_keys_fetched, Run(IsEmpty()));
  backend_remote->FetchKeys(std::move(account_key), on_keys_fetched.Get());
  backend_remote.FlushForTesting();
}

TEST_F(TrustedVaultBackendAshTest, ShouldMarkLocalKeysAsStale) {
  client_ash()->StoreKeys(primary_account_info()->gaia, /*keys=*/{{1, 2, 3}},
                          /*last_key_version=*/1);

  base::MockCallback<TrustedVaultBackendAsh::MarkLocalKeysAsStaleCallback>
      on_keys_marked_as_stale;
  EXPECT_CALL(on_keys_marked_as_stale, Run(true));
  backend_remote()->MarkLocalKeysAsStale(GetPrimaryAccountKey(),
                                         on_keys_marked_as_stale.Get());
  backend_remote().FlushForTesting();

  EXPECT_THAT(client_ash()->keys_marked_as_stale_count(), Eq(1));
}

TEST_F(TrustedVaultBackendAshTest,
       ShouldValidateAccountKeyOnMarkLocalKeysAsStale) {
  client_ash()->StoreKeys(primary_account_info()->gaia, /*keys=*/{{1, 2, 3}},
                          /*last_key_version=*/1);

  base::MockCallback<TrustedVaultBackendAsh::MarkLocalKeysAsStaleCallback>
      on_keys_marked_as_stale;
  EXPECT_CALL(on_keys_marked_as_stale, Run(false));
  backend_remote()->MarkLocalKeysAsStale(GetNonPrimaryAccountKey(),
                                         on_keys_marked_as_stale.Get());
  backend_remote().FlushForTesting();

  EXPECT_THAT(client_ash()->keys_marked_as_stale_count(), Eq(0));
}

TEST_F(TrustedVaultBackendAshTest, ShouldStoreKeys) {
  const std::vector<std::vector<uint8_t>> keys = {{1, 2, 3}};
  backend_remote()->StoreKeys(GetPrimaryAccountKey(), keys,
                              /*last_key_version=*/1);
  backend_remote().FlushForTesting();
  EXPECT_THAT(client_ash()->GetStoredKeys(primary_account_info()->gaia),
              Eq(keys));
}

TEST_F(TrustedVaultBackendAshTest, ShouldGetIsRecoverabilityDegraded) {
  client_ash()->SetIsRecoveryMethodRequired(true);

  base::MockCallback<
      TrustedVaultBackendAsh::GetIsRecoverabilityDegradedCallback>
      on_get_is_recoverability_degraded;
  EXPECT_CALL(on_get_is_recoverability_degraded, Run(true));
  backend_remote()->GetIsRecoverabilityDegraded(
      GetPrimaryAccountKey(), on_get_is_recoverability_degraded.Get());
  // Getting degraded recoverability state is quite asynchronous in this setup:
  // 1. Ensure mojo propagates remote GetIsRecoverabilityDegraded() call.
  backend_remote().FlushForTesting();
  // 2. Mimics asynchronous GetIsRecoverabilityDegraded() completion on
  // client_ash() side.
  EXPECT_TRUE(client_ash()->CompleteAllPendingRequests());
  // 3. Ensure mojo propagates callback call.
  backend_remote().FlushForTesting();
}

TEST_F(TrustedVaultBackendAshTest,
       ShouldValidateAccountOnGetIsRecoverabilityDegraded) {
  client_ash()->SetIsRecoveryMethodRequired(true);

  base::MockCallback<
      TrustedVaultBackendAsh::GetIsRecoverabilityDegradedCallback>
      on_get_is_recoverability_degraded;
  EXPECT_CALL(on_get_is_recoverability_degraded, Run(false));
  backend_remote()->GetIsRecoverabilityDegraded(
      GetNonPrimaryAccountKey(), on_get_is_recoverability_degraded.Get());
  backend_remote().FlushForTesting();
}

TEST_F(TrustedVaultBackendAshTest, ShouldAddTrustedRecoveryMethod) {
  const std::vector<uint8_t> recovery_method_public_key = {1, 2, 3, 4};
  const int recovery_method_type_hint = 4;

  base::MockCallback<TrustedVaultBackendAsh::AddTrustedRecoveryMethodCallback>
      on_recovery_method_added;
  EXPECT_CALL(on_recovery_method_added, Run());
  backend_remote()->AddTrustedRecoveryMethod(
      GetPrimaryAccountKey(), recovery_method_public_key,
      recovery_method_type_hint, on_recovery_method_added.Get());
  backend_remote().FlushForTesting();

  const auto recovery_methods =
      client_ash()->server()->GetRecoveryMethods(primary_account_info()->gaia);
  ASSERT_THAT(recovery_methods, SizeIs(1));
  EXPECT_THAT(recovery_methods[0].public_key, Eq(recovery_method_public_key));
  EXPECT_THAT(recovery_methods[0].method_type_hint,
              Eq(recovery_method_type_hint));
}

TEST_F(TrustedVaultBackendAshTest,
       ShouldValidateAccountOnAddTrustedRecoveryMethod) {
  base::MockCallback<TrustedVaultBackendAsh::AddTrustedRecoveryMethodCallback>
      on_recovery_method_added;
  EXPECT_CALL(on_recovery_method_added, Run());
  backend_remote()->AddTrustedRecoveryMethod(
      GetNonPrimaryAccountKey(), /*public_key=*/{1, 2, 3, 4},
      /*method_type_hint=*/1, on_recovery_method_added.Get());
  backend_remote().FlushForTesting();

  EXPECT_THAT(
      client_ash()->server()->GetRecoveryMethods(primary_account_info()->gaia),
      IsEmpty());
}

TEST_F(TrustedVaultBackendAshTest, ShouldClearLocalDataForAccount) {
  client_ash()->StoreKeys(primary_account_info()->gaia, /*keys=*/{{1, 2, 3}},
                          /*last_key_version=*/1);

  backend_remote()->ClearLocalDataForAccount(GetPrimaryAccountKey());
  backend_remote().FlushForTesting();
  EXPECT_THAT(client_ash()->GetStoredKeys(primary_account_info()->gaia),
              IsEmpty());
}

TEST_F(TrustedVaultBackendAshTest,
       ShouldValidateAccountInfoOnClearLocalDataForAccount) {
  const std::vector<std::vector<uint8_t>> keys = {{1, 2, 3}};
  client_ash()->StoreKeys(primary_account_info()->gaia, keys,
                          /*last_key_version=*/1);

  backend_remote()->ClearLocalDataForAccount(GetNonPrimaryAccountKey());
  backend_remote().FlushForTesting();
  EXPECT_THAT(client_ash()->GetStoredKeys(primary_account_info()->gaia),
              Eq(keys));
}

TEST_F(TrustedVaultBackendAshTest, ShouldNotifyObservers) {
  TestTrustedVaultBackendObserver observer(&backend_remote());
  client_ash()->StoreKeys(primary_account_info()->gaia, /*keys=*/{{1, 2, 3}},
                          /*last_key_version=*/1);
  backend_remote().FlushForTesting();
  EXPECT_THAT(observer.num_on_trusted_vault_keys_changed_calls(), Eq(1));
  EXPECT_THAT(observer.num_on_trusted_vault_recoverability_changed_calls(),
              Eq(0));

  client_ash()->SetIsRecoveryMethodRequired(true);
  backend_remote().FlushForTesting();
  EXPECT_THAT(observer.num_on_trusted_vault_recoverability_changed_calls(),
              Eq(1));
  EXPECT_THAT(observer.num_on_trusted_vault_keys_changed_calls(), Eq(1));
}

TEST_F(TrustedVaultBackendAshTest, ShouldSupportMultipleObservers) {
  TestTrustedVaultBackendObserver observer1(&backend_remote());
  TestTrustedVaultBackendObserver observer2(&backend_remote());
  client_ash()->StoreKeys(primary_account_info()->gaia, /*keys=*/{{1, 2, 3}},
                          /*last_key_version=*/1);
  backend_remote().FlushForTesting();
  EXPECT_THAT(observer1.num_on_trusted_vault_keys_changed_calls(), Eq(1));
  EXPECT_THAT(observer2.num_on_trusted_vault_keys_changed_calls(), Eq(1));
}

TEST_F(TrustedVaultBackendAshTest, ShouldSupportMultipleRemotes) {
  mojo::Remote<crosapi::mojom::TrustedVaultBackend> secondary_remote;
  backend()->BindReceiver(secondary_remote.BindNewPipeAndPassReceiver());

  // Store keys using first remote, should succeed even though
  // `secondary_remote` was bound after it.
  const std::vector<std::vector<uint8_t>> keys = {{1, 2, 3}};
  backend_remote()->StoreKeys(GetPrimaryAccountKey(), keys,
                              /*last_key_version=*/1);
  backend_remote().FlushForTesting();
  EXPECT_THAT(client_ash()->GetStoredKeys(primary_account_info()->gaia),
              Eq(keys));

  // Verify that `secondary_remote` able to store keys too.
  const std::vector<std::vector<uint8_t>> other_keys = {{1, 2, 3, 4}};
  backend_remote()->StoreKeys(GetPrimaryAccountKey(), other_keys,
                              /*last_key_version=*/2);
  backend_remote().FlushForTesting();
  EXPECT_THAT(client_ash()->GetStoredKeys(primary_account_info()->gaia),
              Eq(other_keys));
}

TEST_F(TrustedVaultBackendAshTest, ShouldDisconnectOnDelete) {
  ASSERT_TRUE(backend_remote().is_connected());
  DeleteBackend();
  backend_remote().FlushForTesting();
  EXPECT_FALSE(backend_remote().is_connected());
}

}  // namespace ash