chromium/chrome/browser/ash/cert_provisioning/cert_provisioning_worker_static_unittest.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/40285824): Remove this and convert code to safer constructs.
#pragma allow_unsafe_buffers
#endif

#include "chrome/browser/ash/cert_provisioning/cert_provisioning_worker_static.h"

#include <stdint.h>

#include <memory>
#include <string>
#include <vector>

#include "base/base64.h"
#include "base/functional/callback.h"
#include "base/json/json_string_value_serializer.h"
#include "base/json/json_writer.h"
#include "base/memory/raw_ptr.h"
#include "base/strings/stringprintf.h"
#include "base/test/gmock_callback_support.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/test_future.h"
#include "base/test/values_test_util.h"
#include "base/time/time.h"
#include "chrome/browser/ash/attestation/mock_tpm_challenge_key_subtle.h"
#include "chrome/browser/ash/attestation/tpm_challenge_key_subtle.h"
#include "chrome/browser/ash/cert_provisioning/cert_provisioning_common.h"
#include "chrome/browser/ash/cert_provisioning/cert_provisioning_metrics.h"
#include "chrome/browser/ash/cert_provisioning/cert_provisioning_test_helpers.h"
#include "chrome/browser/ash/cert_provisioning/mock_cert_provisioning_client.h"
#include "chrome/browser/ash/cert_provisioning/mock_cert_provisioning_invalidator.h"
#include "chrome/browser/ash/platform_keys/key_permissions/fake_user_private_token_kpm_service.h"
#include "chrome/browser/ash/platform_keys/key_permissions/key_permissions_manager.h"
#include "chrome/browser/ash/platform_keys/key_permissions/key_permissions_manager_impl.h"
#include "chrome/browser/ash/platform_keys/key_permissions/mock_key_permissions_manager.h"
#include "chrome/browser/ash/platform_keys/key_permissions/user_private_token_kpm_service_factory.h"
#include "chrome/browser/ash/platform_keys/mock_platform_keys_service.h"
#include "chrome/browser/ash/platform_keys/platform_keys_service.h"
#include "chrome/browser/ash/platform_keys/platform_keys_service_factory.h"
#include "chrome/browser/chromeos/platform_keys/platform_keys.h"
#include "chromeos/ash/components/dbus/attestation/fake_attestation_client.h"
#include "components/prefs/pref_change_registrar.h"
#include "components/prefs/pref_service.h"
#include "components/prefs/testing_pref_service.h"
#include "content/public/test/browser_task_environment.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace ash::cert_provisioning {
namespace {

namespace em = ::enterprise_management;

using attestation::MockTpmChallengeKeySubtle;
using ::base::test::IsJson;
using ::base::test::ParseJsonDict;
using ::base::test::RunOnceCallback;
using ::chromeos::platform_keys::HashAlgorithm;
using ::chromeos::platform_keys::KeyAttributeType;
using ::chromeos::platform_keys::Status;
using ::chromeos::platform_keys::TokenId;
using platform_keys::KeyUsage;
using ::testing::_;
using ::testing::AtLeast;
using ::testing::Eq;
using ::testing::Mock;
using ::testing::SaveArg;
using ::testing::StrictMock;

// Generated by chrome/test/data/policy/test_certs/create_test_certs.sh
constexpr char kFakeCertificate[] = R"(-----BEGIN CERTIFICATE-----
MIIDJzCCAg+gAwIBAgIBATANBgkqhkiG9w0BAQsFADAXMRUwEwYDVQQDDAxyb290
X2NhX2NlcnQwHhcNMjAwMjI1MTUyNTU2WhcNMzAwMjIyMTUyNTU2WjAUMRIwEAYD
VQQDDAkxMjcuMC4wLjEwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDW
druvpaJovmyWzIcjtsSk/lp319+zNPSYGLzJzTeEmnFoDf/b89ft6xR1NIahmvVd
UHGOMlzgDKnNkqWw+pgpn6U8dk+leWnwlUefzDz7OY8qXfX29Vh0m/kATQc64lnp
rX19fEi2DOgH6heCQDSaHI/KAnAXccwl8kdGuTEnvdzbdHqQq8pPGpEqzC/NOjk7
kDNkUt0J74ZVMm4+jhVOgZ35mFLtC+xjfycBgbnt8yfPOzmOMwXTjYDPNaIy32AZ
t66oIToteoW5Ilg+j5Mto3unBDHrw8rml3+W/nwHuOPEIgBqLQFfWtXpuX8CbcS6
SFNK4hxCJOvlzUbgTpsrAgMBAAGjgYAwfjAMBgNVHRMBAf8EAjAAMB0GA1UdDgQW
BBRDEl1/2pL5LtKnpIly+XCj3N6MwDAfBgNVHSMEGDAWgBQrwVEnUQZlX850A2N+
URfS8BxoyzAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0RBAgw
BocEfwAAATANBgkqhkiG9w0BAQsFAAOCAQEAXZd+Ul7GUFZPLSiTZ618hUI2UdO0
7rtPwBw3TephWuyEeHht+WhzA3sRL3nprEiJqIg5w/Tlfz4dsObpSU3vKmDhLzAx
HJrN5vKdbEj9wyuhYSRJwvJka1ZOgPzhQcDQOp1SqonNxLx/sSMDR2UIDMBGzrkQ
sDkn58N5eWm+hZADOAKROHR47j85VcsmYGK7z2x479YzsyWyOm0dbACXv7/HvFkz
56KvgxRaPZQzQUg5yuXa21IjQz07wyWSYnHpm2duAbYFl6CTR9Rlj5vpRkKsQP1W
mMhGDBfgEskdbM+0agsZrJupoQMBUbD5gflcJlW3kwlboi3dTtiGixfYWw==
-----END CERTIFICATE-----)";

// Extracted from kFakeCertificate using the command:
// openssl x509 -pubkey -noout -in cert.pem
// and reformatted as a single line.
constexpr char kPublicKeyBase64[] =
    "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1na7r6WiaL5slsyHI7bEpP5ad9ffsz"
    "T0mBi8yc03hJpxaA3/2/"
    "PX7esUdTSGoZr1XVBxjjJc4AypzZKlsPqYKZ+lPHZPpXlp8JVHn8w8+"
    "zmPKl319vVYdJv5AE0HOuJZ6a19fXxItgzoB+"
    "oXgkA0mhyPygJwF3HMJfJHRrkxJ73c23R6kKvKTxqRKswvzTo5O5AzZFLdCe+"
    "GVTJuPo4VToGd+ZhS7QvsY38nAYG57fMnzzs5jjMF042AzzWiMt9gGbeuqCE6LXqFuSJYPo+"
    "TLaN7pwQx68PK5pd/lv58B7jjxCIAai0BX1rV6bl/Am3EukhTSuIcQiTr5c1G4E6bKwIDAQAB";

// A certificate that doesn't match the public key kPublicKeyBase64.
// Taken from net/data/ssl/certificates/client_1.pem
constexpr char kFakeCertificatePubKeyMismatch[] = R"(-----BEGIN CERTIFICATE-----
MIIDEjCCAfqgAwIBAgICEAAwDQYJKoZIhvcNAQELBQAwDzENMAsGA1UEAwwEQiBD
QTAeFw0yMjEwMTkxNjU4NTVaFw0zMjEwMTYxNjU4NTVaMBgxFjAUBgNVBAMMDUNs
aWVudCBDZXJ0IEEwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDa+Dq7
TTFSw1AxRkaftrCM8tuPbYH7NTxLdHil0F2y4G+PvrlqN0qB43tRaKJPQEYhG+Rn
ppXeOk6/AbgOFXBQCPoVJWOjxwMX3ea3rSLM5C9xUP9Rsnf/fkngD6G6pOo2nYin
fgpINQDhGB/r8BJs69RNhvgdbN4aV7Bz8WGYqKF3DVhV+Di5zIOPNC9zoZQPey4d
uMS06OERG7Op8fFws3QoCzEywVdAbe/R+m5oeg875vLVvmONwDi52mqv4rgbfl+a
PhyyPzoR3hdIPEi13AQB5hmyLAcTDtvcib3beNLw586NXcYgQZdcbLmDkjVRDK4u
niE4QaRUGeRJD2+xAgMBAAGjbzBtMAwGA1UdEwEB/wQCMAAwHQYDVR0lBBYwFAYI
KwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBSQ42eTfRBubU1JKWU45fWcZmLd
aTAfBgNVHSMEGDAWgBRvxexARA9ceASOZhFOoe4eODj9cjANBgkqhkiG9w0BAQsF
AAOCAQEAPvOKn7eKh09kVgjmoAfkufGKiFCgzJW9q34Clw/OG/5W/EQ8+rKY+c+0
pgeetMVkmkmQS4Fc7e7MOk/KNujdPOfBK2u5Yin4bcphU0AgMhuF+VUOhVXPv+m4
XewWqo0dhdgYGmqq4pHPuotGLejUqSSYILX34Ln+Lu/plBfDVFePf/gEkWrTDDof
UBwYwQj+yYNPIy/EuI3b0/JFVCcpK/NQfXkfOBcJijiQY4spjBJ/G9oOAbXHfZtl
mVs+3guwg9eBsA20CbRjbSqwJfUfz9+x/IrEKSk70yy6rYMhVtwee4G4d0rLCeIV
Mjt4aTOX/y/glIOdbSfQj/SunXs1GA==
-----END CERTIFICATE-----)";

constexpr char kCertProfileId[] = "cert_profile_1";
constexpr char kCertProfileName[] = "Certificate Profile 1";
constexpr char kCertProfileVersion[] = "cert_profile_version_1";
constexpr base::TimeDelta kCertProfileRenewalPeriod = base::Seconds(0);
// Prefix + certificate profile name.
constexpr char kInvalidationTopic[] = "fake_invalidation_topic_1";
constexpr char kChallenge[] = "fake_va_challenge_1";
constexpr char kChallengeResponse[] = "fake_va_challenge_response_1";
constexpr unsigned int kNonVaKeyModulusLengthBits = 2048;

constexpr base::TimeDelta kSmallDelay = base::Milliseconds(500);
// A delay time that ensures that DownloadCert happens.
constexpr base::TimeDelta kInitialDownloadCertDelay = base::Seconds(35);

const std::string& GetPublicKey() {
  static std::string public_key;
  if (public_key.empty()) {
    base::Base64Decode(kPublicKeyBase64, &public_key);
  }
  return public_key;
}

const std::vector<uint8_t>& GetPublicKeyBin() {
  static std::optional<std::vector<uint8_t>> public_key;
  if (!public_key.has_value()) {
    public_key = base::Base64Decode(kPublicKeyBase64);
    CHECK(public_key.has_value());
  }
  return public_key.value();
}

