// Copyright 2017 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chromeos/ash/services/secure_channel/secure_channel.h"
#include <memory>
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/memory/ptr_util.h"
#include "chromeos/ash/components/multidevice/logging/logging.h"
#include "chromeos/ash/components/multidevice/secure_message_delegate_impl.h"
#include "chromeos/ash/services/secure_channel/file_transfer_update_callback.h"
#include "chromeos/ash/services/secure_channel/public/mojom/nearby_connector.mojom-shared.h"
#include "chromeos/ash/services/secure_channel/public/mojom/secure_channel.mojom-shared.h"
#include "chromeos/ash/services/secure_channel/public/mojom/secure_channel_types.mojom.h"
#include "chromeos/ash/services/secure_channel/wire_message.h"
namespace ash::secure_channel {
// static
SecureChannel::Factory* SecureChannel::Factory::factory_instance_ = nullptr;
// static
std::unique_ptr<SecureChannel> SecureChannel::Factory::Create(
std::unique_ptr<Connection> connection) {
if (factory_instance_)
return factory_instance_->CreateInstance(std::move(connection));
return base::WrapUnique(new SecureChannel(std::move(connection)));
}
// static
void SecureChannel::Factory::SetFactoryForTesting(Factory* factory) {
factory_instance_ = factory;
}
// static
std::string SecureChannel::StatusToString(const Status& status) {
switch (status) {
case Status::DISCONNECTED:
return "[disconnected]";
case Status::CONNECTING:
return "[connecting]";
case Status::CONNECTED:
return "[connected]";
case Status::AUTHENTICATING:
return "[authenticating]";
case Status::AUTHENTICATED:
return "[authenticated]";
case Status::DISCONNECTING:
return "[disconnecting]";
default:
return "[unknown status]";
}
}
SecureChannel::PendingMessage::PendingMessage(const std::string& feature,
const std::string& payload,
int sequence_number)
: feature(feature), payload(payload), sequence_number(sequence_number) {}
SecureChannel::PendingMessage::~PendingMessage() {}
SecureChannel::SecureChannel(std::unique_ptr<Connection> connection)
: status_(Status::DISCONNECTED), connection_(std::move(connection)) {
connection_->AddObserver(this);
connection_->AddNearbyConnectionObserver(this);
}
SecureChannel::~SecureChannel() {
connection_->RemoveObserver(this);
}
void SecureChannel::Initialize() {
DCHECK(status_ == Status::DISCONNECTED);
connection_->Connect();
TransitionToStatus(Status::CONNECTING);
}
int SecureChannel::SendMessage(const std::string& feature,
const std::string& payload) {
DCHECK(status_ == Status::AUTHENTICATED);
int sequence_number = next_sequence_number_;
next_sequence_number_++;
queued_messages_.emplace(
std::make_unique<PendingMessage>(feature, payload, sequence_number));
ProcessMessageQueue();
return sequence_number;
}
void SecureChannel::RegisterPayloadFile(
int64_t payload_id,
mojom::PayloadFilesPtr payload_files,
FileTransferUpdateCallback file_transfer_update_callback,
base::OnceCallback<void(bool)> registration_result_callback) {
DCHECK(status_ == Status::AUTHENTICATED);
connection_->RegisterPayloadFile(payload_id, std::move(payload_files),
std::move(file_transfer_update_callback),
std::move(registration_result_callback));
}
void SecureChannel::Disconnect() {
if (connection_->IsConnected()) {
TransitionToStatus(Status::DISCONNECTING);
// If |connection_| is active, calling Disconnect() will eventually cause
// its status to transition to DISCONNECTED, which will in turn cause this
// class to transition to DISCONNECTED.
connection_->Disconnect();
return;
}
TransitionToStatus(Status::DISCONNECTED);
}
void SecureChannel::AddObserver(Observer* observer) {
observer_list_.AddObserver(observer);
}
void SecureChannel::RemoveObserver(Observer* observer) {
observer_list_.RemoveObserver(observer);
}
void SecureChannel::GetConnectionRssi(
base::OnceCallback<void(std::optional<int32_t>)> callback) {
if (!connection_) {
std::move(callback).Run(std::nullopt);
return;
}
connection_->GetConnectionRssi(std::move(callback));
}
std::optional<std::string> SecureChannel::GetChannelBindingData() {
if (secure_context_)
return secure_context_->GetChannelBindingData();
return std::nullopt;
}
void SecureChannel::OnConnectionStatusChanged(Connection* connection,
Connection::Status old_status,
Connection::Status new_status) {
DCHECK(connection == connection_.get());
if (new_status == Connection::Status::CONNECTED) {
TransitionToStatus(Status::CONNECTED);
// Once the connection has succeeded, authenticate the connection by
// initiating the security handshake.
Authenticate();
return;
}
if (new_status == Connection::Status::DISCONNECTED) {
// If the connection is no longer active, disconnect.
Disconnect();
return;
}
}
void SecureChannel::OnMessageReceived(const Connection& connection,
const WireMessage& wire_message) {
DCHECK(&connection == const_cast<const Connection*>(connection_.get()));
if (wire_message.feature() == Authenticator::kAuthenticationFeature) {
// If the message received was part of the authentication handshake, it
// is a low-level message and should not be forwarded to observers.
return;
}
if (!secure_context_) {
PA_LOG(WARNING) << "Received unexpected message before authentication "
<< "was complete. Feature: " << wire_message.feature()
<< ", Payload size: " << wire_message.payload().size()
<< " byte(s)";
return;
}
secure_context_->DecodeAndDequeue(
wire_message.payload(),
base::BindRepeating(&SecureChannel::OnMessageDecoded,
weak_ptr_factory_.GetWeakPtr(),
wire_message.feature()));
}
void SecureChannel::OnSendCompleted(const Connection& connection,
const WireMessage& wire_message,
bool success) {
if (wire_message.feature() == Authenticator::kAuthenticationFeature) {
// No need to process authentication messages; these are handled by
// |authenticator_|.
return;
}
if (!pending_message_) {
PA_LOG(ERROR) << "OnSendCompleted(), but a send was not expected to be in "
<< "progress. Disconnecting from "
<< connection_->GetDeviceAddress();
Disconnect();
return;
}
if (success && status_ != Status::DISCONNECTED) {
pending_message_.reset();
// Create a WeakPtr to |this| before invoking observer callbacks. It is
// possible that an Observer will respond to the OnMessageSent() call by
// destroying the connection (e.g., if the client only wanted to send one
// message and destroyed the connection after the message was sent).
base::WeakPtr<SecureChannel> weak_this = weak_ptr_factory_.GetWeakPtr();
if (wire_message.sequence_number() != -1) {
for (auto& observer : observer_list_)
observer.OnMessageSent(this, wire_message.sequence_number());
}
// Process the next message if possible. Note that if the SecureChannel was
// deleted by the OnMessageSent() callback, this will be a no-op since
// |weak_this| will have been invalidated in that case.
if (weak_this.get())
weak_this->ProcessMessageQueue();
return;
}
PA_LOG(ERROR) << "Could not send message: {"
<< "payload size: " << pending_message_->payload.size()
<< " byte(s), feature: \"" << pending_message_->feature << "\""
<< "}";
pending_message_.reset();
// The connection automatically retries failed messages, so if |success| is
// |false| here, a fatal error has occurred. Thus, there is no need to retry
// the message; instead, disconnect.
Disconnect();
}
void SecureChannel::OnNearbyConnectionStateChagned(
mojom::NearbyConnectionStep step,
mojom::NearbyConnectionStepResult result) {
for (auto& observer : observer_list_) {
observer.OnNearbyConnectionStateChanged(this, step, result);
}
}
void SecureChannel::OnAuthenticationStateChanged(
mojom::SecureChannelState secure_channel_state) {
for (auto& observer : observer_list_) {
observer.OnSecureChannelAuthenticationStateChanged(this,
secure_channel_state);
}
}
void SecureChannel::TransitionToStatus(const Status& new_status) {
if (new_status == status_) {
// Only report changes to state.
return;
}
Status old_status = status_;
status_ = new_status;
for (auto& observer : observer_list_)
observer.OnSecureChannelStatusChanged(this, old_status, status_);
}
void SecureChannel::Authenticate() {
DCHECK(status_ == Status::CONNECTED);
DCHECK(!authenticator_);
authenticator_ = DeviceToDeviceAuthenticator::Factory::Create(
connection_.get(),
multidevice::SecureMessageDelegateImpl::Factory::Create());
authenticator_->AddObserver(this);
authenticator_->Authenticate(base::BindOnce(
&SecureChannel::OnAuthenticationResult, weak_ptr_factory_.GetWeakPtr()));
TransitionToStatus(Status::AUTHENTICATING);
}
void SecureChannel::ProcessMessageQueue() {
if (pending_message_ || queued_messages_.empty()) {
return;
}
DCHECK(!connection_->is_sending_message());
pending_message_ = std::move(queued_messages_.front());
queued_messages_.pop();
PA_LOG(INFO) << "Sending message to " << connection_->GetDeviceAddress()
<< ": {"
<< "feature: \"" << pending_message_->feature << "\", "
<< "payload size: " << pending_message_->payload.size()
<< " byte(s)"
<< "}";
secure_context_->Encode(
pending_message_->payload,
base::BindOnce(&SecureChannel::OnMessageEncoded,
weak_ptr_factory_.GetWeakPtr(), pending_message_->feature,
pending_message_->sequence_number));
}
void SecureChannel::OnMessageEncoded(const std::string& feature,
int sequence_number,
const std::string& encoded_message) {
connection_->SendMessage(
std::make_unique<WireMessage>(encoded_message, feature, sequence_number));
}
void SecureChannel::OnMessageDecoded(const std::string& feature,
const std::string& decoded_message) {
PA_LOG(VERBOSE) << "Received message from " << connection_->GetDeviceAddress()
<< ": {"
<< "feature: \"" << feature << "\", "
<< "payload size: " << decoded_message.size() << " byte(s)"
<< "}";
for (auto& observer : observer_list_)
observer.OnMessageReceived(this, feature, decoded_message);
}
void SecureChannel::OnAuthenticationResult(
Authenticator::Result result,
std::unique_ptr<SecureContext> secure_context) {
DCHECK(status_ == Status::AUTHENTICATING);
// The authenticator is no longer needed, so release it.
authenticator_->RemoveObserver(this);
authenticator_.reset();
if (result != Authenticator::Result::SUCCESS) {
PA_LOG(WARNING)
<< "Failed to authenticate connection to device with ID "
<< connection_->remote_device().GetTruncatedDeviceIdForLogs();
Disconnect();
return;
}
secure_context_ = std::move(secure_context);
TransitionToStatus(Status::AUTHENTICATED);
}
} // namespace ash::secure_channel