chromium/chrome/browser/extensions/api/enterprise_platform_keys/enterprise_platform_keys_api_unittest.cc

// Copyright 2013 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/extensions/api/enterprise_platform_keys/enterprise_platform_keys_api.h"

#include <optional>
#include <string_view>
#include <utility>

#include "base/containers/span.h"
#include "base/memory/ptr_util.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/ref_counted.h"
#include "base/memory/scoped_refptr.h"
#include "base/values.h"
#include "chrome/browser/ash/attestation/mock_tpm_challenge_key.h"
#include "chrome/browser/ash/platform_keys/key_permissions/fake_user_private_token_kpm_service.h"
#include "chrome/browser/ash/platform_keys/key_permissions/mock_key_permissions_manager.h"
#include "chrome/browser/ash/platform_keys/key_permissions/user_private_token_kpm_service_factory.h"
#include "chrome/browser/extensions/api/enterprise_platform_keys_private/enterprise_platform_keys_private_api.h"
#include "chrome/browser/signin/identity_manager_factory.h"
#include "chrome/common/extensions/api/enterprise_platform_keys.h"
#include "chrome/common/pref_names.h"
#include "chrome/test/base/browser_with_test_window_test.h"
#include "chrome/test/base/testing_profile_manager.h"
#include "chromeos/ash/components/dbus/attestation/keystore.pb.h"
#include "chromeos/ash/components/dbus/constants/attestation_constants.h"
#include "components/signin/public/identity_manager/identity_manager.h"
#include "components/signin/public/identity_manager/identity_test_utils.h"
#include "extensions/browser/api_test_utils.h"
#include "extensions/browser/extension_function_dispatcher.h"
#include "extensions/common/extension_builder.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

using testing::_;
using testing::Invoke;
using testing::NiceMock;

namespace extensions {
namespace {

const char kUserEmail[] = "[email protected]";

void FakeRunCheckNotRegister(::attestation::VerifiedAccessFlow flow_type,
                             Profile* profile,
                             ash::attestation::TpmChallengeKeyCallback callback,
                             const std::string& challenge,
                             bool register_key,
                             ::attestation::KeyType key_crypto_type,
                             const std::string& key_name_for_spkac,
                             const std::optional<std::string>& signals) {
  EXPECT_FALSE(register_key);
  std::move(callback).Run(
      ash::attestation::TpmChallengeKeyResult::MakeChallengeResponse(
          "response"));
}

class EPKChallengeKeyTestBase : public BrowserWithTestWindowTest {
 protected:
  EPKChallengeKeyTestBase() : extension_(ExtensionBuilder("Test").Build()) {
    stub_install_attributes_.SetCloudManaged("google.com", "device_id");
  }

  void SetUp() override {
    BrowserWithTestWindowTest::SetUp();
    prefs_ = browser()->profile()->GetPrefs();
    SetAuthenticatedUser();

    // UserPrivateTokenKeyPermissionsManagerService and the underlying
    // KeyPermissionsManager are not actually used by *ChallengeKey* classes,
    // but they are created as a part of KeystoreService, so just fake them out.
    // It is ok to pass an unretained pointer because the factory should only be
    // used during the tests' lifetime.
    ash::platform_keys::UserPrivateTokenKeyPermissionsManagerServiceFactory::
        GetInstance()
            ->SetTestingFactory(
                browser()->profile(),
                base::BindRepeating(&EPKChallengeKeyTestBase::
                                        CreateKeyPermissionsManagerService,
                                    base::Unretained(this)));
  }

  void SetMockTpmChallenger() {
    auto mock_tpm_challenge_key =
        std::make_unique<NiceMock<ash::attestation::MockTpmChallengeKey>>();
    // Will be used with EXPECT_CALL.
    mock_tpm_challenge_key_ = mock_tpm_challenge_key.get();
    mock_tpm_challenge_key->EnableFake();
    // transfer ownership inside factory
    ash::attestation::TpmChallengeKeyFactory::SetForTesting(
        std::move(mock_tpm_challenge_key));
  }

  void SetMockTpmChallengerBadBase64Error() {
    auto mock_tpm_challenge_key =
        std::make_unique<NiceMock<ash::attestation::MockTpmChallengeKey>>();
    // Error text is "Challenge is not base64 encoded."
    mock_tpm_challenge_key->EnableFakeError(
        ash::attestation::TpmChallengeKeyResultCode::kChallengeBadBase64Error);
    // transfer ownership inside factory
    ash::attestation::TpmChallengeKeyFactory::SetForTesting(
        std::move(mock_tpm_challenge_key));
  }

  // This will be called by BrowserWithTestWindowTest::SetUp();
  std::string GetDefaultProfileName() override { return kUserEmail; }