std::vector<uint8_t> GetDataToSign() {
  return std::vector<uint8_t>({10, 11, 12, 13, 14});
}

std::string GetSignatureStr() {
  return std::string({1, 2, 3, 4, 5});
}

std::vector<uint8_t> GetSignatureBin() {
  return std::vector<uint8_t>({1, 2, 3, 4, 5});
}

std::vector<uint8_t> GetCertProfileIdBin() {
  // -1 because of '\0'.
  return std::vector<uint8_t>(kCertProfileId,
                              kCertProfileId + sizeof(kCertProfileId) - 1);
}

void VerifyDeleteKeyCalledOnce(CertScope cert_scope) {
  const std::vector<::attestation::DeleteKeysRequest> delete_keys_history =
      AttestationClient::Get()->GetTestInterface()->delete_keys_history();
  EXPECT_EQ(delete_keys_history.size(), 1u);
  EXPECT_EQ(delete_keys_history[0].username().empty(),
            cert_scope != CertScope::kUser);
  EXPECT_EQ(delete_keys_history[0].key_label_match(),
            GetKeyName(kCertProfileId));
  EXPECT_EQ(delete_keys_history[0].match_behavior(),
            ::attestation::DeleteKeysRequest::MATCH_BEHAVIOR_EXACT);
}

// Using macros to reduce boilerplate code, but keep real line numbers in
// error messages in case of expectation failure. They use some of protected
// fields of CertProvisioningWorkerStaticTest class and may be considered as
// extra methods of it. *_OK macros immediately call callbacks with some
// successful results. *_NO_OP doesn't call callbacks.
#define EXPECT_PREPARE_KEY_OK(MOCK_TPM_CHALLENGE_KEY, PREPARE_KEY_FUNC)    \
  {                                                                        \
    auto public_key_result =                                               \
        attestation::TpmChallengeKeyResult::MakePublicKey(GetPublicKey()); \
    EXPECT_CALL((MOCK_TPM_CHALLENGE_KEY), PREPARE_KEY_FUNC)                \
        .Times(1)                                                          \
        .WillOnce(RunOnceCallback<5>(public_key_result));                  \
  }

#define EXPECT_SIGN_CHALLENGE_OK(MOCK_TPM_CHALLENGE_KEY, SIGN_CHALLENGE_FUNC) \
  {                                                                           \
    auto sign_challenge_result =                                              \
        attestation::TpmChallengeKeyResult::MakeChallengeResponse(            \
            kChallengeResponse);                                              \
    EXPECT_CALL((MOCK_TPM_CHALLENGE_KEY), SIGN_CHALLENGE_FUNC)                \
        .Times(1)                                                             \
        .WillOnce(RunOnceCallback<1>(sign_challenge_result));                 \
  }

#define EXPECT_REGISTER_KEY_OK(MOCK_TPM_CHALLENGE_KEY, REGISTER_KEY_FUNC) \
  {                                                                       \
    auto register_key_result =                                            \
        attestation::TpmChallengeKeyResult::MakeSuccess();                \
    EXPECT_CALL((MOCK_TPM_CHALLENGE_KEY), REGISTER_KEY_FUNC)              \
        .Times(1)                                                         \
        .WillOnce(RunOnceCallback<0>(register_key_result));               \
  }

#define EXPECT_START_CSR_OK(START_CSR_FUNC, HASHING_ALGO)            \
  {                                                                  \
    EXPECT_CALL(cert_provisioning_client_, START_CSR_FUNC)           \
        .Times(1)                                                    \
        .WillOnce(RunOnceCallback<1>(                                \
            policy::DeviceManagementStatus::DM_STATUS_SUCCESS,       \
            /*response_error=*/std::nullopt,                         \
            /*try_again_later_ms=*/std::nullopt, kInvalidationTopic, \
            kChallenge, HASHING_ALGO, GetDataToSign()));             \
  }

#define EXPECT_START_CSR_OK_WITHOUT_VA(START_CSR_FUNC, HASHING_ALGO) \
  {                                                                  \
    EXPECT_CALL(cert_provisioning_client_, START_CSR_FUNC)           \
        .Times(1)                                                    \
        .WillOnce(RunOnceCallback<1>(                                \
            policy::DeviceManagementStatus::DM_STATUS_SUCCESS,       \
            /*response_error=*/std::nullopt,                         \
            /*try_again_later_ms=*/std::nullopt, kInvalidationTopic, \
            /*va_challenge=*/"", HASHING_ALGO, GetDataToSign()));    \
  }

#define EXPECT_START_CSR_TRY_LATER(START_CSR_FUNC, DELAY_MS)       \
  {                                                                \
    EXPECT_CALL(cert_provisioning_client_, START_CSR_FUNC)         \
        .Times(1)                                                  \
        .WillOnce(RunOnceCallback<1>(                              \
            policy::DeviceManagementStatus::DM_STATUS_SUCCESS,     \
            /*response_error=*/std::nullopt,                       \
            /*try_again_later_ms=*/(DELAY_MS), kInvalidationTopic, \
            /*va_challenge=*/"",                                   \
            enterprise_management::HashingAlgorithm::              \
                HASHING_ALGORITHM_UNSPECIFIED,                     \
            /*data_to_sign=*/std::vector<uint8_t>()));             \
  }

#define EXPECT_START_CSR_INVALID_REQUEST(START_CSR_FUNC)                    \
  {                                                                         \
    EXPECT_CALL(cert_provisioning_client_, START_CSR_FUNC)                  \
        .Times(1)                                                           \
        .WillOnce(RunOnceCallback<1>(                                       \
            policy::DeviceManagementStatus::DM_STATUS_REQUEST_INVALID,      \
            /*response_error=*/std::nullopt,                                \
            /*try_again_later_ms=*/std::nullopt, /*invalidation_topic=*/"", \
            /*va_challenge=*/"",                                            \
            enterprise_management::HashingAlgorithm::                       \
                HASHING_ALGORITHM_UNSPECIFIED,                              \
            /*data_to_sign=*/std::vector<uint8_t>()));                      \
  }

#define EXPECT_START_CSR_CA_ERROR(START_CSR_FUNC)                           \
  {                                                                         \
    EXPECT_CALL(cert_provisioning_client_, START_CSR_FUNC)                  \
        .Times(1)                                                           \
        .WillOnce(RunOnceCallback<1>(                                       \
            policy::DeviceManagementStatus::DM_STATUS_SUCCESS,              \
            /*response_error=*/CertProvisioningResponseError::CA_ERROR,     \
            /*try_again_later_ms=*/std::nullopt, /*invalidation_topic=*/"", \
            /*va_challenge=*/"",                                            \
            enterprise_management::HashingAlgorithm::                       \
                HASHING_ALGORITHM_UNSPECIFIED,                              \
            /*data_to_sign=*/std::vector<uint8_t>()));                      \
  }

#define EXPECT_START_CSR_TEMPORARY_UNAVAILABLE(START_CSR_FUNC)               \
  {                                                                          \
    EXPECT_CALL(cert_provisioning_client_, START_CSR_FUNC)                   \
        .Times(1)                                                            \
        .WillOnce(RunOnceCallback<1>(                                        \
            policy::DeviceManagementStatus::DM_STATUS_TEMPORARY_UNAVAILABLE, \
            /*response_error=*/std::nullopt,                                 \
            /*try_again_later_ms=*/std::nullopt, /*invalidation_topic=*/"",  \
            /*va_challenge=*/"",                                             \
            enterprise_management::HashingAlgorithm::                        \
                HASHING_ALGORITHM_UNSPECIFIED,                               \
            /*data_to_sign=*/std::vector<uint8_t>()));                       \
  }

#define EXPECT_START_CSR_SERVICE_ACTIVATION_PENDING(START_CSR_FUNC)            \
  {                                                                            \
    EXPECT_CALL(cert_provisioning_client_, START_CSR_FUNC)                     \
        .Times(1)                                                              \
        .WillOnce(                                                             \
            RunOnceCallback<1>(policy::DeviceManagementStatus::                \
                                   DM_STATUS_SERVICE_ACTIVATION_PENDING,       \
                               /*response_error=*/std::nullopt,                \
                               /*try_again_later_ms=*/std::nullopt,            \
                               /*invalidation_topic=*/"", /*va_challenge=*/"", \
                               enterprise_management::HashingAlgorithm::       \
                                   HASHING_ALGORITHM_UNSPECIFIED,              \
                               /*data_to_sign=*/std::vector<uint8_t>()));      \
  }

#define EXPECT_START_CSR_INCONSISTENT_DATA(START_CSR_FUNC)                  \
  {                                                                         \
    EXPECT_CALL(cert_provisioning_client_, START_CSR_FUNC)                  \
        .Times(1)                                                           \
        .WillOnce(RunOnceCallback<1>(                                       \
            policy::DeviceManagementStatus::                                \
                DM_STATUS_SUCCESS, /*response_error=*/                      \
            CertProvisioningResponseError::INCONSISTENT_DATA,               \
            /*try_again_later_ms=*/std::nullopt, /*invalidation_topic=*/"", \
            /*va_challenge=*/"",                                            \
            enterprise_management::HashingAlgorithm::                       \
                HASHING_ALGORITHM_UNSPECIFIED,                              \
            /*data_to_sign=*/std::vector<uint8_t>()));                      \
  }

#define EXPECT_START_CSR_NO_OP(START_CSR_FUNC) \
  { EXPECT_CALL(cert_provisioning_client_, START_CSR_FUNC).Times(1); }

#define EXPECT_FINISH_CSR_OK(FINISH_CSR_FUNC)                                \
  {                                                                          \
    EXPECT_CALL(cert_provisioning_client_, FINISH_CSR_FUNC)                  \
        .Times(1)                                                            \
        .WillOnce(RunOnceCallback<3>(                                        \
            policy::DeviceManagementStatus::DM_STATUS_SUCCESS, std::nullopt, \
            std::nullopt));                                                  \
  }

#define EXPECT_FINISH_CSR_TRY_LATER(FINISH_CSR_FUNC, DELAY_MS)               \
  {                                                                          \
    EXPECT_CALL(cert_provisioning_client_, FINISH_CSR_FUNC)                  \
        .Times(1)                                                            \
        .WillOnce(RunOnceCallback<3>(                                        \
            policy::DeviceManagementStatus::DM_STATUS_SUCCESS, std::nullopt, \
            /*try_again_later_ms=*/(DELAY_MS)));                             \
  }

#define EXPECT_FINISH_CSR_SERVICE_ACTIVATION_PENDING(FINISH_CSR_FUNC)          \
  {                                                                            \
    EXPECT_CALL(cert_provisioning_client_, FINISH_CSR_FUNC)                    \
        .Times(1)                                                              \
        .WillOnce(RunOnceCallback<3>(policy::DeviceManagementStatus::          \
                                         DM_STATUS_SERVICE_ACTIVATION_PENDING, \
                                     std::nullopt, std::nullopt));             \
  }

