chromium/chrome/browser/enterprise/connectors/device_trust/attestation/ash/ash_attestation_service_impl_unittest.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/enterprise/connectors/device_trust/attestation/ash/ash_attestation_service_impl.h"

#include <memory>
#include <optional>

#include "base/base64.h"
#include "base/json/json_reader.h"
#include "base/memory/raw_ptr.h"
#include "base/run_loop.h"
#include "base/test/bind.h"
#include "base/test/gmock_callback_support.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "base/values.h"
#include "chrome/browser/ash/attestation/mock_tpm_challenge_key.h"
#include "chrome/browser/ash/attestation/mock_tpm_challenge_key_subtle.h"
#include "chrome/browser/ash/attestation/tpm_challenge_key.h"
#include "chrome/browser/ash/attestation/tpm_challenge_key_result.h"
#include "chrome/browser/ash/attestation/tpm_challenge_key_subtle.h"
#include "chrome/browser/enterprise/connectors/device_trust/attestation/common/attestation_utils.h"
#include "chrome/browser/enterprise/connectors/device_trust/common/common_types.h"
#include "chrome/test/base/testing_browser_process.h"
#include "chrome/test/base/testing_profile.h"
#include "chrome/test/base/testing_profile_manager.h"
#include "chromeos/ash/components/dbus/attestation/attestation_ca.pb.h"
#include "chromeos/ash/components/dbus/attestation/attestation_client.h"
#include "chromeos/ash/components/dbus/constants/attestation_constants.h"
#include "components/device_signals/core/common/signals_constants.h"
#include "content/public/test/browser_task_environment.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

using base::test::RunOnceCallback;
using testing::_;
using testing::Invoke;
using testing::StrictMock;

namespace enterprise_connectors {

namespace {

// A sample VerifiedAccess v2 challenge.
constexpr char kEncodedChallenge[] =
    "CkEKFkVudGVycHJpc2VLZXlDaGFsbGVuZ2USIELlPXqh8+"
    "rZJ2VIqwPXtPFrr653QdRrIzHFwqP+"
    "b3L8GJTcufirLxKAAkindNwTfwYUcbCFDjiW3kXdmDPE0wC0J6b5ZI6X6vOVcSMXTpK7nxsAGK"
    "zFV+i80LCnfwUZn7Ne1bHzloAqBdpLOu53vQ63hKRk6MRPhc9jYVDsvqXfQ7s+"
    "FUA5r3lxdoluxwAUMFqcP4VgnMvKzKTPYbnnB+xj5h5BZqjQToXJYoP4VC3/"
    "ID+YHNsCWy5o7+G5jnq0ak3zeqWfo1+lCibMPsCM+"
    "2g7nCZIwvwWlfoKwv3aKvOVMBcJxPAIxH1w+hH+"
    "NWxqRi6qgZm84q0ylm0ybs6TFjdgLvSViAIp0Z9p/An/"
    "u3W4CMboCswxIxNYRCGrIIVPElE3Yb4QS65mKrg=";

constexpr char kTestUserEmail[] = "[email protected]";
constexpr char kFakeResponse[] = "fake_response";
constexpr char kDisplayName[] = "display-name";

std::string GetSerializedSignedChallenge() {
  std::string serialized_signed_challenge;
  if (!base::Base64Decode(kEncodedChallenge, &serialized_signed_challenge)) {
    return std::string();
  }

  return serialized_signed_challenge;
}

std::optional<std::string> ParseValueFromResponse(const std::string& response) {
  std::optional<base::Value> data = base::JSONReader::Read(
      response, base::JSONParserOptions::JSON_ALLOW_TRAILING_COMMAS);

  // If json is malformed or it doesn't include the needed field return
  // an empty string.
  if (!data || !data->GetDict().FindString("challengeResponse")) {
    return std::nullopt;
  }

  std::string decoded_response_value;
  if (!base::Base64Decode(*data->GetDict().FindString("challengeResponse"),
                          &decoded_response_value)) {
    return std::nullopt;
  }

  return decoded_response_value;
}

ash::attestation::MockTpmChallengeKeySubtle* InjectMockChallengeKeySubtle() {
  auto mock_challenge_key_subtle =
      std::make_unique<ash::attestation::MockTpmChallengeKeySubtle>();
  ash::attestation::MockTpmChallengeKeySubtle* challenge_key_subtle_ptr =
      mock_challenge_key_subtle.get();
  ash::attestation::TpmChallengeKeySubtleFactory::SetForTesting(
      std::move(mock_challenge_key_subtle));
  return challenge_key_subtle_ptr;
}

ash::attestation::MockTpmChallengeKey* InjectMockChallengeKey() {
  auto mock_challenge_key =
      std::make_unique<ash::attestation::MockTpmChallengeKey>();
  ash::attestation::MockTpmChallengeKey* challenge_key_ptr =
      mock_challenge_key.get();
  ash::attestation::TpmChallengeKeyFactory::SetForTesting(
      std::move(mock_challenge_key));
  return challenge_key_ptr;
}

}  // namespace

class AshAttestationServiceImplTest : public testing::Test {
 protected:
  AshAttestationServiceImplTest() {
    ash::AttestationClient::InitializeFake();
    mock_challenge_key_subtle_ = InjectMockChallengeKeySubtle();
    mock_challenge_key_ = InjectMockChallengeKey();
    test_profile_.set_profile_name(kTestUserEmail);
  }