  void LogIn(const std::string& email) override {
    const AccountId account_id = AccountId::FromUserEmail(email);
    user_manager()->AddUserWithAffiliation(account_id,
                                           /*is_affiliated=*/true);
    user_manager()->UserLoggedIn(
        account_id,
        user_manager::FakeUserManager::GetFakeUsernameHash(account_id),
        /*browser_restart=*/false,
        /*is_child=*/false);
  }

  std::unique_ptr<KeyedService> CreateKeyPermissionsManagerService(
      content::BrowserContext* context) {
    return std::make_unique<
        ash::platform_keys::FakeUserPrivateTokenKeyPermissionsManagerService>(
        &key_permissions_manager_);
  }

  // Derived classes can override this method to set the required authenticated
  // user in the IdentityManager class.
  virtual void SetAuthenticatedUser() {
    signin::MakePrimaryAccountAvailable(
        IdentityManagerFactory::GetForProfile(browser()->profile()), kUserEmail,
        signin::ConsentLevel::kSync);
  }

  // Like api_test_utils::RunFunctionAndReturnError but with an
  // explicit list of args.
  std::string RunFunctionAndReturnError(
      ExtensionFunction* function,
      base::Value::List args,
      content::BrowserContext* browser_context) {
    auto dispatcher =
        std::make_unique<ExtensionFunctionDispatcher>(browser_context);
    api_test_utils::RunFunction(
        function, std::move(args), std::move(dispatcher),
        extensions::api_test_utils::FunctionMode::kNone);
    EXPECT_EQ(ExtensionFunction::FAILED, *function->response_type());
    return function->GetError();
  }

  // Like api_test_utils::RunFunctionAndReturnSingleResult but
  // with an explicit list of args.
  base::Value RunFunctionAndReturnSingleResult(
      ExtensionFunction* function,
      base::Value::List args,
      content::BrowserContext* browser_context) {
    scoped_refptr<ExtensionFunction> function_owner(function);
    // Without a callback the function will not generate a result.
    function->set_has_callback(true);
    auto dispatcher =
        std::make_unique<ExtensionFunctionDispatcher>(browser_context);
    api_test_utils::RunFunction(
        function, std::move(args), std::move(dispatcher),
        extensions::api_test_utils::FunctionMode::kNone);
    EXPECT_TRUE(function->GetError().empty())
        << "Unexpected error: " << function->GetError();
    if (function->GetResultListForTest() &&
        !function->GetResultListForTest()->empty()) {
      return (*function->GetResultListForTest())[0].Clone();
    }
    return base::Value();
  }

  scoped_refptr<const extensions::Extension> extension_;
  ash::StubInstallAttributes stub_install_attributes_;
  ash::platform_keys::MockKeyPermissionsManager key_permissions_manager_;
  raw_ptr<PrefService, DanglingUntriaged> prefs_ = nullptr;
  raw_ptr<ash::attestation::MockTpmChallengeKey, DanglingUntriaged>
      mock_tpm_challenge_key_ = nullptr;
};

class EPKChallengeMachineKeyTest : public EPKChallengeKeyTestBase {
 protected:
  EPKChallengeMachineKeyTest()
      : func_(base::MakeRefCounted<
              EnterprisePlatformKeysChallengeMachineKeyFunction>()) {
    func_->set_extension(extension_.get());
  }

  base::Value::List CreateArgs() { return CreateArgsInternal(std::nullopt); }

  base::Value::List CreateArgsNoRegister() {
    return CreateArgsInternal(base::Value(false));
  }

  base::Value::List CreateArgsRegister() {
    return CreateArgsInternal(base::Value(true));
  }

  base::Value::List CreateArgsInternal(
      std::optional<base::Value> register_key) {
    static constexpr std::string_view kData = "challenge";
    base::Value::List args;
    args.Append(base::Value(base::as_bytes(base::make_span(kData))));
    if (register_key) {
      args.Append(std::move(*register_key));
    }
    return args;
  }

