chromium/chromeos/ash/services/device_sync/cryptauth_group_private_key_sharer_impl_unittest.cc

// Copyright 2019 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/services/device_sync/cryptauth_group_private_key_sharer_impl.h"

#include <memory>
#include <optional>
#include <string>
#include <utility>

#include "base/containers/contains.h"
#include "base/containers/flat_set.h"
#include "base/memory/raw_ptr.h"
#include "base/no_destructor.h"
#include "base/timer/mock_timer.h"
#include "chromeos/ash/services/device_sync/cryptauth_client.h"
#include "chromeos/ash/services/device_sync/cryptauth_device.h"
#include "chromeos/ash/services/device_sync/cryptauth_device_sync_result.h"
#include "chromeos/ash/services/device_sync/cryptauth_ecies_encryptor_impl.h"
#include "chromeos/ash/services/device_sync/cryptauth_key.h"
#include "chromeos/ash/services/device_sync/cryptauth_key_bundle.h"
#include "chromeos/ash/services/device_sync/cryptauth_v2_device_sync_test_devices.h"
#include "chromeos/ash/services/device_sync/fake_cryptauth_ecies_encryptor.h"
#include "chromeos/ash/services/device_sync/fake_ecies_encryption.h"
#include "chromeos/ash/services/device_sync/mock_cryptauth_client.h"
#include "chromeos/ash/services/device_sync/network_request_error.h"
#include "chromeos/ash/services/device_sync/proto/cryptauth_common.pb.h"
#include "chromeos/ash/services/device_sync/proto/cryptauth_devicesync.pb.h"
#include "chromeos/ash/services/device_sync/proto/cryptauth_v2_test_util.h"
#include "crypto/sha2.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace ash {