#define EXPECT_DOWNLOAD_CERT_OK(DOWNLOAD_CERT_FUNC, CERTIFICATE_PEM)         \
  {                                                                          \
    EXPECT_CALL(cert_provisioning_client_, DOWNLOAD_CERT_FUNC)               \
        .Times(1)                                                            \
        .WillOnce(RunOnceCallback<1>(                                        \
            policy::DeviceManagementStatus::DM_STATUS_SUCCESS, std::nullopt, \
            std::nullopt, CERTIFICATE_PEM));                                 \
  }

#define EXPECT_DOWNLOAD_CERT_SERVICE_ACTIVATION_PENDING(DOWNLOAD_CERT_FUNC)    \
  {                                                                            \
    EXPECT_CALL(cert_provisioning_client_, DOWNLOAD_CERT_FUNC)                 \
        .Times(1)                                                              \
        .WillOnce(RunOnceCallback<1>(policy::DeviceManagementStatus::          \
                                         DM_STATUS_SERVICE_ACTIVATION_PENDING, \
                                     std::nullopt, std::nullopt,               \
                                     kFakeCertificate));                       \
  }

#define EXPECT_DOWNLOAD_CERT_TRY_LATER(DOWNLOAD_CERT_FUNC, DELAY_MS)         \
  {                                                                          \
    EXPECT_CALL(cert_provisioning_client_, DOWNLOAD_CERT_FUNC)               \
        .Times(1)                                                            \
        .WillOnce(RunOnceCallback<1>(                                        \
            policy::DeviceManagementStatus::DM_STATUS_SUCCESS, std::nullopt, \
            /*try_again_later_ms=*/(DELAY_MS), /*certificate=*/""));         \
  }

#define EXPECT_DOWNLOAD_CERT_NO_OP(DOWNLOAD_CERT_FUNC) \
  { EXPECT_CALL(cert_provisioning_client_, DOWNLOAD_CERT_FUNC).Times(1); }

#define EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(SET_FUNC)        \
  {                                                      \
    EXPECT_CALL(*platform_keys_service_, SET_FUNC)       \
        .Times(1)                                        \
        .WillOnce(RunOnceCallback<4>(Status::kSuccess)); \
  }

#define EXPECT_SET_ATTRIBUTE_FOR_KEY_FAIL(SET_FUNC)            \
  {                                                            \
    EXPECT_CALL(*platform_keys_service_, SET_FUNC)             \
        .Times(1)                                              \
        .WillOnce(RunOnceCallback<4>(Status::kErrorInternal)); \
  }

#define EXPECT_SIGN_RSAPKC1_DIGEST_OK(SIGN_FUNC)                            \
  {                                                                         \
    EXPECT_CALL(*platform_keys_service_, SIGN_FUNC)                         \
        .Times(1)                                                           \
        .WillOnce(RunOnceCallback<4>(GetSignatureBin(), Status::kSuccess)); \
  }

#define EXPECT_SIGN_RSAPKC1_RAW_OK(SIGN_FUNC)                               \
  {                                                                         \
    EXPECT_CALL(*platform_keys_service_, SIGN_FUNC)                         \
        .Times(1)                                                           \
        .WillOnce(RunOnceCallback<3>(GetSignatureBin(), Status::kSuccess)); \
  }

#define EXPECT_SIGN_RSAPKC1_DIGEST_FAIL(SIGN_FUNC)                         \
  {                                                                        \
    EXPECT_CALL(*platform_keys_service_, SIGN_FUNC)                        \
        .Times(1)                                                          \
        .WillOnce(RunOnceCallback<4>(/*signature=*/std::vector<uint8_t>(), \
                                     Status::kErrorInternal));             \
  }

#define EXPECT_IMPORT_CERTIFICATE_OK(IMPORT_FUNC)        \
  {                                                      \
    EXPECT_CALL(*platform_keys_service_, IMPORT_FUNC)    \
        .Times(1)                                        \
        .WillOnce(RunOnceCallback<2>(Status::kSuccess)); \
  }

// A mock for observing the state change callback of the worker.
class StateChangeCallbackObserver {
 public:
  MOCK_METHOD(void, StateChangeCallback, ());
};

class CertProvisioningWorkerStaticTest : public ::testing::Test {
 public:
  CertProvisioningWorkerStaticTest() { Init(); }
  CertProvisioningWorkerStaticTest(const CertProvisioningWorkerStaticTest&) =
      delete;
  CertProvisioningWorkerStaticTest& operator=(
      const CertProvisioningWorkerStaticTest&) = delete;
  ~CertProvisioningWorkerStaticTest() override = default;

  void SetUp() override {
    AttestationClient::InitializeFake();

    RegisterProfilePrefs(testing_pref_service_.registry());
    RegisterLocalStatePrefs(testing_pref_service_.registry());
  }

  void TearDown() override {
    EXPECT_FALSE(
        attestation::TpmChallengeKeySubtleFactory::WillReturnTestingInstance());
    AttestationClient::Shutdown();
  }

 protected:
  void Init() {
    platform_keys_service_ =
        static_cast<platform_keys::MockPlatformKeysService*>(
            platform_keys::PlatformKeysServiceFactory::GetInstance()
                ->SetTestingFactoryAndUse(
                    GetProfile(),
                    base::BindRepeating(
                        &platform_keys::BuildMockPlatformKeysService)));
    ASSERT_TRUE(platform_keys_service_);
    platform_keys::PlatformKeysServiceFactory::GetInstance()
        ->SetDeviceWideServiceForTesting(platform_keys_service_);

    key_permissions_manager_ =
        std::make_unique<platform_keys::MockKeyPermissionsManager>();

    platform_keys::UserPrivateTokenKeyPermissionsManagerServiceFactory::
        GetInstance()
            ->SetTestingFactory(
                GetProfile(),
                base::BindRepeating(
                    &platform_keys::
                        BuildFakeUserPrivateTokenKeyPermissionsManagerService,
                    key_permissions_manager_.get()));
    platform_keys::KeyPermissionsManagerImpl::
        SetSystemTokenKeyPermissionsManagerForTesting(
            key_permissions_manager_.get());

    // Only explicitly expected removals are allowed.
    EXPECT_CALL(*platform_keys_service_, RemoveCertificate).Times(0);
    EXPECT_CALL(*platform_keys_service_, RemoveKey).Times(0);
  }

  // Forward the mock time by `delta` in the smallest possible intervals.
  void FastForwardBy(base::TimeDelta delta) {
    task_environment_.FastForwardBy(delta);
  }

  // Jump the mock time by `delta`, then run all scheduled tasks that should
  // have started by then.
  //
  // This can be useful if scheduled tasks happen in sequence, each scheduling
  // the next task when executed, and a test is verifying the delays between
  // them.
  //
  // Example:
  // 1. Delayed Task TaskA has been posted to run in 10s.
  //    When executed, the task will post TaskB with a delay of 10s.
  // 2. FastForwardBy(base::Seconds(15))
  // Now TaskB is scheduled to run after 5 mock seconds.
  //
  // 1. Delayed Task TaskA has been posted to run in 10s.
  //    When executed, the task will post TaskB with a delay of 10s.
  // 2. AdvanceClockAndRunTasks(base::Seconds(15)
  // Now TaskB is scheduled to run after 10 mock seconds.
  void AdvanceClockAndRunTasks(base::TimeDelta delta) {
    task_environment_.AdvanceClock(delta);
    task_environment_.RunUntilIdle();
  }

  // Replaces next result of TpmChallengeKeySubtleFactory and return pointer to
  // the mock. The mock will injected into next created worker and will live
  // until worker's destruction. Should be called before creation of every
  // worker.
  MockTpmChallengeKeySubtle* PrepareTpmChallengeKey() {
    auto mock_tpm_challenge_key_subtle_impl =
        std::make_unique<MockTpmChallengeKeySubtle>();

    MockTpmChallengeKeySubtle* tpm_challenge_key_impl =
        mock_tpm_challenge_key_subtle_impl.get();

    attestation::TpmChallengeKeySubtleFactory::SetForTesting(
        std::move(mock_tpm_challenge_key_subtle_impl));

    CHECK(tpm_challenge_key_impl);
    return tpm_challenge_key_impl;
  }

  base::RepeatingClosure GetStateChangeCallback() {
    return base::BindRepeating(
        &StateChangeCallbackObserver ::StateChangeCallback,
        base::Unretained(&state_change_callback_observer_));
  }

  CertProvisioningWorkerCallback GetResultCallback() {
    return callback_observer_.GetCallback();
  }

  Profile* GetProfile() { return profile_helper_for_testing_.GetProfile(); }

  std::unique_ptr<MockCertProvisioningInvalidator> MakeInvalidator() {
    return std::make_unique<MockCertProvisioningInvalidator>();
  }

  std::unique_ptr<MockCertProvisioningInvalidator> MakeInvalidator(
      MockCertProvisioningInvalidator** mock_invalidator) {
    auto result = std::make_unique<MockCertProvisioningInvalidator>();
    *mock_invalidator = result.get();
    return result;
  }

  content::BrowserTaskEnvironment task_environment_{
      base::test::TaskEnvironment::TimeSource::MOCK_TIME};

  StrictMock<StateChangeCallbackObserver> state_change_callback_observer_;
  base::test::TestFuture<CertProfile, std::string, CertProvisioningWorkerState>
      callback_observer_;
  ProfileHelperForTesting profile_helper_for_testing_;
  TestingPrefServiceSimple testing_pref_service_;

  MockCertProvisioningClient cert_provisioning_client_;
  raw_ptr<platform_keys::MockPlatformKeysService> platform_keys_service_ =
      nullptr;
  std::unique_ptr<platform_keys::MockKeyPermissionsManager>
      key_permissions_manager_;
};