  scoped_refptr<EnterprisePlatformKeysChallengeMachineKeyFunction> func_;
};

TEST_F(EPKChallengeMachineKeyTest, ExtensionNotAllowed) {
  base::Value::List empty_allowlist;
  prefs_->SetList(prefs::kAttestationExtensionAllowlist,
                  std::move(empty_allowlist));

  EXPECT_EQ(
      ash::attestation::TpmChallengeKeyResult::kExtensionNotAllowedErrorMsg,
      RunFunctionAndReturnError(func_.get(), CreateArgs(), profile()));
}

TEST_F(EPKChallengeMachineKeyTest, Success) {
  SetMockTpmChallenger();

  base::Value::List allowlist;
  allowlist.Append(extension_->id());
  prefs_->SetList(prefs::kAttestationExtensionAllowlist, std::move(allowlist));

  base::Value value(
      RunFunctionAndReturnSingleResult(func_.get(), CreateArgs(), profile()));

  ASSERT_TRUE(value.is_blob());
  std::string response(value.GetBlob().begin(), value.GetBlob().end());
  EXPECT_EQ("response", response);
}

TEST_F(EPKChallengeMachineKeyTest, BadChallengeThenErrorMessageReturned) {
  SetMockTpmChallengerBadBase64Error();

  base::Value::List allowlist;
  allowlist.Append(extension_->id());
  prefs_->SetList(prefs::kAttestationExtensionAllowlist, std::move(allowlist));

  base::Value value(
      RunFunctionAndReturnError(func_.get(), CreateArgs(), profile()));

  EXPECT_EQ(
      ash::attestation::TpmChallengeKeyResult::kChallengeBadBase64ErrorMsg,
      value);
}

TEST_F(EPKChallengeMachineKeyTest, KeyNotRegisteredByDefault) {
  SetMockTpmChallenger();

  base::Value::List allowlist;
  allowlist.Append(extension_->id());
  prefs_->SetList(prefs::kAttestationExtensionAllowlist, std::move(allowlist));

  EXPECT_CALL(*mock_tpm_challenge_key_, BuildResponse)
      .WillOnce(Invoke(FakeRunCheckNotRegister));

  auto dispatcher = std::make_unique<ExtensionFunctionDispatcher>(profile());
  EXPECT_TRUE(api_test_utils::RunFunction(
      func_.get(), CreateArgs(), std::move(dispatcher),
      extensions::api_test_utils::FunctionMode::kNone));
}

class EPKChallengeUserKeyTest : public EPKChallengeKeyTestBase {
 protected:
  EPKChallengeUserKeyTest()
      : func_(base::MakeRefCounted<
              EnterprisePlatformKeysChallengeUserKeyFunction>()) {
    func_->set_extension(extension_.get());
  }

  void SetUp() override { EPKChallengeKeyTestBase::SetUp(); }

  base::Value::List CreateArgs() { return CreateArgsInternal(true); }

  base::Value::List CreateArgsNoRegister() { return CreateArgsInternal(false); }

  base::Value::List CreateArgsInternal(bool register_key) {
    static constexpr std::string_view kData = "challenge";
    base::Value::List args;
    args.Append(base::Value(base::as_bytes(base::make_span(kData))));
    args.Append(register_key);
    return args;
  }