namespace device_sync {

namespace {

const char kAccessTokenUsed[] = "access token used by CryptAuthClient";

const cryptauthv2::ClientMetadata& GetClientMetadata() {
  static const base::NoDestructor<cryptauthv2::ClientMetadata> client_metadata(
      cryptauthv2::BuildClientMetadata(0 /* retry_count */,
                                       cryptauthv2::ClientMetadata::PERIODIC));
  return *client_metadata;
}

const cryptauthv2::RequestContext& GetRequestContext() {
  static const base::NoDestructor<cryptauthv2::RequestContext> request_context(
      [] {
        return cryptauthv2::BuildRequestContext(
            CryptAuthKeyBundle::KeyBundleNameEnumToString(
                CryptAuthKeyBundle::Name::kDeviceSyncBetterTogether),
            GetClientMetadata(),
            cryptauthv2::GetClientAppMetadataForTest().instance_id(),
            cryptauthv2::GetClientAppMetadataForTest().instance_id_token());
      }());
  return *request_context;
}

const CryptAuthKey& GetGroupKey() {
  static const base::NoDestructor<CryptAuthKey> group_key([] {
    return CryptAuthKey(
        kGroupPublicKey, GetPrivateKeyFromPublicKeyForTest(kGroupPublicKey),
        CryptAuthKey::Status::kActive, cryptauthv2::KeyType::P256);
  }());
  return *group_key;
}

CryptAuthGroupPrivateKeySharer::IdToEncryptingKeyMap
IdToEncryptingKeyMapFromDeviceIds(
    const base::flat_set<std::string>& device_ids) {
  CryptAuthGroupPrivateKeySharer::IdToEncryptingKeyMap id_to_encrypting_key_map;
  for (const std::string& id : device_ids) {
    id_to_encrypting_key_map.insert_or_assign(
        id, GetTestDeviceWithId(id).device_better_together_public_key);
  }

  return id_to_encrypting_key_map;
}

}  // namespace

class DeviceSyncCryptAuthGroupPrivateKeySharerImplTest
    : public testing::Test,
      public MockCryptAuthClientFactory::Observer {
 public:
  DeviceSyncCryptAuthGroupPrivateKeySharerImplTest(
      const DeviceSyncCryptAuthGroupPrivateKeySharerImplTest&) = delete;
  DeviceSyncCryptAuthGroupPrivateKeySharerImplTest& operator=(
      const DeviceSyncCryptAuthGroupPrivateKeySharerImplTest&) = delete;

 protected:
  DeviceSyncCryptAuthGroupPrivateKeySharerImplTest()
      : client_factory_(std::make_unique<MockCryptAuthClientFactory>(
            MockCryptAuthClientFactory::MockType::MAKE_NICE_MOCKS)),
        fake_cryptauth_ecies_encryptor_factory_(
            std::make_unique<FakeCryptAuthEciesEncryptorFactory>()) {
    client_factory_->AddObserver(this);
  }

  ~DeviceSyncCryptAuthGroupPrivateKeySharerImplTest() override {
    client_factory_->RemoveObserver(this);
  }

  // testing::Test:
  void SetUp() override {
    CryptAuthEciesEncryptorImpl::Factory::SetFactoryForTesting(
        fake_cryptauth_ecies_encryptor_factory_.get());

    auto mock_timer = std::make_unique<base::MockOneShotTimer>();
    timer_ = mock_timer.get();

    sharer_ = CryptAuthGroupPrivateKeySharerImpl::Factory::Create(
        client_factory_.get(), std::move(mock_timer));
  }

  // testing::Test:
  void TearDown() override {
    CryptAuthEciesEncryptorImpl::Factory::SetFactoryForTesting(nullptr);
  }

  // MockCryptAuthClientFactory::Observer:
  void OnCryptAuthClientCreated(MockCryptAuthClient* client) override {
    ON_CALL(*client, ShareGroupPrivateKey(testing::_, testing::_, testing::_))
        .WillByDefault(
            Invoke(this, &DeviceSyncCryptAuthGroupPrivateKeySharerImplTest::
                             OnShareGroupPrivateKey));

    ON_CALL(*client, GetAccessTokenUsed())
        .WillByDefault(testing::Return(kAccessTokenUsed));
  }

  void ShareGroupPrivateKey(
      const CryptAuthKey& group_key,
      const CryptAuthGroupPrivateKeySharer::IdToEncryptingKeyMap&
          id_to_encrypting_key_map) {
    group_key_ = std::make_unique<CryptAuthKey>(group_key);
    id_to_encrypting_key_map_ = id_to_encrypting_key_map;

    sharer_->ShareGroupPrivateKey(
        GetRequestContext(), group_key, id_to_encrypting_key_map,
        base::BindOnce(&DeviceSyncCryptAuthGroupPrivateKeySharerImplTest::
                           OnShareGroupPrivateKeyComplete,
                       base::Unretained(this)));
  }

  // Fail encryption for IDs in |device_ids_to_fail|. Encryption could fail if
  // the input encrypting key is invalid, for instance.
  void RunGroupPrivateKeyEncryptor(
      const base::flat_set<std::string>& expected_device_ids,
      const base::flat_set<std::string>& device_ids_to_fail) {
    ASSERT_EQ(expected_device_ids.size(),
              encryptor()->id_to_input_map().size());

    for (const auto& id_payload_and_key_pair : encryptor()->id_to_input_map()) {
      const std::string& id = id_payload_and_key_pair.first;
      const std::string& payload = id_payload_and_key_pair.second.payload;
      const std::string& encrypting_key = id_payload_and_key_pair.second.key;

      EXPECT_TRUE(base::Contains(expected_device_ids, id));

      // Verify that encryptor inputs agrees with ShareGroupPrivateKey() inputs.
      const auto it = id_to_encrypting_key_map_.find(id);
      ASSERT_NE(id_to_encrypting_key_map_.end(), it);
      EXPECT_EQ(it->second, encrypting_key);
      ASSERT_TRUE(group_key_);
      ASSERT_TRUE(!group_key_->private_key().empty());
      EXPECT_EQ(group_key_->private_key(), payload);

      id_to_encrypted_group_private_key_map_[id] =
          base::Contains(device_ids_to_fail, id)
              ? std::nullopt
              : std::make_optional<std::string>(
                    MakeFakeEncryptedString(payload, encrypting_key));
    }

    encryptor()->FinishAttempt(FakeCryptAuthEciesEncryptor::Action::kEncryption,
                               id_to_encrypted_group_private_key_map_);
  }

  // Ensures that ShareGroupPrivateKeyRequest is consistent with the output from
  // the encryptor, |id_to_encrypted_group_private_key_map_|.
  void VerifyShareGroupPrivateKeyRequest(
      const base::flat_set<std::string>& expected_device_ids) {
    ASSERT_TRUE(share_group_private_key_request_);
    EXPECT_TRUE(share_group_private_key_success_callback_);
    EXPECT_TRUE(share_group_private_key_failure_callback_);

    EXPECT_EQ(GetRequestContext().SerializeAsString(),
              share_group_private_key_request_->context().SerializeAsString());
    EXPECT_EQ(
        static_cast<int>(expected_device_ids.size()),
        share_group_private_key_request_->encrypted_group_private_keys_size());

    for (const cryptauthv2::EncryptedGroupPrivateKey& request_encrypted_key :
         share_group_private_key_request_->encrypted_group_private_keys()) {
      const std::string& recipient_id =
          request_encrypted_key.recipient_device_id();

      const auto expected_it =
          id_to_encrypted_group_private_key_map_.find(recipient_id);
      ASSERT_NE(id_to_encrypted_group_private_key_map_.end(), expected_it);
      ASSERT_TRUE(expected_it->second);

      EXPECT_EQ(GetRequestContext().device_id(),
                request_encrypted_key.sender_device_id());
      EXPECT_EQ(kGroupPublicKeyHash,
                request_encrypted_key.group_public_key_hash());
      EXPECT_EQ(*expected_it->second,
                request_encrypted_key.encrypted_private_key());

      // Verify that the encrypted group private key can be decrypted with the
      // recipient device's private key.
      std::string recipient_device_better_together_private_key =
          GetPrivateKeyFromPublicKeyForTest(
              GetTestDeviceWithId(recipient_id)
                  .device_better_together_public_key);
      EXPECT_EQ(group_key_->private_key(),
                DecryptFakeEncryptedString(
                    request_encrypted_key.encrypted_private_key(),
                    recipient_device_better_together_private_key));
    }
  }

  void SendShareGroupPrivateKeyResponse() {
    ASSERT_TRUE(share_group_private_key_success_callback_);
    std::move(share_group_private_key_success_callback_)
        .Run(cryptauthv2::ShareGroupPrivateKeyResponse());
  }

  void FailShareGroupPrivateKeyRequest(
      const NetworkRequestError& network_request_error) {
    ASSERT_TRUE(share_group_private_key_failure_callback_);
    std::move(share_group_private_key_failure_callback_)
        .Run(network_request_error);
  }

  void VerifyShareGroupPrivateKeyResult(
      CryptAuthDeviceSyncResult::ResultCode expected_result_code) {
    ASSERT_TRUE(device_sync_result_code_);
    EXPECT_EQ(expected_result_code, device_sync_result_code_);
  }

  base::MockOneShotTimer* timer() { return timer_; }

 private:
  FakeCryptAuthEciesEncryptor* encryptor() {
    return fake_cryptauth_ecies_encryptor_factory_->instance();
  }

  void OnShareGroupPrivateKey(
      const cryptauthv2::ShareGroupPrivateKeyRequest& request,
      CryptAuthClient::ShareGroupPrivateKeyCallback callback,
      CryptAuthClient::ErrorCallback error_callback) {
    EXPECT_FALSE(share_group_private_key_request_);
    EXPECT_FALSE(share_group_private_key_success_callback_);
    EXPECT_FALSE(share_group_private_key_failure_callback_);

    share_group_private_key_request_ = request;
    share_group_private_key_success_callback_ = std::move(callback);
    share_group_private_key_failure_callback_ = std::move(error_callback);
  }

  void OnShareGroupPrivateKeyComplete(
      CryptAuthDeviceSyncResult::ResultCode device_sync_result_code) {
    device_sync_result_code_ = device_sync_result_code;
  }

  std::unique_ptr<CryptAuthKey> group_key_;
  CryptAuthGroupPrivateKeySharer::IdToEncryptingKeyMap
      id_to_encrypting_key_map_;

  std::optional<cryptauthv2::ShareGroupPrivateKeyRequest>
      share_group_private_key_request_;
  CryptAuthClient::ShareGroupPrivateKeyCallback
      share_group_private_key_success_callback_;
  CryptAuthClient::ErrorCallback share_group_private_key_failure_callback_;

  CryptAuthEciesEncryptor::IdToOutputMap id_to_encrypted_group_private_key_map_;

  std::optional<CryptAuthDeviceSyncResult::ResultCode> device_sync_result_code_;

  std::unique_ptr<MockCryptAuthClientFactory> client_factory_;
  std::unique_ptr<FakeCryptAuthEciesEncryptorFactory>
      fake_cryptauth_ecies_encryptor_factory_;
  raw_ptr<base::MockOneShotTimer, DanglingUntriaged> timer_;

  std::unique_ptr<CryptAuthGroupPrivateKeySharer> sharer_;
};

TEST_F(DeviceSyncCryptAuthGroupPrivateKeySharerImplTest, Success) {
  base::flat_set<std::string> device_ids = GetAllTestDeviceIds();
  ShareGroupPrivateKey(GetGroupKey(),
                       IdToEncryptingKeyMapFromDeviceIds(device_ids));

  RunGroupPrivateKeyEncryptor(device_ids, {} /* device_ids_to_fail */);

  VerifyShareGroupPrivateKeyRequest(device_ids);

  SendShareGroupPrivateKeyResponse();

  VerifyShareGroupPrivateKeyResult(
      CryptAuthDeviceSyncResult::ResultCode::kSuccess);
}

TEST_F(DeviceSyncCryptAuthGroupPrivateKeySharerImplTest,
       FinishedWithNonFatalErrors_SingleEncryptionFails) {
  base::flat_set<std::string> device_ids = GetAllTestDeviceIds();
  ShareGroupPrivateKey(GetGroupKey(),
                       IdToEncryptingKeyMapFromDeviceIds(device_ids));

  // Encryption fails for a remote device.
  std::string encryption_failure_device_id =
      GetRemoteDeviceNeedsGroupPrivateKeyForTest().instance_id();
  RunGroupPrivateKeyEncryptor(
      device_ids, {encryption_failure_device_id} /* device_ids_to_fail */);

  base::flat_set<std::string> expected_device_ids = device_ids;
  expected_device_ids.erase(encryption_failure_device_id);
  VerifyShareGroupPrivateKeyRequest(expected_device_ids);

  SendShareGroupPrivateKeyResponse();

  VerifyShareGroupPrivateKeyResult(
      CryptAuthDeviceSyncResult::ResultCode::kFinishedWithNonFatalErrors);
}

TEST_F(DeviceSyncCryptAuthGroupPrivateKeySharerImplTest,
       Failure_AllEncryptionsFails) {
  base::flat_set<std::string> device_ids = GetAllTestDeviceIds();
  ShareGroupPrivateKey(GetGroupKey(),
                       IdToEncryptingKeyMapFromDeviceIds(device_ids));

  // Encryption fails for all devices.
  RunGroupPrivateKeyEncryptor(device_ids, device_ids /* device_ids_to_fail */);

  VerifyShareGroupPrivateKeyResult(
      CryptAuthDeviceSyncResult::ResultCode::kErrorEncryptingGroupPrivateKey);
}

TEST_F(DeviceSyncCryptAuthGroupPrivateKeySharerImplTest,
       FinishedWithNonFatalErrors_SingleEncryptionKeyEmpty) {
  base::flat_set<std::string> device_ids = GetAllTestDeviceIds();
  CryptAuthGroupPrivateKeySharer::IdToEncryptingKeyMap
      id_to_encrypting_key_map = IdToEncryptingKeyMapFromDeviceIds(device_ids);

  // A remote device has an empty encrypting key.
  std::string empty_encrypting_key_device_id =
      GetRemoteDeviceNeedsGroupPrivateKeyForTest().instance_id();
  id_to_encrypting_key_map[empty_encrypting_key_device_id].clear();
  ShareGroupPrivateKey(GetGroupKey(), id_to_encrypting_key_map);

  base::flat_set<std::string> expected_device_ids = device_ids;
  expected_device_ids.erase(empty_encrypting_key_device_id);
  RunGroupPrivateKeyEncryptor(expected_device_ids, {} /* device_ids_to_fail */);

  VerifyShareGroupPrivateKeyRequest(expected_device_ids);

  SendShareGroupPrivateKeyResponse();

  VerifyShareGroupPrivateKeyResult(
      CryptAuthDeviceSyncResult::ResultCode::kFinishedWithNonFatalErrors);
}

TEST_F(DeviceSyncCryptAuthGroupPrivateKeySharerImplTest,
       Failure_AllEncryptionKeysEmpty) {
  base::flat_set<std::string> device_ids = GetAllTestDeviceIds();

  // All devices have an empty encrypting key.
  CryptAuthGroupPrivateKeySharer::IdToEncryptingKeyMap
      id_to_encrypting_key_map = IdToEncryptingKeyMapFromDeviceIds(device_ids);
  for (auto& id_encrypting_key_pair : id_to_encrypting_key_map)
    id_encrypting_key_pair.second.clear();

  ShareGroupPrivateKey(GetGroupKey(), id_to_encrypting_key_map);

  VerifyShareGroupPrivateKeyResult(
      CryptAuthDeviceSyncResult::ResultCode::kErrorEncryptingGroupPrivateKey);
}

TEST_F(DeviceSyncCryptAuthGroupPrivateKeySharerImplTest,
       Failure_Timeout_Encryption) {
  base::flat_set<std::string> device_ids = GetAllTestDeviceIds();
  ShareGroupPrivateKey(GetGroupKey(),
                       IdToEncryptingKeyMapFromDeviceIds(device_ids));

  timer()->Fire();

  VerifyShareGroupPrivateKeyResult(
      CryptAuthDeviceSyncResult::ResultCode::
          kErrorTimeoutWaitingForGroupPrivateKeyEncryption);
}

TEST_F(DeviceSyncCryptAuthGroupPrivateKeySharerImplTest,
       Failure_Timeout_ShareGroupPrivateKeyRequest) {
  base::flat_set<std::string> device_ids = GetAllTestDeviceIds();
  ShareGroupPrivateKey(GetGroupKey(),
                       IdToEncryptingKeyMapFromDeviceIds(device_ids));

  RunGroupPrivateKeyEncryptor(device_ids, {} /* device_ids_to_fail */);

  VerifyShareGroupPrivateKeyRequest(device_ids);

  timer()->Fire();

  VerifyShareGroupPrivateKeyResult(
      CryptAuthDeviceSyncResult::ResultCode::
          kErrorTimeoutWaitingForShareGroupPrivateKeyResponse);
}

TEST_F(DeviceSyncCryptAuthGroupPrivateKeySharerImplTest,
       Failure_ApiCall_ShareGroupPrivateKey) {
  base::flat_set<std::string> device_ids = GetAllTestDeviceIds();
  ShareGroupPrivateKey(GetGroupKey(),
                       IdToEncryptingKeyMapFromDeviceIds(device_ids));

  RunGroupPrivateKeyEncryptor(device_ids, {} /* device_ids_to_fail */);

  VerifyShareGroupPrivateKeyRequest(device_ids);

  FailShareGroupPrivateKeyRequest(NetworkRequestError::kBadRequest);

  VerifyShareGroupPrivateKeyResult(
      CryptAuthDeviceSyncResult::ResultCode::
          kErrorShareGroupPrivateKeyApiCallBadRequest);
}

}  // namespace device_sync

}  // namespace ash