// Checks that the worker makes all necessary requests to other modules during
// success scenario.
TEST_F(CertProvisioningWorkerStaticTest, Success) {
  base::HistogramTester histogram_tester;

  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  const std::string listener_type = MakeInvalidationListenerType(process_id);
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, CertScope::kUser, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  MockCertProvisioningInvalidator* mock_invalidator = nullptr;
  CertProvisioningWorkerStatic worker(
      process_id, CertScope::kUser, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_,
      MakeInvalidator(&mock_invalidator), GetStateChangeCallback(),
      GetResultCallback());

  auto VerifyNoBackendErrorsSeen = [&worker]() {
    EXPECT_EQ(worker.GetLastBackendServerError(), std::nullopt);
  };
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));

    // kKeypairGenerated
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback())
        .WillOnce(VerifyNoBackendErrorsSeen);

    EXPECT_START_CSR_OK(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        em::HashingAlgorithm::SHA256);
    // kStartCsrResponseReceived
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback())
        .WillOnce(VerifyNoBackendErrorsSeen);

    EXPECT_CALL(*mock_invalidator,
                Register(kInvalidationTopic, listener_type, _))
        .Times(1);

    EXPECT_SIGN_CHALLENGE_OK(*mock_tpm_challenge_key,
                             StartSignChallengeStep(kChallenge,
                                                    /*callback=*/_));
    // kVaChallengeFinished
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback())
        .WillOnce(VerifyNoBackendErrorsSeen);

    EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);
    // kKeyRegistered
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback())
        .WillOnce(VerifyNoBackendErrorsSeen);

    EXPECT_CALL(*key_permissions_manager_,
                AllowKeyForUsage(/*callback=*/_, KeyUsage::kCorporate,
                                 GetPublicKeyBin()));

    EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(
        SetAttributeForKey(TokenId::kUser, GetPublicKeyBin(),
                           KeyAttributeType::kCertificateProvisioningId,
                           GetCertProfileIdBin(), _));
    // kKeypairMarked
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback())
        .WillOnce(VerifyNoBackendErrorsSeen);

    EXPECT_SIGN_RSAPKC1_DIGEST_OK(
        SignRsaPkcs1(::testing::Optional(TokenId::kUser), GetDataToSign(),
                     GetPublicKeyBin(), HashAlgorithm::HASH_ALGORITHM_SHA256,
                     /*callback=*/_));
    // kSignCsrFinished
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback())
        .WillOnce(VerifyNoBackendErrorsSeen);

    EXPECT_FINISH_CSR_OK(FinishCsr(Eq(std::ref(provisioning_process)),
                                   kChallengeResponse, GetSignatureStr(),
                                   /*callback=*/_));
    // State change to kFinishCsrResponseReceived.
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback())
        .WillOnce(VerifyNoBackendErrorsSeen);
    // Entering waiting mode.
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback())
        .WillOnce(VerifyNoBackendErrorsSeen);

    worker.DoStep();
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);
  }

  {
    testing::InSequence seq;

    EXPECT_DOWNLOAD_CERT_OK(
        DownloadCert(Eq(std::ref(provisioning_process)), /*callback=*/_),
        kFakeCertificate);

    EXPECT_IMPORT_CERTIFICATE_OK(ImportCertificate(TokenId::kUser,
                                                   /*certificate=*/_,
                                                   /*callback=*/_));
    // kSucceeded
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback())
        .WillOnce(VerifyNoBackendErrorsSeen);

    EXPECT_CALL(*mock_invalidator, Unregister()).Times(1);
  }

  FastForwardBy(kInitialDownloadCertDelay + kSmallDelay);

  EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
  EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
            CertProvisioningWorkerState::kSucceeded);

  histogram_tester.ExpectUniqueSample("ChromeOS.CertProvisioning.Result.User",
                                      CertProvisioningWorkerState::kSucceeded,
                                      1);
  histogram_tester.ExpectBucketCount(
      "ChromeOS.CertProvisioning.Event.User",
      CertProvisioningEvent::kRegisteredToInvalidationTopic, 1);
  histogram_tester.ExpectBucketCount(
      "ChromeOS.CertProvisioning.Event.User",
      CertProvisioningEvent::kWorkerRetrySucceededWithoutInvalidation, 1);
  histogram_tester.ExpectTotalCount(
      "ChromeOS.CertProvisioning.KeypairGenerationTime.User", 1);
  histogram_tester.ExpectTotalCount("ChromeOS.CertProvisioning.VaTime.User", 1);
  histogram_tester.ExpectTotalCount(
      "ChromeOS.CertProvisioning.CsrSignTime.User", 1);
}

// Checks that the worker makes all necessary requests to other modules during
// success scenario when VA challenge is not received.
TEST_F(CertProvisioningWorkerStaticTest, NoVaSuccess) {
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/false, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, CertScope::kUser, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  CertProvisioningWorkerStatic worker(
      process_id, CertScope::kUser, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  {
    testing::InSequence seq;

    EXPECT_CALL(*platform_keys_service_,
                GenerateRSAKey(TokenId::kUser, kNonVaKeyModulusLengthBits,
                               /*sw_backed=*/false,
                               /*callback=*/_))
        .Times(1)
        .WillOnce(RunOnceCallback<3>(GetPublicKeyBin(), Status::kSuccess));

    EXPECT_START_CSR_OK_WITHOUT_VA(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        em::HashingAlgorithm::SHA256);

    EXPECT_CALL(*key_permissions_manager_,
                AllowKeyForUsage(/*callback=*/_, KeyUsage::kCorporate,
                                 GetPublicKeyBin()));

    EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(
        SetAttributeForKey(TokenId::kUser, GetPublicKeyBin(),
                           KeyAttributeType::kCertificateProvisioningId,
                           GetCertProfileIdBin(), _));

    EXPECT_SIGN_RSAPKC1_DIGEST_OK(
        SignRsaPkcs1(::testing::Optional(TokenId::kUser), GetDataToSign(),
                     GetPublicKeyBin(), HashAlgorithm::HASH_ALGORITHM_SHA256,
                     /*callback=*/_));

    EXPECT_FINISH_CSR_OK(FinishCsr(Eq(std::ref(provisioning_process)),
                                   /*va_challenge_response=*/"",
                                   GetSignatureStr(),
                                   /*callback=*/_));

    worker.DoStep();
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);
  }

  {
    testing::InSequence seq;

    EXPECT_DOWNLOAD_CERT_OK(
        DownloadCert(Eq(std::ref(provisioning_process)), /*callback=*/_),
        kFakeCertificate);

    EXPECT_IMPORT_CERTIFICATE_OK(
        ImportCertificate(TokenId::kUser, /*certificate=*/_, /*callback=*/_));
  }

  FastForwardBy(kInitialDownloadCertDelay + kSmallDelay);

  EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
  EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
            CertProvisioningWorkerState::kSucceeded);
}

// Checks that the worker correctly forwards a request with
// hashing_algorithm=NO_HASH to platform_keys.
TEST_F(CertProvisioningWorkerStaticTest, NoHashInStartCsr) {
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  const std::string listener_type = MakeInvalidationListenerType(process_id);
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, CertScope::kUser, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  MockCertProvisioningInvalidator* mock_invalidator = nullptr;
  CertProvisioningWorkerStatic worker(
      process_id, CertScope::kUser, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_,
      MakeInvalidator(&mock_invalidator), GetStateChangeCallback(),
      GetResultCallback());

  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());

    EXPECT_START_CSR_OK(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        em::HashingAlgorithm::NO_HASH);
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());

    EXPECT_CALL(*mock_invalidator,
                Register(kInvalidationTopic, listener_type, _))
        .Times(1);

    EXPECT_SIGN_CHALLENGE_OK(*mock_tpm_challenge_key,
                             StartSignChallengeStep(kChallenge,
                                                    /*callback=*/_));
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());

    EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());

    EXPECT_CALL(*key_permissions_manager_,
                AllowKeyForUsage(/*callback=*/_, KeyUsage::kCorporate,
                                 GetPublicKeyBin()));

    EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(
        SetAttributeForKey(TokenId::kUser, GetPublicKeyBin(),
                           KeyAttributeType::kCertificateProvisioningId,
                           GetCertProfileIdBin(), _));
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());

    EXPECT_SIGN_RSAPKC1_RAW_OK(
        SignRSAPKCS1Raw(::testing::Optional(TokenId::kUser), GetDataToSign(),
                        GetPublicKeyBin(), /*callback=*/_));
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());

    EXPECT_FINISH_CSR_OK(FinishCsr(Eq(std::ref(provisioning_process)),
                                   kChallengeResponse, GetSignatureStr(),
                                   /*callback=*/_));
    // State change to kFinishCsrResponseReceived.
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());
    // Entering waiting mode.
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());

    worker.DoStep();
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);
  }

  {
    testing::InSequence seq;

    EXPECT_DOWNLOAD_CERT_OK(
        DownloadCert(Eq(std::ref(provisioning_process)), /*callback=*/_),
        kFakeCertificate);

    EXPECT_IMPORT_CERTIFICATE_OK(
        ImportCertificate(TokenId::kUser, /*certificate=*/_, /*callback=*/_));
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());

    EXPECT_CALL(*mock_invalidator, Unregister()).Times(1);
  }

  FastForwardBy(kInitialDownloadCertDelay + kSmallDelay);

  EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
  EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
            CertProvisioningWorkerState::kSucceeded);
}

TEST_F(CertProvisioningWorkerStaticTest, PublicKeyMismatch) {
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, CertScope::kUser, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  CertProvisioningWorkerStatic worker(
      process_id, CertScope::kUser, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));

    EXPECT_START_CSR_OK(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        em::HashingAlgorithm::NO_HASH);

    EXPECT_SIGN_CHALLENGE_OK(*mock_tpm_challenge_key,
                             StartSignChallengeStep(kChallenge,
                                                    /*callback=*/_));

    EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);

    EXPECT_CALL(*key_permissions_manager_,
                AllowKeyForUsage(/*callback=*/_, KeyUsage::kCorporate,
                                 GetPublicKeyBin()));

    EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(
        SetAttributeForKey(TokenId::kUser, GetPublicKeyBin(),
                           KeyAttributeType::kCertificateProvisioningId,
                           GetCertProfileIdBin(), _));

    EXPECT_SIGN_RSAPKC1_RAW_OK(
        SignRSAPKCS1Raw(::testing::Optional(TokenId::kUser), GetDataToSign(),
                        GetPublicKeyBin(), /*callback=*/_));

    EXPECT_FINISH_CSR_OK(FinishCsr(Eq(std::ref(provisioning_process)),
                                   kChallengeResponse, GetSignatureStr(),
                                   /*callback=*/_));

    worker.DoStep();
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);
  }

  {
    testing::InSequence seq;

    EXPECT_DOWNLOAD_CERT_OK(
        DownloadCert(Eq(std::ref(provisioning_process)), /*callback=*/_),
        kFakeCertificatePubKeyMismatch);

    EXPECT_CALL(
        *platform_keys_service_,
        RemoveKey(TokenId::kUser,
                  /*public_key_spki_der=*/GetPublicKeyBin(), /*callback=*/_))
        .Times(1)
        .WillOnce(RunOnceCallback<2>(Status::kSuccess));
  }

  FastForwardBy(kInitialDownloadCertDelay + kSmallDelay);

  EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
  EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
            CertProvisioningWorkerState::kFailed);
}

