#ifdef UNSAFE_BUFFERS_BUILD
#pragma allow_unsafe_buffers
#endif
#include "net/socket/tcp_socket.h"
#include <stddef.h>
#include <string.h>
#include <memory>
#include <string>
#include <vector>
#include "base/functional/bind.h"
#include "base/memory/ref_counted.h"
#include "base/test/bind.h"
#include "base/time/time.h"
#include "build/build_config.h"
#include "net/base/address_list.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/base/sockaddr_storage.h"
#include "net/base/sys_addrinfo.h"
#include "net/base/test_completion_callback.h"
#include "net/log/net_log_source.h"
#include "net/socket/socket_descriptor.h"
#include "net/socket/socket_performance_watcher.h"
#include "net/socket/socket_test_util.h"
#include "net/socket/tcp_client_socket.h"
#include "net/test/embedded_test_server/embedded_test_server.h"
#include "net/test/gtest_util.h"
#include "net/test/test_with_task_environment.h"
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "testing/platform_test.h"
#if BUILDFLAG(IS_ANDROID)
#include "base/android/build_info.h"
#include "net/android/network_change_notifier_factory_android.h"
#include "net/base/network_change_notifier.h"
#endif
#if BUILDFLAG(IS_WIN)
#include <winsock2.h>
#else
#include <sys/socket.h>
#endif
IsError;
IsOk;
namespace net {
namespace {
class IOBufferWithDestructionCallback : public IOBufferWithSize { … };
class TestSocketPerformanceWatcher : public SocketPerformanceWatcher { … };
const int kListenBacklog = …;
class TCPSocketTest : public PlatformTest, public WithTaskEnvironment { … };
TEST_F(TCPSocketTest, Accept) { … }
TEST_F(TCPSocketTest, AcceptAsync) { … }
TEST_F(TCPSocketTest, AdoptConnectedSocket) { … }
TEST_F(TCPSocketTest, AcceptForAdoptedUnconnectedSocket) { … }
TEST_F(TCPSocketTest, Accept2Connections) { … }
TEST_F(TCPSocketTest, AcceptIPv6) { … }
TEST_F(TCPSocketTest, ReadWrite) { … }
TEST_F(TCPSocketTest, DestroyWithPendingRead) { … }
TEST_F(TCPSocketTest, DestroyWithPendingWrite) { … }
TEST_F(TCPSocketTest, CancelPendingReadIfReady) { … }
TEST_F(TCPSocketTest, IsConnected) { … }
TEST_F(TCPSocketTest, BeforeConnectCallback) { … }
TEST_F(TCPSocketTest, BeforeConnectCallbackFails) { … }
TEST_F(TCPSocketTest, SetKeepAlive) { … }
TEST_F(TCPSocketTest, SetNoDelay) { … }
#if defined(TCP_INFO) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
TEST_F(TCPSocketTest, SPWNotInterested) { … }
TEST_F(TCPSocketTest, SPWNoAdvance) { … }
#endif
#if BUILDFLAG(IS_ANDROID)
TEST_F(TCPSocketTest, Tag) {
if (!CanGetTaggedBytes()) {
DVLOG(0) << "Skipping test - GetTaggedBytes unsupported.";
return;
}
EmbeddedTestServer test_server;
test_server.AddDefaultHandlers(base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr_list;
ASSERT_TRUE(test_server.GetAddressList(&addr_list));
EXPECT_EQ(socket_.Open(addr_list[0].GetFamily()), OK);
int32_t tag_val1 = 0x12345678;
uint64_t old_traffic = GetTaggedBytes(tag_val1);
SocketTag tag1(SocketTag::UNSET_UID, tag_val1);
socket_.ApplySocketTag(tag1);
TestCompletionCallback connect_callback;
int connect_result =
socket_.Connect(addr_list[0], connect_callback.callback());
EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
int32_t tag_val2 = 0x87654321;
old_traffic = GetTaggedBytes(tag_val2);
SocketTag tag2(getuid(), tag_val2);
socket_.ApplySocketTag(tag2);
const char kRequest1[] = "GET / HTTP/1.0";
scoped_refptr<IOBuffer> write_buffer1 =
base::MakeRefCounted<StringIOBuffer>(kRequest1);
TestCompletionCallback write_callback1;
EXPECT_EQ(
socket_.Write(write_buffer1.get(), strlen(kRequest1),
write_callback1.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
static_cast<int>(strlen(kRequest1)));
EXPECT_GT(GetTaggedBytes(tag_val2), old_traffic);
old_traffic = GetTaggedBytes(tag_val1);
socket_.ApplySocketTag(tag1);
const char kRequest2[] = "\n\n";
scoped_refptr<IOBuffer> write_buffer2 =
base::MakeRefCounted<StringIOBuffer>(kRequest2);
TestCompletionCallback write_callback2;
EXPECT_EQ(
socket_.Write(write_buffer2.get(), strlen(kRequest2),
write_callback2.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
static_cast<int>(strlen(kRequest2)));
EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
socket_.Close();
}
TEST_F(TCPSocketTest, TagAfterConnect) {
if (!CanGetTaggedBytes()) {
DVLOG(0) << "Skipping test - GetTaggedBytes unsupported.";
return;
}
EmbeddedTestServer test_server;
test_server.AddDefaultHandlers(base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr_list;
ASSERT_TRUE(test_server.GetAddressList(&addr_list));
EXPECT_EQ(socket_.Open(addr_list[0].GetFamily()), OK);
TestCompletionCallback connect_callback;
int connect_result =
socket_.Connect(addr_list[0], connect_callback.callback());
EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
int32_t tag_val2 = 0x87654321;
uint64_t old_traffic = GetTaggedBytes(tag_val2);
SocketTag tag2(getuid(), tag_val2);
socket_.ApplySocketTag(tag2);
const char kRequest1[] = "GET / HTTP/1.0";
scoped_refptr<IOBuffer> write_buffer1 =
base::MakeRefCounted<StringIOBuffer>(kRequest1);
TestCompletionCallback write_callback1;
EXPECT_EQ(
socket_.Write(write_buffer1.get(), strlen(kRequest1),
write_callback1.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
static_cast<int>(strlen(kRequest1)));
EXPECT_GT(GetTaggedBytes(tag_val2), old_traffic);
int32_t tag_val1 = 0x12345678;
old_traffic = GetTaggedBytes(tag_val1);
SocketTag tag1(SocketTag::UNSET_UID, tag_val1);
socket_.ApplySocketTag(tag1);
const char kRequest2[] = "\n\n";
scoped_refptr<IOBuffer> write_buffer2 =
base::MakeRefCounted<StringIOBuffer>(kRequest2);
TestCompletionCallback write_callback2;
EXPECT_EQ(
socket_.Write(write_buffer2.get(), strlen(kRequest2),
write_callback2.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
static_cast<int>(strlen(kRequest2)));
EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
socket_.Close();
}
TEST_F(TCPSocketTest, BindToNetwork) {
NetworkChangeNotifierFactoryAndroid ncn_factory;
NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
const handles::NetworkHandle wrong_network_handle = 65536;
const IPEndPoint ip(IPAddress::IPv4Localhost(), 0);
TCPClientSocket wrong_socket(local_address_list(), nullptr, nullptr, nullptr,
NetLogSource(), wrong_network_handle);
int rv = wrong_socket.Bind(ip);
EXPECT_NE(OK, rv);
EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
const handles::NetworkHandle network_handle =
NetworkChangeNotifier::GetDefaultNetwork();
if (network_handle != handles::kInvalidNetworkHandle) {
TCPClientSocket correct_socket(local_address_list(), nullptr, nullptr,
nullptr, NetLogSource(), network_handle);
EXPECT_EQ(OK, correct_socket.Bind(ip));
}
}
#endif
}
}