chromium/components/device_signals/core/browser/crowdstrike_client_unittest.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/device_signals/core/browser/crowdstrike_client.h"

#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/files/scoped_temp_dir.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "build/build_config.h"
#include "components/device_signals/core/browser/metrics_utils.h"
#include "components/device_signals/core/browser/signals_types.h"
#include "components/device_signals/core/common/common_types.h"
#include "services/data_decoder/public/cpp/test_support/in_process_data_decoder.h"
#include "testing/gtest/include/gtest/gtest.h"

#if BUILDFLAG(IS_WIN)
#include "base/strings/string_number_conversions.h"
#include "base/test/test_reg_util_win.h"
#endif  // BUILDFLAG(IS_WIN)

namespace device_signals {

namespace {

constexpr base::FilePath::CharType kFakeFileName[] =
    FILE_PATH_LITERAL("data.zta");

constexpr char kValidFakeJwtZtaContent[] =
    "eyJhbGciOiJSUzI1NiIsImtpZCI6InYxIiwidHlwIjoiSldUIn0."
    "eyJhc3Nlc3NtZW50Ijp7Im92ZXJhbGwiOjU1LCJvcyI6NTAsInNlbnNvcl9jb25maWciOjYwLC"
    "J2ZXJzaW9uIjoiIn0sImV4cCI6MTYxMjg5NTk2NywiaWF0IjoxNjEyODkyMzY3LCJwbGF0Zm9y"
    "bSI6IldpbmRvd3MgMTAiLCJzZXJpYWxfbnVtYmVyIjoic2VyaWFsTnVtYmVyMTIzIiwic3ViIj"
    "oiYmVlZmJlZWZiZWVmYmVlZmJlZWZiZWVmYmVlZjExMTEiLCJjaWQiOiJhYmNkZWYxMjM0NTY3"
    "ODkiLCJ0eXAiOiJjcm93ZHN0cmlrZS16dGErand0In0."
    "Kn8vntxGzDH9D97eIp2JqchPUsrom4qudliIhRlpn1RpjC5ILX2u9hLqdR6yKSVh9VtA2QWx"
    "7EdB_1JeVZFb37sE9wDIq6vENctqH-CcCr2CK-4d8JeeHO_KEOK6xhWSRCD66f-"
    "ZYzYKG5xNlHToth4ef1lZqZyoJ3lLS0qv3uliAP_28c-stxioRUelh7p8lRIUQnS22b0Ud_"
    "LGKB0H7juFx9jdPSFBo31R63MvELMmneltQmjBrj5TKgG30NwAa_OKNsShlgM9kZQes-"
    "ms2RpfEq5UpJ5teDTdqpXtLUEwB7ROkfDhz6nhPHyfUh_S6ummIe_"
    "qLGYVq4dxDtSYnXFt5Etnip1KHBK5RXOBiFV11NahPFWRRd45CoX5mrD9PgL0JxtJLetShNT"
    "-nFktKEIbWtWX3OTiJn7SKnatGB-YRTKkTy0-"
    "2BlTITPM4Uqj3OVTbKimRYJ2bzdzyTc4Ls6FPih6I-"
    "j1KH1SKO80FyKXTUIbeYSGO3t3PcsVgUzVNUUXYdpwn7zHBEivuVgGw2hftdokK9ocx42Sad"
    "Pz_HnIvpt4JGXGOJMsemp4FeCT56hNKuCInN_zsFVe2O6xbZwU_8DTfIsfWNgErCroYr-"
    "Z6NSO6O6xaWojTiEsDSHFQ3lkpccscRZDz0rCluR-2xWUDWkrHht4FGRyCQz4NaM";

constexpr char kExpectedAgentId[] = "beefbeefbeefbeefbeefbeefbeef1111";
constexpr char kExpectedCustomerId[] = "abcdef123456789";

#if BUILDFLAG(IS_WIN)
constexpr wchar_t kCSAgentRegPath[] =
    L"SYSTEM\\CurrentControlSet\\services\\CSAgent\\Sim";

constexpr char kFakeHexCSCustomerId[] = "CABCDEF1234ABCD1234D";
constexpr char kFakeHexCSAgentId[] = "ADEBCA432156ABDC";

// CU is the registry value containing the customer ID.
constexpr wchar_t kCSCURegKey[] = L"CU";

// AG is the registry value containing the agent ID.
constexpr wchar_t kCSAGRegKey[] = L"AG";

void CreateRegistryKey() {
  base::win::RegKey key;
  LONG res = key.Create(HKEY_LOCAL_MACHINE, kCSAgentRegPath, KEY_WRITE);
  ASSERT_EQ(res, ERROR_SUCCESS);
}

void DeleteRegistryKey() {
  base::win::RegKey key(HKEY_LOCAL_MACHINE);
  LONG res = key.DeleteKey(kCSAgentRegPath);
  ASSERT_EQ(res, ERROR_SUCCESS);
}

void SetUpCrowdStrikeInfo(const std::optional<std::string>& customer_id,
                          const std::optional<std::string>& agent_id) {
  CreateRegistryKey();

  base::win::RegKey key;
  LONG res = key.Open(HKEY_LOCAL_MACHINE, kCSAgentRegPath, KEY_WRITE);
  ASSERT_EQ(res, ERROR_SUCCESS);

  if (customer_id) {
    // Have to Hex-decode the values before storing them.
    std::string decoded_customer_id;
    ASSERT_TRUE(base::HexStringToString(customer_id.value().c_str(),
                                        &decoded_customer_id));
    res = key.WriteValue(kCSCURegKey, decoded_customer_id.data(),
                         decoded_customer_id.size(), REG_BINARY);
    ASSERT_EQ(res, ERROR_SUCCESS);
  }

  if (agent_id) {
    // Have to Hex-decode the values before storing them.
    std::string decoded_agent_id;
    ASSERT_TRUE(
        base::HexStringToString(agent_id.value().c_str(), &decoded_agent_id));
    res = key.WriteValue(kCSAGRegKey, decoded_agent_id.data(),
                         decoded_agent_id.size(), REG_BINARY);
    ASSERT_EQ(res, ERROR_SUCCESS);
  }
}

#endif  // BUILDFLAG(IS_WIN)

}  // namespace

class CrowdStrikeClientTest : public testing::Test {
 protected:
  void SetUp() override {
#if BUILDFLAG(IS_WIN)
    registry_override_manager_.OverrideRegistry(HKEY_LOCAL_MACHINE);
#endif

    ASSERT_TRUE(scoped_temp_dir_.CreateUniqueTempDir());

    client_ = CrowdStrikeClient::CreateForTesting(GetDataFilePath());
  }