// Checks that when the server returns try_again_later field, the worker will
// retry a request when it asked to continue the provisioning.
TEST_F(CertProvisioningWorkerStaticTest, TryLaterManualRetry) {
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, CertScope::kDevice, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  CertProvisioningWorkerStatic worker(
      process_id, CertScope::kDevice, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());
  const base::TimeDelta delay = base::Seconds(30);

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(
        *mock_tpm_challenge_key,
        StartPrepareKeyStep(::attestation::ENTERPRISE_MACHINE,
                            /*will_register_key=*/true,
                            ::attestation::KEY_TYPE_RSA,
                            /*key_name=*/GetKeyName(kCertProfileId),
                            /*profile=*/_,
                            /*callback=*/_, /*signals=*/_));

    EXPECT_START_CSR_TRY_LATER(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        delay.InMilliseconds());

    worker.DoStep();
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kKeypairGenerated);
  }

  {
    testing::InSequence seq;

    EXPECT_START_CSR_OK(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        em::HashingAlgorithm::SHA256);

    EXPECT_SIGN_CHALLENGE_OK(*mock_tpm_challenge_key,
                             StartSignChallengeStep(kChallenge,
                                                    /*callback=*/_));

    EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);

    EXPECT_CALL(*key_permissions_manager_,
                AllowKeyForUsage(/*callback=*/_, KeyUsage::kCorporate,
                                 GetPublicKeyBin()));

    EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(
        SetAttributeForKey(TokenId::kSystem, GetPublicKeyBin(),
                           KeyAttributeType::kCertificateProvisioningId,
                           GetCertProfileIdBin(), _));

    EXPECT_SIGN_RSAPKC1_DIGEST_OK(SignRsaPkcs1);

    EXPECT_FINISH_CSR_TRY_LATER(FinishCsr(Eq(std::ref(provisioning_process)),
                                          kChallengeResponse, GetSignatureStr(),
                                          /*callback=*/_),
                                delay.InMilliseconds());

    worker.DoStep();
    EXPECT_EQ(worker.GetState(), CertProvisioningWorkerState::kSignCsrFinished);
  }

  {
    testing::InSequence seq;

    EXPECT_FINISH_CSR_OK(FinishCsr(Eq(std::ref(provisioning_process)),
                                   kChallengeResponse, GetSignatureStr(),
                                   /*callback=*/_));

    worker.DoStep();
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);
  }

  {
    testing::InSequence seq;

    EXPECT_DOWNLOAD_CERT_TRY_LATER(
        DownloadCert(Eq(std::ref(provisioning_process)), /*callback=*/_),
        delay.InMilliseconds());

    worker.DoStep();
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);
  }

  {
    testing::InSequence seq;

    EXPECT_DOWNLOAD_CERT_OK(
        DownloadCert(Eq(std::ref(provisioning_process)), /*callback=*/_),
        kFakeCertificate);

    EXPECT_IMPORT_CERTIFICATE_OK(
        ImportCertificate(TokenId::kSystem, /*certificate=*/_, /*callback=*/_));

    worker.DoStep();
    EXPECT_EQ(worker.GetState(), CertProvisioningWorkerState::kSucceeded);

    EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
    EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
              CertProvisioningWorkerState::kSucceeded);
  }
}

// Checks that when the server returns try_again_later field, the worker will
// automatically retry a request after some time.
TEST_F(CertProvisioningWorkerStaticTest, TryLaterWait) {
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, CertScope::kUser, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  CertProvisioningWorkerStatic worker(
      process_id, CertScope::kUser, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());

  const base::TimeDelta start_csr_delay = base::Seconds(30);
  const base::TimeDelta finish_csr_delay = base::Seconds(30);
  const base::TimeDelta download_cert_server_delay = base::Milliseconds(100);
  // The minimum "try_later" delay is 10 seconds.
  const base::TimeDelta download_cert_real_delay = base::Seconds(10);
  const base::TimeDelta small_delay = base::Milliseconds(500);

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));

    EXPECT_START_CSR_TRY_LATER(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        start_csr_delay.InMilliseconds());

    worker.DoStep();
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kKeypairGenerated);
  }

  {
    testing::InSequence seq;

    EXPECT_START_CSR_OK(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        em::HashingAlgorithm::SHA256);

    EXPECT_SIGN_CHALLENGE_OK(*mock_tpm_challenge_key,
                             StartSignChallengeStep(kChallenge,
                                                    /*callback=*/_));

    EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);

    EXPECT_CALL(*key_permissions_manager_,
                AllowKeyForUsage(/*callback=*/_, KeyUsage::kCorporate,
                                 GetPublicKeyBin()));

    EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(
        SetAttributeForKey(TokenId::kUser, GetPublicKeyBin(),
                           KeyAttributeType::kCertificateProvisioningId,
                           GetCertProfileIdBin(), _));

    EXPECT_SIGN_RSAPKC1_DIGEST_OK(
        SignRsaPkcs1(::testing::Optional(TokenId::kUser), GetDataToSign(),
                     GetPublicKeyBin(), HashAlgorithm::HASH_ALGORITHM_SHA256,
                     /*callback=*/_));

    EXPECT_FINISH_CSR_TRY_LATER(FinishCsr(Eq(std::ref(provisioning_process)),
                                          kChallengeResponse, GetSignatureStr(),
                                          /*callback=*/_),
                                finish_csr_delay.InMilliseconds());

    FastForwardBy(start_csr_delay + small_delay);
    EXPECT_EQ(worker.GetState(), CertProvisioningWorkerState::kSignCsrFinished);
  }

  {
    testing::InSequence seq;

    EXPECT_FINISH_CSR_OK(FinishCsr(Eq(std::ref(provisioning_process)),
                                   kChallengeResponse, GetSignatureStr(),
                                   /*callback=*/_));

    FastForwardBy(finish_csr_delay + small_delay);
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);
  }

  {
    testing::InSequence seq;
    EXPECT_DOWNLOAD_CERT_TRY_LATER(
        DownloadCert(Eq(std::ref(provisioning_process)), /*callback=*/_),
        download_cert_server_delay.InMilliseconds());

    FastForwardBy(kInitialDownloadCertDelay + kSmallDelay);

    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);
  }

  {
    testing::InSequence seq;

    EXPECT_DOWNLOAD_CERT_OK(DownloadCert, kFakeCertificate);

    EXPECT_IMPORT_CERTIFICATE_OK(
        ImportCertificate(TokenId::kUser, /*certificate=*/_, /*callback=*/_));

    FastForwardBy(small_delay);
    // Check that minimum wait time is not too small even if the server
    // has responded with a small one.
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);

    FastForwardBy(download_cert_real_delay + small_delay);
    EXPECT_EQ(worker.GetState(), CertProvisioningWorkerState::kSucceeded);

    EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
    EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
              CertProvisioningWorkerState::kSucceeded);
  }
}

// Checks that when the device management server returns a
// DM_STATUS_SERVICE_ACTIVATION_PENDING status error (which is 412 pending
// approval) the server retries the request after the expected delay depending
// on the request.
TEST_F(CertProvisioningWorkerStaticTest, ServiceActivationPendingResponse) {
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, CertScope::kUser, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  CertProvisioningWorkerStatic worker(
      process_id, CertScope::kUser, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());

  const base::TimeDelta kExpectedStartCsrDelay = base::Hours(1);
  const base::TimeDelta kExpectedFinishCsrDelay = base::Hours(1);

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));

    EXPECT_START_CSR_SERVICE_ACTIVATION_PENDING(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_));

    worker.DoStep();
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kKeypairGenerated);
  }

  {
    testing::InSequence seq;

    // Verify that nothing happens until right before the expected StartCsr
    // delay.
    AdvanceClockAndRunTasks(kExpectedStartCsrDelay - kSmallDelay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);

    EXPECT_START_CSR_OK(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        em::HashingAlgorithm::SHA256);

    EXPECT_SIGN_CHALLENGE_OK(*mock_tpm_challenge_key,
                             StartSignChallengeStep(kChallenge,
                                                    /*callback=*/_));

    EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);

    EXPECT_CALL(*key_permissions_manager_,
                AllowKeyForUsage(/*callback=*/_, KeyUsage::kCorporate,
                                 GetPublicKeyBin()));

    EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(
        SetAttributeForKey(TokenId::kUser, GetPublicKeyBin(),
                           KeyAttributeType::kCertificateProvisioningId,
                           GetCertProfileIdBin(), _));

    EXPECT_SIGN_RSAPKC1_DIGEST_OK(
        SignRsaPkcs1(::testing::Optional(TokenId::kUser), GetDataToSign(),
                     GetPublicKeyBin(), HashAlgorithm::HASH_ALGORITHM_SHA256,
                     /*callback=*/_));

    EXPECT_FINISH_CSR_SERVICE_ACTIVATION_PENDING(
        FinishCsr(Eq(std::ref(provisioning_process)), kChallengeResponse,
                  GetSignatureStr(), /*callback=*/_));

    AdvanceClockAndRunTasks(2 * kSmallDelay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);
    EXPECT_EQ(worker.GetState(), CertProvisioningWorkerState::kSignCsrFinished);
  }

  {
    testing::InSequence seq;

    // Verify that nothing happens until right before of the expected FinishCsr
    // delay.
    AdvanceClockAndRunTasks(kExpectedFinishCsrDelay - kSmallDelay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);

    EXPECT_FINISH_CSR_OK(FinishCsr(Eq(std::ref(provisioning_process)),
                                   kChallengeResponse, GetSignatureStr(),
                                   /*callback=*/_));

    AdvanceClockAndRunTasks(2 * kSmallDelay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);
  }

  {
    testing::InSequence seq;

    // Verify that nothing happens until the initial DownloadCert delay passes.
    AdvanceClockAndRunTasks(kInitialDownloadCertDelay - kSmallDelay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);

    EXPECT_DOWNLOAD_CERT_SERVICE_ACTIVATION_PENDING(
        DownloadCert(Eq(std::ref(provisioning_process)), /*callback=*/_));

    AdvanceClockAndRunTasks(2 * kSmallDelay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);
  }

  // Backoff round 1
  // The DownloadCert backoff policy has a jitter of 10%. Jitter is always
  // subtracted.
  const double kEffectiveJitterFactor = 0.9;

  {
    const base::TimeDelta kBackoff1MaxDelay = kInitialDownloadCertDelay * 4;
    const base::TimeDelta kBackoff1MinDelay =
        kBackoff1MaxDelay * kEffectiveJitterFactor;

    testing::InSequence seq;

    // Verify that nothing happens until the backoff time is reached.
    AdvanceClockAndRunTasks(kBackoff1MinDelay - kSmallDelay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);

    EXPECT_DOWNLOAD_CERT_SERVICE_ACTIVATION_PENDING(
        DownloadCert(Eq(std::ref(provisioning_process)), /*callback=*/_));
    AdvanceClockAndRunTasks(kBackoff1MaxDelay - kBackoff1MinDelay +
                            2 * kSmallDelay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);
  }

  // Backoff round 2
  {
    const base::TimeDelta kBackoff2MaxDelay = kInitialDownloadCertDelay * 16;
    const base::TimeDelta kBackoff2MinDelay =
        kBackoff2MaxDelay * kEffectiveJitterFactor;

    testing::InSequence seq;

    // Verify that nothing happens until the backoff time is reached.
    AdvanceClockAndRunTasks(kBackoff2MinDelay - kSmallDelay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);

    EXPECT_DOWNLOAD_CERT_SERVICE_ACTIVATION_PENDING(
        DownloadCert(Eq(std::ref(provisioning_process)), /*callback=*/_));
    AdvanceClockAndRunTasks(kBackoff2MaxDelay - kBackoff2MinDelay +
                            2 * kSmallDelay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);
  }

  // Backoff round 3
  {
    // This value is set manually, so no jitter is applied.
    const base::TimeDelta kBackoff3Delay = base::Hours(8);

    testing::InSequence seq;

    // Verify that nothing happens until the backoff time is reached.
    AdvanceClockAndRunTasks(kBackoff3Delay - kSmallDelay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);

    EXPECT_DOWNLOAD_CERT_OK(DownloadCert, kFakeCertificate);

    EXPECT_IMPORT_CERTIFICATE_OK(
        ImportCertificate(TokenId::kUser, /*certificate=*/_, /*callback=*/_));

    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);

    AdvanceClockAndRunTasks(2 * kSmallDelay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);
    EXPECT_EQ(worker.GetState(), CertProvisioningWorkerState::kSucceeded);

    EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
    EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
              CertProvisioningWorkerState::kSucceeded);
  }
}

