chromium/chromecast/net/fake_stream_socket.cc

// Copyright 2016 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromecast/net/fake_stream_socket.h"

#include <algorithm>
#include <cstring>
#include <vector>

#include "base/check_op.h"
#include "base/functional/bind.h"
#include "base/functional/callback_helpers.h"
#include "base/location.h"
#include "base/memory/weak_ptr.h"
#include "base/task/sequenced_task_runner.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/socket/next_proto.h"
#include "net/traffic_annotation/network_traffic_annotation.h"

namespace chromecast {

// Buffer used for communication between two FakeStreamSockets.
class SocketBuffer {
 public:
  SocketBuffer() : pending_read_data_(nullptr), pending_read_len_(0) {}

  SocketBuffer(const SocketBuffer&) = delete;
  SocketBuffer& operator=(const SocketBuffer&) = delete;

  ~SocketBuffer() {}

  // Reads |len| bytes from the buffer and writes it to |data|. Returns the
  // number of bytes written to |data| if the read is synchronous, or
  // ERR_IO_PENDING if the read is asynchronous. If the read is asynchronous,
  // |callback| is called with the number of bytes written to |data| once the
  // data has been written.
  int Read(char* data, size_t len, net::CompletionOnceCallback callback) {
    DCHECK(data);
    DCHECK_GT(len, 0u);
    DCHECK(callback);
    if (data_.empty()) {
      if (eos_) {
        return 0;
      }
      pending_read_data_ = data;
      pending_read_len_ = len;
      pending_read_callback_ = std::move(callback);
      return net::ERR_IO_PENDING;
    }
    return ReadInternal(data, len);
  }

  // Writes |len| bytes from |data| to the buffer. The write is always completed
  // synchronously and all bytes are guaranteed to be written.
  void Write(const char* data, size_t len) {
    DCHECK(data);
    DCHECK_GT(len, 0u);
    data_.insert(data_.end(), data, data + len);
    if (!pending_read_callback_.is_null()) {
      int result = ReadInternal(pending_read_data_, pending_read_len_);
      pending_read_data_ = nullptr;
      pending_read_len_ = 0;
      PostReadCallback(std::move(pending_read_callback_), result);
    }
  }

  // Called when the remote end of the fake connection disconnects.
  void ReceiveEOS() {
    eos_ = true;
    if (pending_read_callback_ && data_.empty()) {
      PostReadCallback(std::move(pending_read_callback_), 0);
    }
  }

 private:
  int ReadInternal(char* data, size_t len) {
    DCHECK(data);
    DCHECK_GT(len, 0u);
    len = std::min(len, data_.size());
    std::memcpy(data, data_.data(), len);
    data_.erase(data_.begin(), data_.begin() + len);
    return len;
  }

  void PostReadCallback(net::CompletionOnceCallback callback, int result) {
    base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
        FROM_HERE, base::BindOnce(&SocketBuffer::CallReadCallback,
                                  weak_factory_.GetWeakPtr(),
                                  std::move(callback), result));
  }

  // Need a member function to asynchronously call the read callback, so we
  // can use weak ptr.
  void CallReadCallback(net::CompletionOnceCallback callback, int result) {
    std::move(callback).Run(result);
  }

  std::vector<char> data_;
  char* pending_read_data_;
  size_t pending_read_len_;
  net::CompletionOnceCallback pending_read_callback_;
  bool eos_ = false;
  base::WeakPtrFactory<SocketBuffer> weak_factory_{this};
};

FakeStreamSocket::FakeStreamSocket() : FakeStreamSocket(net::IPEndPoint()) {}

FakeStreamSocket::FakeStreamSocket(const net::IPEndPoint& local_address)
    : local_address_(local_address),
      buffer_(std::make_unique<SocketBuffer>()),
      peer_(nullptr) {}

FakeStreamSocket::~FakeStreamSocket() {
  if (peer_) {
    peer_->RemoteDisconnected();
  }
}

void FakeStreamSocket::SetPeer(FakeStreamSocket* peer) {
  DCHECK(peer);
  peer_ = peer;
}

void FakeStreamSocket::RemoteDisconnected() {
  peer_ = nullptr;
  buffer_->ReceiveEOS();
}

void FakeStreamSocket::SetBadSenderMode(bool bad_sender) {
  bad_sender_mode_ = bad_sender;
}

int FakeStreamSocket::Read(net::IOBuffer* buf,
                           int buf_len,
                           net::CompletionOnceCallback callback) {
  DCHECK(buf);
  return buffer_->Read(buf->data(), buf_len, std::move(callback));
}

int FakeStreamSocket::Write(
    net::IOBuffer* buf,
    int buf_len,
    net::CompletionOnceCallback /* callback */,
    const net::NetworkTrafficAnnotationTag& /*traffic_annotation*/) {
  DCHECK(buf);
  if (!peer_) {
    return net::ERR_SOCKET_NOT_CONNECTED;
  }
  int amount_to_send = buf_len;
  if (bad_sender_mode_) {
    amount_to_send = std::min(buf_len, buf_len / 2 + 1);
  }
  peer_->buffer_->Write(buf->data(), amount_to_send);
  return amount_to_send;
}

int FakeStreamSocket::SetReceiveBufferSize(int32_t /* size */) {
  return net::OK;
}

int FakeStreamSocket::SetSendBufferSize(int32_t /* size */) {
  return net::OK;
}

int FakeStreamSocket::Connect(net::CompletionOnceCallback /* callback */) {
  return net::OK;
}

void FakeStreamSocket::Disconnect() {}

bool FakeStreamSocket::IsConnected() const {
  return true;
}

bool FakeStreamSocket::IsConnectedAndIdle() const {
  return false;
}

int FakeStreamSocket::GetPeerAddress(net::IPEndPoint* address) const {
  DCHECK(address);
  if (!peer_) {
    return net::ERR_SOCKET_NOT_CONNECTED;
  }
  *address = peer_->local_address_;
  return net::OK;
}

int FakeStreamSocket::GetLocalAddress(net::IPEndPoint* address) const {
  DCHECK(address);
  *address = local_address_;
  return net::OK;
}

const net::NetLogWithSource& FakeStreamSocket::NetLog() const {
  return net_log_;
}

bool FakeStreamSocket::WasEverUsed() const {
  return false;
}

net::NextProto FakeStreamSocket::GetNegotiatedProtocol() const {
  return net::kProtoUnknown;
}

bool FakeStreamSocket::GetSSLInfo(net::SSLInfo* /* ssl_info */) {
  return false;
}

int64_t FakeStreamSocket::GetTotalReceivedBytes() const {
  return 0;
}

void FakeStreamSocket::ApplySocketTag(const net::SocketTag& tag) {}

}  // namespace chromecast