  void CreateFakeFileWithContent(const std::string& file_content) {
    ASSERT_TRUE(base::WriteFile(GetDataFilePath(), file_content));
  }

  void DeleteFakeFile() { ASSERT_TRUE(base::DeleteFile(GetDataFilePath())); }

  base::FilePath GetDataFilePath() {
    return scoped_temp_dir_.GetPath().Append(kFakeFileName);
  }

  std::optional<CrowdStrikeSignals> GetSignals(
      std::optional<SignalCollectionError> expected_error = std::nullopt) {
    base::test::TestFuture<std::optional<CrowdStrikeSignals>,
                           std::optional<SignalCollectionError>>
        future;
    client_->GetIdentifiers(future.GetCallback());

    // Should not have an error if signals are expected to be returned.
    if (expected_error) {
      EXPECT_EQ(expected_error, future.Get<1>());
    } else {
      EXPECT_FALSE(future.Get<1>());
    }

    return future.Get<0>();
  }

  std::optional<SignalCollectionError> GetSignalCollectionError() {
    base::test::TestFuture<std::optional<CrowdStrikeSignals>,
                           std::optional<SignalCollectionError>>
        future;
    client_->GetIdentifiers(future.GetCallback());

    // Should not have signals if an error is expected to be returned.
    EXPECT_FALSE(future.Get<0>());

    return future.Get<1>();
  }

  void ValidateHistogram(std::optional<SignalsParsingError> error) {
    static constexpr char kCrowdStrikeErrorHistogram[] =
        "Enterprise.DeviceSignals.Collection.CrowdStrike.Error";
    if (error) {
      histogram_tester_.ExpectUniqueSample(kCrowdStrikeErrorHistogram,
                                           error.value(), 1);
    } else {
      histogram_tester_.ExpectTotalCount(kCrowdStrikeErrorHistogram, 0);
    }
  }

  base::test::TaskEnvironment task_environment_{
      base::test::TaskEnvironment::TimeSource::MOCK_TIME};
  base::ScopedTempDir scoped_temp_dir_;
  data_decoder::test::InProcessDataDecoder in_process_data_decoder_;
  base::HistogramTester histogram_tester_;

#if BUILDFLAG(IS_WIN)
  registry_util::RegistryOverrideManager registry_override_manager_;
#endif

