chromium/services/network/brokered_udp_client_socket_unittest.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/network/brokered_udp_client_socket.h"

#include "base/test/bind.h"
#include "base/test/scoped_run_loop_timeout.h"
#include "base/test/task_environment.h"
#include "build/build_config.h"
#include "net/base/completion_repeating_callback.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/base/test_completion_callback.h"
#include "net/log/net_log_source.h"
#include "net/nqe/network_quality_estimator_test_util.h"
#include "net/socket/socket_tag.h"
#include "net/socket/socket_test_util.h"
#include "net/socket/udp_server_socket.h"
#include "net/test/gtest_util.h"
#include "net/test/test_with_task_environment.h"
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
#include "services/network/brokered_client_socket_factory.h"
#include "services/network/test/test_socket_broker_impl.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

#if BUILDFLAG(IS_ANDROID)
#include "net/android/network_change_notifier_factory_android.h"
#include "net/base/network_change_notifier.h"
#endif

#if BUILDFLAG(IS_WIN)
#include "services/network/broker_helper_win.h"
#endif

using net::test::IsError;
using net::test::IsOk;
using testing::Not;

namespace network {

#if BUILDFLAG(IS_WIN)
// A BrokerHelper delegate to manually set whether a socket needs to be
// brokered. This is necessary to make sure we can test connecting unbrokered
// sockets on Windows, since otherwise ShouldBroker would return true for
// localhost addresses.
class TestBrokerHelperDelegate : public BrokerHelperWin::Delegate {
 public:
  explicit TestBrokerHelperDelegate(bool should_broker)
      : should_broker_(should_broker) {}

  bool ShouldBroker() const override { return should_broker_; }

 private:
  bool should_broker_;
};
#endif

// This class's only purpose is to return a BrokeredUdpClientSocket instead of a
// DatagramClientSocket. This is necessary as BrokeredUdpClientSocket has
// specific helper methods for unit tests that DatagramclientSocket does not
// need.
class TestBrokeredClientSocketFactory : public BrokeredClientSocketFactory {
 public:
  explicit TestBrokeredClientSocketFactory(
      mojo::PendingRemote<mojom::SocketBroker> pending_remote)
      : BrokeredClientSocketFactory(std::move(pending_remote)) {}