// Test that with kCertProvisioningUseOnlyInvalidationsForTesting feature flag
// enabled the worker only progresses when it receives an invalidation.
TEST_F(CertProvisioningWorkerStaticTest, TryLaterWaitForInvalidation) {
  base::test::ScopedFeatureList scoped_feature_list{
      kCertProvisioningUseOnlyInvalidationsForTesting};

  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, CertScope::kUser, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  CertProvisioningWorkerStatic worker(
      process_id, CertScope::kUser, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());

  const base::TimeDelta very_long_delay = base::Days(7);

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));

    EXPECT_START_CSR_SERVICE_ACTIVATION_PENDING(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_));

    worker.DoStep();
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kKeypairGenerated);
  }

  {
    testing::InSequence seq;

    // Verify that the worker doesn't progress without an invalidation.
    AdvanceClockAndRunTasks(very_long_delay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);

    EXPECT_START_CSR_OK(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        em::HashingAlgorithm::SHA256);

    EXPECT_SIGN_CHALLENGE_OK(*mock_tpm_challenge_key,
                             StartSignChallengeStep(kChallenge,
                                                    /*callback=*/_));

    EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);

    EXPECT_CALL(*key_permissions_manager_,
                AllowKeyForUsage(/*callback=*/_, KeyUsage::kCorporate,
                                 GetPublicKeyBin()));

    EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(
        SetAttributeForKey(TokenId::kUser, GetPublicKeyBin(),
                           KeyAttributeType::kCertificateProvisioningId,
                           GetCertProfileIdBin(), _));

    EXPECT_SIGN_RSAPKC1_DIGEST_OK(
        SignRsaPkcs1(::testing::Optional(TokenId::kUser), GetDataToSign(),
                     GetPublicKeyBin(), HashAlgorithm::HASH_ALGORITHM_SHA256,
                     /*callback=*/_));

    EXPECT_FINISH_CSR_SERVICE_ACTIVATION_PENDING(
        FinishCsr(Eq(std::ref(provisioning_process)), kChallengeResponse,
                  GetSignatureStr(), /*callback=*/_));

    // Emulate an invalidation.
    worker.DoStep();
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);
    EXPECT_EQ(worker.GetState(), CertProvisioningWorkerState::kSignCsrFinished);
  }

  {
    testing::InSequence seq;

    // Verify that the worker doesn't progress without an invalidation.
    AdvanceClockAndRunTasks(very_long_delay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);

    EXPECT_FINISH_CSR_OK(FinishCsr(Eq(std::ref(provisioning_process)),
                                   kChallengeResponse, GetSignatureStr(),
                                   /*callback=*/_));

    // Emulate an invalidation.
    worker.DoStep();
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);
  }

  {
    testing::InSequence seq;

    // Verify that the worker doesn't progress without an invalidation.
    AdvanceClockAndRunTasks(very_long_delay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);

    EXPECT_DOWNLOAD_CERT_SERVICE_ACTIVATION_PENDING(
        DownloadCert(Eq(std::ref(provisioning_process)), /*callback=*/_));

    // Emulate an invalidation.
    worker.DoStep();
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);
  }

  {
    testing::InSequence seq;

    // Verify that the worker doesn't progress without an invalidation.
    AdvanceClockAndRunTasks(very_long_delay);
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);

    EXPECT_DOWNLOAD_CERT_OK(DownloadCert, kFakeCertificate);

    EXPECT_IMPORT_CERTIFICATE_OK(
        ImportCertificate(TokenId::kUser, /*certificate=*/_, /*callback=*/_));

    // Emulate an invalidation.
    worker.DoStep();
    Mock::VerifyAndClearExpectations(&cert_provisioning_client_);
    EXPECT_EQ(worker.GetState(), CertProvisioningWorkerState::kSucceeded);
  }
}

// Checks that when the server returns try_again_later field, the worker will
// retry when successfully subscribed for the invalidation or when the
// invalidation is triggered.
TEST_F(CertProvisioningWorkerStaticTest, InvalidationRespected) {
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  const std::string listener_type = MakeInvalidationListenerType(process_id);
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, CertScope::kUser, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  MockCertProvisioningInvalidator* mock_invalidator = nullptr;
  CertProvisioningWorkerStatic worker(
      process_id, CertScope::kUser, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_,
      MakeInvalidator(&mock_invalidator), GetStateChangeCallback(),
      GetResultCallback());

  const base::TimeDelta start_csr_delay = base::Seconds(30);
  const base::TimeDelta finish_csr_delay = base::Seconds(30);
  const base::TimeDelta download_cert_server_delay = base::Milliseconds(100);
  const base::TimeDelta small_delay = base::Milliseconds(500);

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));

    EXPECT_START_CSR_TRY_LATER(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        start_csr_delay.InMilliseconds());

    worker.DoStep();
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kKeypairGenerated);
  }

  OnInvalidationEventCallback on_invalidation_event_callback;
  {
    testing::InSequence seq;

    EXPECT_START_CSR_OK(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        em::HashingAlgorithm::SHA256);
    EXPECT_CALL(*mock_invalidator,
                Register(kInvalidationTopic, listener_type, _))
        .WillOnce(SaveArg<2>(&on_invalidation_event_callback));

    EXPECT_SIGN_CHALLENGE_OK(*mock_tpm_challenge_key,
                             StartSignChallengeStep(kChallenge,
                                                    /*callback=*/_));

    EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);

    EXPECT_CALL(*key_permissions_manager_,
                AllowKeyForUsage(/*callback=*/_, KeyUsage::kCorporate,
                                 GetPublicKeyBin()));

    EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(
        SetAttributeForKey(TokenId::kUser, GetPublicKeyBin(),
                           KeyAttributeType::kCertificateProvisioningId,
                           GetCertProfileIdBin(), _));

    EXPECT_SIGN_RSAPKC1_DIGEST_OK(
        SignRsaPkcs1(::testing::Optional(TokenId::kUser), GetDataToSign(),
                     GetPublicKeyBin(), HashAlgorithm::HASH_ALGORITHM_SHA256,
                     /*callback=*/_));

    EXPECT_FINISH_CSR_TRY_LATER(FinishCsr(Eq(std::ref(provisioning_process)),
                                          kChallengeResponse, GetSignatureStr(),
                                          /*callback=*/_),
                                finish_csr_delay.InMilliseconds());

    FastForwardBy(start_csr_delay + small_delay);
    EXPECT_EQ(worker.GetState(), CertProvisioningWorkerState::kSignCsrFinished);
  }

  {
    testing::InSequence seq;

    EXPECT_FINISH_CSR_OK(FinishCsr(Eq(std::ref(provisioning_process)),
                                   kChallengeResponse, GetSignatureStr(),
                                   /*callback=*/_));
    EXPECT_DOWNLOAD_CERT_TRY_LATER(
        DownloadCert(Eq(std::ref(provisioning_process)), /*callback=*/_),
        download_cert_server_delay.InMilliseconds());

    FastForwardBy(finish_csr_delay + small_delay);
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);
  }

  {
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);

    on_invalidation_event_callback.Run(
        InvalidationEvent::kSuccessfullySubscribed);
  }

  {
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kFinishCsrResponseReceived);

    testing::InSequence seq;

    EXPECT_DOWNLOAD_CERT_OK(DownloadCert, kFakeCertificate);

    EXPECT_IMPORT_CERTIFICATE_OK(
        ImportCertificate(TokenId::kUser, /*certificate=*/_, /*callback=*/_));

    EXPECT_CALL(*mock_invalidator, Unregister()).Times(1);

    on_invalidation_event_callback.Run(
        InvalidationEvent::kInvalidationReceived);
    FastForwardBy(small_delay);

    EXPECT_EQ(worker.GetState(), CertProvisioningWorkerState::kSucceeded);

    EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
    EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
              CertProvisioningWorkerState::kSucceeded);
  }
}

// Checks that when the server returns error status, the worker will enter an
// error state and stop the provisioning.
TEST_F(CertProvisioningWorkerStaticTest, StatusErrorHandling) {
  const CertScope kCertScope = CertScope::kUser;
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, CertScope::kUser, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  CertProvisioningWorkerStatic worker(
      process_id, kCertScope, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));

    EXPECT_START_CSR_INVALID_REQUEST(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_));
  }

  worker.DoStep();
  FastForwardBy(base::Seconds(1));

  VerifyDeleteKeyCalledOnce(kCertScope);

  EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
  EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
            CertProvisioningWorkerState::kFailed);
}

// Checks that when the server returns response error, the worker will enter an
// error state and stop the provisioning. Also check factory.
TEST_F(CertProvisioningWorkerStaticTest, ResponseErrorHandling) {
  const CertScope kCertScope = CertScope::kUser;
  base::HistogramTester histogram_tester;

  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, kCertScope, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  auto worker = CertProvisioningWorkerFactory::Get()->Create(
      process_id, kCertScope, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));

    EXPECT_START_CSR_CA_ERROR(StartCsr);
  }

  worker->DoStep();
  FastForwardBy(base::Seconds(1));

  VerifyDeleteKeyCalledOnce(kCertScope);

  EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
  EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
            CertProvisioningWorkerState::kFailed);

  histogram_tester.ExpectBucketCount("ChromeOS.CertProvisioning.Result.User",
                                     CertProvisioningWorkerState::kFailed, 1);
  histogram_tester.ExpectBucketCount(
      "ChromeOS.CertProvisioning.Result.User",
      CertProvisioningWorkerState::kKeypairGenerated, 1);
  histogram_tester.ExpectTotalCount("ChromeOS.CertProvisioning.Result.User", 2);
}

