// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/device_bound_sessions/registration_fetcher.h"
#include <utility>
#include "components/unexportable_keys/background_task_priority.h"
#include "components/unexportable_keys/unexportable_key_service.h"
#include "net/base/io_buffer.h"
#include "net/device_bound_sessions/session_binding_utils.h"
#include "net/device_bound_sessions/session_json_utils.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
#include "net/url_request/url_request_context.h"
namespace net::device_bound_sessions {
namespace {
constexpr char kJwtSessionHeaderName[] = "Sec-Session-Response";
constexpr net::NetworkTrafficAnnotationTag kRegistrationTrafficAnnotation =
net::DefineNetworkTrafficAnnotation("dbsc_registration", R"(
semantics {
sender: "Device Bound Session Credentials API"
description:
"Device Bound Session Credentials (DBSC) let a server create a "
"session with the local device. For more info see "
"https://github.com/WICG/dbsc."
trigger:
"Server sending a response with a Sec-Session-Registration header."
data: "A signed JWT with the new key created for this session."
destination: WEBSITE
last_reviewed: "2024-04-10"
user_data {
type: ACCESS_TOKEN
}
internal {
contacts {
email: "[email protected]"
}
contacts {
email: "[email protected]"
}
}
}
policy {
cookies_allowed: YES
cookies_store: "user"
setting: "There is no seperate setting for this feature, but it will "
"follow the cookie settings."
policy_exception_justification: "Not implemented."
})");
constexpr int kBufferSize = 4096;
// A server will provide a list of acceptable algorithms in the future.
constexpr crypto::SignatureVerifier::SignatureAlgorithm
kAcceptableAlgorithms[] = {crypto::SignatureVerifier::ECDSA_SHA256,
crypto::SignatureVerifier::RSA_PKCS1_SHA256};
// New session registration doesn't block the user and can be done with a delay.
constexpr unexportable_keys::BackgroundTaskPriority kTaskPriority =
unexportable_keys::BackgroundTaskPriority::kBestEffort;
void OnDataSigned(
crypto::SignatureVerifier::SignatureAlgorithm algorithm,
unexportable_keys::UnexportableKeyService& unexportable_key_service,
std::string header_and_payload,
unexportable_keys::UnexportableKeyId key_id,
base::OnceCallback<void(
std::optional<RegistrationFetcher::RegistrationTokenResult>)> callback,
unexportable_keys::ServiceErrorOr<std::vector<uint8_t>> result) {
if (!result.has_value()) {
std::move(callback).Run(std::nullopt);
return;
}
const std::vector<uint8_t>& signature = result.value();
std::optional<std::string> registration_token =
AppendSignatureToHeaderAndPayload(header_and_payload, algorithm,
signature);
if (!registration_token.has_value()) {
std::move(callback).Run(std::nullopt);
return;
}
std::move(callback).Run(RegistrationFetcher::RegistrationTokenResult(
registration_token.value(), key_id));
}
void OnKeyGenerated(
unexportable_keys::UnexportableKeyService& unexportable_key_service,
std::string_view challenge,
const GURL& registration_url,
base::OnceCallback<void(
std::optional<RegistrationFetcher::RegistrationTokenResult>)> callback,
unexportable_keys::ServiceErrorOr<unexportable_keys::UnexportableKeyId>
result) {
if (!result.has_value()) {
std::move(callback).Run(std::nullopt);
return;
}
unexportable_keys::UnexportableKeyId key_id = result.value();
auto expected_algorithm = unexportable_key_service.GetAlgorithm(key_id);
auto expected_public_key =
unexportable_key_service.GetSubjectPublicKeyInfo(key_id);
if (!expected_algorithm.has_value() || !expected_public_key.has_value()) {
std::move(callback).Run(std::nullopt);
return;
}
std::optional<std::string> optional_header_and_payload =
CreateKeyRegistrationHeaderAndPayload(
challenge, registration_url, expected_algorithm.value(),
expected_public_key.value(), base::Time::Now());
if (!optional_header_and_payload.has_value()) {
std::move(callback).Run(std::nullopt);
return;
}
std::string header_and_payload = optional_header_and_payload.value();
unexportable_key_service.SignSlowlyAsync(
key_id, base::as_bytes(base::make_span(header_and_payload)),
kTaskPriority,
base::BindOnce(&OnDataSigned, expected_algorithm.value(),
std::ref(unexportable_key_service), header_and_payload,
key_id, std::move(callback)));
}
void CreateTokenAsync(
unexportable_keys::UnexportableKeyService& unexportable_key_service,
std::string challenge,
const GURL& registration_url,
base::OnceCallback<
void(std::optional<RegistrationFetcher::RegistrationTokenResult>)>
callback) {
unexportable_key_service.GenerateSigningKeySlowlyAsync(
kAcceptableAlgorithms, kTaskPriority,
base::BindOnce(&OnKeyGenerated, std::ref(unexportable_key_service),
challenge, registration_url, std::move(callback)));
}
class RegistrationFetcherImpl : public URLRequest::Delegate {
public:
// URLRequest::Delegate
void OnReceivedRedirect(URLRequest* request,
const RedirectInfo& redirect_info,
bool* defer_redirect) override {
if (!redirect_info.new_url.SchemeIsCryptographic()) {
request->Cancel();
OnResponseCompleted();
// *this is deleted here
}
}
// TODO(kristianm): Look into if OnAuthRequired might need to be customize
// for DBSC
// TODO(kristianm): Think about what to do for DBSC with
// OnCertificateRequested, leaning towards not supporting it but not sure.
// Always cancel requests on SSL errors, this is the default implementation
// of OnSSLCertificateError.
// This is always called unless the request is deleted before it is called.
void OnResponseStarted(URLRequest* request, int net_error) override {
if (net_error != OK) {
OnResponseCompleted();
// *this is deleted here
return;
}
HttpResponseHeaders* headers = request->response_headers();
int response_code = headers ? headers->response_code() : 0;
if (response_code < 200 || response_code >= 300) {
OnResponseCompleted();
// *this is deleted here
return;
}
// Initiate the first read.
int bytes_read = request->Read(buf_.get(), kBufferSize);
if (bytes_read >= 0) {
OnReadCompleted(request, bytes_read);
} else if (bytes_read != ERR_IO_PENDING) {
OnResponseCompleted();
// *this is deleted here
}
}
void OnReadCompleted(URLRequest* request, int bytes_read) override {
data_received_.append(buf_->data(), bytes_read);
while (bytes_read > 0) {
bytes_read = request->Read(buf_.get(), kBufferSize);
if (bytes_read > 0) {
data_received_.append(buf_->data(), bytes_read);
}
}
if (bytes_read != ERR_IO_PENDING) {
OnResponseCompleted();
// *this is deleted here
}
}
RegistrationFetcherImpl(
RegistrationFetcherParam registration_params,
unexportable_keys::UnexportableKeyService& key_service,
const URLRequestContext* context,
const IsolationInfo& isolation_info,
RegistrationFetcher::RegistrationCompleteCallback callback)
: registration_params_(std::move(registration_params)),
key_service_(key_service),
context_(context),
isolation_info_(isolation_info),
callback_(std::move(callback)),
buf_(base::MakeRefCounted<IOBufferWithSize>(kBufferSize)) {}
~RegistrationFetcherImpl() override { CHECK(!callback_); }
void OnRegistrationTokenCreated(
std::optional<RegistrationFetcher::RegistrationTokenResult> result) {
if (!result) {
RunCallbackAndDeleteSelf(std::nullopt);
return;
}
key_id_ = result->key_id;
StartFetchingRegistration(result->registration_token);
}
private:
void StartFetchingRegistration(const std::string& registration_token) {
request_ =
context_->CreateRequest(registration_params_.registration_endpoint(),
IDLE, this, kRegistrationTrafficAnnotation);
request_->set_method("POST");
request_->SetLoadFlags(LOAD_DISABLE_CACHE);
request_->set_allow_credentials(true);
request_->set_site_for_cookies(isolation_info_.site_for_cookies());
// TODO(kristianm): Set initiator to the URL of the registration header
request_->set_initiator(url::Origin());
request_->set_isolation_info(isolation_info_);
request_->SetExtraRequestHeaderByName(
kJwtSessionHeaderName, registration_token, /*overwrite*/ true);
request_->Start();
}
void OnResponseCompleted() {
if (!data_received_.empty()) {
std::optional<SessionParams> params =
ParseSessionInstructionJson(data_received_);
if (params) {
RunCallbackAndDeleteSelf(
RegistrationFetcher::RegistrationCompleteParams(
std::move(*params), *key_id_, request_->url()));
} else {
RunCallbackAndDeleteSelf(std::nullopt);
}
} else {
RunCallbackAndDeleteSelf(std::nullopt);
}
// *this is deleted here
}
// Running callback when fetching is complete or on error.
// Deletes `this` afterwards.
void RunCallbackAndDeleteSelf(
std::optional<RegistrationFetcher::RegistrationCompleteParams> params) {
std::move(callback_).Run(std::move(params));
delete this;
}
// State passed in to constructor
RegistrationFetcherParam registration_params_;
const raw_ref<unexportable_keys::UnexportableKeyService> key_service_;
raw_ptr<const URLRequestContext> context_;
IsolationInfo isolation_info_;
RegistrationFetcher::RegistrationCompleteCallback callback_;
// Set during key creation, before sending request to fetch data.
// Should always be nullopt before that, and always a valid key after key
// creation.
std::optional<unexportable_keys::UnexportableKeyId> key_id_ = std::nullopt;
// Created to fetch data
std::unique_ptr<URLRequest> request_;
scoped_refptr<IOBuffer> buf_;
std::string data_received_;
};
std::optional<RegistrationFetcher::RegistrationCompleteParams> (
*g_mock_fetcher)() = nullptr;
} // namespace
void RegistrationFetcher::StartCreateTokenAndFetch(
RegistrationFetcherParam registration_params,
unexportable_keys::UnexportableKeyService& key_service,
// TODO(kristianm): Check the lifetime of context and make sure this use
// is safe.
const URLRequestContext* context,
const IsolationInfo& isolation_info,
RegistrationCompleteCallback callback) {
// Using mock fetcher for testing
if (g_mock_fetcher) {
std::move(callback).Run(g_mock_fetcher());
return;
}
GURL registration_endpoint = registration_params.registration_endpoint();
std::string challenge = registration_params.challenge();
RegistrationFetcherImpl* fetcher =
new RegistrationFetcherImpl(std::move(registration_params), key_service,
context, isolation_info, std::move(callback));
// base::Unretained() is safe because the fetcher cannot be destroyed until
// after this callback is run, as it controls its own lifetime.
CreateTokenAsync(
key_service, std::move(challenge), registration_endpoint,
base::BindOnce(&RegistrationFetcherImpl::OnRegistrationTokenCreated,
base::Unretained(fetcher)));
}
void RegistrationFetcher::SetFetcherForTesting(FetcherType func) {
if (g_mock_fetcher) {
CHECK(!func);
g_mock_fetcher = nullptr;
} else {
g_mock_fetcher = func;
}
}
void RegistrationFetcher::CreateTokenAsyncForTesting(
unexportable_keys::UnexportableKeyService& unexportable_key_service,
std::string challenge,
const GURL& registration_url,
base::OnceCallback<
void(std::optional<RegistrationFetcher::RegistrationTokenResult>)>
callback) {
CreateTokenAsync(unexportable_key_service, challenge, registration_url,
std::move(callback));
}
} // namespace net::device_bound_sessions