// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/policy/test_support/fake_dmserver.h"
#include <utility>
#include <vector>
#include "base/base64.h"
#include "base/files/file_util.h"
#include "base/json/json_file_value_serializer.h"
#include "base/logging.h"
#include "base/notreached.h"
#include "base/scoped_observation.h"
#include "base/strings/stringprintf.h"
#include "base/task/bind_post_task.h"
#include "base/time/time.h"
#include "base/timer/timer.h"
#include "components/policy/test_support/client_storage.h"
#include "components/policy/test_support/embedded_policy_test_server.h"
#include "components/policy/test_support/policy_storage.h"
#include "components/policy/test_support/request_handler_for_policy.h"
#include "components/policy/test_support/test_server_helpers.h"
#define RETURN_IF_FALSE(expr) \
if (!expr) { \
return false; \
}
namespace fakedms {
namespace {
constexpr char kPolicyTypeKey[] = "policy_type";
constexpr char kEntityIdKey[] = "entity_id";
constexpr char kPolicyValueKey[] = "value";
constexpr char kDeviceIdKey[] = "device_id";
constexpr char kDeviceTokenKey[] = "device_token";
constexpr char kMachineNameKey[] = "machine_name";
constexpr char kUsernameKey[] = "username";
constexpr char kStateKeysKey[] = "state_keys";
constexpr char kAllowedPolicyTypesKey[] = "allowed_policy_types";
constexpr char kPoliciesKey[] = "policies";
constexpr char kExternalPoliciesKey[] = "external_policies";
constexpr char kManagedUsersKey[] = "managed_users";
constexpr char kDeviceAffiliationIdsKey[] = "device_affiliation_ids";
constexpr char kUserAffiliationIdsKey[] = "user_affiliation_ids";
constexpr char kDirectoryApiIdKey[] = "directory_api_id";
constexpr char kRequestErrorsKey[] = "request_errors";
constexpr char kRobotApiAuthCodeKey[] = "robot_api_auth_code";
constexpr char kAllowSetDeviceAttributesKey[] = "allow_set_device_attributes";
constexpr char kUseUniversalSigningKeysKey[] = "use_universal_signing_keys";
constexpr char kInitialEnrollmentStateKey[] = "initial_enrollment_state";
constexpr char kManagementDomainKey[] = "management_domain";
constexpr char kInitialEnrollmentModeKey[] = "initial_enrollment_mode";
constexpr char kCurrentKeyIndexKey[] = "current_key_index";
constexpr char kPolicyUserKey[] = "policy_user";
constexpr char kDefaultPolicyBlobFilename[] = "policy.json";
constexpr char kDefaultClientStateFilename[] = "state.json";
constexpr int kDefaultMinLogLevel = logging::LOGGING_INFO;
constexpr bool kDefaultLogToConsole = false;
constexpr char kPolicyBlobPathSwitch[] = "policy-blob-path";
constexpr char kClientStatePathSwitch[] = "client-state-path";
constexpr char kGrpcUnixSocketUriSwitch[] = "grpc-unix-socket-uri";
constexpr char kLogPathSwitch[] = "log-path";
constexpr char kStartupPipeSwitch[] = "startup-pipe";
constexpr char kMinLogLevelSwitch[] = "min-log-level";
constexpr char kLogToConsoleSwitch[] = "log-to-console";
constexpr base::TimeDelta kRemoteCommandTimeoutSeconds = base::Seconds(10);
constexpr int64_t kDefaultServerStopTimeoutMs = 100;
static remote_commands::WaitRemoteCommandResultResponse
BuildWaitRemoteCommandResultResponse(const em::RemoteCommandResult& result) {
remote_commands::WaitRemoteCommandResultResponse resp;
em::RemoteCommandResult* remote_command_result = resp.mutable_result();
remote_command_result->set_result(result.result());
remote_command_result->set_command_id(result.command_id());
remote_command_result->set_timestamp(result.timestamp());
remote_command_result->set_payload(result.payload());
return resp;
}
void ParsePolicyUser(const base::Value::Dict* dict,
policy::PolicyStorage* policy_storage) {
const std::string* policy_user = dict->FindString(kPolicyUserKey);
if (policy_user) {
LOG(INFO) << "Adding " << *policy_user << " as a policy user";
policy_storage->set_policy_user(*policy_user);
} else {
LOG(INFO) << "The policy_user key isn't found and the default policy "
"user "
<< policy::kDefaultUsername << " will be used";
}
}
void ParseManagedUsers(const base::Value::Dict* dict,
policy::PolicyStorage* policy_storage) {
const base::Value::List* managed_users = dict->FindList(kManagedUsersKey);
if (managed_users) {
for (const base::Value& managed_user : *managed_users) {
const std::string* managed_val = managed_user.GetIfString();
if (managed_val) {
LOG(INFO) << "Adding " << *managed_val << " as a managed user";
policy_storage->add_managed_user(*managed_val);
}
}
}
}
void ParseDeviceAffiliationIds(const base::Value::Dict* dict,
policy::PolicyStorage* policy_storage) {
const base::Value::List* device_affiliation_ids =
dict->FindList(kDeviceAffiliationIdsKey);
if (device_affiliation_ids) {
for (const base::Value& device_affiliation_id : *device_affiliation_ids) {
const std::string* device_affiliation_id_val =
device_affiliation_id.GetIfString();
if (device_affiliation_id_val) {
LOG(INFO) << "Adding " << *device_affiliation_id_val
<< " as a device affiliation id";
policy_storage->add_device_affiliation_id(*device_affiliation_id_val);
}
}
}
}
void ParseUserAffiliationIds(const base::Value::Dict* dict,
policy::PolicyStorage* policy_storage) {
const base::Value::List* user_affiliation_ids =
dict->FindList(kUserAffiliationIdsKey);
if (user_affiliation_ids) {
for (const base::Value& user_affiliation_id : *user_affiliation_ids) {
const std::string* user_affiliation_id_val =
user_affiliation_id.GetIfString();
if (user_affiliation_id_val) {
LOG(INFO) << "Adding " << *user_affiliation_id_val
<< " as a user affiliation id";
policy_storage->add_user_affiliation_id(*user_affiliation_id_val);
}
}
}
}
void ParseDirectoryApiId(const base::Value::Dict* dict,
policy::PolicyStorage* policy_storage) {
const std::string* directory_api_id = dict->FindString(kDirectoryApiIdKey);
if (directory_api_id) {
LOG(INFO) << "Adding " << *directory_api_id << " as a directory API ID";
policy_storage->set_directory_api_id(*directory_api_id);
}
}
bool ParseAllowSetDeviceAttributes(const base::Value::Dict* dict,
policy::PolicyStorage* policy_storage) {
if (const base::Value* v = dict->Find(kAllowSetDeviceAttributesKey); v) {
std::optional<bool> allow_set_device_attributes = v->GetIfBool();
if (!allow_set_device_attributes.has_value()) {
LOG(ERROR)
<< "The allow_set_device_attributes key isn't a bool, found type "
<< v->type() << ", found value " << *v;
return false;
}
policy_storage->set_allow_set_device_attributes(
allow_set_device_attributes.value());
}
return true;
}
bool ParseUseUniversalSigningKeys(const base::Value::Dict* dict,
policy::PolicyStorage* policy_storage) {
const base::Value* use_universal_signing_keys =
dict->Find(kUseUniversalSigningKeysKey);
if (use_universal_signing_keys) {
std::optional<bool> maybe_value = use_universal_signing_keys->GetIfBool();
if (!maybe_value.has_value()) {
LOG(ERROR)
<< "The use_universal_signing_keys key isn't a bool, found type "
<< use_universal_signing_keys->type() << ", found value "
<< *use_universal_signing_keys;
return false;
}
if (maybe_value.value()) {
policy_storage->signature_provider()->SetUniversalSigningKeys();
}
}
return true;
}
void ParseRobotApiAuthCode(const base::Value::Dict* dict,
policy::PolicyStorage* policy_storage) {
const std::string* robot_api_auth_code =
dict->FindString(kRobotApiAuthCodeKey);
if (robot_api_auth_code) {
LOG(INFO) << "Adding " << *robot_api_auth_code
<< " as a robot api auth code";
policy_storage->set_robot_api_auth_code(*robot_api_auth_code);
}
}
bool ParseRequestErrors(const base::Value::Dict* dict,
FakeDMServer* fake_dmserver) {
const base::Value::Dict* request_errors = dict->FindDict(kRequestErrorsKey);
if (request_errors) {
for (auto request_error : *request_errors) {
std::optional<int> net_error_code = request_error.second.GetIfInt();
if (!net_error_code.has_value()) {
LOG(ERROR) << "The error code isn't an int";
return false;
}
LOG(INFO) << "Configuring request " << request_error.first << " to error "
<< net_error_code.value();
fake_dmserver->ConfigureRequestError(
request_error.first,
static_cast<net::HttpStatusCode>(net_error_code.value()));
}
}
return true;
}
bool ParseInitialEnrollmentState(const base::Value::Dict* dict,
policy::PolicyStorage* policy_storage) {
const base::Value::Dict* initial_enrollment_state =
dict->FindDict(kInitialEnrollmentStateKey);
if (initial_enrollment_state) {
for (auto state : *initial_enrollment_state) {
const base::Value::Dict* state_val = state.second.GetIfDict();
if (!state_val) {
LOG(ERROR) << "The current state value for key " << state.first
<< " isn't a dict";
return false;
}
const std::string* management_domain =
state_val->FindString(kManagementDomainKey);
if (!management_domain) {
LOG(ERROR) << "The management_domain key isn't a string";
return false;
}
std::optional<int> initial_enrollment_mode =
state_val->FindInt(kInitialEnrollmentModeKey);
if (!initial_enrollment_mode.has_value()) {
LOG(ERROR) << "The initial_enrollment_mode key isn't an int";
return false;
}
policy::PolicyStorage::InitialEnrollmentState initial_value;
initial_value.management_domain = *management_domain;
initial_value.initial_enrollment_mode = static_cast<
enterprise_management::DeviceInitialEnrollmentStateResponse::
InitialEnrollmentMode>(initial_enrollment_mode.value());
policy_storage->SetInitialEnrollmentState(state.first, initial_value);
}
}
return true;
}
bool ParseCurrentKeyIndex(const base::Value::Dict* dict,
policy::PolicyStorage* policy_storage) {
if (const base::Value* v = dict->Find(kCurrentKeyIndexKey); v) {
std::optional<int> current_key_index = v->GetIfInt();
if (!current_key_index.has_value()) {
LOG(ERROR) << "The current_key_index key isn't an int, found type "
<< v->type() << ", found value " << *v;
return false;
}
policy_storage->signature_provider()->set_current_key_version(
current_key_index.value());
}
return true;
}
} // namespace
void InitLogging(const std::optional<std::string>& log_path,
bool log_to_console,
int min_log_level) {
logging::LoggingSettings settings;
if (log_path.has_value()) {
settings.log_file_path = log_path.value().c_str();
settings.logging_dest = logging::LOG_TO_FILE;
} else {
settings.logging_dest = logging::LOG_TO_STDERR;
}
// If log_to_console exists then log to everything.
if (log_to_console) {
settings.logging_dest |=
logging::LOG_TO_SYSTEM_DEBUG_LOG | logging::LOG_TO_STDERR;
}
logging::SetMinLogLevel(min_log_level);
logging::InitLogging(settings);
logging::SetLogItems(/*enable_process_id=*/true, /*enable_thread_id=*/true,
/*enable_timestamp=*/true, /*enable_timestamp=*/false);
}
void ParseFlags(const base::CommandLine& command_line,
std::string& policy_blob_path,
std::string& client_state_path,
std::string& grpc_unix_socket_uri,
std::optional<std::string>& log_path,
base::ScopedFD& startup_pipe,
bool& log_to_console,
int& min_log_level) {
policy_blob_path = kDefaultPolicyBlobFilename;
client_state_path = kDefaultClientStateFilename;
log_to_console = kDefaultLogToConsole;
min_log_level = kDefaultMinLogLevel;
if (command_line.HasSwitch(kPolicyBlobPathSwitch)) {
policy_blob_path = command_line.GetSwitchValueASCII(kPolicyBlobPathSwitch);
}
if (command_line.HasSwitch(kLogPathSwitch)) {
log_path = command_line.GetSwitchValueASCII(kLogPathSwitch);
}
if (command_line.HasSwitch(kClientStatePathSwitch)) {
client_state_path =
command_line.GetSwitchValueASCII(kClientStatePathSwitch);
}
if (command_line.HasSwitch(kGrpcUnixSocketUriSwitch)) {
grpc_unix_socket_uri =
command_line.GetSwitchValueASCII(kGrpcUnixSocketUriSwitch);
}
if (command_line.HasSwitch(kStartupPipeSwitch)) {
std::string pipe_str = command_line.GetSwitchValueASCII(kStartupPipeSwitch);
int pipe_val;
CHECK(base::StringToInt(pipe_str, &pipe_val))
<< "Expected an int value for --startup-pipe switch, but got: "
<< pipe_str;
startup_pipe = base::ScopedFD(pipe_val);
}
if (command_line.HasSwitch(kMinLogLevelSwitch)) {
std::string log_str = command_line.GetSwitchValueASCII(kMinLogLevelSwitch);
CHECK(base::StringToInt(log_str, &min_log_level))
<< "Expected an int value for --min-log-level switch, but got: "
<< log_str;
}
if (command_line.HasSwitch(kLogToConsoleSwitch)) {
log_to_console = true;
}
}
enum class RemoteCommandsWaitType { kAcknowledged, kResultAvailable };
class RemoteCommandsWaitOperation
: public policy::RemoteCommandsState::Observer {
public:
// Callback for a RemoteCommandsWaitOperation.
// `wait_operation` will refer to the RemoteCommandsWaitOperation object that
// invoked the callback. If `success` is true, the `RemoteCommandsWaitType`
// has happened, otherwise the wait timed out.
using RemoteCommandsWaitCallback =
base::OnceCallback<void(RemoteCommandsWaitOperation* wait_operation,
bool success)>;
RemoteCommandsWaitOperation(
policy::RemoteCommandsState* remote_commands_state,
RemoteCommandsWaitType wait_type,
RemoteCommandsWaitOperation::RemoteCommandsWaitCallback wait_callback);
~RemoteCommandsWaitOperation() override;
void OnRemoteCommandResultAvailable(int64_t command_id) override;
void OnRemoteCommandAcked(int64_t command_id) override;
void OnTimeout();
private:
const raw_ptr<policy::RemoteCommandsState> remote_commands_state_;
const RemoteCommandsWaitType wait_type_;
RemoteCommandsWaitCallback wait_callback_;
base::ScopedObservation<policy::RemoteCommandsState,
policy::RemoteCommandsState::Observer>
state_observation_{this};
// Timer that fires to prevent indefinite wait if the remote command result
// takes too long.
base::OneShotTimer result_timeout_timer_;
base::WeakPtrFactory<RemoteCommandsWaitOperation> weak_ptr_factory_{this};
};
RemoteCommandsWaitOperation::RemoteCommandsWaitOperation(
policy::RemoteCommandsState* remote_commands_state,
RemoteCommandsWaitType wait_type,
RemoteCommandsWaitOperation::RemoteCommandsWaitCallback wait_callback)
: remote_commands_state_(remote_commands_state),
wait_type_(wait_type),
wait_callback_(std::move(wait_callback)) {
state_observation_.Observe(remote_commands_state);
// Start a timer for 10 seconds to wait for the remote command result.
result_timeout_timer_.Start(
FROM_HERE, kRemoteCommandTimeoutSeconds,
base::BindOnce(&RemoteCommandsWaitOperation::OnTimeout,
weak_ptr_factory_.GetWeakPtr()));
}
RemoteCommandsWaitOperation::~RemoteCommandsWaitOperation() = default;
void RemoteCommandsWaitOperation::OnRemoteCommandResultAvailable(
int64_t command_id) {
if (wait_type_ != RemoteCommandsWaitType::kResultAvailable) {
return;
}
const bool result_available =
remote_commands_state_->IsRemoteCommandResultAvailable(command_id);
// The result must be available now.
CHECK(result_available);
// Invoke the wait callback OnWaitRemoteCommandResultDone to write the result
// to the reactor.
std::move(wait_callback_).Run(this, result_available);
}
void RemoteCommandsWaitOperation::OnRemoteCommandAcked(int64_t command_id) {
if (wait_type_ != RemoteCommandsWaitType::kAcknowledged) {
return;
}
const bool command_acked =
remote_commands_state_->IsRemoteCommandAcked(command_id);
// The command must be acknowledged now.
CHECK(command_acked);
// Invoke the wait callback OnWaitRemoteCommandAckDone to write the ack to the
// reactor.
std::move(wait_callback_).Run(this, command_acked);
}
void RemoteCommandsWaitOperation::OnTimeout() {
std::move(wait_callback_).Run(this, false);
}
FakeDMServer::FakeDMServer(const std::string& policy_blob_path,
const std::string& client_state_path,
const std::string& grpc_unix_socket_uri,
base::OnceClosure shutdown_cb)
: policy_blob_path_(policy_blob_path),
client_state_path_(client_state_path),
grpc_unix_socket_uri_(grpc_unix_socket_uri) {
shut_down_on_main_task_runner_ =
base::BindPostTaskToCurrentDefault(base::BindOnce(
&FakeDMServer::TriggerShutdown, weak_ptr_factory_.GetWeakPtr()));
shut_down_server_ = base::BindPostTaskToCurrentDefault(
base::BindOnce(&FakeDMServer::OnShutdownGrpcServerDone,
weak_ptr_factory_.GetWeakPtr(), std::move(shutdown_cb)));
DETACH_FROM_SEQUENCE(embedded_server_sequence_checker_);
}
FakeDMServer::~FakeDMServer() {
DCHECK_CALLED_ON_VALID_SEQUENCE(fake_dmserver_main_sequence_checker_);
}
void FakeDMServer::EraseWaitOperation(RemoteCommandsWaitOperation* operation) {
DCHECK_CALLED_ON_VALID_SEQUENCE(fake_dmserver_main_sequence_checker_);
auto it = waiters_.find(operation);
CHECK(it != waiters_.end());
waiters_.erase(it);
}
void FakeDMServer::StartGrpcServer() {
DCHECK_CALLED_ON_VALID_SEQUENCE(fake_dmserver_main_sequence_checker_);
LOG(INFO) << "Starting the gRPC server on endpoint " << grpc_unix_socket_uri_;
grpc_server_.emplace();
grpc_server_->SetHandler<
remote_commands::RemoteCommandsServiceHandler::SendRemoteCommand>(
base::BindPostTask(
base::SingleThreadTaskRunner::GetCurrentDefault(),
base::BindRepeating(&FakeDMServer::HandleSendRemoteCommand,
weak_ptr_factory_.GetWeakPtr())));
grpc_server_->SetHandler<
remote_commands::RemoteCommandsServiceHandler::WaitRemoteCommandResult>(
base::BindPostTask(
base::SingleThreadTaskRunner::GetCurrentDefault(),
base::BindRepeating(&FakeDMServer::HandleWaitRemoteCommandResult,
weak_ptr_factory_.GetWeakPtr())));
grpc_server_->SetHandler<
remote_commands::RemoteCommandsServiceHandler::WaitRemoteCommandAcked>(
base::BindPostTask(
base::SingleThreadTaskRunner::GetCurrentDefault(),
base::BindRepeating(&FakeDMServer::HandleWaitRemoteCommandAcked,
weak_ptr_factory_.GetWeakPtr())));
auto status = grpc_server_->Start(grpc_unix_socket_uri_);
// Browser runtime must crash if the runtime service failed to start to avoid
// the process to dangle without any proper connection to the Cast Core.
CHECK(status.ok()) << "Failed to start DM gRPC server: status="
<< status.error_message();
}
void FakeDMServer::HandleSendRemoteCommand(
remote_commands::SendRemoteCommandRequest request,
remote_commands::RemoteCommandsServiceHandler::SendRemoteCommand::Reactor*
reactor) {
DCHECK_CALLED_ON_VALID_SEQUENCE(fake_dmserver_main_sequence_checker_);
LOG(INFO) << "Processing SendRemoteCommand grpc request.";
int64_t command_id = remote_commands_state()->AddPendingRemoteCommand(
request.remote_command());
remote_commands::SendRemoteCommandResponse resp;
resp.set_command_id(command_id);
reactor->Write(std::move(resp));
}
void FakeDMServer::OnWaitRemoteCommandResultDone(
remote_commands::RemoteCommandsServiceHandler::WaitRemoteCommandResult::
Reactor* reactor,
int64_t command_id,
RemoteCommandsWaitOperation* wait_operation,
bool wait_success) {
DCHECK_CALLED_ON_VALID_SEQUENCE(fake_dmserver_main_sequence_checker_);
auto it = waiters_.find(wait_operation);
CHECK(it != waiters_.end());
waiters_.erase(it);
if (!wait_success) {
reactor->Write(grpc::Status(
grpc::StatusCode::CANCELLED,
"Timeout waiting for remote command result took more than 10 seconds"));
return;
}
em::RemoteCommandResult result;
bool result_available =
remote_commands_state()->GetRemoteCommandResult(command_id, &result);
CHECK(result_available);
auto resp = BuildWaitRemoteCommandResultResponse(result);
reactor->Write(std::move(resp));
}
void FakeDMServer::HandleWaitRemoteCommandResult(
remote_commands::WaitRemoteCommandResultRequest request,
remote_commands::RemoteCommandsServiceHandler::WaitRemoteCommandResult::
Reactor* reactor) {
DCHECK_CALLED_ON_VALID_SEQUENCE(fake_dmserver_main_sequence_checker_);
LOG(INFO) << "Processing WaitRemoteCommandResult grpc request.";
int64_t command_id = request.command_id();
em::RemoteCommandResult result;
bool result_available =
remote_commands_state()->GetRemoteCommandResult(command_id, &result);
if (!result_available) {
LOG(INFO) << "Remote command result isn't available yet.";
// Insert the wait operation into the set and bind the erase function to
// erase it if the result is available.
waiters_.insert(std::make_unique<RemoteCommandsWaitOperation>(
remote_commands_state(), RemoteCommandsWaitType::kResultAvailable,
base::BindOnce(&FakeDMServer::OnWaitRemoteCommandResultDone,
weak_ptr_factory_.GetWeakPtr(),
base::Unretained(reactor), command_id)));
return;
}
LOG(INFO) << "Remote command result is available. Resolving the grpc call.";
auto resp = BuildWaitRemoteCommandResultResponse(result);
reactor->Write(std::move(resp));
}
void FakeDMServer::OnWaitRemoteCommandAckDone(
remote_commands::RemoteCommandsServiceHandler::WaitRemoteCommandAcked::
Reactor* reactor,
int64_t command_id,
RemoteCommandsWaitOperation* wait_operation,
bool wait_success) {
DCHECK_CALLED_ON_VALID_SEQUENCE(fake_dmserver_main_sequence_checker_);
auto it = waiters_.find(wait_operation);
CHECK(it != waiters_.end());
waiters_.erase(it);
if (!wait_success) {
reactor->Write(grpc::Status(grpc::StatusCode::CANCELLED,
"Timeout waiting for remote command "
"acknowledgement took more than 10 seconds"));
return;
}
bool command_acked =
remote_commands_state()->IsRemoteCommandAcked(command_id);
CHECK(command_acked);
remote_commands::WaitRemoteCommandAckedResponse resp;
reactor->Write(std::move(resp));
}
void FakeDMServer::HandleWaitRemoteCommandAcked(
remote_commands::WaitRemoteCommandAckedRequest request,
remote_commands::RemoteCommandsServiceHandler::WaitRemoteCommandAcked::
Reactor* reactor) {
DCHECK_CALLED_ON_VALID_SEQUENCE(fake_dmserver_main_sequence_checker_);
LOG(INFO) << "Processing WaitRemoteCommandAcked grpc request.";
int64_t command_id = request.command_id();
bool command_acked =
remote_commands_state()->IsRemoteCommandAcked(command_id);
if (!command_acked) {
LOG(INFO) << "Remote command isn't acknowledged yet.";
// Insert the wait operation into the set and bind the erase function to
// erase it if the command is acknowledged.
waiters_.insert(std::make_unique<RemoteCommandsWaitOperation>(
remote_commands_state(), RemoteCommandsWaitType::kAcknowledged,
base::BindOnce(&FakeDMServer::OnWaitRemoteCommandAckDone,
weak_ptr_factory_.GetWeakPtr(),
base::Unretained(reactor), command_id)));
return;
}
LOG(INFO) << "Remote command is acknowledged. Resolving the grpc call.";
remote_commands::WaitRemoteCommandAckedResponse resp;
reactor->Write(std::move(resp));
}
bool FakeDMServer::StartFakeServer() {
DCHECK_CALLED_ON_VALID_SEQUENCE(fake_dmserver_main_sequence_checker_);
LOG(INFO) << "Starting the FakeDMServer with args policy_blob_path="
<< policy_blob_path_ << " client_state_path=" << client_state_path_
<< " grpc_unix_socket_uri=" << grpc_unix_socket_uri_;
if (!policy::EmbeddedPolicyTestServer::Start()) {
LOG(ERROR) << "Failed to start the EmbeddedPolicyTestServer";
return false;
}
LOG(INFO) << "Server started running on URL: "
<< EmbeddedPolicyTestServer::GetServiceURL();
if (grpc_unix_socket_uri_.empty()) {
LOG(INFO) << "grpc_unix_socket_uri is empty the grpc server won't start";
return true;
}
StartGrpcServer();
return true;
}
void FakeDMServer::ShutdownGrpcServer(
base::OnceClosure server_stopped_callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(fake_dmserver_main_sequence_checker_);
CHECK(grpc_server_);
grpc_server_->Stop(kDefaultServerStopTimeoutMs,
std::move(server_stopped_callback));
}
void FakeDMServer::OnShutdownGrpcServerDone(
base::OnceClosure server_stopped_callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(fake_dmserver_main_sequence_checker_);
grpc_server_.reset();
std::move(server_stopped_callback).Run();
}
void FakeDMServer::TriggerShutdown() {
DCHECK_CALLED_ON_VALID_SEQUENCE(fake_dmserver_main_sequence_checker_);
if (!grpc_server_) {
return std::move(shut_down_server_).Run();
}
ShutdownGrpcServer(std::move(shut_down_server_));
}
bool FakeDMServer::WriteURLToPipe(base::ScopedFD&& startup_pipe) {
DCHECK_CALLED_ON_VALID_SEQUENCE(fake_dmserver_main_sequence_checker_);
GURL server_url = EmbeddedPolicyTestServer::GetServiceURL();
std::string server_data =
base::StringPrintf("{\"host\": \"%s\", \"port\": %s}",
server_url.host().c_str(), server_url.port().c_str());
base::File pipe_writer(startup_pipe.release());
if (!pipe_writer.WriteAtCurrentPosAndCheck(
base::as_bytes(base::make_span(server_data)))) {
LOG(ERROR) << "Failed to write the server url data to the pipe, data: "
<< server_data;
return false;
}
return true;
}
std::unique_ptr<net::test_server::HttpResponse> FakeDMServer::HandleRequest(
const net::test_server::HttpRequest& request) {
DCHECK_CALLED_ON_VALID_SEQUENCE(embedded_server_sequence_checker_);
GURL url = request.GetURL();
if (url.path() == "/test/exit") {
LOG(INFO) << "Stopping the FakeDMServer";
std::move(shut_down_on_main_task_runner_).Run();
return policy::CreateHttpResponse(net::HTTP_OK, "Policy Server exited.");
}
if (url.path() == "/test/ping") {
return policy::CreateHttpResponse(net::HTTP_OK, "Pong.");
}
EmbeddedPolicyTestServer::ResetServerState();
if (!ReadPolicyBlobFile()) {
return policy::CreateHttpResponse(net::HTTP_INTERNAL_SERVER_ERROR,
"Failed to read policy blob file.");
}
if (!ReadClientStateFile()) {
return policy::CreateHttpResponse(net::HTTP_INTERNAL_SERVER_ERROR,
"Failed to read client state file.");
}
auto resp = policy::EmbeddedPolicyTestServer::HandleRequest(request);
if (!WriteClientStateFile()) {
return policy::CreateHttpResponse(net::HTTP_INTERNAL_SERVER_ERROR,
"Failed to write client state file.");
}
return resp;
}
bool FakeDMServer::SetPolicyPayload(const std::string* policy_type,
const std::string* entity_id,
const std::string* serialized_proto) {
DCHECK_CALLED_ON_VALID_SEQUENCE(embedded_server_sequence_checker_);
if (!policy_type || !serialized_proto) {
LOG(ERROR) << "Couldn't find the policy type or value fields";
return false;
}
std::string decoded_proto;
if (!base::Base64Decode(*serialized_proto, &decoded_proto)) {
LOG(ERROR) << "Unable to base64 decode validation value from "
<< *serialized_proto;
return false;
}
if (entity_id) {
policy_storage()->SetPolicyPayload(*policy_type, *entity_id, decoded_proto);
} else {
policy_storage()->SetPolicyPayload(*policy_type, decoded_proto);
}
return true;
}
bool FakeDMServer::SetExternalPolicyPayload(
const std::string* policy_type,
const std::string* entity_id,
const std::string* serialized_raw_policy) {
DCHECK_CALLED_ON_VALID_SEQUENCE(embedded_server_sequence_checker_);
if (!policy_type || !entity_id || !serialized_raw_policy) {
LOG(ERROR) << "Couldn't find the policy type or entity id or value fields";
return false;
}
std::string decoded_raw_policy;
if (!base::Base64Decode(*serialized_raw_policy, &decoded_raw_policy)) {
LOG(ERROR) << "Unable to base64 decode validation value from "
<< *serialized_raw_policy;
return false;
}
EmbeddedPolicyTestServer::UpdateExternalPolicy(*policy_type, *entity_id,
decoded_raw_policy);
return true;
}
bool FakeDMServer::ParsePolicies(const base::Value::Dict* dict) {
const base::Value::List* policies = dict->FindList(kPoliciesKey);
if (policies) {
for (const base::Value& policy : *policies) {
const base::Value::Dict* policy_as_dict = policy.GetIfDict();
if (!policy_as_dict) {
LOG(ERROR) << "The current policy isn't a dict";
return false;
}
if (!SetPolicyPayload(policy_as_dict->FindString(kPolicyTypeKey),
policy_as_dict->FindString(kEntityIdKey),
policy_as_dict->FindString(kPolicyValueKey))) {
LOG(ERROR) << "Failed to set the policy";
return false;
}
}
}
const base::Value::List* external_policies =
dict->FindList(kExternalPoliciesKey);
if (external_policies) {
for (const base::Value& policy : *external_policies) {
const base::Value::Dict* policy_as_dict = policy.GetIfDict();
if (!policy_as_dict) {
LOG(ERROR) << "The current external policy isn't a dict";
return false;
}
if (!SetExternalPolicyPayload(
policy_as_dict->FindString(kPolicyTypeKey),
policy_as_dict->FindString(kEntityIdKey),
policy_as_dict->FindString(kPolicyValueKey))) {
LOG(ERROR) << "Failed to set the external policy";
return false;
}
}
}
return true;
}
bool FakeDMServer::ReadPolicyBlobFile() {
DCHECK_CALLED_ON_VALID_SEQUENCE(embedded_server_sequence_checker_);
if (!base::PathExists(policy_blob_path_)) {
LOG(INFO) << "Policy blob file doesn't exist yet.";
return true;
}
JSONFileValueDeserializer deserializer(policy_blob_path_);
int error_code = 0;
std::string error_msg;
std::unique_ptr<base::Value> value =
deserializer.Deserialize(&error_code, &error_msg);
if (!value) {
LOG(ERROR) << "Failed to read the policy blob file " << policy_blob_path_
<< ": " << error_msg;
return false;
}
LOG(INFO) << "Deserialized value of the policy blob: " << *value;
const base::Value::Dict* dict = value->GetIfDict();
if (!dict) {
LOG(ERROR) << "Policy blob isn't a dict";
return false;
}
ParsePolicyUser(dict, policy_storage());
ParseManagedUsers(dict, policy_storage());
ParseDeviceAffiliationIds(dict, policy_storage());
ParseUserAffiliationIds(dict, policy_storage());
ParseDirectoryApiId(dict, policy_storage());
ParseRobotApiAuthCode(dict, policy_storage());
RETURN_IF_FALSE(ParseAllowSetDeviceAttributes(dict, policy_storage()));
RETURN_IF_FALSE(ParseUseUniversalSigningKeys(dict, policy_storage()));
RETURN_IF_FALSE(ParseRequestErrors(dict, this));
RETURN_IF_FALSE(ParseInitialEnrollmentState(dict, policy_storage()));
RETURN_IF_FALSE(ParseCurrentKeyIndex(dict, policy_storage()));
RETURN_IF_FALSE(ParsePolicies(dict));
return true;
}
base::Value::Dict FakeDMServer::GetValueFromClient(
const policy::ClientStorage::ClientInfo& c) {
base::Value::Dict dict;
dict.Set(kDeviceIdKey, c.device_id);
dict.Set(kDeviceTokenKey, c.device_token);
dict.Set(kMachineNameKey, c.machine_name);
dict.Set(kUsernameKey, c.username.value_or(""));
base::Value::List state_keys, allowed_policy_types;
for (auto& key : c.state_keys) {
state_keys.Append(key);
}
dict.Set(kStateKeysKey, std::move(state_keys));
for (auto& policy_type : c.allowed_policy_types) {
allowed_policy_types.Append(policy_type);
}
dict.Set(kAllowedPolicyTypesKey, std::move(allowed_policy_types));
return dict;
}
bool FakeDMServer::WriteClientStateFile() {
DCHECK_CALLED_ON_VALID_SEQUENCE(embedded_server_sequence_checker_);
std::vector<policy::ClientStorage::ClientInfo> clients =
client_storage()->GetAllClients();
base::Value::Dict dict_clients;
for (auto& c : clients) {
dict_clients.Set(c.device_id, GetValueFromClient(c));
}
JSONFileValueSerializer serializer(client_state_path_);
return serializer.Serialize(base::ValueView(dict_clients));
}
bool FakeDMServer::FindKey(const base::Value::Dict& dict,
const std::string& key,
base::Value::Type type) {
switch (type) {
case base::Value::Type::STRING: {
const std::string* str_val = dict.FindString(key);
if (!str_val) {
LOG(ERROR) << "Key `" << key << "` is missing or not a string.";
return false;
}
return true;
}
case base::Value::Type::LIST: {
const base::Value::List* list_val = dict.FindList(key);
if (!list_val) {
LOG(ERROR) << "Key `" << key << "` is missing or not a list.";
return false;
}
return true;
}
default: {
NOTREACHED() << "Unsupported type for client file key";
}
}
}
std::optional<policy::ClientStorage::ClientInfo>
FakeDMServer::GetClientFromValue(const base::Value& v) {
policy::ClientStorage::ClientInfo client_info;
const base::Value::Dict* dict = v.GetIfDict();
if (!dict) {
LOG(ERROR) << "Client value isn't a dict";
return std::nullopt;
}
if (!FindKey(*dict, kDeviceIdKey, base::Value::Type::STRING) ||
!FindKey(*dict, kDeviceTokenKey, base::Value::Type::STRING) ||
!FindKey(*dict, kMachineNameKey, base::Value::Type::STRING) ||
!FindKey(*dict, kUsernameKey, base::Value::Type::STRING) ||
!FindKey(*dict, kStateKeysKey, base::Value::Type::LIST) ||
!FindKey(*dict, kAllowedPolicyTypesKey, base::Value::Type::LIST)) {
return std::nullopt;
}
client_info.device_id = *dict->FindString(kDeviceIdKey);
client_info.device_token = *dict->FindString(kDeviceTokenKey);
client_info.machine_name = *dict->FindString(kMachineNameKey);
client_info.username = *dict->FindString(kUsernameKey);
const base::Value::List* state_keys = dict->FindList(kStateKeysKey);
for (const auto& it : *state_keys) {
const std::string* key = it.GetIfString();
if (!key) {
LOG(ERROR) << "State key list entry is not a string: " << it;
return std::nullopt;
}
client_info.state_keys.emplace_back(*key);
}
const base::Value::List* policy_types =
dict->FindList(kAllowedPolicyTypesKey);
for (const auto& it : *policy_types) {
const std::string* key = it.GetIfString();
if (!key) {
LOG(ERROR) << "Policy type list entry is not a string: " << it;
return std::nullopt;
}
client_info.allowed_policy_types.insert(*key);
}
return client_info;
}
bool FakeDMServer::ReadClientStateFile() {
DCHECK_CALLED_ON_VALID_SEQUENCE(embedded_server_sequence_checker_);
if (!base::PathExists(client_state_path_)) {
LOG(INFO) << "Client state file doesn't exist yet.";
return true;
}
JSONFileValueDeserializer deserializer(client_state_path_);
int error_code = 0;
std::string error_msg;
std::unique_ptr<base::Value> value =
deserializer.Deserialize(&error_code, &error_msg);
if (!value) {
LOG(ERROR) << "Failed to read client state file " << client_state_path_
<< ": " << error_msg;
return false;
}
const base::Value::Dict* dict = value->GetIfDict();
if (!dict) {
LOG(ERROR) << "The client state file isn't a dict.";
return false;
}
for (auto it : *dict) {
std::optional<policy::ClientStorage::ClientInfo> c =
GetClientFromValue(it.second);
if (!c.has_value()) {
LOG(ERROR) << "The client isn't configured correctly.";
return false;
}
client_storage()->RegisterClient(c.value());
}
return true;
}
} // namespace fakedms