// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/network/brokered_tcp_client_socket.h"
#include "base/functional/bind.h"
#include "base/functional/callback_helpers.h"
#include "base/memory/weak_ptr.h"
#include "build/build_config.h"
#include "net/base/address_list.h"
#include "net/base/completion_once_callback.h"
#include "net/socket/tcp_client_socket.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
#include "services/network/brokered_client_socket_factory.h"
#include "services/network/public/cpp/transferable_socket.h"
namespace network {
BrokeredTcpClientSocket::BrokeredTcpClientSocket(
const net::AddressList& addresses,
std::unique_ptr<net::SocketPerformanceWatcher> socket_performance_watcher,
net::NetworkQualityEstimator* network_quality_estimator,
net::NetLog* net_log,
const net::NetLogSource& source,
BrokeredClientSocketFactory* client_socket_factory)
: addresses_(addresses),
socket_performance_watcher_(std::move(socket_performance_watcher)),
network_quality_estimator_(network_quality_estimator),
net_log_source_(
net::NetLogWithSource::Make(net_log, net::NetLogSourceType::SOCKET)),
client_socket_factory_(client_socket_factory) {
net_log_source_.BeginEventReferencingSource(
net::NetLogEventType::BROKERED_SOCKET_ALIVE, source);
}
BrokeredTcpClientSocket::~BrokeredTcpClientSocket() {
net_log_source_.EndEvent(net::NetLogEventType::BROKERED_SOCKET_ALIVE);
Disconnect();
}
int BrokeredTcpClientSocket::Bind(const net::IPEndPoint& address) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (IsConnected() || is_connect_in_progress_) {
// Cannot bind the socket if we are already connected or connecting.
NOTREACHED_IN_MIGRATION();
return net::ERR_UNEXPECTED;
}
// Since opening a socket must be done via an asynchronous IPC, we will store
// the bind address and attempt to bind when Connect() is called. Bind() will
// be done after opening a socket but before actually connecting.
bind_address_ = std::make_unique<net::IPEndPoint>(address);
return net::OK;
}
bool BrokeredTcpClientSocket::SetKeepAlive(bool enable, int delay) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!brokered_socket_) {
return false;
}
return brokered_socket_->SetKeepAlive(enable, delay);
}
bool BrokeredTcpClientSocket::SetNoDelay(bool no_delay) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!brokered_socket_) {
return false;
}
return brokered_socket_->SetNoDelay(no_delay);
}
void BrokeredTcpClientSocket::SetBeforeConnectCallback(
const BeforeConnectCallback& before_connect_callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(!before_connect_callback_);
DCHECK(!IsConnected());
DCHECK(!is_connect_in_progress_);
before_connect_callback_ = before_connect_callback;
}
int BrokeredTcpClientSocket::Connect(net::CompletionOnceCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
// TODO(liza): add support for reconnecting disconnected socket, or look into
// removing support for reconnection from TCPClientSocket if it's not needed.
DCHECK(!callback.is_null());
// If connecting or already connected, then just return OK.
if (IsConnected() || is_connect_in_progress_) {
return net::OK;
}
is_connect_in_progress_ = true;
net_log_source_.BeginEvent(net::NetLogEventType::BROKERED_CREATE_SOCKET);
// TODO(crbug.com/40223835): Pass in AddressFamily of single IPEndPoint
client_socket_factory_->BrokerCreateTcpSocket(
addresses_.begin()->GetFamily(),
base::BindOnce(&BrokeredTcpClientSocket::DidCompleteCreate,
brokered_weak_ptr_factory_.GetWeakPtr(),
std::move(callback)));
return net::ERR_IO_PENDING;
}
void BrokeredTcpClientSocket::DidCompleteConnect(
net::CompletionOnceCallback callback,
int result) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK_NE(result, net::ERR_IO_PENDING);
is_connect_in_progress_ = false;
// The callback may delete {this}.
std::move(callback).Run(result);
}
void BrokeredTcpClientSocket ::DidCompleteCreate(
net::CompletionOnceCallback callback,
network::TransferableSocket socket,
int result) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
net_log_source_.EndEventWithNetErrorCode(
net::NetLogEventType::BROKERED_CREATE_SOCKET, result);
if (result != net::OK) {
std::move(callback).Run(result);
return;
}
// Create an unconnected TCPSocket with the socket fd that was opened in the
// browser process.
std::unique_ptr<net::TCPSocket> tcp_socket = std::make_unique<net::TCPSocket>(
std::move(socket_performance_watcher_), net_log_source_);
tcp_socket->AdoptUnconnectedSocket(socket.TakeSocket());
// If Bind() was called prior to connecting, attempt to bind now that a socket
// has been opened.
if (bind_address_) {
int bind_result = tcp_socket->Bind(*bind_address_);
if (bind_result != net::OK) {
tcp_socket->Close();
std::move(callback).Run(bind_result);
return;
}
}
// TODO(liza): Pass through the NetworkHandle.
brokered_socket_ = std::make_unique<net::TCPClientSocket>(
std::move(tcp_socket), addresses_, std::move(bind_address_),
network_quality_estimator_);
brokered_socket_->ApplySocketTag(tag_);
if (before_connect_callback_) {
int callback_result = before_connect_callback_.Run();
DCHECK_NE(net::ERR_IO_PENDING, callback_result);
if (callback_result != net::OK) {
net_log_source_.AddEventWithNetErrorCode(net::NetLogEventType::FAILED,
callback_result);
std::move(callback).Run(callback_result);
return;
}
}
auto split_connect_callback = base::SplitOnceCallback(std::move(callback));
int connect_result = brokered_socket_->Connect(
base::BindOnce(&BrokeredTcpClientSocket::DidCompleteConnect,
brokered_weak_ptr_factory_.GetWeakPtr(),
std::move(split_connect_callback.first)));
if (connect_result != net::ERR_IO_PENDING) {
DidCompleteConnect(std::move(split_connect_callback.second),
connect_result);
}
}
void BrokeredTcpClientSocket::Disconnect() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (brokered_socket_) {
brokered_socket_->Disconnect();
}
bind_address_.reset();
is_connect_in_progress_ = false;
}
bool BrokeredTcpClientSocket::IsConnected() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!brokered_socket_) {
return false;
}
return brokered_socket_->IsConnected();
}
bool BrokeredTcpClientSocket::IsConnectedAndIdle() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!brokered_socket_) {
return false;
}
return brokered_socket_->IsConnectedAndIdle();
}
int BrokeredTcpClientSocket::GetPeerAddress(net::IPEndPoint* address) const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!brokered_socket_) {
return net::ERR_SOCKET_NOT_CONNECTED;
}
return brokered_socket_->GetPeerAddress(std::move(address));
}
int BrokeredTcpClientSocket::GetLocalAddress(net::IPEndPoint* address) const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!brokered_socket_) {
return net::ERR_SOCKET_NOT_CONNECTED;
}
return brokered_socket_->GetLocalAddress(std::move(address));
}
const net::NetLogWithSource& BrokeredTcpClientSocket::NetLog() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return net_log_source_;
}
bool BrokeredTcpClientSocket::WasEverUsed() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!brokered_socket_) {
return false;
}
return brokered_socket_->WasEverUsed();
}
net::NextProto BrokeredTcpClientSocket::GetNegotiatedProtocol() const {
return net::kProtoUnknown;
}
bool BrokeredTcpClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
return false;
}
int64_t BrokeredTcpClientSocket::GetTotalReceivedBytes() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!brokered_socket_) {
return 0;
}
return brokered_socket_->GetTotalReceivedBytes();
}
void BrokeredTcpClientSocket::ApplySocketTag(const net::SocketTag& tag) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!brokered_socket_) {
tag_ = tag;
} else {
brokered_socket_->ApplySocketTag(tag);
}
}
int BrokeredTcpClientSocket::Read(net::IOBuffer* buf,
int buf_len,
net::CompletionOnceCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!brokered_socket_) {
return net::ERR_SOCKET_NOT_CONNECTED;
}
return brokered_socket_->Read(buf, buf_len, std::move(callback));
}
int BrokeredTcpClientSocket::ReadIfReady(net::IOBuffer* buf,
int buf_len,
net::CompletionOnceCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!brokered_socket_) {
return net::ERR_SOCKET_NOT_CONNECTED;
}
return brokered_socket_->ReadIfReady(buf, buf_len, std::move(callback));
}
int BrokeredTcpClientSocket::CancelReadIfReady() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!brokered_socket_) {
return net::ERR_SOCKET_NOT_CONNECTED;
}
return brokered_socket_->CancelReadIfReady();
}
int BrokeredTcpClientSocket::Write(
net::IOBuffer* buf,
int buf_len,
net::CompletionOnceCallback callback,
const net::NetworkTrafficAnnotationTag& traffic_annotation) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!brokered_socket_) {
return net::ERR_SOCKET_NOT_CONNECTED;
}
return brokered_socket_->Write(std::move(buf), buf_len, std::move(callback),
traffic_annotation);
}
int BrokeredTcpClientSocket::SetReceiveBufferSize(int32_t size) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!brokered_socket_) {
return net::ERR_SOCKET_NOT_CONNECTED;
}
return brokered_socket_->SetReceiveBufferSize(size);
}
int BrokeredTcpClientSocket::SetSendBufferSize(int32_t size) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!brokered_socket_) {
return net::ERR_SOCKET_NOT_CONNECTED;
}
return brokered_socket_->SetSendBufferSize(size);
}
} // namespace network