chromium/chrome/browser/ash/net/network_diagnostics/tls_prober.cc

// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/net/network_diagnostics/tls_prober.h"

#include <optional>
#include <utility>

#include "base/functional/bind.h"
#include "base/functional/callback_helpers.h"
#include "base/task/task_runner.h"
#include "chrome/browser/ash/net/network_diagnostics/network_diagnostics_util.h"
#include "content/public/browser/browser_task_traits.h"
#include "content/public/browser/browser_thread.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "net/base/address_list.h"
#include "net/base/host_port_pair.h"
#include "net/base/net_errors.h"
#include "services/network/public/cpp/resolve_host_client_base.h"
#include "services/network/public/cpp/simple_host_resolver.h"

namespace ash::network_diagnostics {

namespace {

net::NetworkTrafficAnnotationTag GetTrafficAnnotationTag() {
  return net::DefineNetworkTrafficAnnotation("network_diagnostics_tls",
                                             R"(
      semantics {
        sender: "NetworkDiagnosticsRoutines"
        description:
            "Routines send network traffic to hosts in order to "
            "validate the internet connection on a device."
        trigger:
            "A routine attempts a socket connection or makes an http/s "
            "request."
        data:
            "No data other than the origin (scheme-host-port) is sent. "
            "No user identifier is sent along with the data."
        destination: GOOGLE_OWNED_SERVICE
      }
      policy {
        cookies_allowed: NO
        policy_exception_justification:
            "Not implemented. Does not contain user identifier."
      }
  )");
}

}  // namespace

TlsProber::TlsProber(network::NetworkContextGetter network_context_getter,
                     net::HostPortPair host_port_pair,
                     bool negotiate_tls,
                     TlsProbeCompleteCallback callback)
    : network_context_getter_(std::move(network_context_getter)),
      host_port_pair_(std::move(host_port_pair)),
      negotiate_tls_(negotiate_tls),
      callback_(std::move(callback)) {
  DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
  DCHECK(callback_);
  DCHECK(!host_port_pair_.IsEmpty());

  network::mojom::NetworkContext* network_context =
      network_context_getter_.Run();
  DCHECK(network_context);

  host_resolver_ = network::SimpleHostResolver::Create(network_context);

  network::mojom::ResolveHostParametersPtr parameters =
      network::mojom::ResolveHostParameters::New();
  parameters->dns_query_type = net::DnsQueryType::A;
  parameters->source = net::HostResolverSource::DNS;
  parameters->cache_usage =
      network::mojom::ResolveHostParameters::CacheUsage::DISALLOWED;

  // Unretained(this) is safe here because the callback is invoked directly by
  // |host_resolver_| which is owned by |this|.
  host_resolver_->ResolveHost(
      network::mojom::HostResolverHost::NewHostPortPair(host_port_pair_),
      net::NetworkAnonymizationKey::CreateTransient(), std::move(parameters),
      base::BindOnce(&TlsProber::OnHostResolutionComplete,
                     base::Unretained(this)));
}

TlsProber::TlsProber()
    : network_context_getter_(base::NullCallback()), negotiate_tls_(false) {}

TlsProber::~TlsProber() = default;

void TlsProber::OnHostResolutionComplete(
    int result,
    const net::ResolveErrorInfo&,
    const std::optional<net::AddressList>& resolved_addresses,
    const std::optional<net::HostResolverEndpointResults>&) {
  DCHECK_CURRENTLY_ON(content::BrowserThread::UI);

  host_resolver_.reset();
  if (result != net::OK) {
    CHECK(!resolved_addresses);
    OnDone(result, ProbeExitEnum::kDnsFailure);
    return;
  }
  CHECK(resolved_addresses);

  network::mojom::NetworkContext::CreateTCPConnectedSocketCallback
      completion_callback = base::BindOnce(&TlsProber::OnConnectComplete,
                                           weak_factory_.GetWeakPtr());
  auto pending_receiver =
      tcp_connected_socket_remote_.BindNewPipeAndPassReceiver();
  // Add a disconnect handler to the TCPConnectedSocket remote.
  tcp_connected_socket_remote_.set_disconnect_handler(
      base::BindOnce(&TlsProber::OnDisconnect, weak_factory_.GetWeakPtr()));

  network::mojom::NetworkContext* network_context =
      network_context_getter_.Run();
  CHECK(network_context);

  network_context->CreateTCPConnectedSocket(
      /*local_addr=*/std::nullopt, resolved_addresses.value(),
      /*tcp_connected_socket_options=*/nullptr,
      net::MutableNetworkTrafficAnnotationTag(GetTrafficAnnotationTag()),
      std::move(pending_receiver), /*observer=*/mojo::NullRemote(),
      std::move(completion_callback));
}

void TlsProber::OnConnectComplete(
    int result,
    const std::optional<net::IPEndPoint>& local_addr,
    const std::optional<net::IPEndPoint>& peer_addr,
    mojo::ScopedDataPipeConsumerHandle receive_stream,
    mojo::ScopedDataPipeProducerHandle send_stream) {
  DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
  DCHECK(tcp_connected_socket_remote_.is_bound());

  if (result != net::OK) {
    OnDone(result, ProbeExitEnum::kTcpConnectionFailure);
    return;
  }
  if (!negotiate_tls_) {
    OnDone(result, ProbeExitEnum::kSuccess);
    return;
  }

  DCHECK(peer_addr.has_value());

  auto pending_receiver =
      tls_client_socket_remote_.BindNewPipeAndPassReceiver();
  // Remove the disconnect handler on |tcp_connected_socket_remote_|, which is
  // disconnected from its receiver when it's upgraded to a TLSClientSocket
  // remote.
  tcp_connected_socket_remote_.set_disconnect_handler(base::NullCallback());
  // Add a disconnect handler to the TLSClientSocket remote.
  tls_client_socket_remote_.set_disconnect_handler(
      base::BindOnce(&TlsProber::OnDisconnect, weak_factory_.GetWeakPtr()));
  tcp_connected_socket_remote_->UpgradeToTLS(
      host_port_pair_,
      /*options=*/nullptr,
      net::MutableNetworkTrafficAnnotationTag(GetTrafficAnnotationTag()),
      std::move(pending_receiver),
      /*observer=*/mojo::NullRemote(),
      base::BindOnce(&TlsProber::OnTlsUpgrade, weak_factory_.GetWeakPtr()));
}

void TlsProber::OnTlsUpgrade(int result,
                             mojo::ScopedDataPipeConsumerHandle receive_stream,
                             mojo::ScopedDataPipeProducerHandle send_stream,
                             const std::optional<net::SSLInfo>& ssl_info) {
  // |send_stream| and |receive_stream|, created on the TLS connection, fall out
  // of scope when this method completes.
  DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
  if (result == net::OK) {
    OnDone(result, ProbeExitEnum::kSuccess);
    return;
  }
  OnDone(result, ProbeExitEnum::kTlsUpgradeFailure);
}

void TlsProber::OnDisconnect() {
  DCHECK_CURRENTLY_ON(content::BrowserThread::UI);

  OnDone(net::ERR_FAILED, ProbeExitEnum::kMojoDisconnectFailure);
}

void TlsProber::OnDone(int result, ProbeExitEnum probe_exit_enum) {
  DCHECK_CURRENTLY_ON(content::BrowserThread::UI);

  // Invalidate pending callbacks.
  weak_factory_.InvalidateWeakPtrs();
  // Destroy the socket connection.
  tcp_connected_socket_remote_.reset();
  tls_client_socket_remote_.reset();

  std::move(callback_).Run(result, probe_exit_enum);
}

}  // namespace ash::network_diagnostics