  std::unique_ptr<CrowdStrikeClient> client_;
};

TEST_F(CrowdStrikeClientTest, Identifiers_NoFile) {
  // Expect no signals and no error.
  EXPECT_FALSE(GetSignalCollectionError());

  // No value logged, not having the file available is not considered a failure.
  ValidateHistogram(std::nullopt);
}

TEST_F(CrowdStrikeClientTest, Identifiers_EmptyFile) {
  CreateFakeFileWithContent("");

  // Expect no signals and no error.
  EXPECT_FALSE(GetSignalCollectionError());

  // No value logged, having an empty file is not considered a failure.
  ValidateHistogram(std::nullopt);
}

TEST_F(CrowdStrikeClientTest, Identifiers_NotJwt) {
  CreateFakeFileWithContent("some.random.content");

  const auto& error = GetSignalCollectionError();
  ASSERT_TRUE(error);
  EXPECT_EQ(error.value(), SignalCollectionError::kParsingFailed);

  ValidateHistogram(SignalsParsingError::kJsonParsingFailed);
}

TEST_F(CrowdStrikeClientTest, Identifiers_MaxDataSize) {
  std::string content(33 * 1024, 'a');
  CreateFakeFileWithContent(content);

  const auto& error = GetSignalCollectionError();
  ASSERT_TRUE(error);
  EXPECT_EQ(error.value(), SignalCollectionError::kParsingFailed);

  ValidateHistogram(SignalsParsingError::kHitMaxDataSize);
}

TEST_F(CrowdStrikeClientTest, Identifiers_DecodingFailed) {
  CreateFakeFileWithContent("some.random%%.content");

  const auto& error = GetSignalCollectionError();
  ASSERT_TRUE(error);
  EXPECT_EQ(error.value(), SignalCollectionError::kParsingFailed);

  ValidateHistogram(SignalsParsingError::kBase64DecodingFailed);
}

TEST_F(CrowdStrikeClientTest, Identifiers_MissingJwtSection) {
  constexpr char kFakeJwtZtaContent[] =
      "eyJhbGciOiJSUzI1NiIsImtpZCI6InYxIiwidHlwIjoiSldUIn0."
      "eyJhc3Nlc3NtZW50Ijp7Im92ZXJhbGwiOjU1LCJvcyI6NTAsInNlbnNvcl9jb25maWciOjYw"
      "LCJ2ZXJzaW9uIjoiIn0sImV4cCI6MTYxMjg5NTk2NywiaWF0IjoxNjEyODkyMzY3LCJwbGF0"
      "Zm9ybSI6IldpbmRvd3MgMTAiLCJzZXJpYWxfbnVtYmVyIjoic2VyaWFsTnVtYmVyMTIzIiwi"
      "c3ViIjoiYmVlZmJlZWZiZWVmYmVlZmJlZWZiZWVmYmVlZjExMTEiLCJ0eXAiOiJjcm93ZHN0"
      "cmlrZS16dGErand0In0";
  CreateFakeFileWithContent(kFakeJwtZtaContent);

  const auto& error = GetSignalCollectionError();
  ASSERT_TRUE(error);
  EXPECT_EQ(error.value(), SignalCollectionError::kUnexpectedValue);

  ValidateHistogram(SignalsParsingError::kDataMalformed);
}

TEST_F(CrowdStrikeClientTest, Identifiers_MissingSub) {
  // JWT value where `sub` is missing from the payload.
  static constexpr char kFakeJwtZtaContent[] =
      "eyJhbGciOiJSUzI1NiIsImtpZCI6InYxIiwidHlwIjoiSldUIn0."
      "eyJhc3Nlc3NtZW50Ijp7Im92ZXJhbGwiOjU1LCJvcyI6NTAsInNlbnNvcl9jb25maWciOjYw"
      "LCJ2ZXJzaW9uIjoiIn0sImV4cCI6MTYxMjg5NTk2NywiaWF0IjoxNjEyODkyMzY3LCJwbGF0"
      "Zm9ybSI6IldpbmRvd3MgMTAiLCJzZXJpYWxfbnVtYmVyIjoic2VyaWFsTnVtYmVyMTIzIiwi"
      "dHlwIjoiY3Jvd2RzdHJpa2UtenRhK2p3dCJ9."
      "Kn8vntxGzDH9D97eIp2JqchPUsrom4qudliIhRlpn1RpjC5ILX2u9hLqdR6yKSVh9VtA2QWx"
      "7EdB_1JeVZFb37sE9wDIq6vENctqH-CcCr2CK-4d8JeeHO_KEOK6xhWSRCD66f-"
      "ZYzYKG5xNlHToth4ef1lZqZyoJ3lLS0qv3uliAP_28c-stxioRUelh7p8lRIUQnS22b0Ud_"
      "LGKB0H7juFx9jdPSFBo31R63MvELMmneltQmjBrj5TKgG30NwAa_OKNsShlgM9kZQes-"
      "ms2RpfEq5UpJ5teDTdqpXtLUEwB7ROkfDhz6nhPHyfUh_S6ummIe_"
      "qLGYVq4dxDtSYnXFt5Etnip1KHBK5RXOBiFV11NahPFWRRd45CoX5mrD9PgL0JxtJLetShNT"
      "-nFktKEIbWtWX3OTiJn7SKnatGB-YRTKkTy0-"
      "2BlTITPM4Uqj3OVTbKimRYJ2bzdzyTc4Ls6FPih6I-"
      "j1KH1SKO80FyKXTUIbeYSGO3t3PcsVgUzVNUUXYdpwn7zHBEivuVgGw2hftdokK9ocx42Sad"
      "Pz_HnIvpt4JGXGOJMsemp4FeCT56hNKuCInN_zsFVe2O6xbZwU_8DTfIsfWNgErCroYr-"
      "Z6NSO6O6xaWojTiEsDSHFQ3lkpccscRZDz0rCluR-2xWUDWkrHht4FGRyCQz4NaM";
  CreateFakeFileWithContent(kFakeJwtZtaContent);

  const auto& error = GetSignalCollectionError();
  ASSERT_TRUE(error);
  EXPECT_EQ(error.value(), SignalCollectionError::kParsingFailed);

  ValidateHistogram(SignalsParsingError::kMissingRequiredProperty);
}

TEST_F(CrowdStrikeClientTest, Identifiers_Success) {
  CreateFakeFileWithContent(kValidFakeJwtZtaContent);
  auto signals = GetSignals();

  ASSERT_TRUE(signals);
  EXPECT_EQ(signals->agent_id, kExpectedAgentId);
  EXPECT_EQ(signals->customer_id, kExpectedCustomerId);

  ValidateHistogram(std::nullopt);
}

TEST_F(CrowdStrikeClientTest, Identifiers_Success_CachedValue) {
  CreateFakeFileWithContent(kValidFakeJwtZtaContent);
  auto signals = GetSignals();

  ASSERT_TRUE(signals);
  EXPECT_EQ(signals->agent_id, kExpectedAgentId);
  EXPECT_EQ(signals->customer_id, kExpectedCustomerId);

  DeleteFakeFile();

  signals = GetSignals();

  ASSERT_TRUE(signals);
  EXPECT_EQ(signals->agent_id, kExpectedAgentId);
  EXPECT_EQ(signals->customer_id, kExpectedCustomerId);

  // Going beyond cache expiry without the data file should make
  // the client return no value.
  static constexpr int kBeyondCacheExpiryInHours = 2;
  task_environment_.FastForwardBy(base::Hours(kBeyondCacheExpiryInHours));

  EXPECT_FALSE(GetSignals());
}

#if BUILDFLAG(IS_WIN)

// Tests that only having the customer ID in the registry is treated
// as insufficient, and no value is returned.
TEST_F(CrowdStrikeClientTest, Identifiers_NoFile_RegistryNoAgentId) {
  SetUpCrowdStrikeInfo(kFakeHexCSCustomerId, std::nullopt);

  auto signals = GetSignals();

  ASSERT_TRUE(signals);
  EXPECT_EQ(signals->customer_id, base::ToLowerASCII(kFakeHexCSCustomerId));
  EXPECT_TRUE(signals->agent_id.empty());
}

TEST_F(CrowdStrikeClientTest, Identifiers_NoFile_RegistryNoCustomerId) {
  SetUpCrowdStrikeInfo(std::nullopt, kFakeHexCSAgentId);

  auto signals = GetSignals();

  ASSERT_TRUE(signals);
  EXPECT_EQ(signals->agent_id, base::ToLowerASCII(kFakeHexCSAgentId));
  EXPECT_TRUE(signals->customer_id.empty());

  DeleteRegistryKey();

  // Expect the value to not have been cached.
  EXPECT_FALSE(GetSignals());
}

TEST_F(CrowdStrikeClientTest, Identifiers_FileHasPrecendence) {
  SetUpCrowdStrikeInfo(kFakeHexCSCustomerId, kFakeHexCSAgentId);

  CreateFakeFileWithContent(kValidFakeJwtZtaContent);

  auto signals = GetSignals();

  ASSERT_TRUE(signals);
  EXPECT_EQ(signals->agent_id, kExpectedAgentId);
  EXPECT_EQ(signals->customer_id, kExpectedCustomerId);
}

TEST_F(CrowdStrikeClientTest, Identifiers_DecodingFailed_RegistryFallback) {
  CreateFakeFileWithContent("some.random%%.content");
  SetUpCrowdStrikeInfo(kFakeHexCSCustomerId, kFakeHexCSAgentId);

  auto signals =
      GetSignals(/*expected_error=*/SignalCollectionError::kParsingFailed);

  ASSERT_TRUE(signals);
  EXPECT_EQ(signals->customer_id, base::ToLowerASCII(kFakeHexCSCustomerId));
  EXPECT_EQ(signals->agent_id, base::ToLowerASCII(kFakeHexCSAgentId));
  ValidateHistogram(SignalsParsingError::kBase64DecodingFailed);
}

#endif  // BUILDFLAG(IS_WIN)

}  // namespace device_signals