chromium/chrome/browser/ash/cert_provisioning/cert_provisioning_client_unittest.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/cert_provisioning/cert_provisioning_client.h"

#include <string>

#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "base/types/expected.h"
#include "components/policy/core/common/cloud/cloud_policy_client.h"
#include "components/policy/core/common/cloud/mock_cloud_policy_client.h"
#include "components/policy/proto/device_management_backend.pb.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace em = enterprise_management;

using testing::_;
using testing::Invoke;
using testing::SizeIs;

namespace ash::cert_provisioning {

namespace {

MATCHER_P(EqualsProto,
          message,
          "Match a proto Message equal to the matcher's argument.") {
  std::string expected_serialized, actual_serialized;
  message.SerializeToString(&expected_serialized);
  arg.SerializeToString(&actual_serialized);
  return expected_serialized == actual_serialized;
}

// A fake CloudPolicyClient that can record cert provisioning actions and
// provides the test a way to supply a response by saving the callbacks passed
// by the code-under-test.
class FakeCloudPolicyClient : public policy::MockCloudPolicyClient {
 public:
  struct CertProvCall {
    em::ClientCertificateProvisioningRequest request;
    ClientCertProvisioningRequestCallback callback;
  };

  FakeCloudPolicyClient() {
    EXPECT_CALL(*this, ClientCertProvisioningRequest)
        .WillRepeatedly(Invoke(
            this, &FakeCloudPolicyClient::OnClientCertProvisioningRequest));
  }

  std::vector<CertProvCall>& cert_prov_calls() { return cert_prov_calls_; }

 private:
  void OnClientCertProvisioningRequest(
      em::ClientCertificateProvisioningRequest request,
      ClientCertProvisioningRequestCallback callback) {
    cert_prov_calls_.push_back({std::move(request), std::move(callback)});
  }

  std::vector<CertProvCall> cert_prov_calls_;
};

// A TestFuture that supports waiting for a
// CertProvisioningClient::StartCallback.
using StartFuture = base::test::TestFuture<
    base::expected<em::CertProvStartResponse, CertProvisioningClient::Error>>;

// A TestFuture that supports waiting for a
// CertProvisioningClient::NextInstructionCallback.
using NextInstructionFuture = base::test::TestFuture<
    base::expected<em::CertProvGetNextInstructionResponse,
                   CertProvisioningClient::Error>>;

// A callback that provides no data but could provide an Error.
// Currently this matches AuthorizeCallback and UploadProofOfPossessionCallback.
using NoDataCallback = base::OnceCallback<void(
    base::expected<void, CertProvisioningClient::Error> result)>;

// A TestFuture that supports waiting for a NoDataCallback (see above).
using NoDataFuture =
    base::test::TestFuture<base::expected<void, CertProvisioningClient::Error>>;

// A TestFuture that supports waiting for a
// CertProvisioningClient::StartCsrCallback.
class StartCsrFuture
    : public base::test::TestFuture<
          policy::DeviceManagementStatus,
          std::optional<em::ClientCertificateProvisioningResponse::Error>,
          std::optional<int64_t>,
          std::string,
          std::string,
          em::HashingAlgorithm,
          std::vector<uint8_t>> {
 public:
  CertProvisioningClient::StartCsrCallback GetStartCsrCallback() {
    return GetCallback<
        policy::DeviceManagementStatus,
        std::optional<em::ClientCertificateProvisioningResponse::Error>,
        std::optional<int64_t>, const std::string&, const std::string&,
        em::HashingAlgorithm, std::vector<uint8_t>>();
  }

  policy::DeviceManagementStatus GetStatus() { return Get<0>(); }

  std::optional<em::ClientCertificateProvisioningResponse::Error> GetError() {
    return Get<1>();
  }

  std::optional<int64_t> GetTryLater() { return Get<2>(); }

  const std::string& GetInvalidationTopic() { return Get<3>(); }

  const std::string& GetVaChallenge() { return Get<4>(); }

  em::HashingAlgorithm GetHashingAlgorithm() { return Get<5>(); }

  const std::vector<uint8_t>& GetDataToSign() { return Get<6>(); }
};

// A TestFuture that supports waiting for a
// CertProvisioningClient::FinishCsrCallback.
class FinishCsrFuture
    : public base::test::TestFuture<
          policy::DeviceManagementStatus,
          std::optional<em::ClientCertificateProvisioningResponse::Error>,
          std::optional<int64_t>> {
 public:
  CertProvisioningClient::FinishCsrCallback GetFinishCsrCallback() {
    return GetCallback();
  }

  policy::DeviceManagementStatus GetStatus() { return Get<0>(); }

  std::optional<em::ClientCertificateProvisioningResponse::Error> GetError() {
    return Get<1>();
  }

  std::optional<int64_t> GetTryLater() { return Get<2>(); }
};

// A TestFuture that supports waiting for a
// CertProvisioningClient::DownloadCertCallback.
class DownloadCertFuture
    : public base::test::TestFuture<
          policy::DeviceManagementStatus,
          std::optional<em::ClientCertificateProvisioningResponse::Error>,
          std::optional<int64_t>,
          std::string> {
 public:
  CertProvisioningClient::DownloadCertCallback GetDownloadCertCallback() {
    return GetCallback<
        policy::DeviceManagementStatus,
        std::optional<em::ClientCertificateProvisioningResponse::Error>,
        std::optional<int64_t>, const std::string&>();
  }

  policy::DeviceManagementStatus GetStatus() { return Get<0>(); }

  std::optional<em::ClientCertificateProvisioningResponse::Error> GetError() {
    return Get<1>();
  }

  std::optional<int64_t> GetTryLater() { return Get<2>(); }

  const std::string& GetPemEncodedCertificate() { return Get<3>(); }
};

}  // namespace

// Tuple of CertScope enum value and corresponding device management protocol
// string.
using CertScopePair = std::tuple<CertScope, std::string>;

// Base class for testing CertProvisioningClient.
// The subclasses will implement different test parameters.
class CertProvisioningClientTestBase : public testing::Test {
 public:
  CertProvisioningClientTestBase() = default;
  ~CertProvisioningClientTestBase() override = default;

