// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/nearby_sharing/tcp_socket/nearby_connections_tcp_socket_factory.h"
#include "base/containers/contains.h"
#include "base/functional/bind.h"
#include "base/metrics/histogram_functions.h"
#include "chromeos/ash/services/nearby/public/cpp/tcp_server_socket_port.h"
#include "net/base/ip_address.h"
#include "net/base/net_errors.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
NearbyConnectionsTcpSocketFactory::ConnectTask::ConnectTask(
network::mojom::NetworkContext* network_context,
const std::optional<net::IPEndPoint>& local_addr,
const net::AddressList& remote_addr_list,
network::mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options,
const net::MutableNetworkTrafficAnnotationTag& traffic_annotation,
mojo::PendingReceiver<network::mojom::TCPConnectedSocket> receiver,
mojo::PendingRemote<network::mojom::SocketObserver> observer,
CreateTCPConnectedSocketCallback callback)
: callback_(std::move(callback)) {
DCHECK(network_context);
task_ = base::BindOnce(
&network::mojom::NetworkContext::CreateTCPConnectedSocket,
base::Unretained(network_context), local_addr, remote_addr_list,
std::move(tcp_connected_socket_options), traffic_annotation,
std::move(receiver), std::move(observer),
base::BindOnce(&ConnectTask::OnFinished, weak_ptr_factory_.GetWeakPtr()));
}
NearbyConnectionsTcpSocketFactory::ConnectTask::~ConnectTask() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}
void NearbyConnectionsTcpSocketFactory::ConnectTask::Run(
base::TimeDelta timeout) {
timer_.Start(FROM_HERE, timeout,
base::BindOnce(&ConnectTask::OnTimeout, base::Unretained(this)));
start_time_ = base::TimeTicks::Now();
std::move(task_).Run();
}
void NearbyConnectionsTcpSocketFactory::ConnectTask::OnFinished(
int32_t result,
const std::optional<net::IPEndPoint>& local_addr,
const std::optional<net::IPEndPoint>& peer_addr,
mojo::ScopedDataPipeConsumerHandle receive_stream,
mojo::ScopedDataPipeProducerHandle send_stream) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
timer_.Stop();
if (result == net::OK) {
base::UmaHistogramTimes("Nearby.Connections.WifiLan.TimeToConnect",
base::TimeTicks::Now() - start_time_);
}
// Just to be safe, protect against finish/timeout race conditions.
if (!callback_)
return;
std::move(callback_).Run(result, local_addr, peer_addr,
std::move(receive_stream), std::move(send_stream));
}
void NearbyConnectionsTcpSocketFactory::ConnectTask::OnTimeout() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
weak_ptr_factory_.InvalidateWeakPtrs();
OnFinished(net::ERR_TIMED_OUT, /*local_addr=*/std::nullopt,
/*peer_addr=*/std::nullopt,
/*receive_stream=*/mojo::ScopedDataPipeConsumerHandle(),
/*send_stream=*/mojo::ScopedDataPipeProducerHandle());
}
NearbyConnectionsTcpSocketFactory::NearbyConnectionsTcpSocketFactory(
network::NetworkContextGetter network_context_getter)
: network_context_getter_(std::move(network_context_getter)) {}
NearbyConnectionsTcpSocketFactory::~NearbyConnectionsTcpSocketFactory() =
default;
void NearbyConnectionsTcpSocketFactory::CreateTCPServerSocket(
const net::IPAddress& local_addr,
const ash::nearby::TcpServerSocketPort& port,
uint32_t backlog,
const net::MutableNetworkTrafficAnnotationTag& traffic_annotation,
mojo::PendingReceiver<network::mojom::TCPServerSocket> receiver,
CreateTCPServerSocketCallback callback) {
network::mojom::NetworkContext* network_context =
network_context_getter_.Run();
if (!network_context) {
std::move(callback).Run(net::ERR_FAILED, /*local_addr=*/std::nullopt);
return;
}
auto options = network::mojom::TCPServerSocketOptions::New();
options->backlog = backlog;
network_context->CreateTCPServerSocket(
net::IPEndPoint(local_addr, port.port()), std::move(options),
traffic_annotation, std::move(receiver),
base::BindOnce(
&NearbyConnectionsTcpSocketFactory::OnTcpServerSocketCreated,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}
void NearbyConnectionsTcpSocketFactory::CreateTCPConnectedSocket(
base::TimeDelta timeout,
const std::optional<net::IPEndPoint>& local_addr,
const net::AddressList& remote_addr_list,
network::mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options,
const net::MutableNetworkTrafficAnnotationTag& traffic_annotation,
mojo::PendingReceiver<network::mojom::TCPConnectedSocket> receiver,
mojo::PendingRemote<network::mojom::SocketObserver> observer,
CreateTCPConnectedSocketCallback callback) {
network::mojom::NetworkContext* network_context =
network_context_getter_.Run();
if (!network_context) {
std::move(callback).Run(
net::ERR_FAILED, /*local_addr=*/std::nullopt,
/*peer_addr=*/std::nullopt,
/*receive_stream=*/mojo::ScopedDataPipeConsumerHandle(),
/*send_stream=*/mojo::ScopedDataPipeProducerHandle());
return;
}
base::UnguessableToken task_id = base::UnguessableToken::Create();
connect_tasks_.insert_or_assign(
task_id,
std::make_unique<ConnectTask>(
network_context, local_addr, remote_addr_list,
std::move(tcp_connected_socket_options), traffic_annotation,
std::move(receiver), std::move(observer),
base::BindOnce(
&NearbyConnectionsTcpSocketFactory::OnTcpConnectedSocketCreated,
base::Unretained(this), task_id, std::move(callback))));
connect_tasks_[task_id]->Run(timeout);
}
void NearbyConnectionsTcpSocketFactory::OnTcpServerSocketCreated(
CreateTCPServerSocketCallback callback,
int32_t result,
const std::optional<net::IPEndPoint>& local_addr) {
std::move(callback).Run(result, local_addr);
}
void NearbyConnectionsTcpSocketFactory::OnTcpConnectedSocketCreated(
base::UnguessableToken task_id,
CreateTCPConnectedSocketCallback callback,
int32_t result,
const std::optional<net::IPEndPoint>& local_addr,
const std::optional<net::IPEndPoint>& peer_addr,
mojo::ScopedDataPipeConsumerHandle receive_stream,
mojo::ScopedDataPipeProducerHandle send_stream) {
std::move(callback).Run(result, local_addr, peer_addr,
std::move(receive_stream), std::move(send_stream));
DCHECK(base::Contains(connect_tasks_, task_id));
connect_tasks_.erase(task_id);
}