  ~AshAttestationServiceImplTest() override {
    ash::attestation::TpmChallengeKeySubtleFactory::Create();
    ash::AttestationClient::Shutdown();
  }

  base::Value::Dict CreateSignals() {
    base::Value::Dict signals;
    signals.Set(device_signals::names::kDisplayName, kDisplayName);
    return signals;
  }

  void SetDeviceManagement(bool is_managed) {
    if (is_managed) {
      StubInstallAttributes()->SetCloudManaged("test_domain", "test_device_id");
    } else {
      StubInstallAttributes()->SetConsumerOwned();
    }
  }

  ash::StubInstallAttributes* StubInstallAttributes() {
    return test_profile_.ScopedCrosSettingsTestHelper()->InstallAttributes();
  }

  void VerifyAttestationFlowSuccessful(
      const AttestationResponse& attestation_response) {
    auto challenge_response = attestation_response.challenge_response;
    ASSERT_FALSE(challenge_response.empty());
    EXPECT_EQ(attestation_response.result_code, DTAttestationResult::kSuccess);
    auto parsed_value = ParseValueFromResponse(challenge_response);
    ASSERT_TRUE(parsed_value.has_value());
    EXPECT_EQ(kFakeResponse, parsed_value.value());
  }

  content::BrowserTaskEnvironment task_environment_;

  TestingProfile test_profile_;
  raw_ptr<ash::attestation::MockTpmChallengeKeySubtle, DanglingUntriaged>
      mock_challenge_key_subtle_;
  raw_ptr<ash::attestation::MockTpmChallengeKey, DanglingUntriaged>
      mock_challenge_key_;
  std::set<enterprise_connectors::DTCPolicyLevel> levels_;
};

TEST_F(AshAttestationServiceImplTest,
       BuildChallengeResponse_DeviceManagedSuccess) {
  SetDeviceManagement(true);

  auto attestation_service =
      std::make_unique<AshAttestationServiceImpl>(&test_profile_);

  auto protoChallenge = GetSerializedSignedChallenge();
  EXPECT_CALL(*mock_challenge_key_,
              BuildResponse(::attestation::ENTERPRISE_MACHINE,
                            /*profile=*/&test_profile_, /*callback=*/_,
                            /*challenge=*/protoChallenge,
                            /*register_key=*/false,
                            /*key_crypto_type=*/::attestation::KEY_TYPE_RSA,
                            /*key_name=*/std::string(),
                            /*signals=*/_))
      .WillOnce(RunOnceCallback<2>(
          ash::attestation::TpmChallengeKeyResult::MakeChallengeResponse(
              kFakeResponse)));

  base::test::TestFuture<const AttestationResponse&> future;
  attestation_service->BuildChallengeResponseForVAChallenge(
      protoChallenge, CreateSignals(), levels_, future.GetCallback());

  VerifyAttestationFlowSuccessful(future.Get());
}

TEST_F(AshAttestationServiceImplTest,
       BuildChallengeResponse_DeviceUnmanagedSuccess) {
  SetDeviceManagement(false);

  auto attestation_service =
      std::make_unique<AshAttestationServiceImpl>(&test_profile_);

  auto protoChallenge = GetSerializedSignedChallenge();
  EXPECT_CALL(
      *mock_challenge_key_,
      BuildResponse(
          ::attestation::DEVICE_TRUST_CONNECTOR,
          /*profile=*/&test_profile_, /*callback=*/_,
          /*challenge=*/protoChallenge,
          /*register_key=*/false,
          /*key_crypto_type=*/::attestation::KEY_TYPE_RSA,
          /*key_name=*/
          std::string(ash::attestation::kDeviceTrustConnectorKeyPrefix) +
              kTestUserEmail,
          /*signals=*/_))
      .WillOnce(RunOnceCallback<2>(
          ash::attestation::TpmChallengeKeyResult::MakeChallengeResponse(
              kFakeResponse)));

  base::test::TestFuture<const AttestationResponse&> future;
  attestation_service->BuildChallengeResponseForVAChallenge(
      protoChallenge, CreateSignals(), levels_, future.GetCallback());

  VerifyAttestationFlowSuccessful(future.Get());
}

TEST_F(AshAttestationServiceImplTest, TryPrepareKey_DeviceManagedSkip) {
  SetDeviceManagement(true);

  EXPECT_CALL(*mock_challenge_key_subtle_,
              StartPrepareKeyStep(_, _, _, _, _, _, _))
      .Times(0);

  auto attestation_service =
      std::make_unique<AshAttestationServiceImpl>(&test_profile_);
  attestation_service->TryPrepareKey();
}

TEST_F(AshAttestationServiceImplTest, TryPrepareKey_DeviceUnmanagedSuccess) {
  SetDeviceManagement(false);

  EXPECT_CALL(
      *mock_challenge_key_subtle_,
      StartPrepareKeyStep(
          ::attestation::DEVICE_TRUST_CONNECTOR, false,
          ::attestation::KEY_TYPE_RSA,
          std::string(ash::attestation::kDeviceTrustConnectorKeyPrefix) +
              kTestUserEmail,
          &test_profile_, _, _))
      .Times(1);

  auto attestation_service =
      std::make_unique<AshAttestationServiceImpl>(&test_profile_);
  attestation_service->TryPrepareKey();
}

}  // namespace enterprise_connectors