// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/network/brokered_udp_client_socket.h"
#include "base/component_export.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "base/task/single_thread_task_runner.h"
#include "build/build_config.h"
#include "mojo/public/cpp/platform/platform_handle.h"
#include "net/base/address_list.h"
#include "net/base/completion_once_callback.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/base/network_change_notifier.h"
#include "net/base/network_handle.h"
#include "net/log/net_log_source.h"
#include "net/nqe/network_quality_estimator.h"
#include "net/socket/datagram_client_socket.h"
#include "net/socket/datagram_socket.h"
#include "net/socket/socket_tag.h"
#include "net/socket/stream_socket.h"
#include "net/socket/udp_client_socket.h"
#include "net/socket/udp_socket.h"
#include "net/socket/udp_socket_global_limits.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
#include "services/network/brokered_client_socket_factory.h"
#include "services/network/public/cpp/transferable_socket.h"
namespace network {
BrokeredUdpClientSocket::BrokeredUdpClientSocket(
net::DatagramSocket::BindType bind_type,
net::NetLog* net_log,
const net::NetLogSource& source,
BrokeredClientSocketFactory* client_socket_factory,
net::handles::NetworkHandle network)
: bind_type_(bind_type),
network_(network),
net_log_source_(
net::NetLogWithSource::Make(net_log, net::NetLogSourceType::SOCKET)),
client_socket_factory_(client_socket_factory) {
net_log_source_.BeginEventReferencingSource(
net::NetLogEventType::BROKERED_SOCKET_ALIVE, source);
}
BrokeredUdpClientSocket::~BrokeredUdpClientSocket() {
net_log_source_.EndEvent(net::NetLogEventType::BROKERED_SOCKET_ALIVE);
}
int BrokeredUdpClientSocket::Connect(const net::IPEndPoint& address) {
if (!broker_helper_.ShouldBroker(address.address())) {
return ConnectInternal(address);
}
// Brokered sockets can only support asynchronous connections so this does not
// need to be implemented. However, this path can still be hit if the sandbox
// is enabled and a caller attempts to call a synchronous Connect. Callers are
// expected to handle Connect failures themselves so we just return
// ERR_NOT_IMPLEMENTED.
return net::ERR_NOT_IMPLEMENTED;
}
int BrokeredUdpClientSocket::ConnectUsingNetwork(
net::handles::NetworkHandle network,
const net::IPEndPoint& address) {
// NetworkHandles are not supported on Windows, so this method and the
// following Connect*Network() methods don't need to return anything.
return net::ERR_NOT_IMPLEMENTED;
}
int BrokeredUdpClientSocket::ConnectUsingDefaultNetwork(
const net::IPEndPoint& address) {
return net::ERR_NOT_IMPLEMENTED;
}
int BrokeredUdpClientSocket::ConnectAsync(
const net::IPEndPoint& address,
net::CompletionOnceCallback callback) {
return ConnectAsyncInternal(address, std::move(callback));
}
int BrokeredUdpClientSocket::ConnectUsingNetworkAsync(
net::handles::NetworkHandle network,
const net::IPEndPoint& address,
net::CompletionOnceCallback callback) {
return net::ERR_NOT_IMPLEMENTED;
}
int BrokeredUdpClientSocket::ConnectUsingDefaultNetworkAsync(
const net::IPEndPoint& address,
net::CompletionOnceCallback callback) {
return net::ERR_NOT_IMPLEMENTED;
}
int BrokeredUdpClientSocket::ConnectAsyncInternal(
const net::IPEndPoint& address,
net::CompletionOnceCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(callback);
DCHECK(!socket_);
CHECK(!connect_called_);
connect_called_ = true;
if (!broker_helper_.ShouldBroker(address.address())) {
return DidCompleteCreate(/*should_broker=*/false, address,
std::move(callback), network::TransferableSocket(),
net::OK);
}
net_log_source_.BeginEvent(net::NetLogEventType::BROKERED_CREATE_SOCKET);
client_socket_factory_->BrokerCreateUdpSocket(
address.GetFamily(),
base::BindOnce(
base::IgnoreResult(&BrokeredUdpClientSocket::DidCompleteCreate),
brokered_weak_ptr_factory_.GetWeakPtr(), /*should_broker=*/true,
address, std::move(callback)));
return net::ERR_IO_PENDING;
}
int BrokeredUdpClientSocket::ConnectInternal(const net::IPEndPoint& address) {
socket_ = std::make_unique<net::UDPClientSocket>(bind_type_, net_log_source_,
network_);
// These options must be set before opening a socket or adopting an opened
// socket.
if (use_non_blocking_io_) {
socket_->UseNonBlockingIO();
}
if (recv_optimization_) {
socket_->EnableRecvOptimization();
}
int set_multicast_rv = socket_->SetMulticastInterface(interface_index_);
if (set_multicast_rv != net::OK) {
return set_multicast_rv;
}
socket_->ApplySocketTag(tag_);
socket_->SetMsgConfirm(set_msg_confirm_);
int connect_rv = socket_->Connect(address);
return connect_rv;
}
int BrokeredUdpClientSocket::DidCompleteCreate(
bool should_broker,
const net::IPEndPoint& address,
net::CompletionOnceCallback callback,
network::TransferableSocket socket,
int result) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (should_broker) {
net_log_source_.EndEventWithNetErrorCode(
net::NetLogEventType::BROKERED_CREATE_SOCKET, result);
if (result != net::OK) {
std::move(callback).Run(result);
return result;
}
}
socket_ = std::make_unique<net::UDPClientSocket>(bind_type_, net_log_source_,
network_);
// These options must be set before opening a socket or adopting an opened
// socket.
if (use_non_blocking_io_) {
socket_->UseNonBlockingIO();
}
if (recv_optimization_) {
socket_->EnableRecvOptimization();
}
if (should_broker) {
int adopt_socket_rv =
socket_->AdoptOpenedSocket(address.GetFamily(), socket.TakeSocket());
if (adopt_socket_rv != net::OK) {
Close();
std::move(callback).Run(adopt_socket_rv);
return adopt_socket_rv;
}
}
int set_multicast_rv = socket_->SetMulticastInterface(interface_index_);
if (set_multicast_rv != net::OK) {
if (should_broker) {
std::move(callback).Run(set_multicast_rv);
}
return set_multicast_rv;
}
socket_->ApplySocketTag(tag_);
socket_->SetMsgConfirm(set_msg_confirm_);
auto split_callback = base::SplitOnceCallback(std::move(callback));
const int connect_rv =
socket_->ConnectAsync(address, std::move(split_callback.first));
if (should_broker && connect_rv != net::ERR_IO_PENDING) {
std::move(split_callback.second).Run(connect_rv);
}
return connect_rv;
}
net::handles::NetworkHandle BrokeredUdpClientSocket::GetBoundNetwork() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!socket_) {
return net::handles::kInvalidNetworkHandle;
}
return socket_->GetBoundNetwork();
}
void BrokeredUdpClientSocket::ApplySocketTag(const net::SocketTag& tag) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (socket_) {
socket_->ApplySocketTag(tag);
}
tag_ = tag;
}
int BrokeredUdpClientSocket::Read(net::IOBuffer* buf,
int buf_len,
net::CompletionOnceCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!socket_) {
return net::ERR_SOCKET_NOT_CONNECTED;
}
return socket_->Read(buf, buf_len, std::move(callback));
}
int BrokeredUdpClientSocket::Write(
net::IOBuffer* buf,
int buf_len,
net::CompletionOnceCallback callback,
const net::NetworkTrafficAnnotationTag& traffic_annotation) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!socket_) {
return net::ERR_SOCKET_NOT_CONNECTED;
}
return socket_->Write(buf, buf_len, std::move(callback), traffic_annotation);
}
void BrokeredUdpClientSocket::Close() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
socket_.reset();
brokered_weak_ptr_factory_.InvalidateWeakPtrs();
}
int BrokeredUdpClientSocket::GetPeerAddress(net::IPEndPoint* address) const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!socket_) {
return net::ERR_SOCKET_NOT_CONNECTED;
}
return socket_->GetPeerAddress(address);
}
int BrokeredUdpClientSocket::GetLocalAddress(net::IPEndPoint* address) const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!socket_) {
return net::ERR_SOCKET_NOT_CONNECTED;
}
return socket_->GetLocalAddress(address);
}
int BrokeredUdpClientSocket::SetReceiveBufferSize(int32_t size) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(socket_);
return socket_->SetReceiveBufferSize(size);
}
int BrokeredUdpClientSocket::SetSendBufferSize(int32_t size) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(socket_);
return socket_->SetSendBufferSize(size);
}
int BrokeredUdpClientSocket::SetDoNotFragment() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(socket_);
return socket_->SetDoNotFragment();
}
int BrokeredUdpClientSocket::SetRecvTos() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(socket_);
return socket_->SetRecvTos();
}
int BrokeredUdpClientSocket::SetTos(net::DiffServCodePoint dscp,
net::EcnCodePoint ecn) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(socket_);
return socket_->SetTos(dscp, ecn);
}
void BrokeredUdpClientSocket::SetMsgConfirm(bool confirm) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
set_msg_confirm_ = confirm;
}
const net::NetLogWithSource& BrokeredUdpClientSocket::NetLog() const {
return net_log_source_;
}
void BrokeredUdpClientSocket::UseNonBlockingIO() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
use_non_blocking_io_ = true;
}
int BrokeredUdpClientSocket::SetMulticastInterface(uint32_t interface_index) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!socket_) {
interface_index_ = interface_index;
return net::OK;
}
return socket_->SetMulticastInterface(interface_index);
}
void BrokeredUdpClientSocket::EnableRecvOptimization() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
recv_optimization_ = true;
}
void BrokeredUdpClientSocket::SetIOSNetworkServiceType(
int ios_network_service_type) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
socket_->SetIOSNetworkServiceType(ios_network_service_type);
}
net::DscpAndEcn BrokeredUdpClientSocket::GetLastTos() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return socket_->GetLastTos();
}
} // namespace network