  EPKPChallengeKey impl_;
  scoped_refptr<EnterprisePlatformKeysChallengeUserKeyFunction> func_;
};

TEST_F(EPKChallengeUserKeyTest, Success) {
  SetMockTpmChallenger();

  base::Value::List allowlist;
  allowlist.Append(extension_->id());
  prefs_->SetList(prefs::kAttestationExtensionAllowlist, std::move(allowlist));

  base::Value value(
      RunFunctionAndReturnSingleResult(func_.get(), CreateArgs(), profile()));

  ASSERT_TRUE(value.is_blob());
  std::string response(value.GetBlob().begin(), value.GetBlob().end());
  EXPECT_EQ("response", response);
}

TEST_F(EPKChallengeUserKeyTest, BadChallengeThenErrorMessageReturned) {
  SetMockTpmChallengerBadBase64Error();

  base::Value::List allowlist;
  allowlist.Append(extension_->id());
  prefs_->SetList(prefs::kAttestationExtensionAllowlist, std::move(allowlist));

  base::Value value(
      RunFunctionAndReturnError(func_.get(), CreateArgs(), profile()));

  EXPECT_EQ(
      ash::attestation::TpmChallengeKeyResult::kChallengeBadBase64ErrorMsg,
      value);
}

TEST_F(EPKChallengeUserKeyTest, ExtensionNotAllowedThenErrorMessageReturned) {
  base::Value::List empty_allowlist;
  prefs_->SetList(prefs::kAttestationExtensionAllowlist,
                  std::move(empty_allowlist));

  EXPECT_EQ(
      ash::attestation::TpmChallengeKeyResult::kExtensionNotAllowedErrorMsg,
      RunFunctionAndReturnError(func_.get(), CreateArgs(), profile()));
}

using EPKChallengeKeyParams =
    std::tuple<api::enterprise_platform_keys::Scope,
               std::optional<api::enterprise_platform_keys::Algorithm>>;

class EPKChallengeKeyTest
    : public EPKChallengeKeyTestBase,
      public testing::WithParamInterface<EPKChallengeKeyParams> {
 protected:
  EPKChallengeKeyTest()
      : func_(base::MakeRefCounted<
              EnterprisePlatformKeysChallengeKeyFunction>()) {
    func_->set_extension(extension_.get());
  }

  void AllowlistExtension() {
    base::Value::List allowlist;
    allowlist.Append(extension_->id());
    prefs_->SetList(prefs::kAttestationExtensionAllowlist,
                    std::move(allowlist));
  }

  base::Value::List CreateArgs(
      std::optional<api::enterprise_platform_keys::RegisterKeyOptions>
          register_key,
      api::enterprise_platform_keys::Scope scope) {
    api::enterprise_platform_keys::ChallengeKeyOptions options;
    auto challenge = base::as_bytes(base::make_span("challenge"));
    options.challenge = std::vector(challenge.begin(), challenge.end());
    if (register_key.has_value()) {
      options.register_key.emplace(std::move(register_key.value()));
    }
    options.scope = scope;

    base::Value::List args;
    args.Append(options.ToValue());
    return args;
  }

  scoped_refptr<EnterprisePlatformKeysChallengeKeyFunction> func_;
  base::Value::List args_;
};

// This test ensures challengeKey propagates algorithm, scope, and registerKey
// parameters to the TpmChallengeKey class.
TEST_P(EPKChallengeKeyTest, Success) {
  SetMockTpmChallenger();
  AllowlistExtension();

  auto scope = std::get<0>(GetParam());
  ::attestation::VerifiedAccessFlow expected_va_flow_type;
  switch (scope) {
    case api::enterprise_platform_keys::Scope::kNone:
    case api::enterprise_platform_keys::Scope::kMachine:
      expected_va_flow_type = ::attestation::ENTERPRISE_MACHINE;
      break;
    case api::enterprise_platform_keys::Scope::kUser:
      expected_va_flow_type = ::attestation::ENTERPRISE_USER;
      break;
  }
  auto algorithm_opt = std::get<1>(GetParam());
  auto expect_register = algorithm_opt.has_value();
  auto expect_crypto_key_type = ::attestation::KEY_TYPE_RSA;
  std::optional<api::enterprise_platform_keys::RegisterKeyOptions>
      register_key = std::nullopt;
  if (algorithm_opt.has_value()) {
    switch (algorithm_opt.value()) {
      case api::enterprise_platform_keys::Algorithm::kNone:
      case api::enterprise_platform_keys::Algorithm::kRsa:
        expect_crypto_key_type = ::attestation::KEY_TYPE_RSA;
        break;
      case api::enterprise_platform_keys::Algorithm::kEcdsa:
        expect_crypto_key_type = ::attestation::KEY_TYPE_ECC;
        break;
    }
    register_key = api::enterprise_platform_keys::RegisterKeyOptions();
    register_key.value().algorithm = algorithm_opt.value();
  }

  EXPECT_CALL(*mock_tpm_challenge_key_,
              BuildResponse(expected_va_flow_type, _, _, _, expect_register,
                            expect_crypto_key_type, _, _));

  base::Value value(RunFunctionAndReturnSingleResult(
      func_.get(), CreateArgs(std::move(register_key), scope), profile()));

  ASSERT_TRUE(value.is_blob());
  std::string response(value.GetBlob().begin(), value.GetBlob().end());
  EXPECT_EQ("response", response);
}

// This test ensures challengeKey cannot be called by extensions not on the
// allow list.
TEST_P(EPKChallengeKeyTest, ExtensionNotAllowed) {
  base::Value::List empty_allowlist;
  prefs_->SetList(prefs::kAttestationExtensionAllowlist,
                  std::move(empty_allowlist));

  auto scope = std::get<0>(GetParam());
  auto algorithm_opt = std::get<1>(GetParam());
  std::optional<api::enterprise_platform_keys::RegisterKeyOptions>
      register_key = std::nullopt;
  if (algorithm_opt.has_value()) {
    register_key = api::enterprise_platform_keys::RegisterKeyOptions();
    register_key.value().algorithm = algorithm_opt.value();
  }

  auto args = CreateArgs(std::move(register_key), scope);

  EXPECT_EQ(
      ash::attestation::TpmChallengeKeyResult::kExtensionNotAllowedErrorMsg,
      RunFunctionAndReturnError(func_.get(), std::move(args), profile()));
}

INSTANTIATE_TEST_SUITE_P(
    EPKChallengeKeyTests,
    EPKChallengeKeyTest,
    testing::Combine(
        testing::Values(api::enterprise_platform_keys::Scope::kMachine,
                        api::enterprise_platform_keys::Scope::kUser),
        testing::Values(api::enterprise_platform_keys::Algorithm::kRsa,
                        api::enterprise_platform_keys::Algorithm::kEcdsa,
                        std::nullopt)),

    [](const testing::TestParamInfo<EPKChallengeKeyParams>& info) {
      std::string alg =
          api::enterprise_platform_keys::ToString(std::get<0>(info.param));
      auto scope_opt = std::get<1>(info.param);

      std::string scope =
          scope_opt.has_value()
              ? api::enterprise_platform_keys::ToString(scope_opt.value())
              : "Unregistered";
      return std::string(alg) + scope;
    });

}  // namespace
}  // namespace extensions