// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "remoting/host/it2me/it2me_native_messaging_host_ash.h"
#include <utility>
#include "base/feature_list.h"
#include "base/functional/callback.h"
#include "base/json/json_writer.h"
#include "base/notreached.h"
#include "base/time/time.h"
#include "base/values.h"
#include "extensions/browser/api/messaging/native_message_host.h"
#include "remoting/host/chromeos/chromeos_enterprise_params.h"
#include "remoting/host/chromeos/features.h"
#include "remoting/host/chromoting_host_context.h"
#include "remoting/host/it2me/it2me_native_messaging_host.h"
#include "remoting/host/it2me/reconnect_params.h"
#include "remoting/host/native_messaging/native_messaging_helpers.h"
#include "remoting/host/policy_watcher.h"
namespace remoting {
namespace {
bool ShouldSuppressNotifications(
const mojom::SupportSessionParams& params,
const std::optional<ChromeOsEnterpriseParams>& enterprise_params) {
if (enterprise_params.has_value()) {
return enterprise_params->suppress_notifications;
}
// On non-debug builds, do not allow setting this value through the Mojom API.
#if !defined(NDEBUG)
return params.suppress_notifications;
#else
return false;
#endif
}
bool ShouldSuppressUserDialog(
const mojom::SupportSessionParams& params,
const std::optional<ChromeOsEnterpriseParams>& enterprise_params) {
if (enterprise_params.has_value()) {
return enterprise_params->suppress_user_dialogs;
}
// On non-debug builds, do not allow setting this value through the Mojom API.
#if !defined(NDEBUG)
return params.suppress_user_dialogs;
#else
return false;
#endif
}
bool ShouldTerminateUponInput(
const mojom::SupportSessionParams& params,
const std::optional<ChromeOsEnterpriseParams>& enterprise_params) {
if (enterprise_params.has_value()) {
return enterprise_params->terminate_upon_input;
}
// On non-debug builds, do not allow setting this value through the Mojom API.
#if !defined(NDEBUG)
return params.terminate_upon_input;
#else
return false;
#endif
}
bool ShouldCurtainLocalUserSession(
const mojom::SupportSessionParams& params,
const std::optional<ChromeOsEnterpriseParams>& enterprise_params) {
if (!base::FeatureList::IsEnabled(features::kEnableCrdAdminRemoteAccess)) {
return false;
}
if (enterprise_params.has_value()) {
return enterprise_params->curtain_local_user_session;
}
// On non-debug builds, do not allow setting this value through the Mojom API.
#if !defined(NDEBUG)
return params.curtain_local_user_session;
#else
return false;
#endif
}
bool ShouldShowTroubleshootingTools(
const std::optional<ChromeOsEnterpriseParams>& enterprise_params) {
if (enterprise_params.has_value()) {
return enterprise_params->show_troubleshooting_tools;
}
return false;
}
bool ShouldAllowTroubleshootingTools(
const std::optional<ChromeOsEnterpriseParams>& enterprise_params) {
if (enterprise_params.has_value()) {
return enterprise_params->allow_troubleshooting_tools;
}
return false;
}
bool ShouldAllowReconnections(
const std::optional<ChromeOsEnterpriseParams>& enterprise_params) {
if (enterprise_params.has_value()) {
return enterprise_params->allow_reconnections;
}
return false;
}
bool ShouldAllowFileTransfer(
const std::optional<ChromeOsEnterpriseParams>& enterprise_params) {
if (enterprise_params.has_value()) {
return enterprise_params->allow_file_transfer;
}
return false;
}
} // namespace
It2MeNativeMessageHostAsh::It2MeNativeMessageHostAsh(
std::unique_ptr<It2MeHostFactory> host_factory)
: host_factory_(std::move(host_factory)) {}
It2MeNativeMessageHostAsh::~It2MeNativeMessageHostAsh() = default;
mojo::PendingReceiver<mojom::SupportHostObserver>
It2MeNativeMessageHostAsh::Start(
std::unique_ptr<ChromotingHostContext> context,
std::unique_ptr<PolicyWatcher> policy_watcher) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(!native_message_host_);
// Create the remote IPC channel before starting the NMH so any errors are
// queued for sending once the receiver end of the channel is bound.
mojo::PendingReceiver<mojom::SupportHostObserver> observer =
remote_.BindNewPipeAndPassReceiver();
remote_.set_disconnect_handler(base::BindOnce(
&It2MeNativeMessageHostAsh::Disconnect, base::Unretained(this)));
native_message_host_ = std::make_unique<It2MeNativeMessagingHost>(
/*needs_elevation=*/false, std::move(policy_watcher), std::move(context),
host_factory_->Clone());
native_message_host_->Start(this);
return observer;
}
void It2MeNativeMessageHostAsh::Connect(
const mojom::SupportSessionParams& params,
const std::optional<ChromeOsEnterpriseParams>& enterprise_params,
const std::optional<ReconnectParams>& reconnect_params,
base::OnceClosure connected_callback,
HostStateConnectedCallback host_state_connected_callback,
base::OnceClosure host_state_disconnected_callback,
base::OnceClosure disconnected_callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(native_message_host_);
DCHECK(!connected_callback_);
DCHECK(!disconnected_callback_);
connected_callback_ = std::move(connected_callback);
disconnected_callback_ = std::move(disconnected_callback);
host_state_connected_callback_ = std::move(host_state_connected_callback);
host_state_disconnected_callback_ =
std::move(host_state_disconnected_callback);
// The version of Lacros is guaranteed to be at least as new as the code
// running in ash so we can remove this shim in M124, however we need it until
// then as Lacros will continue sending the oauth2 prefix for back-compat
// until that milestone. Basically the shim code in Lacros and Ash can be
// removed in the same milestone but the Lacros code needs to stay in place
// until the back-compat behavior is no longer required.
std::string access_token = params.oauth_access_token;
const char kOAuth2ServicePrefix[] = "oauth2:";
// Strip the prefix off, if it exists.
if (access_token.starts_with(kOAuth2ServicePrefix)) {
access_token = access_token.substr(strlen(kOAuth2ServicePrefix));
}
auto message =
base::Value::Dict()
.Set(kMessageType, kConnectMessage)
.Set(kUserName, params.user_name)
.Set(kAccessToken, access_token)
.Set(kSuppressUserDialogs,
ShouldSuppressUserDialog(params, enterprise_params))
.Set(kSuppressNotifications,
ShouldSuppressNotifications(params, enterprise_params))
.Set(kTerminateUponInput,
ShouldTerminateUponInput(params, enterprise_params))
.Set(kCurtainLocalUserSession,
ShouldCurtainLocalUserSession(params, enterprise_params))
.Set(kShowTroubleshootingTools,
ShouldShowTroubleshootingTools(enterprise_params))
.Set(kAllowTroubleshootingTools,
ShouldAllowTroubleshootingTools(enterprise_params))
.Set(kAllowReconnections, ShouldAllowReconnections(enterprise_params))
.Set(kAllowFileTransfer, ShouldAllowFileTransfer(enterprise_params))
.Set(kIsEnterpriseAdminUser, enterprise_params.has_value());
if (params.authorized_helper.has_value()) {
message.Set(kAuthorizedHelper, *params.authorized_helper);
}
if (reconnect_params.has_value()) {
// We persist the previously connected user as the `authorized_helper`, to
// prevent anyone else from snooping in and connecting to the session.
CHECK(params.authorized_helper.has_value());
message.Set(kReconnectParamsDict,
ReconnectParams::ToDict(*reconnect_params));
}
native_message_host_->OnMessage(*base::WriteJson(message));
}
void It2MeNativeMessageHostAsh::Disconnect() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
native_message_host_->OnMessage(*base::WriteJson(
base::Value::Dict().Set(kMessageType, kDisconnectMessage)));
// Notify the owner that the host has been disconnected. This will result in
// the destruction of this object so do not access member variables after this
// callback is run.
std::move(disconnected_callback_).Run();
}
void It2MeNativeMessageHostAsh::PostMessageFromNativeHost(
const std::string& message) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
std::string type;
base::Value::Dict contents;
if (!ParseNativeMessageJson(message, type, contents)) {
CloseChannel(std::string());
return;
}
if (type.empty()) {
LOG(ERROR) << "'type' not found in request.";
CloseChannel(ErrorCodeToString(protocol::ErrorCode::INCOMPATIBLE_PROTOCOL));
return;
}
if (type == kConnectResponse) {
HandleConnectResponse();
} else if (type == kDisconnectResponse) {
HandleDisconnectResponse();
} else if (type == kIncomingIqResponse) {
// These responses do not need to be handled as the Lacros NMH sends a
// response when the request message is first received.
} else if (type == kHostStateChangedMessage) {
HandleHostStateChangeMessage(std::move(contents));
} else if (type == kNatPolicyChangedMessage) {
HandleNatPolicyChangedMessage(std::move(contents));
} else if (type == kPolicyErrorMessage) {
HandlePolicyErrorMessage(std::move(contents));
} else if (type == kErrorMessage) {
HandleErrorMessage(std::move(contents));
} else {
LOG(ERROR) << "Unsupported message type: " << type;
CloseChannel(ErrorCodeToString(protocol::ErrorCode::INCOMPATIBLE_PROTOCOL));
}
}
void It2MeNativeMessageHostAsh::CloseChannel(const std::string& error_message) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
LOG_IF(ERROR, !error_message.empty())
<< "CloseChannel called with error: " << error_message;
Disconnect();
}
void It2MeNativeMessageHostAsh::HandleConnectResponse() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
std::move(connected_callback_).Run();
}
void It2MeNativeMessageHostAsh::HandleDisconnectResponse() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
remote_->OnHostStateDisconnected(ErrorCodeToString(protocol::ErrorCode::OK));
}
void It2MeNativeMessageHostAsh::HandleHostStateChangeMessage(
base::Value::Dict message) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
const std::string* new_state = message.FindString(kState);
if (!new_state) {
LOG(ERROR) << "Missing |" << kState << "| value in message.";
CloseChannel(ErrorCodeToString(protocol::ErrorCode::INCOMPATIBLE_PROTOCOL));
return;
}
if (*new_state == kHostStateStarting) {
remote_->OnHostStateStarting();
} else if (*new_state == kHostStateDisconnected) {
const std::string* disconnect_reason =
message.FindString(kDisconnectReason);
remote_->OnHostStateDisconnected(
disconnect_reason ? *disconnect_reason
: ErrorCodeToString(protocol::ErrorCode::OK));
std::move(host_state_disconnected_callback_).Run();
} else if (*new_state == kHostStateRequestedAccessCode) {
remote_->OnHostStateRequestedAccessCode();
} else if (*new_state == kHostStateReceivedAccessCode) {
const std::string* access_code = message.FindString(kAccessCode);
if (!access_code) {
LOG(ERROR) << "Missing |" << kAccessCode << "| value in message.";
CloseChannel(
ErrorCodeToString(protocol::ErrorCode::INCOMPATIBLE_PROTOCOL));
return;
}
std::optional<int> access_code_lifetime =
message.FindInt(kAccessCodeLifetime);
if (!access_code_lifetime) {
LOG(ERROR) << "Missing |" << kAccessCodeLifetime << "| value in message.";
CloseChannel(
ErrorCodeToString(protocol::ErrorCode::INCOMPATIBLE_PROTOCOL));
return;
}
remote_->OnHostStateReceivedAccessCode(
*access_code, base::Seconds(*access_code_lifetime));
} else if (*new_state == kHostStateConnecting) {
remote_->OnHostStateConnecting();
} else if (*new_state == kHostStateConnected) {
const std::string* remote_username = message.FindString(kClient);
if (!remote_username) {
LOG(ERROR) << "Missing |" << kClient << "| value in message.";
CloseChannel(
ErrorCodeToString(protocol::ErrorCode::INCOMPATIBLE_PROTOCOL));
return;
}
remote_->OnHostStateConnected(*remote_username);
std::optional<ReconnectParams> reconnect_params;
const auto* reconnect_params_ptr = message.FindDict(kReconnectParamsDict);
if (reconnect_params_ptr) {
reconnect_params.emplace(
ReconnectParams::FromDict(*reconnect_params_ptr));
}
std::move(host_state_connected_callback_).Run(std::move(reconnect_params));
} else if (*new_state == kHostStateError) {
const std::string* error_code_string =
message.FindString(kErrorMessageCode);
if (!error_code_string) {
LOG(ERROR) << "Missing |" << kErrorMessageCode << "| value in message.";
CloseChannel(
ErrorCodeToString(protocol::ErrorCode::INCOMPATIBLE_PROTOCOL));
return;
}
protocol::ErrorCode error_code;
if (!ParseErrorCode(*error_code_string, &error_code)) {
LOG(ERROR) << "Invalid |" << kErrorMessageCode << "| value "
<< *error_code_string << "in message.";
CloseChannel(
ErrorCodeToString(protocol::ErrorCode::INCOMPATIBLE_PROTOCOL));
return;
}
remote_->OnHostStateError(static_cast<int64_t>(error_code));
} else if (*new_state == kHostStateDomainError) {
remote_->OnInvalidDomainError();
} else {
NOTREACHED_IN_MIGRATION() << "Unknown state: " << *new_state;
CloseChannel(ErrorCodeToString(protocol::ErrorCode::INCOMPATIBLE_PROTOCOL));
return;
}
}
void It2MeNativeMessageHostAsh::HandleNatPolicyChangedMessage(
base::Value::Dict message) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
std::optional<bool> nat_enabled =
message.FindBool(kNatPolicyChangedMessageNatEnabled);
if (!nat_enabled.has_value()) {
LOG(ERROR) << "Missing |" << kNatPolicyChangedMessageNatEnabled
<< "| value in message.";
CloseChannel(ErrorCodeToString(protocol::ErrorCode::INCOMPATIBLE_PROTOCOL));
return;
}
std::optional<bool> relay_enabled =
message.FindBool(kNatPolicyChangedMessageRelayEnabled);
if (!nat_enabled.has_value()) {
LOG(ERROR) << "Missing |" << kNatPolicyChangedMessageRelayEnabled
<< "| value in message.";
CloseChannel(ErrorCodeToString(protocol::ErrorCode::INCOMPATIBLE_PROTOCOL));
return;
}
mojom::NatPolicyStatePtr nat_policy = mojom::NatPolicyState::New();
nat_policy->nat_enabled = *nat_enabled;
nat_policy->relay_enabled = *relay_enabled;
remote_->OnNatPolicyChanged(std::move(nat_policy));
}
void It2MeNativeMessageHostAsh::HandlePolicyErrorMessage(
base::Value::Dict message) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
remote_->OnPolicyError();
}
void It2MeNativeMessageHostAsh::HandleErrorMessage(base::Value::Dict message) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
const std::string* error_code_string = message.FindString(kErrorMessageCode);
if (!error_code_string) {
LOG(ERROR) << "Missing |" << kErrorMessageCode << "| value in message.";
CloseChannel(ErrorCodeToString(protocol::ErrorCode::INCOMPATIBLE_PROTOCOL));
return;
}
protocol::ErrorCode error_code;
if (!ParseErrorCode(*error_code_string, &error_code)) {
LOG(ERROR) << "Invalid |" << kErrorMessageCode << "| value "
<< *error_code_string << "in message.";
CloseChannel(ErrorCodeToString(protocol::ErrorCode::INCOMPATIBLE_PROTOCOL));
return;
}
remote_->OnHostStateError(static_cast<int64_t>(error_code));
}
} // namespace remoting