TEST_F(CertProvisioningWorkerStaticTest, InconsistentDataErrorHandling) {
  const CertScope kCertScope = CertScope::kUser;
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  auto worker = CertProvisioningWorkerFactory::Get()->Create(
      process_id, kCertScope, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));

    EXPECT_START_CSR_INCONSISTENT_DATA(StartCsr);
  }

  worker->DoStep();
  FastForwardBy(base::Seconds(1));

  VerifyDeleteKeyCalledOnce(kCertScope);

  EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
  EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
            CertProvisioningWorkerState::kInconsistentDataError);
}

// Checks that when the server returns TEMPORARY_UNAVAILABLE status code, the
// worker will automatically retry a request using exponential backoff strategy.
TEST_F(CertProvisioningWorkerStaticTest, BackoffStrategy) {
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);

  const std::string process_id = GenerateCertProvisioningId();
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, CertScope::kUser, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  CertProvisioningWorkerStatic worker(
      process_id, CertScope::kUser, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());

  base::TimeDelta next_delay = base::Seconds(30);
  const base::TimeDelta small_delay = base::Milliseconds(500);

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));

    EXPECT_START_CSR_TEMPORARY_UNAVAILABLE(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_));
    worker.DoStep();
  }

  Mock::VerifyAndClearExpectations(&cert_provisioning_client_);

  {
    EXPECT_START_CSR_TEMPORARY_UNAVAILABLE(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_));
    FastForwardBy(next_delay + small_delay * 10);
    next_delay *= 2;
  }

  Mock::VerifyAndClearExpectations(&cert_provisioning_client_);

  {
    EXPECT_START_CSR_TEMPORARY_UNAVAILABLE(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_));
    FastForwardBy(next_delay + small_delay * 10);
    next_delay *= 2;
  }

  Mock::VerifyAndClearExpectations(&cert_provisioning_client_);

  {
    EXPECT_START_CSR_TEMPORARY_UNAVAILABLE(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_));
    FastForwardBy(next_delay + small_delay);
    next_delay *= 2;
  }
}

// Checks that when the server returns TEMPORARY_UNAVAILABLE status code, the
// worker will update its BackendServerError attribute.
TEST_F(CertProvisioningWorkerStaticTest, ProcessBackendServerErrorResponse) {
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, CertScope::kUser, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  CertProvisioningWorkerStatic worker(
      process_id, CertScope::kUser, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());

  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback());

    EXPECT_START_CSR_TEMPORARY_UNAVAILABLE(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_));
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback())
        .WillOnce([&worker]() {
          EXPECT_THAT(worker.GetLastBackendServerError(),
                      testing::Ne(std::nullopt));
        });
    worker.DoStep();
  }

  Mock::VerifyAndClearExpectations(&cert_provisioning_client_);

  {
    testing::InSequence seq;
    EXPECT_START_CSR_OK(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        em::HashingAlgorithm::SHA256);
    EXPECT_CALL(state_change_callback_observer_, StateChangeCallback())
        .WillOnce([&worker]() {
          EXPECT_THAT(worker.GetLastBackendServerError(),
                      testing::Eq(std::nullopt));
        });
    worker.DoStep();
  }
}

// Checks that when a success scenario happens, the backend server error is
// cleared.
TEST_F(CertProvisioningWorkerStaticTest, ClearBackendServerError) {
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, CertScope::kUser, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  CertProvisioningWorkerStatic worker(
      process_id, CertScope::kUser, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));

    EXPECT_START_CSR_OK(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        em::HashingAlgorithm::SHA256);
    worker.DoStep();
  }

  Mock::VerifyAndClearExpectations(&cert_provisioning_client_);
  EXPECT_THAT(worker.GetLastBackendServerError(), testing::Eq(std::nullopt));
}
// Checks that the worker removes a key when an error occurs after the key was
// registered.
TEST_F(CertProvisioningWorkerStaticTest, RemoveRegisteredKey) {
  base::HistogramTester histogram_tester;

  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  const std::string listener_type = MakeInvalidationListenerType(process_id);
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, CertScope::kUser, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  MockCertProvisioningInvalidator* mock_invalidator = nullptr;
  CertProvisioningWorkerStatic worker(
      process_id, CertScope::kUser, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_,
      MakeInvalidator(&mock_invalidator), GetStateChangeCallback(),
      GetResultCallback());

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));

    EXPECT_START_CSR_OK(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        em::HashingAlgorithm::SHA256);

    EXPECT_CALL(*mock_invalidator,
                Register(kInvalidationTopic, listener_type, _))
        .Times(1);

    EXPECT_SIGN_CHALLENGE_OK(*mock_tpm_challenge_key,
                             StartSignChallengeStep(kChallenge,
                                                    /*callback=*/_));

    EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);

    EXPECT_CALL(*key_permissions_manager_,
                AllowKeyForUsage(/*callback=*/_, KeyUsage::kCorporate,
                                 GetPublicKeyBin()));

    EXPECT_SET_ATTRIBUTE_FOR_KEY_FAIL(
        SetAttributeForKey(TokenId::kUser, GetPublicKeyBin(),
                           KeyAttributeType::kCertificateProvisioningId,
                           GetCertProfileIdBin(), _));

    EXPECT_CALL(*mock_invalidator, Unregister()).Times(1);

    EXPECT_CALL(
        *platform_keys_service_,
        RemoveKey(TokenId::kUser,
                  /*public_key_spki_der=*/GetPublicKeyBin(), /*callback=*/_))
        .Times(1)
        .WillOnce(RunOnceCallback<2>(Status::kSuccess));
  }

  worker.DoStep();
  FastForwardBy(base::Seconds(1));

  EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
  EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
            CertProvisioningWorkerState::kFailed);

  histogram_tester.ExpectBucketCount("ChromeOS.CertProvisioning.Result.User",
                                     CertProvisioningWorkerState::kFailed, 1);
  histogram_tester.ExpectBucketCount(
      "ChromeOS.CertProvisioning.Result.User",
      CertProvisioningWorkerState::kKeyRegistered, 1);
  histogram_tester.ExpectTotalCount("ChromeOS.CertProvisioning.Result.User", 2);
}
// Checks that the worker reset flag is raised once it is marked for reset.
TEST_F(CertProvisioningWorkerStaticTest, ResetWorker) {
  const CertScope kCertScope = CertScope::kDevice;
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();

  auto worker = CertProvisioningWorkerFactory::Get()->Create(
      process_id, kCertScope, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());

  worker->MarkWorkerForReset();
  ASSERT_EQ(worker->IsWorkerMarkedForReset(), true);
}

class PrefServiceObserver {
 public:
  PrefServiceObserver(PrefService* service, const char* pref_name)
      : service_(service), pref_name_(pref_name) {
    pref_change_registrar_.Init(service);
    pref_change_registrar_.Add(
        pref_name, base::BindRepeating(&PrefServiceObserver::OnPrefsChange,
                                       weak_factory_.GetWeakPtr()));
  }

  void OnPrefsChange() {
    const base::Value& pref_value = service_->GetValue(pref_name_);
    OnPrefValueUpdated(pref_value);
  }

  // Allows to add expectations about preference changes and verify new values.
  MOCK_METHOD(void, OnPrefValueUpdated, (const base::Value& value));

 private:
  raw_ptr<PrefService> service_ = nullptr;
  const char* pref_name_ = nullptr;
  PrefChangeRegistrar pref_change_registrar_;
  base::WeakPtrFactory<PrefServiceObserver> weak_factory_{this};
};

