// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ash/cert_provisioning/cert_provisioning_serializer.h"
#include <optional>
#include <string>
#include "base/base64.h"
#include "base/logging.h"
#include "base/numerics/safe_conversions.h"
#include "base/time/time.h"
#include "chrome/browser/ash/cert_provisioning/cert_provisioning_common.h"
#include "components/prefs/pref_service.h"
#include "components/prefs/scoped_user_pref_update.h"
namespace ash {
namespace cert_provisioning {
namespace {
constexpr char kKeyNameProcessId[] = "process_id";
constexpr char kKeyNameCertScope[] = "cert_scope";
constexpr char kKeyNameCertProfile[] = "cert_profile";
constexpr char kKeyNameState[] = "state";
constexpr char kKeyNamePublicKey[] = "public_key";
constexpr char kKeyNameInvalidationTopic[] = "invalidation_topic";
constexpr char kKeyNameKeyLocation[] = "key_location";
constexpr char kKeyNameAttemptedVaChallenge[] = "attempted_va_challenge";
constexpr char kKeyNameAttemptedProofOfPossession[] =
"attempted_proof_of_possession";
constexpr char kKeyNameProofOfPossessionSignature[] =
"proof_of_possession_signature";
constexpr char kKeyNameCertProfileId[] = "profile_id";
constexpr char kKeyNameCertProfileName[] = "name";
constexpr char kKeyNameCertProfileVersion[] = "policy_version";
constexpr char kKeyNameCertProfileProtocolVersion[] = "protocol_version";
constexpr char kKeyNameCertProfileVaEnabled[] = "va_enabled";
constexpr char kKeyNameCertProfileRenewalPeriod[] = "renewal_period";
template <typename T>
bool ConvertToEnum(int value, T* dst) {
if ((value < 0) || (value > static_cast<int>(T::kMaxValue))) {
return false;
}
*dst = static_cast<T>(value);
return true;
}
template <typename T>
bool DeserializeEnumValue(const base::Value::Dict& parent_dict,
const char* value_name,
T* dst) {
std::optional<int> serialized_enum = parent_dict.FindInt(value_name);
if (!serialized_enum.has_value()) {
return false;
}
return ConvertToEnum<T>(*serialized_enum, dst);
}
bool DeserializeStringValue(const base::Value::Dict& parent_dict,
const char* value_name,
std::string* dst) {
const std::string* serialized_string = parent_dict.FindString(value_name);
if (!serialized_string) {
return false;
}
*dst = *serialized_string;
return true;
}
bool DeserializeBoolValue(const base::Value::Dict& parent_dict,
const char* value_name,
bool* dst) {
std::optional<bool> serialized_bool = parent_dict.FindBool(value_name);
if (!serialized_bool.has_value()) {
return false;
}
*dst = *serialized_bool;
return true;
}
bool DeserializeRenewalPeriod(const base::Value::Dict& parent_dict,
const char* value_name,
base::TimeDelta* dst) {
std::optional<int> serialized_time = parent_dict.FindInt(value_name);
*dst = base::Seconds(serialized_time.value_or(0));
return true;
}
bool DeserializeProtocolVersion(const base::Value::Dict& parent_value,
const char* value_name,
ProtocolVersion* dst) {
std::optional<int> protocol_version_value = parent_value.FindInt(value_name);
std::optional<ProtocolVersion> protocol_version =
ParseProtocolVersion(protocol_version_value);
if (!protocol_version.has_value()) {
return false;
}
*dst = *protocol_version;
return true;
}
base::Value::Dict SerializeCertProfile(const CertProfile& profile) {
static_assert(CertProfile::kVersion == 6, "This function should be updated");
base::Value::Dict result;
result.Set(kKeyNameCertProfileId, profile.profile_id);
result.Set(kKeyNameCertProfileName, profile.name);
result.Set(kKeyNameCertProfileVersion, profile.policy_version);
result.Set(kKeyNameCertProfileVaEnabled, profile.is_va_enabled);
if (profile.protocol_version != ProtocolVersion::kStatic) {
// Only set the protocol_version if it's not kStatic to avoid changing how
// "static flow" workers are serialized.
result.Set(kKeyNameCertProfileProtocolVersion,
static_cast<int>(profile.protocol_version));
}
if (!profile.renewal_period.is_zero()) {
result.Set(kKeyNameCertProfileRenewalPeriod,
base::saturated_cast<int>(profile.renewal_period.InSeconds()));
}
return result;
}
bool DeserializeCertProfile(const base::Value::Dict& parent_dict,
const char* value_name,
CertProfile* dst) {
static_assert(CertProfile::kVersion == 6, "This function should be updated");
const base::Value::Dict* serialized_profile =
parent_dict.FindDict(value_name);
if (!serialized_profile) {
return false;
}
bool is_ok = true;
is_ok = is_ok &&
DeserializeStringValue(*serialized_profile, kKeyNameCertProfileId,
&(dst->profile_id));
is_ok =
is_ok && DeserializeStringValue(*serialized_profile,
kKeyNameCertProfileName, &(dst->name));
is_ok = is_ok && DeserializeStringValue(*serialized_profile,
kKeyNameCertProfileVersion,
&(dst->policy_version));
is_ok = is_ok && DeserializeBoolValue(*serialized_profile,
kKeyNameCertProfileVaEnabled,
&(dst->is_va_enabled));
is_ok = is_ok && DeserializeRenewalPeriod(*serialized_profile,
kKeyNameCertProfileRenewalPeriod,
&(dst->renewal_period));
is_ok = is_ok && DeserializeProtocolVersion(
*serialized_profile, kKeyNameCertProfileProtocolVersion,
&(dst->protocol_version));
return is_ok;
}
std::string SerializeBase64Encoded(const std::vector<uint8_t>& public_key) {
return base::Base64Encode(public_key);
}
bool DeserializeBase64Encoded(const base::Value::Dict& parent_dict,
const char* value_name,
std::vector<uint8_t>* dst) {
const std::string* serialized_public_key = parent_dict.FindString(value_name);
if (!serialized_public_key) {
return false;
}
std::optional<std::vector<uint8_t>> public_key =
base::Base64Decode(*serialized_public_key);
if (!public_key) {
return false;
}
*dst = std::move(*public_key);
return true;
}
} // namespace
void CertProvisioningSerializer::SerializeWorkerToPrefs(
PrefService* pref_service,
const CertProvisioningWorkerStatic& worker) {
ScopedDictPrefUpdate scoped_dict_updater(
pref_service, GetPrefNameForSerialization(worker.cert_scope_));
base::Value::Dict& saved_workers = scoped_dict_updater.Get();
saved_workers.Set(worker.cert_profile_.profile_id, SerializeWorker(worker));
}
void CertProvisioningSerializer::SerializeWorkerToPrefs(
PrefService* pref_service,
const CertProvisioningWorkerDynamic& worker) {
ScopedDictPrefUpdate scoped_dict_updater(
pref_service, GetPrefNameForSerialization(worker.cert_scope_));
base::Value::Dict& saved_workers = scoped_dict_updater.Get();
saved_workers.Set(worker.cert_profile_.profile_id, SerializeWorker(worker));
}
void CertProvisioningSerializer::DeleteWorkerFromPrefs(
PrefService* pref_service,
const CertProvisioningWorkerStatic& worker) {
ScopedDictPrefUpdate scoped_dict_updater(
pref_service, GetPrefNameForSerialization(worker.cert_scope_));
base::Value::Dict& saved_workers = scoped_dict_updater.Get();
saved_workers.Remove(worker.cert_profile_.profile_id);
}
void CertProvisioningSerializer::DeleteWorkerFromPrefs(
PrefService* pref_service,
const CertProvisioningWorkerDynamic& worker) {
ScopedDictPrefUpdate scoped_dict_updater(
pref_service, GetPrefNameForSerialization(worker.cert_scope_));
base::Value::Dict& saved_workers = scoped_dict_updater.Get();
saved_workers.Remove(worker.cert_profile_.profile_id);
}
// Serialization scheme:
// {
// "cert_scope": <number>,
// "cert_profile": <CertProfile>,
// "state": <number>,
// "public_key": <string>,
// "invalidation_topic": <string>,
// }
base::Value::Dict CertProvisioningSerializer::SerializeWorker(
const CertProvisioningWorkerStatic& worker) {
static_assert(CertProvisioningWorkerStatic::kVersion == 2,
"This function should be updated");
base::Value::Dict result;
result.Set(kKeyNameProcessId, worker.process_id_);
result.Set(kKeyNameCertProfile, SerializeCertProfile(worker.cert_profile_));
result.Set(kKeyNameCertScope, static_cast<int>(worker.cert_scope_));
result.Set(kKeyNameState, static_cast<int>(worker.state_));
result.Set(kKeyNamePublicKey, SerializeBase64Encoded(worker.public_key_));
result.Set(kKeyNameInvalidationTopic, worker.invalidation_topic_);
return result;
}
// Serialization scheme:
// {
// "cert_scope": <number>,
// "cert_profile": <CertProfile>,
// "state": <number>,
// "public_key": <string>,
// "invalidation_topic": <string>,
// "key_location": <number>,
// "attempted_va_challenge": <bool>,
// "proof_of_possession_count": <number>,
// }
base::Value::Dict CertProvisioningSerializer::SerializeWorker(
const CertProvisioningWorkerDynamic& worker) {
static_assert(CertProvisioningWorkerDynamic::kVersion == 3,
"This function should be updated");
base::Value::Dict result;
result.Set(kKeyNameProcessId, worker.process_id_);
result.Set(kKeyNameCertProfile, SerializeCertProfile(worker.cert_profile_));
result.Set(kKeyNameCertScope, static_cast<int>(worker.cert_scope_));
result.Set(kKeyNameState, static_cast<int>(worker.state_));
result.Set(kKeyNamePublicKey, SerializeBase64Encoded(worker.public_key_));
result.Set(kKeyNameInvalidationTopic, worker.invalidation_topic_);
result.Set(kKeyNameKeyLocation, static_cast<int>(worker.key_location_));
result.Set(kKeyNameAttemptedVaChallenge, worker.attempted_va_challenge_);
result.Set(kKeyNameAttemptedProofOfPossession,
worker.attempted_proof_of_possession_);
result.Set(kKeyNameProofOfPossessionSignature,
SerializeBase64Encoded(worker.signature_));
return result;
}
bool CertProvisioningSerializer::DeserializeWorker(
const base::Value::Dict& saved_worker,
CertProvisioningWorkerStatic* worker) {
static_assert(CertProvisioningWorkerStatic::kVersion == 2,
"This function should be updated");
// This will show to the scheduler that the worker is not doing anything yet
// and that it should be continued manually.
worker->is_waiting_ = true;
bool is_ok = true;
int error_code = 0;
// Try to only add new deserialize statements at the end so error_code values
// are stable.
is_ok = is_ok && ++error_code &&
DeserializeEnumValue<CertScope>(saved_worker, kKeyNameCertScope,
&(worker->cert_scope_));
is_ok = is_ok && ++error_code &&
DeserializeCertProfile(saved_worker, kKeyNameCertProfile,
&(worker->cert_profile_));
is_ok = is_ok && ++error_code &&
DeserializeEnumValue<CertProvisioningWorkerState>(
saved_worker, kKeyNameState, &(worker->state_));
is_ok = is_ok && ++error_code &&
DeserializeBase64Encoded(saved_worker, kKeyNamePublicKey,
&(worker->public_key_));
is_ok = is_ok && ++error_code &&
DeserializeStringValue(saved_worker, kKeyNameInvalidationTopic,
&(worker->invalidation_topic_));
is_ok = is_ok && ++error_code &&
DeserializeStringValue(saved_worker, kKeyNameProcessId,
&(worker->process_id_));
if (!is_ok) {
LOG(ERROR)
<< " Failed to deserialize cert provisioning worker, error code: "
<< error_code;
return false;
}
worker->InitAfterDeserialization();
return true;
}
bool CertProvisioningSerializer::DeserializeWorker(
const base::Value::Dict& saved_worker,
CertProvisioningWorkerDynamic* worker) {
static_assert(CertProvisioningWorkerDynamic::kVersion == 3,
"This function should be updated");
// This will show to the scheduler that the worker is not doing anything yet
// and that it should be continued manually.
worker->is_waiting_ = true;
bool is_ok = true;
int error_code = 0;
// Try to only add new deserialize statements at the end so error_code values
// are stable.
is_ok = is_ok && ++error_code &&
DeserializeEnumValue<CertScope>(saved_worker, kKeyNameCertScope,
&(worker->cert_scope_));
is_ok = is_ok && ++error_code &&
DeserializeCertProfile(saved_worker, kKeyNameCertProfile,
&(worker->cert_profile_));
is_ok = is_ok && ++error_code &&
DeserializeEnumValue<CertProvisioningWorkerState>(
saved_worker, kKeyNameState, &(worker->state_));
is_ok = is_ok && ++error_code &&
DeserializeBase64Encoded(saved_worker, kKeyNamePublicKey,
&(worker->public_key_));
is_ok = is_ok && ++error_code &&
DeserializeStringValue(saved_worker, kKeyNameInvalidationTopic,
&(worker->invalidation_topic_));
is_ok = is_ok && ++error_code &&
DeserializeEnumValue<KeyLocation>(saved_worker, kKeyNameKeyLocation,
&(worker->key_location_));
is_ok = is_ok && ++error_code &&
DeserializeBoolValue(saved_worker, kKeyNameAttemptedVaChallenge,
&(worker->attempted_va_challenge_));
is_ok = is_ok && ++error_code &&
DeserializeBoolValue(saved_worker, kKeyNameAttemptedProofOfPossession,
&(worker->attempted_proof_of_possession_));
is_ok =
is_ok && ++error_code &&
DeserializeBase64Encoded(saved_worker, kKeyNameProofOfPossessionSignature,
&(worker->signature_));
is_ok = is_ok && ++error_code &&
DeserializeStringValue(saved_worker, kKeyNameProcessId,
&(worker->process_id_));
if (!is_ok) {
LOG(ERROR)
<< " Failed to deserialize cert provisioning worker, error code: "
<< error_code;
return false;
}
worker->InitAfterDeserialization();
return true;
}
std::optional<ProtocolVersion> CertProvisioningSerializer::GetProtocolVersion(
const base::Value::Dict& saved_worker) {
CertProfile cert_profile;
if (!DeserializeCertProfile(saved_worker, kKeyNameCertProfile,
&cert_profile)) {
return {};
}
return cert_profile.protocol_version;
}
} // namespace cert_provisioning
} // namespace ash