chromium/chrome/services/sharing/nearby/platform/wifi_direct_socket_unittest.cc

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/40285824): Remove this and convert code to safer constructs.
#pragma allow_unsafe_buffers
#endif

#include "chrome/services/sharing/nearby/platform/wifi_direct_socket.h"

#include <algorithm>

#include "base/task/thread_pool.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/task_environment.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/socket/stream_socket.h"
#include "net/socket/tcp_socket.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace {

const std::vector<uint8_t> kTestData = {0x01, 0x02, 0x03, 0x04};

constexpr char kReadResultMetricName[] =
    "Nearby.Connections.WifiDirect.Socket.Read.Result";
constexpr char kWriteResultMetricName[] =
    "Nearby.Connections.WifiDirect.Socket.Write.Result";

void RunOnTaskRunner(base::OnceClosure task) {
  base::RunLoop run_loop;
  base::ThreadPool::CreateSequencedTaskRunner({base::MayBlock()})
      ->PostTaskAndReply(FROM_HERE, std::move(task), run_loop.QuitClosure());
  run_loop.Run();
}

nearby::ByteArray ToByteArray(const std::vector<uint8_t>& expected_data) {
  return nearby::ByteArray(
      std::string(expected_data.begin(), expected_data.end()));
}

class FakeStreamSocket : public net::StreamSocket {
 public:
  ~FakeStreamSocket() override = default;

  const std::vector<uint8_t>& GetWriteData() { return write_data_; }
  void SetReadData(std::vector<uint8_t> data) { data_to_read_ = data; }
  void SetReadError(int error) { read_error_ = error; }

  // net::Socket
  int Read(net::IOBuffer* buf,
           int buf_len,
           net::CompletionOnceCallback callback) override {
    if (read_error_) {
      return read_error_.value();
    }

    auto bytes_to_write = std::max(uint(buf_len), uint(data_to_read_.size()));
    std::copy(data_to_read_.data(), data_to_read_.data() + bytes_to_write,
              buf->data());
    return bytes_to_write;
  }

  int ReadIfReady(net::IOBuffer* buf,
                  int buf_len,
                  net::CompletionOnceCallback callback) override {
    return net::ERR_NOT_IMPLEMENTED;
  }

  int CancelReadIfReady() override { return net::ERR_NOT_IMPLEMENTED; }

  int Write(
      net::IOBuffer* buf,
      int buf_len,
      net::CompletionOnceCallback callback,
      const net::NetworkTrafficAnnotationTag& traffic_annotation) override {
    write_data_ = std::vector(buf->bytes(), buf->bytes() + buf_len);
    return buf_len;
  }

  int SetReceiveBufferSize(int32_t size) override {
    return net::ERR_NOT_IMPLEMENTED;
  }

  int SetSendBufferSize(int32_t size) override {
    return net::ERR_NOT_IMPLEMENTED;
  }

  // net::StreamSocket
  void SetBeforeConnectCallback(
      const BeforeConnectCallback& before_connect_callback) override {}

  int Connect(net::CompletionOnceCallback callback) override {
    return net::ERR_NOT_IMPLEMENTED;
  }

  void Disconnect() override {}

  bool IsConnected() const override { return false; }

  bool IsConnectedAndIdle() const override { return false; }

  int GetLocalAddress(net::IPEndPoint* address) const override {
    return net::ERR_NOT_IMPLEMENTED;
  }

  int GetPeerAddress(net::IPEndPoint* address) const override {
    return net::ERR_NOT_IMPLEMENTED;
  }

  bool WasEverUsed() const override { return false; }

  net::NextProto GetNegotiatedProtocol() const override {
    return net::NextProto::kProtoUnknown;
  }

  bool GetSSLInfo(net::SSLInfo* ssl_info) override { return false; }

  void ApplySocketTag(const net::SocketTag& tag) override {}

  const net::NetLogWithSource& NetLog() const override { return net_log_; }

  int64_t GetTotalReceivedBytes() const override { return 0; }

 private:
  net::NetLogWithSource net_log_;
  std::vector<uint8_t> write_data_;
  std::vector<uint8_t> data_to_read_;
  std::optional<int> read_error_;
};

}  // namespace

namespace nearby::chrome {

class WifiDirectSocketTest : public ::testing::Test {
 public:
  // ::testing::Test
  void SetUp() override {
    io_thread_ = std::make_unique<base::Thread>("wifi-direct-socket-test");
    io_thread_->StartWithOptions(
        base::Thread::Options(base::MessagePumpType::IO, 0));

    auto stream_socket = std::make_unique<FakeStreamSocket>();

    // Create the subject under test.
    socket_ = std::make_unique<WifiDirectSocket>(io_thread_->task_runner(),
                                                 std::move(stream_socket));
  }

  // ::testing::Test
  void TearDown() override {
    socket_.reset();
    io_thread_->Stop();
  }

  WifiDirectSocket* socket() { return socket_.get(); }

 private:
  base::test::TaskEnvironment task_environment_;
  std::unique_ptr<base::Thread> io_thread_;
  std::unique_ptr<WifiDirectSocket> socket_;
};

TEST_F(WifiDirectSocketTest, Close) {
  RunOnTaskRunner(base::BindOnce(
      [](WifiDirectSocket* socket) { EXPECT_TRUE(socket->Close()); },
      socket()));
}

TEST_F(WifiDirectSocketTest, Close_MultipleCalls) {
  RunOnTaskRunner(base::BindOnce(
      [](WifiDirectSocket* socket) {
        EXPECT_TRUE(socket->Close());
        EXPECT_FALSE(socket->Close());
      },
      socket()));
}

// SocketInputStream
class SocketInputStreamTest : public ::testing::Test {
 public:
  // ::testing::Test
  void SetUp() override {
    stream_socket_ = std::make_unique<FakeStreamSocket>();
    input_stream_ = std::make_unique<SocketInputStream>(
        stream_socket_.get(), task_environment_.GetMainThreadTaskRunner());
  }

