// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/device_signals/core/browser/crowdstrike_client.h"
#include <memory>
#include "base/base64url.h"
#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/location.h"
#include "base/memory/weak_ptr.h"
#include "base/notreached.h"
#include "base/sequence_checker.h"
#include "base/strings/string_split.h"
#include "base/task/bind_post_task.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/thread_pool.h"
#include "build/build_config.h"
#include "components/device_signals/core/browser/metrics_utils.h"
#include "components/device_signals/core/browser/signals_types.h"
#include "components/device_signals/core/common/cached_signal.h"
#include "components/device_signals/core/common/common_types.h"
#include "components/device_signals/core/common/platform_utils.h"
#include "components/device_signals/core/common/signals_constants.h"
#include "services/data_decoder/public/cpp/data_decoder.h"
namespace device_signals {
using SignalsCallback =
base::OnceCallback<void(std::optional<CrowdStrikeSignals>,
std::optional<SignalCollectionError>)>;
namespace {
constexpr int kCacheExpiryInHours = 1;
constexpr size_t kMaxZtaFileSize = 32 * 1024;
constexpr char kAgentIdJwtPropertyKey[] = "sub";
constexpr char kCustomerIdJwtPropertyKey[] = "cid";
// Core logic of getting the CrowdStrike agent information. Extracted into
// a function in the anonymous namespace to have it run in a background
// thread. `zta_file_path` points to the data.zta file. `json_decode_callback`
// can be used to decode JSON values out-of-process and then lead into invoking
// the final callback. `results_callback` is the final callback which ultimately
// returns the collected signals to the caller.
void GetZtaJwtPayload(
const base::FilePath& zta_file_path,
base::OnceCallback<void(const std::string&, SignalsCallback)>
json_decode_callback,
SignalsCallback results_callback) {
if (!base::PathExists(zta_file_path)) {
// Not finding a file is a supported use-case (not an error).
std::move(results_callback).Run(std::nullopt, std::nullopt);
return;
}
std::string file_content;
if (!base::ReadFileToStringWithMaxSize(zta_file_path, &file_content,
kMaxZtaFileSize)) {
LogCrowdStrikeParsingError(SignalsParsingError::kHitMaxDataSize);
std::move(results_callback)
.Run(std::nullopt, SignalCollectionError::kParsingFailed);
return;
}
if (file_content.empty()) {
// Having an empty file is a supported use-case (not an error).
std::move(results_callback).Run(std::nullopt, std::nullopt);
return;
}
// A valid ZTA file represents a JWT. For parsing out the identifiers, only
// the payload section is relevant. More information on JWTs here:
// https://en.wikipedia.org/wiki/JSON_Web_Token
std::vector<std::string> jwt_sections = base::SplitString(
file_content, ".", base::KEEP_WHITESPACE, base::SPLIT_WANT_ALL);
if (jwt_sections.size() != 3) {
// A JWT payload must have three sections.
LogCrowdStrikeParsingError(SignalsParsingError::kDataMalformed);
std::move(results_callback)
.Run(std::nullopt, SignalCollectionError::kUnexpectedValue);
return;
}
std::string json_payload;
if (!base::Base64UrlDecode(jwt_sections[1],
base::Base64UrlDecodePolicy::IGNORE_PADDING,
&json_payload)) {
LogCrowdStrikeParsingError(SignalsParsingError::kBase64DecodingFailed);
std::move(results_callback)
.Run(std::nullopt, SignalCollectionError::kParsingFailed);
return;
}
std::move(json_decode_callback)
.Run(json_payload, std::move(results_callback));
}
void OnStaticSignalsRetrieved(SignalsCallback callback,
std::optional<SignalCollectionError> error,
std::optional<CrowdStrikeSignals> signals) {
// Forward the unexpected `error` to make sure it is captured in the metrics.
std::move(callback).Run(signals, error);
}
} // namespace
class CrowdStrikeClientImpl : public CrowdStrikeClient {
public:
explicit CrowdStrikeClientImpl(const base::FilePath& zta_file_path);
~CrowdStrikeClientImpl() override;
// CrowdStrikeClient:
void GetIdentifiers(SignalsCallback callback) override;
private:
// Delegated the JSON decoding of `json_content` to a out-of-process utility.
// Will invoke OnPayloadParsed with the result, while forwarding `callback`.
void DecodeJson(const std::string& json_content, SignalsCallback callback);
// Invoked after decoding some JSON content with `result`. That result is
// then parsed for the required signals. Then, `callback` is invoked with
// any signals that were found.
void OnPayloadParsed(SignalsCallback callback,
data_decoder::DataDecoder::ValueOrError result);
// Final function to be called in this flow with `signals` containing any
// value that was successfully found. This function will set the cache and
// then invoke the original caller's `callback`.
void OnSignalsRetrieved(SignalsCallback callback,
std::optional<CrowdStrikeSignals> signals,
std::optional<SignalCollectionError> error);
SEQUENCE_CHECKER(sequence_checker_);
const base::FilePath zta_file_path_;
data_decoder::DataDecoder data_decoder_;
CachedSignal<CrowdStrikeSignals> cached_signals_;
base::WeakPtrFactory<CrowdStrikeClientImpl> weak_ptr_factory_{this};
};
// static
std::unique_ptr<CrowdStrikeClient> CrowdStrikeClient::Create() {
#if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_MAC)
return std::make_unique<CrowdStrikeClientImpl>(GetCrowdStrikeZtaFilePath());
#else
NOTREACHED_IN_MIGRATION();
return nullptr;
#endif // BUILDFLAG(IS_WIN) || BUILDFLAG(IS_MAC)
}
std::unique_ptr<CrowdStrikeClient> CrowdStrikeClient::CreateForTesting(
const base::FilePath& zta_file_path) {
return std::make_unique<CrowdStrikeClientImpl>(zta_file_path);
}
CrowdStrikeClientImpl::CrowdStrikeClientImpl(
const base::FilePath& zta_file_path)
: zta_file_path_(zta_file_path),
cached_signals_(base::Hours(kCacheExpiryInHours)) {}
CrowdStrikeClientImpl::~CrowdStrikeClientImpl() = default;
void CrowdStrikeClientImpl::GetIdentifiers(SignalsCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
const auto& cached_values = cached_signals_.Get();
if (cached_values) {
std::move(callback).Run(cached_values.value(), /*error=*/std::nullopt);
return;
}
base::OnceCallback<void(const std::string&, SignalsCallback)>
json_decode_callback = base::BindPostTaskToCurrentDefault(base::BindOnce(
&CrowdStrikeClientImpl::DecodeJson, weak_ptr_factory_.GetWeakPtr()));
SignalsCallback result_callback = base::BindPostTaskToCurrentDefault(
base::BindOnce(&CrowdStrikeClientImpl::OnSignalsRetrieved,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
base::ThreadPool::PostTask(
FROM_HERE,
{base::MayBlock(), base::TaskPriority::USER_BLOCKING,
base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN},
base::BindOnce(&GetZtaJwtPayload, zta_file_path_,
std::move(json_decode_callback),
std::move(result_callback)));
}
void CrowdStrikeClientImpl::DecodeJson(const std::string& json_content,
SignalsCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
// Parse the JSON content in a child process.
data_decoder_.ParseJson(
json_content,
base::BindOnce(&CrowdStrikeClientImpl::OnPayloadParsed,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}
void CrowdStrikeClientImpl::OnPayloadParsed(
SignalsCallback callback,
data_decoder::DataDecoder::ValueOrError result) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!result.has_value()) {
LogCrowdStrikeParsingError(SignalsParsingError::kJsonParsingFailed);
std::move(callback).Run(std::nullopt,
SignalCollectionError::kParsingFailed);
return;
}
const base::Value::Dict& result_dict = result->GetDict();
const std::string* agent_id = result_dict.FindString(kAgentIdJwtPropertyKey);
if (!agent_id) {
LogCrowdStrikeParsingError(SignalsParsingError::kMissingRequiredProperty);
std::move(callback).Run(std::nullopt,
SignalCollectionError::kParsingFailed);
return;
}
CrowdStrikeSignals identifiers;
identifiers.agent_id = *agent_id;
const std::string* customer_id =
result_dict.FindString(kCustomerIdJwtPropertyKey);
if (customer_id) {
identifiers.customer_id = *customer_id;
}
std::move(callback).Run(identifiers, /*error=*/std::nullopt);
}
void CrowdStrikeClientImpl::OnSignalsRetrieved(
SignalsCallback callback,
std::optional<CrowdStrikeSignals> signals,
std::optional<SignalCollectionError> error) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!signals) {
// If signals could not be retrieved via the ZTA file, then fallback to
// some other platform-specific mechanism. However, do not cache that
// value as it is inexpensive to retrieve and the ZTA file is preferred.
base::ThreadPool::PostTaskAndReplyWithResult(
FROM_HERE,
{base::MayBlock(), base::TaskPriority::USER_BLOCKING,
base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN},
base::BindOnce(&GetCrowdStrikeSignals),
base::BindOnce(&OnStaticSignalsRetrieved, std::move(callback),
std::move(error)));
return;
}
cached_signals_.Set(signals.value());
std::move(callback).Run(std::move(signals), error);
}
} // namespace device_signals