TEST_F(CertProvisioningWorkerStaticTest, SerializationSuccess) {
  const base::TimeDelta kRenewalPeriod = base::Seconds(1200300);
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kRenewalPeriod, ProtocolVersion::kStatic);
  const CertScope kCertScope = CertScope::kUser;
  const std::string process_id = GenerateCertProvisioningId();
  const std::string listener_type = MakeInvalidationListenerType(process_id);
  const CertProvisioningClient::ProvisioningProcess provisioning_process(
      process_id, kCertScope, kCertProfileId, kCertProfileVersion,
      GetPublicKeyBin());

  std::unique_ptr<MockCertProvisioningInvalidator> mock_invalidator_obj;
  MockCertProvisioningInvalidator* mock_invalidator = nullptr;

  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  std::unique_ptr<CertProvisioningWorker> worker =
      CertProvisioningWorkerFactory::Get()->Create(
          process_id, kCertScope, GetProfile(), &testing_pref_service_,
          cert_profile, &cert_provisioning_client_, MakeInvalidator(),
          GetStateChangeCallback(), GetResultCallback());

  StrictMock<PrefServiceObserver> pref_observer(
      &testing_pref_service_, GetPrefNameForSerialization(kCertScope));
  base::Value::Dict pref_val;

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));

  // Prepare key, send start csr request.
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));

    pref_val = ParseJsonDict(base::StringPrintf(
        R"({
          "cert_profile_1": {
            "cert_profile": {
              "policy_version": "cert_profile_version_1",
              "name": "Certificate Profile 1",
              "profile_id": "cert_profile_1",
              "va_enabled": true,
              "renewal_period": 1200300
            },
            "cert_scope": 0,
            "invalidation_topic": "",
            "process_id": "%s",
            "public_key": "%s",
            "state": 1
          }
        })",
        process_id.c_str(), kPublicKeyBase64));
    EXPECT_CALL(pref_observer, OnPrefValueUpdated(IsJson(pref_val))).Times(1);

    EXPECT_START_CSR_NO_OP(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_));

    worker->DoStep();
  }

  // Recreate worker.
  {
    testing::InSequence seq;

    mock_tpm_challenge_key = PrepareTpmChallengeKey();

    EXPECT_CALL(*mock_tpm_challenge_key,
                RestorePreparedKeyState(
                    ::attestation::ENTERPRISE_USER,
                    /*will_register_key=*/true, ::attestation::KEY_TYPE_RSA,
                    GetKeyName(kCertProfileId), GetPublicKey(), /*profile=*/_))
        .Times(1);

    worker = CertProvisioningWorkerFactory::Get()->Deserialize(
        kCertScope, GetProfile(), &testing_pref_service_,
        *pref_val.FindDict(kCertProfileId), &cert_provisioning_client_,
        MakeInvalidator(&mock_invalidator), GetStateChangeCallback(),
        GetResultCallback());
  }

  // Retry start csr request, receive response, try sign challenge.
  {
    testing::InSequence seq;

    EXPECT_START_CSR_OK(
        StartCsr(Eq(std::ref(provisioning_process)), /*callback=*/_),
        em::HashingAlgorithm::SHA256);

    pref_val = ParseJsonDict("{}");
    EXPECT_CALL(pref_observer, OnPrefValueUpdated(IsJson(pref_val))).Times(1);

    EXPECT_CALL(*mock_invalidator,
                Register(kInvalidationTopic, listener_type, _))
        .Times(1);

    EXPECT_SIGN_CHALLENGE_OK(*mock_tpm_challenge_key,
                             StartSignChallengeStep(kChallenge,
                                                    /*callback=*/_));

    EXPECT_REGISTER_KEY_OK(*mock_tpm_challenge_key, StartRegisterKeyStep);

    EXPECT_CALL(*key_permissions_manager_,
                AllowKeyForUsage(/*callback=*/_, KeyUsage::kCorporate,
                                 GetPublicKeyBin()));

    EXPECT_SET_ATTRIBUTE_FOR_KEY_OK(
        SetAttributeForKey(TokenId::kUser, GetPublicKeyBin(),
                           KeyAttributeType::kCertificateProvisioningId,
                           GetCertProfileIdBin(), _));

    EXPECT_SIGN_RSAPKC1_DIGEST_OK(
        SignRsaPkcs1(::testing::Optional(TokenId::kUser), GetDataToSign(),
                     GetPublicKeyBin(), HashAlgorithm::HASH_ALGORITHM_SHA256,
                     /*callback=*/_));

    EXPECT_FINISH_CSR_OK(FinishCsr(Eq(std::ref(provisioning_process)),
                                   kChallengeResponse, GetSignatureStr(),
                                   /*callback=*/_));

    pref_val = ParseJsonDict(base::StringPrintf(
        R"({
          "cert_profile_1": {
            "cert_profile": {
              "policy_version": "cert_profile_version_1",
              "name": "Certificate Profile 1",
              "profile_id": "cert_profile_1",
              "va_enabled": true,
              "renewal_period": 1200300
            },
            "cert_scope": 0,
            "invalidation_topic": "fake_invalidation_topic_1",
            "process_id": "%s",
            "public_key": "%s",
            "state": 7
          }
        })",
        process_id.c_str(), kPublicKeyBase64));
    EXPECT_CALL(pref_observer, OnPrefValueUpdated(IsJson(pref_val))).Times(1);

    worker->DoStep();
  }

  // Recreate worker.
  {
    testing::InSequence seq;

    mock_invalidator_obj = MakeInvalidator(&mock_invalidator);
    EXPECT_CALL(*mock_invalidator,
                Register(kInvalidationTopic, listener_type, _))
        .Times(1);

    mock_tpm_challenge_key = PrepareTpmChallengeKey();
    EXPECT_CALL(*mock_tpm_challenge_key,
                RestorePreparedKeyState(
                    ::attestation::ENTERPRISE_USER,
                    /*will_register_key=*/true, ::attestation::KEY_TYPE_RSA,
                    GetKeyName(kCertProfileId), GetPublicKey(), /*profile=*/_))
        .Times(1);

    worker = CertProvisioningWorkerFactory::Get()->Deserialize(
        kCertScope, GetProfile(), &testing_pref_service_,
        *pref_val.FindDict(kCertProfileId), &cert_provisioning_client_,
        std::move(mock_invalidator_obj), GetStateChangeCallback(),
        GetResultCallback());
  }

  // Retry download cert request, receive response, try import certificate.
  {
    testing::InSequence seq;

    EXPECT_DOWNLOAD_CERT_OK(
        DownloadCert(Eq(std::ref(provisioning_process)), /*callback=*/_),
        kFakeCertificate);

    EXPECT_IMPORT_CERTIFICATE_OK(
        ImportCertificate(TokenId::kUser, /*certificate=*/_, /*callback=*/_));

    pref_val = ParseJsonDict("{}");
    EXPECT_CALL(pref_observer, OnPrefValueUpdated(IsJson(pref_val))).Times(1);

    EXPECT_CALL(*mock_invalidator, Unregister()).Times(1);

    worker->DoStep();

    EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
    EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
              CertProvisioningWorkerState::kSucceeded);
  }
}

TEST_F(CertProvisioningWorkerStaticTest, SerializationOnFailure) {
  const CertScope kCertScope = CertScope::kUser;
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);

  const std::string process_id = GenerateCertProvisioningId();
  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  auto worker = CertProvisioningWorkerFactory::Get()->Create(
      process_id, kCertScope, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());

  PrefServiceObserver pref_observer(&testing_pref_service_,
                                    GetPrefNameForSerialization(kCertScope));
  base::Value::Dict pref_val;

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key,
                          StartPrepareKeyStep(::attestation::ENTERPRISE_USER,
                                              /*will_register_key=*/true,
                                              ::attestation::KEY_TYPE_RSA,
                                              GetKeyName(kCertProfileId),
                                              /*profile=*/_,
                                              /*callback=*/_, /*signals=*/_));

    pref_val = ParseJsonDict(base::StringPrintf(
        R"({
          "cert_profile_1": {
            "cert_profile": {
              "policy_version": "cert_profile_version_1",
              "name": "Certificate Profile 1",
              "profile_id": "cert_profile_1",
              "va_enabled": true
            },
            "cert_scope": 0,
            "invalidation_topic": "",
            "process_id": "%s",
            "public_key": "%s",
            "state": 1
          }
        })",
        process_id.c_str(), kPublicKeyBase64));
    EXPECT_CALL(pref_observer, OnPrefValueUpdated(IsJson(pref_val))).Times(1);

    EXPECT_START_CSR_CA_ERROR(StartCsr);

    pref_val = ParseJsonDict("{}");
    EXPECT_CALL(pref_observer, OnPrefValueUpdated(IsJson(pref_val))).Times(1);
  }

  worker->DoStep();
  FastForwardBy(base::Seconds(1));

  VerifyDeleteKeyCalledOnce(kCertScope);

  EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
  EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
            CertProvisioningWorkerState::kFailed);
}

TEST_F(CertProvisioningWorkerStaticTest, InformationalGetters) {
  const CertScope kCertScope = CertScope::kUser;
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);
  const std::string process_id = GenerateCertProvisioningId();
  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  CertProvisioningWorkerStatic worker(
      process_id, kCertScope, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(*mock_tpm_challenge_key, StartPrepareKeyStep);

    EXPECT_START_CSR_TRY_LATER(StartCsr, base::Seconds(30).InMilliseconds());

    worker.DoStep();
    EXPECT_EQ(worker.GetState(),
              CertProvisioningWorkerState::kKeypairGenerated);
    EXPECT_EQ(worker.GetPreviousState(),
              CertProvisioningWorkerState::kInitState);
    EXPECT_EQ(worker.GetCertProfile(), cert_profile);
    EXPECT_EQ(worker.GetPublicKey(), GetPublicKeyBin());
  }

  {
    testing::InSequence seq;

    EXPECT_START_CSR_CA_ERROR(StartCsr);

    worker.DoStep();
    FastForwardBy(base::Seconds(1));

    VerifyDeleteKeyCalledOnce(kCertScope);

    EXPECT_EQ(worker.GetState(), CertProvisioningWorkerState::kFailed);
    EXPECT_EQ(worker.GetPreviousState(),
              CertProvisioningWorkerState::kKeypairGenerated);
    EXPECT_EQ(worker.GetCertProfile(), cert_profile);
    EXPECT_EQ(worker.GetPublicKey(), GetPublicKeyBin());

    EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
    EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
              CertProvisioningWorkerState::kFailed);
  }
}

TEST_F(CertProvisioningWorkerStaticTest, CancelDeviceWorker) {
  base::HistogramTester histogram_tester;

  const CertScope kCertScope = CertScope::kDevice;
  const CertProfile cert_profile(
      kCertProfileId, kCertProfileName, kCertProfileVersion,
      /*is_va_enabled=*/true, kCertProfileRenewalPeriod,
      ProtocolVersion::kStatic);

  EXPECT_CALL(state_change_callback_observer_, StateChangeCallback)
      .Times(AtLeast(1));
  const std::string process_id = GenerateCertProvisioningId();
  MockTpmChallengeKeySubtle* mock_tpm_challenge_key = PrepareTpmChallengeKey();
  auto worker = CertProvisioningWorkerFactory::Get()->Create(
      process_id, kCertScope, GetProfile(), &testing_pref_service_,
      cert_profile, &cert_provisioning_client_, MakeInvalidator(),
      GetStateChangeCallback(), GetResultCallback());

  PrefServiceObserver pref_observer(&testing_pref_service_,
                                    GetPrefNameForSerialization(kCertScope));
  base::Value::Dict pref_val;

  {
    testing::InSequence seq;

    EXPECT_PREPARE_KEY_OK(
        *mock_tpm_challenge_key,
        StartPrepareKeyStep(::attestation::ENTERPRISE_MACHINE,
                            /*will_register_key=*/true,
                            ::attestation::KEY_TYPE_RSA,
                            /*key_name=*/GetKeyName(kCertProfileId),
                            /*profile=*/_,
                            /*callback=*/_, /*signals=*/_));

    pref_val = ParseJsonDict(base::StringPrintf(
        R"({
          "cert_profile_1": {
            "cert_profile": {
              "policy_version": "cert_profile_version_1",
              "name": "Certificate Profile 1",
              "profile_id": "cert_profile_1",
              "va_enabled": true
            },
            "cert_scope": 1,
            "invalidation_topic": "",
            "process_id": "%s",
            "public_key": "%s",
            "state": 1
          }
        })",
        process_id.c_str(), kPublicKeyBase64));
    EXPECT_CALL(pref_observer, OnPrefValueUpdated(IsJson(pref_val))).Times(1);

    EXPECT_START_CSR_NO_OP(StartCsr);

    worker->DoStep();
  }

  {
    pref_val = ParseJsonDict("{}");
    EXPECT_CALL(pref_observer, OnPrefValueUpdated(IsJson(pref_val))).Times(1);

    worker->Stop(CertProvisioningWorkerState::kCanceled);

    FastForwardBy(base::Seconds(1));

    VerifyDeleteKeyCalledOnce(kCertScope);
  }

  EXPECT_EQ(callback_observer_.Get<CertProfile>(), cert_profile);
  EXPECT_EQ(callback_observer_.Get<CertProvisioningWorkerState>(),
            CertProvisioningWorkerState::kCanceled);

  histogram_tester.ExpectUniqueSample("ChromeOS.CertProvisioning.Result.Device",
                                      CertProvisioningWorkerState::kCanceled,
                                      1);
}
}  // namespace
}  // namespace ash::cert_provisioning