  SocketInputStream* input_stream() { return input_stream_.get(); }
  FakeStreamSocket* stream_socket() { return stream_socket_.get(); }
  base::HistogramTester& histogram_tester() { return histogram_tester_; }

 private:
  base::test::TaskEnvironment task_environment_{
      base::test::TaskEnvironment::MainThreadType::IO};
  std::unique_ptr<FakeStreamSocket> stream_socket_;
  std::unique_ptr<SocketInputStream> input_stream_;
  base::HistogramTester histogram_tester_;
};

TEST_F(SocketInputStreamTest, Read) {
  stream_socket()->SetReadData(kTestData);
  histogram_tester().ExpectTotalCount(kReadResultMetricName, 0);

  RunOnTaskRunner(base::BindOnce(
      [](SocketInputStream* input_stream) {
        base::ScopedAllowBaseSyncPrimitivesForTesting allow;
        auto result = input_stream->Read(kTestData.size());
        EXPECT_TRUE(result.ok());
        EXPECT_EQ(result.GetResult(), ToByteArray(kTestData));
      },
      input_stream()));

  histogram_tester().ExpectTotalCount(kReadResultMetricName, 1);
  histogram_tester().ExpectBucketCount(kReadResultMetricName,
                                       /*bucket:true=*/1, 1);
}

TEST_F(SocketInputStreamTest, Read_Error) {
  stream_socket()->SetReadError(net::ERR_FAILED);
  histogram_tester().ExpectTotalCount(kReadResultMetricName, 0);

  RunOnTaskRunner(base::BindOnce(
      [](SocketInputStream* input_stream) {
        base::ScopedAllowBaseSyncPrimitivesForTesting allow;
        auto result = input_stream->Read(1);
        EXPECT_FALSE(result.ok());
        EXPECT_EQ(result.GetException(), Exception{Exception::kFailed});
      },
      input_stream()));

  histogram_tester().ExpectTotalCount(kReadResultMetricName, 1);
  histogram_tester().ExpectBucketCount(kReadResultMetricName,
                                       /*bucket:false=*/0, 1);
}

TEST_F(SocketInputStreamTest, Read_AfterClose) {
  stream_socket()->SetReadData(kTestData);
  histogram_tester().ExpectTotalCount(kReadResultMetricName, 0);

  RunOnTaskRunner(base::BindOnce(
      [](SocketInputStream* input_stream) {
        base::ScopedAllowBaseSyncPrimitivesForTesting allow;
        input_stream->Close();
        auto result = input_stream->Read(1);
        EXPECT_FALSE(result.ok());
        EXPECT_EQ(result.GetException(), Exception{Exception::kFailed});
      },
      input_stream()));

  histogram_tester().ExpectTotalCount(kReadResultMetricName, 1);
  histogram_tester().ExpectBucketCount(kReadResultMetricName,
                                       /*bucket:false=*/0, 1);
}

// SocketOutputStream
class SocketOutputStreamTest : public ::testing::Test {
 public:
  // ::testing::Test
  void SetUp() override {
    stream_socket_ = std::make_unique<FakeStreamSocket>();
    output_stream_ = std::make_unique<SocketOutputStream>(
        stream_socket_.get(), task_environment_.GetMainThreadTaskRunner());
  }

  SocketOutputStream* output_stream() { return output_stream_.get(); }
  FakeStreamSocket* stream_socket() { return stream_socket_.get(); }
  base::HistogramTester& histogram_tester() { return histogram_tester_; }

 private:
  base::test::TaskEnvironment task_environment_{
      base::test::TaskEnvironment::MainThreadType::IO};
  std::unique_ptr<FakeStreamSocket> stream_socket_;
  std::unique_ptr<SocketOutputStream> output_stream_;
  base::HistogramTester histogram_tester_;
};

TEST_F(SocketOutputStreamTest, Write) {
  histogram_tester().ExpectTotalCount(kWriteResultMetricName, 0);

  RunOnTaskRunner(base::BindOnce(
      [](SocketOutputStream* output_stream, FakeStreamSocket* socket) {
        base::ScopedAllowBaseSyncPrimitivesForTesting allow;

        auto result = output_stream->Write(ToByteArray(kTestData));
        EXPECT_TRUE(result.Ok());
        EXPECT_EQ(socket->GetWriteData(), kTestData);
      },
      output_stream(), stream_socket()));

  histogram_tester().ExpectTotalCount(kWriteResultMetricName, 1);
  histogram_tester().ExpectBucketCount(kWriteResultMetricName,
                                       /*bucket:true=*/1, 1);
}

TEST_F(SocketOutputStreamTest, Write_AfterClose) {
  RunOnTaskRunner(base::BindOnce(
      [](SocketOutputStream* output_stream, FakeStreamSocket* socket) {
        base::ScopedAllowBaseSyncPrimitivesForTesting allow;
        output_stream->Close();
        auto result = output_stream->Write(ToByteArray(kTestData));
        EXPECT_FALSE(result.Ok());
        EXPECT_TRUE(socket->GetWriteData().empty());
      },
      output_stream(), stream_socket()));

  histogram_tester().ExpectTotalCount(kWriteResultMetricName, 1);
  histogram_tester().ExpectBucketCount(kWriteResultMetricName,
                                       /*bucket:false=*/0, 1);
}

}  // namespace nearby::chrome