  std::unique_ptr<BrokeredUdpClientSocket> CreateBrokeredUdpClientSocket(
      net::DatagramSocket::BindType bind_type,
      net::NetLog* net_log,
      const net::NetLogSource& source) {
    return std::make_unique<BrokeredUdpClientSocket>(bind_type, net_log, source,
                                                     this);
  }
};

class BrokeredUdpClientSocketTest : public testing::Test,
                                    public net::WithTaskEnvironment {
 public:
  BrokeredUdpClientSocketTest()
      : receiver_(&socket_broker_impl_),
        client_socket_factory_(TestBrokeredClientSocketFactory(
            receiver_.BindNewPipeAndPassRemote())),
        buffer_(base::MakeRefCounted<net::IOBufferWithSize>(kMaxRead)) {}

  ~BrokeredUdpClientSocketTest() override = default;

  void SetUp() override {
    // Set up the socket_
    socket_ = client_socket_factory_.CreateBrokeredUdpClientSocket(
        net::DatagramSocket::DEFAULT_BIND, net::NetLog::Get(),
        net::NetLogSource());

    // Confirm that we fail gracefully when making certain calls before
    // connecting.
    net::IPEndPoint test_address;
    EXPECT_EQ(socket_->GetPeerAddress(&test_address),
              net::ERR_SOCKET_NOT_CONNECTED);
    EXPECT_EQ(socket_->GetLocalAddress(&test_address),
              net::ERR_SOCKET_NOT_CONNECTED);
  }

  // Writes specified message to the socket.
  int WriteToClientSocket(const std::string& msg) {
    scoped_refptr<net::StringIOBuffer> io_buffer =
        base::MakeRefCounted<net::StringIOBuffer>(msg);
    net::TestCompletionCallback callback;
    int rv = socket_->Write(io_buffer.get(), io_buffer->size(),
                            callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
    return callback.GetResult(rv);
  }

  std::string ReadFromClientSocket() {
    net::TestCompletionCallback callback;

    int rv = socket_->Read(buffer_.get(), kMaxRead, callback.callback());
    rv = callback.GetResult(rv);
    if (rv < 0) {
      return std::string();
    }
    return std::string(buffer_->data(), rv);
  }

  int SendToSocket(net::UDPServerSocket* socket,
                   std::string msg,
                   const net::IPEndPoint& address) {
    auto io_buffer = base::MakeRefCounted<net::StringIOBuffer>(msg);
    net::TestCompletionCallback callback;
    int rv = socket->SendTo(io_buffer.get(), io_buffer->size(), address,
                            callback.callback());
    return callback.GetResult(rv);
  }

  // Blocks until data is read from the socket.
  std::string RecvFromSocket(net::UDPServerSocket* socket,
                             net::IPEndPoint& address) {
    net::TestCompletionCallback callback;
    int rv = socket->RecvFrom(buffer_.get(), kMaxRead, &address,
                              callback.callback());
    rv = callback.GetResult(rv);
    if (rv < 0) {
      return std::string();
    }
    return std::string(buffer_->data(), rv);
  }

  void SimpleReadAndWrite(net::UDPServerSocket* server) {
    std::string simple_message("hello world!");
    int rv = WriteToClientSocket(simple_message);
    net::IPEndPoint address;
    EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
    // Server waits for message.
    std::string str = RecvFromSocket(server, address);
    EXPECT_EQ(simple_message, str);
    // Server echoes reply.
    rv = SendToSocket(server, simple_message, address);
    EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
    // Client waits for response.
    str = ReadFromClientSocket();
    EXPECT_EQ(simple_message, str);
  }

 protected:
  std::unique_ptr<BrokeredUdpClientSocket> socket_;
  mojo::Receiver<mojom::SocketBroker> receiver_;
  TestBrokeredClientSocketFactory client_socket_factory_;
  scoped_refptr<net::IOBufferWithSize> buffer_;
  static const int kMaxRead = 1024;

  TestSocketBrokerImpl socket_broker_impl_;
};

const int BrokeredUdpClientSocketTest::kMaxRead;

TEST_F(BrokeredUdpClientSocketTest, FailedConnectAsync) {
  net::TestCompletionCallback callback;
  base::test::ScopedDisableRunLoopTimeout disable_timeout;
  net::IPEndPoint server_address(net::IPAddress::IPv4Localhost(),
                                 /*port=*/8080);

  socket_broker_impl_.SetConnectionFailure(true);

  int rv = socket_->ConnectAsync(server_address, callback.callback());

  ASSERT_EQ(rv, net::ERR_IO_PENDING);
  rv = callback.WaitForResult();
  EXPECT_EQ(rv, net::ERR_CONNECTION_FAILED);
}

TEST_F(BrokeredUdpClientSocketTest, ConnectAsync) {
  ASSERT_EQ(0, net::GetGlobalUDPSocketCountForTesting());
  net::TestCompletionCallback callback;
  net::IPEndPoint server_address(net::IPAddress::IPv4Localhost(),
                                 /*port=*/8080);

  int rv = socket_->ConnectAsync(server_address, callback.callback());

  ASSERT_EQ(rv, net::ERR_IO_PENDING);
  rv = callback.WaitForResult();
  EXPECT_EQ(rv, net::OK);
  EXPECT_EQ(net::handles::kInvalidNetworkHandle, socket_->GetBoundNetwork());
  ASSERT_EQ(1, net::GetGlobalUDPSocketCountForTesting());
  socket_->Close();
  ASSERT_EQ(0, net::GetGlobalUDPSocketCountForTesting());
}

TEST_F(BrokeredUdpClientSocketTest, Connect) {
  ASSERT_EQ(0, net::GetGlobalUDPSocketCountForTesting());
  net::TestCompletionCallback callback;
  net::IPEndPoint server_address(net::IPAddress::IPv4Localhost(),
                                 /*port=*/8080);
  int rv = net::OK;

#if BUILDFLAG(IS_WIN)
  // Pretending we don't need to broker a localhost address to be able to
  // reliably test connecting synchronously.
  socket_->SetBrokerHelperDelegateForTesting(
      std::make_unique<TestBrokerHelperDelegate>(false));
  rv = socket_->Connect(server_address);
  ASSERT_EQ(rv, net::OK);
  EXPECT_EQ(net::handles::kInvalidNetworkHandle, socket_->GetBoundNetwork());

  // ConnectUsingNetwork and ConnectUsingDefaultNetwork should return
  // ERR_NOT_IMPLEMENTED even if brokering is not required on windows.
  auto socket2 = client_socket_factory_.CreateBrokeredUdpClientSocket(
      net::DatagramSocket::DEFAULT_BIND, net::NetLog::Get(),
      net::NetLogSource());
  socket2->SetBrokerHelperDelegateForTesting(
      std::make_unique<TestBrokerHelperDelegate>(false));
  rv = socket2->ConnectUsingNetwork(net::handles::kInvalidNetworkHandle,
                                    server_address);
  ASSERT_EQ(rv, net::ERR_NOT_IMPLEMENTED);
  EXPECT_EQ(net::handles::kInvalidNetworkHandle, socket2->GetBoundNetwork());

  auto socket3 = client_socket_factory_.CreateBrokeredUdpClientSocket(
      net::DatagramSocket::DEFAULT_BIND, net::NetLog::Get(),
      net::NetLogSource());
  socket3->SetBrokerHelperDelegateForTesting(
      std::make_unique<TestBrokerHelperDelegate>(false));
  rv = socket3->ConnectUsingDefaultNetwork(server_address);
  ASSERT_EQ(rv, net::ERR_NOT_IMPLEMENTED);
  EXPECT_EQ(net::handles::kInvalidNetworkHandle, socket3->GetBoundNetwork());
#else
  rv = socket_->Connect(server_address);
  ASSERT_EQ(rv, net::ERR_NOT_IMPLEMENTED);
  EXPECT_EQ(net::handles::kInvalidNetworkHandle, socket_->GetBoundNetwork());
#endif

  // ConnectUsingNetwork and ConnectUsingDefaultNetwork should also return
  // ERR_NOT_IMPLEMENTED on all platforms.
  auto socket4 = client_socket_factory_.CreateDatagramClientSocket(
      net::DatagramSocket::DEFAULT_BIND, net::NetLog::Get(),
      net::NetLogSource());
  rv = socket4->ConnectUsingNetwork(net::handles::kInvalidNetworkHandle,
                                    server_address);
  ASSERT_EQ(rv, net::ERR_NOT_IMPLEMENTED);
  EXPECT_EQ(net::handles::kInvalidNetworkHandle, socket4->GetBoundNetwork());
  auto socket5 = client_socket_factory_.CreateDatagramClientSocket(
      net::DatagramSocket::DEFAULT_BIND, net::NetLog::Get(),
      net::NetLogSource());
  rv = socket5->ConnectUsingDefaultNetwork(server_address);
  ASSERT_EQ(rv, net::ERR_NOT_IMPLEMENTED);
  EXPECT_EQ(net::handles::kInvalidNetworkHandle, socket5->GetBoundNetwork());
}

TEST_F(BrokeredUdpClientSocketTest, SetOptions) {
  net::TestCompletionCallback callback;
  net::IPEndPoint server_address(net::IPAddress::IPv4Localhost(),
                                 /*port=*/8080);
  EXPECT_THAT(socket_->SetMulticastInterface(1), IsOk());
  socket_->SetMsgConfirm(true);
  socket_->EnableRecvOptimization();
  socket_->UseNonBlockingIO();
  int rv = socket_->ConnectAsync(server_address, callback.callback());

  ASSERT_EQ(rv, net::ERR_IO_PENDING);
  rv = callback.WaitForResult();
  EXPECT_EQ(rv, net::OK);

  EXPECT_EQ(socket_->get_multicast_interface_for_testing(), uint32_t(1));
#if (!BUILDFLAG(IS_APPLE) && !BUILDFLAG(IS_WIN))
  EXPECT_TRUE(socket_->get_msg_confirm_for_testing());
#endif
#if BUILDFLAG(IS_POSIX)
  EXPECT_TRUE(socket_->get_recv_optimization_for_testing());
#endif
#if BUILDFLAG(IS_WIN)
  EXPECT_TRUE(socket_->get_use_non_blocking_io_for_testing());

  // Set up a new socket to check that options are set correctly when sockets
  // don't need to be brokered on win.
  auto new_socket = client_socket_factory_.CreateBrokeredUdpClientSocket(
      net::DatagramSocket::DEFAULT_BIND, net::NetLog::Get(),
      net::NetLogSource());

  net::TestCompletionCallback callback2;
  net::IPEndPoint server_address2(net::IPAddress::IPv4AllZeros(),
                                  /*port=*/8080);
  EXPECT_THAT(new_socket->SetMulticastInterface(1), IsOk());
  new_socket->UseNonBlockingIO();
  rv = new_socket->ConnectAsync(server_address2, callback2.callback());

  // `new_socket` shouldn't successfully connect since the address is invalid,
  // but the options should still be set.
  EXPECT_EQ(rv, net::ERR_ADDRESS_INVALID);
  EXPECT_EQ(new_socket->get_multicast_interface_for_testing(), uint32_t(1));
  EXPECT_TRUE(new_socket->get_use_non_blocking_io_for_testing());
#endif
}

TEST_F(BrokeredUdpClientSocketTest, SimpleReadWrite) {
  net::TestCompletionCallback callback;
  net::UDPServerSocket server(nullptr, net::NetLogSource());
  ASSERT_THAT(
      server.Listen(net::IPEndPoint(net::IPAddress::IPv4Localhost(), 0)),
      IsOk());
  net::IPEndPoint server_address;
  ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());

  int rv = socket_->ConnectAsync(server_address, callback.callback());

  ASSERT_EQ(rv, net::ERR_IO_PENDING);
  rv = callback.WaitForResult();
  EXPECT_EQ(rv, net::OK);

  SimpleReadAndWrite(&server);
}

TEST_F(BrokeredUdpClientSocketTest, ConnectUsingNetworkAsync) {
  // The specific value of this address doesn't really matter, and no
  // server needs to be running here. The test only needs to call
  // ConnectUsingNetworkAsync() and won't send any datagrams.
  net::IPEndPoint server_address(net::IPAddress::IPv4Localhost(),
                                 /*port=*/8080);
  const net::handles::NetworkHandle wrong_network_handle = 65536;
  net::TestCompletionCallback callback;
#if BUILDFLAG(IS_ANDROID)
  net::NetworkChangeNotifierFactoryAndroid ncn_factory;
  net::NetworkChangeNotifierDelegateAndroid::
      EnableNetworkChangeNotifierAutoDetectForTest();
  std::unique_ptr<net::NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
  if (!net::NetworkChangeNotifier::AreNetworkHandlesSupported()) {
    GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
  }

  {
    // Connecting using a not existing network should fail but not report
    // ERR_NOT_IMPLEMENTED when network handles are supported.
    auto socket = client_socket_factory_.CreateDatagramClientSocket(
        net::DatagramSocket::DEFAULT_BIND, net::NetLog::Get(),
        net::NetLogSource());
    int rv = socket->ConnectUsingNetworkAsync(
        wrong_network_handle, server_address, callback.callback());
    EXPECT_EQ(rv, net::ERR_IO_PENDING);
    rv = callback.WaitForResult();
    EXPECT_NE(net::ERR_NOT_IMPLEMENTED, rv);
    EXPECT_NE(net::OK, rv);
  }

  {
    // Connecting using an existing network should succeed when
    // NetworkChangeNotifier returns a valid default network.
    const net::handles::NetworkHandle network_handle =
        net::NetworkChangeNotifier::GetDefaultNetwork();
    if (network_handle != net::handles::kInvalidNetworkHandle) {
      auto socket2 = client_socket_factory_.CreateDatagramClientSocket(
          net::DatagramSocket::DEFAULT_BIND, net::NetLog::Get(),
          net::NetLogSource());
      int rv = socket2->ConnectUsingNetworkAsync(network_handle, server_address,
                                                 callback.callback());
      EXPECT_EQ(rv, net::ERR_IO_PENDING);
      rv = callback.WaitForResult();
      EXPECT_EQ(net::OK, rv);
      EXPECT_EQ(network_handle, socket2->GetBoundNetwork());
      // Also check that connecting using the default network succeeds with a
      // valid default network.
      auto socket3 = client_socket_factory_.CreateDatagramClientSocket(
          net::DatagramSocket::DEFAULT_BIND, net::NetLog::Get(),
          net::NetLogSource());
      rv = socket3->ConnectUsingDefaultNetworkAsync(server_address,
                                                    callback.callback());
      EXPECT_EQ(rv, net::ERR_IO_PENDING);
      rv = callback.WaitForResult();
      EXPECT_EQ(net::OK, rv);
      EXPECT_EQ(network_handle, socket3->GetBoundNetwork());
    }
  }
#else
  EXPECT_EQ(net::ERR_NOT_IMPLEMENTED,
            socket_->ConnectUsingNetworkAsync(
                wrong_network_handle, server_address, callback.callback()));
  auto socket2 = client_socket_factory_.CreateDatagramClientSocket(
      net::DatagramSocket::DEFAULT_BIND, net::NetLog::Get(),
      net::NetLogSource());
  EXPECT_EQ(net::ERR_NOT_IMPLEMENTED, socket2->ConnectUsingDefaultNetworkAsync(
                                          server_address, callback.callback()));
#endif  // BUILDFLAG(IS_ANDROID)
}

// On Android, where socket tagging is supported, verify that the
// BrokeredUdpClientSocket sets tags correctly.
#if BUILDFLAG(IS_ANDROID)
TEST_F(BrokeredUdpClientSocketTest, Tag) {
  if (!net::CanGetTaggedBytes()) {
    DVLOG(0) << "Skipping test - GetTaggedBytes unsupported.";
    return;
  }

  net::TestCompletionCallback callback;
  net::UDPServerSocket server(nullptr, net::NetLogSource());
  ASSERT_THAT(
      server.Listen(net::IPEndPoint(net::IPAddress::IPv4Localhost(), 0)),
      IsOk());
  net::IPEndPoint server_address;
  ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());

  // Verify tag is properly set when ApplySocketTag is called before connecting.
  int32_t tag_val1 = 0x12345678;
  uint64_t old_traffic = net::GetTaggedBytes(tag_val1);
  net::SocketTag tag1(net::SocketTag::UNSET_UID, tag_val1);
  socket_->ApplySocketTag(tag1);
  int rv = socket_->ConnectAsync(server_address, callback.callback());

  ASSERT_EQ(rv, net::ERR_IO_PENDING);
  rv = callback.WaitForResult();
  EXPECT_EQ(rv, net::OK);
  SimpleReadAndWrite(&server);
  EXPECT_GT(net::GetTaggedBytes(tag_val1), old_traffic);

  // Verify socket can be retagged with a new value and the current process's
  // UID.
  int32_t tag_val2 = 0x87654321;
  old_traffic = net::GetTaggedBytes(tag_val2);
  net::SocketTag tag2(getuid(), tag_val2);
  socket_->ApplySocketTag(tag2);
  SimpleReadAndWrite(&server);
  EXPECT_GT(net::GetTaggedBytes(tag_val2), old_traffic);
}
#endif

}  // namespace network