  virtual CertScope cert_scope() const = 0;
  virtual const std::string& cert_scope_dm_api_string() const = 0;

 protected:
  base::test::SingleThreadTaskEnvironment task_environment_;

  FakeCloudPolicyClient cloud_policy_client_;

  const std::string kCertProfileId = "fake_cert_profile_id_1";
  const std::string kCertProfileVersion = "fake_cert_profile_version_1";
  const std::vector<uint8_t> kPublicKey = {0x66, 0x61, 0x6B, 0x65,
                                           0x5F, 0x6B, 0x65, 0x79};
  const std::string kPublicKeyAsString =
      std::string(kPublicKey.begin(), kPublicKey.end());

  const std::string kInvalidationTopic = "fake_invalidation_topic_1";
  const std::string kVaChallange = "fake_va_challenge_1";
  const std::string kDataToSignStr = {10, 11, 12, 13, 14};
  const std::vector<uint8_t> kDataToSignBin = {10, 11, 12, 13, 14};
  const em::HashingAlgorithm kHashAlgorithm = em::HashingAlgorithm::SHA256;
  const em::SigningAlgorithm kSignAlgorithm =
      em::SigningAlgorithm::RSA_PKCS1_V1_5;
  const std::string kVaChallengeResponse = "fake_va_challenge_response_1";
  const std::string kSignature = "fake_signature_1";
  const std::string kPemEncodedCert = "fake_pem_encoded_cert_1";
};

// Test fixture for CertProvisioningClient, parametrized by CertScope.
class CertProvisioningClientTest
    : public CertProvisioningClientTestBase,
      public testing::WithParamInterface<CertScopePair> {
 public:
  CertScope cert_scope() const override { return std::get<0>(GetParam()); }

  const std::string& cert_scope_dm_api_string() const override {
    return std::get<1>(GetParam());
  }

  const std::string kCertProvisioningId = GenerateCertProvisioningId();
};

// Checks a successful invocation of Start.
TEST_P(CertProvisioningClientTest, StartSuccess) {
  CertProvisioningClientImpl cert_provisioning_client(cloud_policy_client_);

  StartFuture start_future;
  cert_provisioning_client.Start(
      CertProvisioningClient::ProvisioningProcess(
          kCertProvisioningId, cert_scope(), kCertProfileId,
          kCertProfileVersion, kPublicKey),
      start_future.GetCallback());

  // Expect one request to CloudPolicyClient, verify its contents.
  ASSERT_THAT(cloud_policy_client_.cert_prov_calls(), SizeIs(1));
  FakeCloudPolicyClient::CertProvCall& cert_prov_call =
      cloud_policy_client_.cert_prov_calls().back();
  {
    em::ClientCertificateProvisioningRequest expected_request;
    expected_request.set_certificate_provisioning_process_id(
        kCertProvisioningId);
    expected_request.set_certificate_scope(cert_scope_dm_api_string());
    expected_request.set_cert_profile_id(kCertProfileId);
    expected_request.set_policy_version(kCertProfileVersion);
    expected_request.set_public_key(kPublicKeyAsString);
    // Sets the request type, no actual data is required.
    expected_request.mutable_start_request();

    EXPECT_THAT(cert_prov_call.request, EqualsProto(expected_request));
  }

  // Make CloudPolicyClient answer the request.
  const std::string invalidation_topic = "test";

  em::ClientCertificateProvisioningResponse response;
  response.mutable_start_response()->set_invalidation_topic(invalidation_topic);

  std::move(cert_prov_call.callback).Run(policy::DM_STATUS_SUCCESS, response);

  // Check that CertProvisioningClient has forwarded the answer correctly.
  ASSERT_TRUE(start_future.Get().has_value());
  EXPECT_EQ(start_future.Get().value().invalidation_topic(),
            invalidation_topic);
}

// Checks a successful invocation of GetNextInstruction.
TEST_P(CertProvisioningClientTest, GetNextInstructionSuccess) {
  CertProvisioningClientImpl cert_provisioning_client(cloud_policy_client_);

  NextInstructionFuture next_instruction_future;
  cert_provisioning_client.GetNextInstruction(
      CertProvisioningClient::ProvisioningProcess(
          kCertProvisioningId, cert_scope(), kCertProfileId,
          kCertProfileVersion, kPublicKey),
      next_instruction_future.GetCallback());

  // Expect one request to CloudPolicyClient, verify its contents.
  ASSERT_THAT(cloud_policy_client_.cert_prov_calls(), SizeIs(1));
  FakeCloudPolicyClient::CertProvCall& cert_prov_call =
      cloud_policy_client_.cert_prov_calls().back();
  {
    em::ClientCertificateProvisioningRequest expected_request;
    expected_request.set_certificate_provisioning_process_id(
        kCertProvisioningId);
    expected_request.set_certificate_scope(cert_scope_dm_api_string());
    expected_request.set_cert_profile_id(kCertProfileId);
    expected_request.set_policy_version(kCertProfileVersion);
    expected_request.set_public_key(kPublicKeyAsString);
    // Sets the request type, no actual data is required.
    expected_request.mutable_get_next_instruction_request();

    EXPECT_THAT(cert_prov_call.request, EqualsProto(expected_request));
  }

  // Make CloudPolicyClient answer the request.
  const std::string va_challenge = "test";

  em::CertProvGetNextInstructionResponse next_instruction_response;
  next_instruction_response.mutable_authorize_instruction()->set_va_challenge(
      va_challenge);
  em::ClientCertificateProvisioningResponse response;
  *response.mutable_get_next_instruction_response() = next_instruction_response;

  std::move(cert_prov_call.callback).Run(policy::DM_STATUS_SUCCESS, response);

  // Check that CertProvisioningClient has forwarded the answer correctly.
  ASSERT_TRUE(next_instruction_future.Get().has_value());
  EXPECT_THAT(next_instruction_future.Get().value(),
              EqualsProto(next_instruction_response));
}

// Checks a successful invocation of Authorize.
TEST_P(CertProvisioningClientTest, AuthorizeSuccess) {
  CertProvisioningClientImpl cert_provisioning_client(cloud_policy_client_);

  NoDataFuture no_data_future;
  cert_provisioning_client.Authorize(
      CertProvisioningClient::ProvisioningProcess(
          kCertProvisioningId, cert_scope(), kCertProfileId,
          kCertProfileVersion, kPublicKey),
      kVaChallengeResponse, no_data_future.GetCallback());

  // Expect one request to CloudPolicyClient, verify its contents.
  ASSERT_THAT(cloud_policy_client_.cert_prov_calls(), SizeIs(1));
  FakeCloudPolicyClient::CertProvCall& cert_prov_call =
      cloud_policy_client_.cert_prov_calls().back();
  {
    em::ClientCertificateProvisioningRequest expected_request;
    expected_request.set_certificate_provisioning_process_id(
        kCertProvisioningId);
    expected_request.set_certificate_scope(cert_scope_dm_api_string());
    expected_request.set_cert_profile_id(kCertProfileId);
    expected_request.set_policy_version(kCertProfileVersion);
    expected_request.set_public_key(kPublicKeyAsString);

    auto* authorize_request = expected_request.mutable_authorize_request();
    authorize_request->set_va_challenge_response(kVaChallengeResponse);

    EXPECT_THAT(cert_prov_call.request, EqualsProto(expected_request));
  }

  // Make CloudPolicyClient answer the request.
  em::ClientCertificateProvisioningResponse response;
  response.mutable_authorize_response();

  std::move(cert_prov_call.callback).Run(policy::DM_STATUS_SUCCESS, response);

  // Check that the response has no error.
  EXPECT_TRUE(no_data_future.Get().has_value());
}

// Checks a successful invocation of UploadProofOfPossession.
TEST_P(CertProvisioningClientTest, UploadProofOfPossessionSuccess) {
  CertProvisioningClientImpl cert_provisioning_client(cloud_policy_client_);

  NoDataFuture no_data_future;
  cert_provisioning_client.UploadProofOfPossession(
      CertProvisioningClient::ProvisioningProcess(
          kCertProvisioningId, cert_scope(), kCertProfileId,
          kCertProfileVersion, kPublicKey),
      kSignature, no_data_future.GetCallback());

  // Expect one request to CloudPolicyClient, verify its contents.
  ASSERT_THAT(cloud_policy_client_.cert_prov_calls(), SizeIs(1));
  FakeCloudPolicyClient::CertProvCall& cert_prov_call =
      cloud_policy_client_.cert_prov_calls().back();
  {
    em::ClientCertificateProvisioningRequest expected_request;
    expected_request.set_certificate_provisioning_process_id(
        kCertProvisioningId);
    expected_request.set_certificate_scope(cert_scope_dm_api_string());
    expected_request.set_cert_profile_id(kCertProfileId);
    expected_request.set_policy_version(kCertProfileVersion);
    expected_request.set_public_key(kPublicKeyAsString);

    auto* upload_proof_of_possession_request =
        expected_request.mutable_upload_proof_of_possession_request();
    upload_proof_of_possession_request->set_signature(kSignature);

    EXPECT_THAT(cert_prov_call.request, EqualsProto(expected_request));
  }

  // Make CloudPolicyClient answer the request.
  em::ClientCertificateProvisioningResponse response;
  response.mutable_upload_proof_of_possession_response();

  std::move(cert_prov_call.callback).Run(policy::DM_STATUS_SUCCESS, response);

  // Check that the response has no error.
  EXPECT_TRUE(no_data_future.Get().has_value());
}

// 1. Checks that `StartCsr` generates a correct request.
// 2. Checks that CertProvisioningClient correctly extracts data from a response
// that contains data.
TEST_P(CertProvisioningClientTest, StartCsrSuccess) {
  CertProvisioningClientImpl cert_provisioning_client(cloud_policy_client_);

  StartCsrFuture start_csr_future;
  cert_provisioning_client.StartCsr(
      CertProvisioningClient::ProvisioningProcess(
          kCertProvisioningId, cert_scope(), kCertProfileId,
          kCertProfileVersion, kPublicKey),
      start_csr_future.GetStartCsrCallback());

  // Expect one request to CloudPolicyClient, verify its contents.
  ASSERT_THAT(cloud_policy_client_.cert_prov_calls(), SizeIs(1));
  FakeCloudPolicyClient::CertProvCall& cert_prov_call =
      cloud_policy_client_.cert_prov_calls().back();
  {
    em::ClientCertificateProvisioningRequest expected_request;
    expected_request.set_certificate_provisioning_process_id(
        kCertProvisioningId);
    expected_request.set_certificate_scope(cert_scope_dm_api_string());
    expected_request.set_cert_profile_id(kCertProfileId);
    expected_request.set_policy_version(kCertProfileVersion);
    expected_request.set_public_key(kPublicKeyAsString);
    // Sets the request type, no actual data is required.
    expected_request.mutable_start_csr_request();

    EXPECT_THAT(cert_prov_call.request, EqualsProto(expected_request));
  }

  // Make CloudPolicyClient answer the request.
  em::ClientCertificateProvisioningResponse response;
  {
    em::StartCsrResponse* start_csr_response =
        response.mutable_start_csr_response();
    start_csr_response->set_invalidation_topic(kInvalidationTopic);
    start_csr_response->set_va_challenge(kVaChallange);
    start_csr_response->set_hashing_algorithm(kHashAlgorithm);
    start_csr_response->set_signing_algorithm(kSignAlgorithm);
    start_csr_response->set_data_to_sign(kDataToSignStr);
  }
  std::move(cert_prov_call.callback).Run(policy::DM_STATUS_SUCCESS, response);

  // Check that CertProvisioningClient has translated the answer correctly.
  EXPECT_EQ(start_csr_future.GetStatus(), policy::DM_STATUS_SUCCESS);
  EXPECT_EQ(start_csr_future.GetError(), std::nullopt);
  EXPECT_EQ(start_csr_future.GetTryLater(), std::nullopt);
  EXPECT_EQ(start_csr_future.GetInvalidationTopic(), kInvalidationTopic);
  EXPECT_EQ(start_csr_future.GetVaChallenge(), kVaChallange);
  EXPECT_EQ(start_csr_future.GetHashingAlgorithm(), kHashAlgorithm);
  EXPECT_EQ(start_csr_future.GetDataToSign(), kDataToSignBin);
}

// Checks that CertProvisioningClient correctly reacts on the `try_later` field
// in a response to StartCsr.
TEST_P(CertProvisioningClientTest, StartCsrTryLater) {
  const int64_t try_later = 60000;

  CertProvisioningClientImpl cert_provisioning_client(cloud_policy_client_);

  StartCsrFuture start_csr_future;
  cert_provisioning_client.StartCsr(
      CertProvisioningClient::ProvisioningProcess(
          kCertProvisioningId, cert_scope(), kCertProfileId,
          kCertProfileVersion, kPublicKey),
      start_csr_future.GetStartCsrCallback());

  // Expect one request to CloudPolicyClient.
  ASSERT_THAT(cloud_policy_client_.cert_prov_calls(), SizeIs(1));
  FakeCloudPolicyClient::CertProvCall& cert_prov_call =
      cloud_policy_client_.cert_prov_calls().back();

  // Make CloudPolicyClient answer the request.
  em::ClientCertificateProvisioningResponse response;
  response.set_try_again_later(try_later);
  std::move(cert_prov_call.callback).Run(policy::DM_STATUS_SUCCESS, response);

  // Check that CertProvisioningClient has translated the answer correctly.
  EXPECT_EQ(start_csr_future.GetStatus(), policy::DM_STATUS_SUCCESS);
  EXPECT_EQ(start_csr_future.GetError(), std::nullopt);
  EXPECT_EQ(start_csr_future.GetTryLater(), std::make_optional(try_later));
}

// Checks that CertProvisioningClient correctly reacts on the `error` field
// in a response to StartCsr.
TEST_P(CertProvisioningClientTest, StartCsrError) {
  const CertProvisioningResponseErrorType error =
      CertProvisioningResponseError::CA_ERROR;

  CertProvisioningClientImpl cert_provisioning_client(cloud_policy_client_);

  StartCsrFuture start_csr_future;
  cert_provisioning_client.StartCsr(
      CertProvisioningClient::ProvisioningProcess(
          kCertProvisioningId, cert_scope(), kCertProfileId,
          kCertProfileVersion, kPublicKey),
      start_csr_future.GetStartCsrCallback());

  // Expect one request to CloudPolicyClient.
  ASSERT_THAT(cloud_policy_client_.cert_prov_calls(), SizeIs(1));
  FakeCloudPolicyClient::CertProvCall& cert_prov_call =
      cloud_policy_client_.cert_prov_calls().back();

  // Make CloudPolicyClient answer the request.
  em::ClientCertificateProvisioningResponse response;
  response.set_error(error);
  std::move(cert_prov_call.callback).Run(policy::DM_STATUS_SUCCESS, response);

  // Check that CertProvisioningClient has translated the answer correctly.
  EXPECT_EQ(start_csr_future.GetStatus(), policy::DM_STATUS_SUCCESS);
  EXPECT_EQ(start_csr_future.GetError(), std::make_optional(error));
  EXPECT_EQ(start_csr_future.GetTryLater(), std::nullopt);
}

// 1. Checks that `FinishCsr` generates a correct request.
// 2. Checks that CertProvisioningClient correctly extracts data from a response
// that contains data.
TEST_P(CertProvisioningClientTest, FinishCsrSuccess) {
  CertProvisioningClientImpl cert_provisioning_client(cloud_policy_client_);

  FinishCsrFuture finish_csr_future;
  cert_provisioning_client.FinishCsr(
      CertProvisioningClient::ProvisioningProcess(
          kCertProvisioningId, cert_scope(), kCertProfileId,
          kCertProfileVersion, kPublicKey),
      kVaChallengeResponse, kSignature,
      finish_csr_future.GetFinishCsrCallback());

  // Expect one request to CloudPolicyClient, verify its contents.
  ASSERT_THAT(cloud_policy_client_.cert_prov_calls(), SizeIs(1));
  FakeCloudPolicyClient::CertProvCall& cert_prov_call =
      cloud_policy_client_.cert_prov_calls().back();
  {
    em::ClientCertificateProvisioningRequest expected_request;
    expected_request.set_certificate_provisioning_process_id(
        kCertProvisioningId);
    expected_request.set_certificate_scope(cert_scope_dm_api_string());
    expected_request.set_cert_profile_id(kCertProfileId);
    expected_request.set_policy_version(kCertProfileVersion);
    expected_request.set_public_key(kPublicKeyAsString);

    em::FinishCsrRequest* finish_csr_request =
        expected_request.mutable_finish_csr_request();
    finish_csr_request->set_va_challenge_response(kVaChallengeResponse);
    finish_csr_request->set_signature(kSignature);

    EXPECT_THAT(cert_prov_call.request, EqualsProto(expected_request));
  }

  // Make CloudPolicyClient answer the request.
  em::ClientCertificateProvisioningResponse response;
  {
    // Sets the response id, no actual data is required.
    response.mutable_finish_csr_response();
  }
  std::move(cert_prov_call.callback).Run(policy::DM_STATUS_SUCCESS, response);

  // Check that CertProvisioningClient has translated the answer correctly.
  EXPECT_EQ(finish_csr_future.GetStatus(), policy::DM_STATUS_SUCCESS);
  EXPECT_EQ(finish_csr_future.GetError(), std::nullopt);
  EXPECT_EQ(finish_csr_future.GetTryLater(), std::nullopt);
}

// Checks that CertProvisioningClient correctly reacts on the `error` field
// in a response to FinishCsr.
TEST_P(CertProvisioningClientTest, FinishCsrError) {
  CertProvisioningClientImpl cert_provisioning_client(cloud_policy_client_);

  FinishCsrFuture finish_csr_future;
  cert_provisioning_client.FinishCsr(
      CertProvisioningClient::ProvisioningProcess(
          kCertProvisioningId, cert_scope(), kCertProfileId,
          kCertProfileVersion, kPublicKey),
      kVaChallengeResponse, kSignature,
      finish_csr_future.GetFinishCsrCallback());

  // Expect one request to CloudPolicyClient.
  ASSERT_THAT(cloud_policy_client_.cert_prov_calls(), SizeIs(1));
  FakeCloudPolicyClient::CertProvCall& cert_prov_call =
      cloud_policy_client_.cert_prov_calls().back();

  // Make CloudPolicyClient answer the request.
  const CertProvisioningResponseErrorType error =
      CertProvisioningResponseError::CA_ERROR;
  em::ClientCertificateProvisioningResponse response;
  response.set_error(error);
  std::move(cert_prov_call.callback).Run(policy::DM_STATUS_SUCCESS, response);

  // Check that CertProvisioningClient has translated the answer correctly.
  EXPECT_EQ(finish_csr_future.GetStatus(), policy::DM_STATUS_SUCCESS);
  EXPECT_EQ(finish_csr_future.GetError(), std::make_optional(error));
  EXPECT_EQ(finish_csr_future.GetTryLater(), std::nullopt);
}

// 1. Checks that `DownloadCert` generates a correct request.
// 2. Checks that CertProvisioningClient correctly extracts data from a response
// that contains data.
TEST_P(CertProvisioningClientTest, DownloadCertSuccess) {
  CertProvisioningClientImpl cert_provisioning_client(cloud_policy_client_);

  DownloadCertFuture download_cert_future;
  cert_provisioning_client.DownloadCert(
      CertProvisioningClient::ProvisioningProcess(
          kCertProvisioningId, cert_scope(), kCertProfileId,
          kCertProfileVersion, kPublicKey),
      download_cert_future.GetDownloadCertCallback());

  // Expect one request to CloudPolicyClient, verify its contents.
  ASSERT_THAT(cloud_policy_client_.cert_prov_calls(), SizeIs(1));
  FakeCloudPolicyClient::CertProvCall& cert_prov_call =
      cloud_policy_client_.cert_prov_calls().back();
  {
    em::ClientCertificateProvisioningRequest expected_request;
    expected_request.set_certificate_provisioning_process_id(
        kCertProvisioningId);
    expected_request.set_certificate_scope(cert_scope_dm_api_string());
    expected_request.set_cert_profile_id(kCertProfileId);
    expected_request.set_policy_version(kCertProfileVersion);
    expected_request.set_public_key(kPublicKeyAsString);

    // Sets the request type, no actual data is required.
    expected_request.mutable_download_cert_request();

    EXPECT_THAT(cert_prov_call.request, EqualsProto(expected_request));
  }

  // Make CloudPolicyClient answer the request.
  em::ClientCertificateProvisioningResponse response;
  {
    em::DownloadCertResponse* download_cert_response =
        response.mutable_download_cert_response();
    download_cert_response->set_pem_encoded_certificate(kPemEncodedCert);
  }
  std::move(cert_prov_call.callback).Run(policy::DM_STATUS_SUCCESS, response);

  // Check that CertProvisioningClient has translated the answer correctly.
  EXPECT_EQ(download_cert_future.GetStatus(), policy::DM_STATUS_SUCCESS);
  EXPECT_EQ(download_cert_future.GetError(), std::nullopt);
  EXPECT_EQ(download_cert_future.GetTryLater(), std::nullopt);
  EXPECT_EQ(download_cert_future.GetPemEncodedCertificate(), kPemEncodedCert);
}

// Checks that CertProvisioningClient correctly reacts on the `error` field
// in a response to DownloadCert.
TEST_P(CertProvisioningClientTest, DownloadCertError) {
  CertProvisioningClientImpl cert_provisioning_client(cloud_policy_client_);

  DownloadCertFuture download_cert_future;
  cert_provisioning_client.DownloadCert(
      CertProvisioningClient::ProvisioningProcess(
          kCertProvisioningId, cert_scope(), kCertProfileId,
          kCertProfileVersion, kPublicKey),
      download_cert_future.GetDownloadCertCallback());

  // Expect one request to CloudPolicyClient.
  ASSERT_THAT(cloud_policy_client_.cert_prov_calls(), SizeIs(1));
  FakeCloudPolicyClient::CertProvCall& cert_prov_call =
      cloud_policy_client_.cert_prov_calls().back();

  // Make CloudPolicyClient answer the request.
  const CertProvisioningResponseErrorType error =
      CertProvisioningResponseError::CA_ERROR;
  em::ClientCertificateProvisioningResponse response;
  response.set_error(error);
  std::move(cert_prov_call.callback).Run(policy::DM_STATUS_SUCCESS, response);

  // Check that CertProvisioningClient has translated the answer correctly.
  EXPECT_EQ(download_cert_future.GetStatus(), policy::DM_STATUS_SUCCESS);
  EXPECT_EQ(download_cert_future.GetError(), std::make_optional(error));
  EXPECT_EQ(download_cert_future.GetTryLater(), std::nullopt);
  EXPECT_EQ(download_cert_future.GetPemEncodedCertificate(), std::string());
}

INSTANTIATE_TEST_SUITE_P(
    AllScopes,
    CertProvisioningClientTest,
    ::testing::Values(CertScopePair(CertScope::kUser, "google/chromeos/user"),
                      CertScopePair(CertScope::kDevice,
                                    "google/chromeos/device")));

// A Test case for CertProvisioningClientErrorHandlingTest.
struct ErrorHandlingTestCase {
  // Invokes a CertProvisioningClient API call.
  // As these tests only test error cases, it is expected that any callback will
  // be adapted to a NoDataCallback.
  base::RepeatingCallback<void(
      CertProvisioningClient*,
      CertProvisioningClient::ProvisioningProcess provisioining_process,
      NoDataCallback callback)>
      act_function;
};

// Test fixture for CertProvisioningClient, parametrized by CertScope and a
// ErrorHandlingTestCase which implements a call to one of the
// "dynamic flow" API calls.
// This is useful for testing error response processing across all "dynamic
// flow" API calls.
class CertProvisioningClientErrorHandlingTest
    : public CertProvisioningClientTestBase,
      public testing::WithParamInterface<
          std::tuple<CertScopePair, ErrorHandlingTestCase>> {
 public:
  CertScope cert_scope() const override {
    return std::get<0>(cert_scope_pair());
  }

  const std::string& cert_scope_dm_api_string() const override {
    return std::get<1>(cert_scope_pair());
  }

  void ExecuteCertProvisioningClientCall(
      CertProvisioningClient* client,
      CertProvisioningClient::ProvisioningProcess provisioning_process,
      NoDataCallback callback) const {
    std::get<1>(GetParam())
        .act_function.Run(client, std::move(provisioning_process),
                          std::move(callback));
  }

  const std::string kCertProvisioningId = GenerateCertProvisioningId();

 private:
  const CertScopePair& cert_scope_pair() const {
    return std::get<0>(GetParam());
  }
};

// Checks that all "dynamic flow" API calls forward a CertProvBackendError
// correctly.
TEST_P(CertProvisioningClientErrorHandlingTest, CertProvBackendError) {
  CertProvisioningClientImpl cert_provisioning_client(cloud_policy_client_);

  // Execute the CertProvisioningClient API call. Don't verify the filled
  // request proto - this is done by other tests in this file.
  NoDataFuture no_data_future;
  ExecuteCertProvisioningClientCall(
      &cert_provisioning_client,
      CertProvisioningClient::ProvisioningProcess(
          kCertProvisioningId, cert_scope(), kCertProfileId,
          kCertProfileVersion, kPublicKey),
      no_data_future.GetCallback());

  ASSERT_THAT(cloud_policy_client_.cert_prov_calls(), SizeIs(1));
  FakeCloudPolicyClient::CertProvCall& cert_prov_call =
      cloud_policy_client_.cert_prov_calls().back();
  // Make CloudPolicyClient answer the request.
  const em::CertProvBackendError::Error error =
      em::CertProvBackendError::CA_FAILURE;
  const std::string debug_message = "debug info";
  em::ClientCertificateProvisioningResponse response;
  response.mutable_backend_error()->set_error(error);
  response.mutable_backend_error()->set_debug_message(debug_message);
  std::move(cert_prov_call.callback).Run(policy::DM_STATUS_SUCCESS, response);

  // Expect that the CertProvisioningClient provides the error.
  ASSERT_FALSE(no_data_future.Get().has_value());
  EXPECT_EQ(no_data_future.Get().error().device_management_status,
            policy::DM_STATUS_SUCCESS);
  EXPECT_EQ(no_data_future.Get().error().backend_error.error(), error);
  EXPECT_EQ(no_data_future.Get().error().backend_error.debug_message(),
            debug_message);
}

// Checks that all "dynamic flow" API calls forward forward a "DM_STATUS_.."
// error correctly.
TEST_P(CertProvisioningClientErrorHandlingTest, DeviceManagementError) {
  CertProvisioningClientImpl cert_provisioning_client(cloud_policy_client_);

  // Execute the CertProvisioningClient API call. Don't verify the filled
  // request proto - this is done by other tests in this file.
  NoDataFuture no_data_future;
  ExecuteCertProvisioningClientCall(
      &cert_provisioning_client,
      CertProvisioningClient::ProvisioningProcess(
          kCertProvisioningId, cert_scope(), kCertProfileId,
          kCertProfileVersion, kPublicKey),
      no_data_future.GetCallback());

  ASSERT_THAT(cloud_policy_client_.cert_prov_calls(), SizeIs(1));
  FakeCloudPolicyClient::CertProvCall& cert_prov_call =
      cloud_policy_client_.cert_prov_calls().back();
  // Make CloudPolicyClient answer the request with a device management error.
  em::ClientCertificateProvisioningResponse response;
  std::move(cert_prov_call.callback)
      .Run(policy::DM_STATUS_SERVICE_DEVICE_NOT_FOUND, response);

  // Expect that the CertProvisioningClient provides the error.
  ASSERT_FALSE(no_data_future.Get().has_value());
  EXPECT_EQ(no_data_future.Get().error().device_management_status,
            policy::DM_STATUS_SERVICE_DEVICE_NOT_FOUND);
}

// Checks that if no "oneof" field of ClientCertificateProvisioningResponseis
// filled, a decoding error will be signaled.
TEST_P(CertProvisioningClientErrorHandlingTest, ResponseFieldNotFilled) {
  CertProvisioningClientImpl cert_provisioning_client(cloud_policy_client_);

  // Execute the CertProvisioningClient API call. Don't verify the filled
  // request proto - this is done by other tests in this file.
  NoDataFuture no_data_future;
  ExecuteCertProvisioningClientCall(
      &cert_provisioning_client,
      CertProvisioningClient::ProvisioningProcess(
          kCertProvisioningId, cert_scope(), kCertProfileId,
          kCertProfileVersion, kPublicKey),
      no_data_future.GetCallback());

  ASSERT_THAT(cloud_policy_client_.cert_prov_calls(), SizeIs(1));
  FakeCloudPolicyClient::CertProvCall& cert_prov_call =
      cloud_policy_client_.cert_prov_calls().back();
  // Make CloudPolicyClient answer the request with no "oneof" field filled.
  em::ClientCertificateProvisioningResponse response;
  std::move(cert_prov_call.callback).Run(policy::DM_STATUS_SUCCESS, response);

  // Expect that the CertProvisioningClient provides a decoding error.
  ASSERT_FALSE(no_data_future.Get().has_value());
  EXPECT_EQ(no_data_future.Get().error().device_management_status,
            policy::DM_STATUS_RESPONSE_DECODING_ERROR);
}

template <typename ResponseType>
void AdaptToNoDataCallback(
    NoDataCallback no_data_callback,
    base::expected<ResponseType, CertProvisioningClient::Error> result) {
  if (result.has_value()) {
    return std::move(no_data_callback).Run({});
  }
  std::move(no_data_callback).Run(base::unexpected(result.error()));
}

INSTANTIATE_TEST_SUITE_P(
    AllTests,
    CertProvisioningClientErrorHandlingTest,
    ::testing::Combine(
        ::testing::Values(
            CertScopePair(CertScope::kUser, "google/chromeos/user"),
            CertScopePair(CertScope::kDevice, "google/chromeos/device")),
        ::testing::Values(
            ErrorHandlingTestCase{base::BindRepeating(
                [](CertProvisioningClient* client,
                   CertProvisioningClient::ProvisioningProcess
                       provisioning_process,
                   NoDataCallback callback) {
                  client->Start(
                      std::move(provisioning_process),
                      base::BindOnce(
                          &AdaptToNoDataCallback<em::CertProvStartResponse>,
                          std::move(callback)));
                })},
            ErrorHandlingTestCase{base::BindRepeating(
                [](CertProvisioningClient* client,
                   CertProvisioningClient::ProvisioningProcess
                       provisioning_process,
                   NoDataCallback callback) {
                  client->GetNextInstruction(
                      std::move(provisioning_process),
                      base::BindOnce(
                          &AdaptToNoDataCallback<
                              em::CertProvGetNextInstructionResponse>,
                          std::move(callback)));
                })},
            ErrorHandlingTestCase{base::BindRepeating(
                [](CertProvisioningClient* client,
                   CertProvisioningClient::ProvisioningProcess
                       provisioning_process,
                   NoDataCallback callback) {
                  client->Authorize(std::move(provisioning_process),
                                    /*va_challenge_response=*/std::string(),
                                    std::move(callback));
                })},
            ErrorHandlingTestCase{base::BindRepeating(
                [](CertProvisioningClient* client,
                   CertProvisioningClient::ProvisioningProcess
                       provisioning_process,
                   NoDataCallback callback) {
                  client->UploadProofOfPossession(
                      std::move(provisioning_process),
                      /*signature=*/std::string(), std::move(callback));
                })})));

}  // namespace ash::cert_provisioning