// Copyright 2017 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chromecast/net/small_message_socket.h"
#include <stdint.h>
#include <string.h>
#include <limits>
#include <utility>
#include "base/check_op.h"
#include "base/containers/span_writer.h"
#include "base/functional/bind.h"
#include "base/functional/callback_helpers.h"
#include "base/location.h"
#include "base/numerics/byte_conversions.h"
#include "base/task/sequenced_task_runner.h"
#include "chromecast/net/io_buffer_pool.h"
#include "net/base/net_errors.h"
#include "net/socket/socket.h"
namespace chromecast {
namespace {
// Maximum number of times to read/write in a loop before reposting on the
// run loop (to allow other tasks to run).
const int kMaxIOLoop = 5;
const int kDefaultBufferSize = 2048;
constexpr size_t kMax2ByteSize = std::numeric_limits<uint16_t>::max();
} // namespace
SmallMessageSocket::BufferWrapper::BufferWrapper() = default;
SmallMessageSocket::BufferWrapper::~BufferWrapper() {
// The `data_` pointer in the base class is pointing into the buffer in the
// `buffer_` field. Stop pointing into the field before its buffer is freed.
data_ = nullptr;
}
void SmallMessageSocket::BufferWrapper::SetUnderlyingBuffer(
scoped_refptr<IOBuffer> buffer,
size_t capacity) {
CHECK_LE(capacity, static_cast<size_t>(buffer->size()));
buffer_ = std::move(buffer);
used_ = 0;
capacity_ = capacity;
size_ = capacity_;
data_ = buffer_->data();
}
scoped_refptr<net::IOBuffer>
SmallMessageSocket::BufferWrapper::TakeUnderlyingBuffer() {
return std::move(buffer_);
}
void SmallMessageSocket::BufferWrapper::ClearUnderlyingBuffer() {
data_ = nullptr;
buffer_.reset();
}
void SmallMessageSocket::BufferWrapper::DidConsume(size_t bytes) {
CHECK(buffer_);
CHECK_LE(bytes, static_cast<size_t>(size_));
size_ -= bytes;
used_ += bytes;
data_ += bytes;
CHECK_EQ(data_, buffer_->data() + used_);
}
char* SmallMessageSocket::BufferWrapper::StartOfBuffer() const {
CHECK(buffer_);
return buffer_->data();
}
base::span<const uint8_t> SmallMessageSocket::BufferWrapper::used_span() const {
CHECK(buffer_);
return base::span(buffer_->bytes(), used_);
}
SmallMessageSocket::SmallMessageSocket(Delegate* delegate,
std::unique_ptr<net::Socket> socket)
: delegate_(delegate),
socket_(std::move(socket)),
task_runner_(base::SequencedTaskRunner::GetCurrentDefault()),
write_storage_(base::MakeRefCounted<net::GrowableIOBuffer>()),
write_buffer_(base::MakeRefCounted<BufferWrapper>()),
read_storage_(base::MakeRefCounted<net::GrowableIOBuffer>()),
read_buffer_(base::MakeRefCounted<BufferWrapper>()),
weak_factory_(this) {
DCHECK(delegate_);
write_storage_->SetCapacity(kDefaultBufferSize);
read_storage_->SetCapacity(kDefaultBufferSize);
}
SmallMessageSocket::~SmallMessageSocket() = default;
void SmallMessageSocket::UseBufferPool(
scoped_refptr<IOBufferPool> buffer_pool) {
DCHECK(buffer_pool);
if (buffer_pool_) {
// Replace existing buffer pool. No need to copy data out of existing buffer
// since it will remain valid until we are done using it.
buffer_pool_ = std::move(buffer_pool);
return;
}
buffer_pool_ = std::move(buffer_pool);
if (!in_message_) {
ActivateBufferPool(read_storage_->span_before_offset());
}
}
void SmallMessageSocket::ActivateBufferPool(
base::span<const uint8_t> current_data) {
// Copy any already-read data into a new buffer for pool-based operation.
DCHECK(buffer_pool_);
DCHECK(!in_message_);
const size_t current_size = current_data.size();
scoped_refptr<::net::IOBuffer> new_buffer;
size_t new_buffer_size;
if (current_size <= buffer_pool_->buffer_size()) {
new_buffer = buffer_pool_->GetBuffer();
CHECK(new_buffer);
new_buffer_size = buffer_pool_->buffer_size();
} else {
new_buffer =
base::MakeRefCounted<::net::IOBufferWithSize>(current_size * 2);
new_buffer_size = current_size * 2;
}
base::as_writable_bytes(new_buffer->span()).copy_prefix_from(current_data);
read_buffer_->SetUnderlyingBuffer(std::move(new_buffer), new_buffer_size);
read_buffer_->DidConsume(current_size);
}
void SmallMessageSocket::RemoveBufferPool() {
if (!buffer_pool_) {
return;
}
if (static_cast<size_t>(read_storage_->capacity()) < read_buffer_->used()) {
read_storage_->SetCapacity(read_buffer_->used());
}
base::span<const uint8_t> used_span = read_buffer_->used_span();
read_storage_->everything().copy_prefix_from(used_span);
read_storage_->set_offset(used_span.size());
buffer_pool_.reset();
}
void* SmallMessageSocket::PrepareSend(size_t message_size) {
if (write_buffer_->size()) {
send_blocked_ = true;
return nullptr;
}
size_t bytes_for_size = SizeDataBytes(message_size);
write_storage_->set_offset(0);
const size_t total_size = bytes_for_size + message_size;
// TODO(lethalantidote): Remove cast once capacity converted to size_t.
if (static_cast<size_t>(write_storage_->capacity()) < total_size) {
write_storage_->SetCapacity(total_size);
}
write_buffer_->SetUnderlyingBuffer(write_storage_, total_size);
auto span = base::as_writable_bytes(write_buffer_->span());
WriteSizeData(span, message_size);
return span.subspan(bytes_for_size).data();
}
bool SmallMessageSocket::SendBuffer(scoped_refptr<net::IOBuffer> data,
size_t size) {
if (write_buffer_->size()) {
send_blocked_ = true;
return false;
}
write_buffer_->SetUnderlyingBuffer(std::move(data), size);
Send();
return true;
}
// static
size_t SmallMessageSocket::SizeDataBytes(size_t message_size) {
return (message_size < kMax2ByteSize ? 2 : 6);
}
// static
size_t SmallMessageSocket::WriteSizeData(base::span<uint8_t> buf,
size_t message_size) {
base::SpanWriter writer(buf);
if (message_size < kMax2ByteSize) {
writer.WriteU16BigEndian(static_cast<uint16_t>(message_size));
} else {
writer.WriteU16BigEndian(base::checked_cast<uint16_t>(kMax2ByteSize));
writer.WriteU32BigEndian(
// NOTE: This may truncate the message_size.
static_cast<uint32_t>(message_size));
}
return buf.size() - writer.remaining();
}
void SmallMessageSocket::Send() {
for (int i = 0; i < kMaxIOLoop; ++i) {
int result =
socket_->Write(write_buffer_.get(), write_buffer_->size(),
base::BindOnce(&SmallMessageSocket::OnWriteComplete,
base::Unretained(this)),
MISSING_TRAFFIC_ANNOTATION);
if (!HandleWriteResult(result)) {
return;
}
}
task_runner_->PostTask(FROM_HERE, base::BindOnce(&SmallMessageSocket::Send,
weak_factory_.GetWeakPtr()));
}
void SmallMessageSocket::OnWriteComplete(int result) {
if (HandleWriteResult(result)) {
Send();
}
}
bool SmallMessageSocket::HandleWriteResult(int result) {
if (result == net::ERR_IO_PENDING) {
return false;
}
if (result <= 0) {
// Post a task rather than just calling OnError(), to avoid calling
// OnError()
// synchronously.
task_runner_->PostTask(FROM_HERE,
base::BindOnce(&SmallMessageSocket::OnError,
weak_factory_.GetWeakPtr(), result));
return false;
}
write_buffer_->DidConsume(result);
if (write_buffer_->size()) {
return true;
}
write_buffer_->ClearUnderlyingBuffer();
if (send_blocked_) {
send_blocked_ = false;
delegate_->OnSendUnblocked();
}
return false;
}
void SmallMessageSocket::OnError(int error) {
delegate_->OnError(error);
}
void SmallMessageSocket::ReceiveMessages() {
// Post a task rather than just calling Read(), to avoid calling delegate
// methods from within this method.
task_runner()->PostTask(
FROM_HERE,
base::BindOnce(&SmallMessageSocket::ReceiveMessagesSynchronously,
weak_factory_.GetWeakPtr()));
}
void SmallMessageSocket::ReceiveMessagesSynchronously() {
if ((buffer_pool_ && HandleCompletedMessageBuffers()) ||
(!buffer_pool_ && HandleCompletedMessages())) {
Read();
}
}
void SmallMessageSocket::Read() {
// Read in a loop for a few times while data is immediately available.
// This improves average packet receive delay as compared to always posting a
// new task for each call to Read().
for (int i = 0; i < kMaxIOLoop; ++i) {
net::IOBuffer* buffer;
int size;
if (buffer_pool_) {
buffer = read_buffer_.get();
size = read_buffer_->size();
} else {
buffer = read_storage_.get();
size = read_storage_->RemainingCapacity();
}
int read_result =
socket()->Read(buffer, size,
base::BindOnce(&SmallMessageSocket::OnReadComplete,
base::Unretained(this)));
if (!HandleReadResult(read_result)) {
return;
}
}
task_runner()->PostTask(
FROM_HERE,
base::BindOnce(&SmallMessageSocket::Read, weak_factory_.GetWeakPtr()));
}
void SmallMessageSocket::OnReadComplete(int result) {
if (HandleReadResult(result)) {
Read();
}
}
bool SmallMessageSocket::HandleReadResult(int result) {
if (result == net::ERR_IO_PENDING) {
return false;
}
if (result == 0 || result == net::ERR_CONNECTION_CLOSED) {
delegate_->OnEndOfStream();
return false;
}
if (result < 0) {
delegate_->OnError(result);
return false;
}
if (buffer_pool_) {
read_buffer_->DidConsume(result);
return HandleCompletedMessageBuffers();
} else {
read_storage_->set_offset(read_storage_->offset() + result);
return HandleCompletedMessages();
}
}
// static
bool SmallMessageSocket::ReadSize(char* ptr,
size_t bytes_read,
size_t& data_offset,
size_t& message_size) {
if (bytes_read < sizeof(uint16_t)) {
return false;
}
// TODO(crbug.com/40284755): This span is not safely constructed and is likely
// incorrect for some callers. ReadSize() should receive a span instead of the
// unbounded pointer `ptr`. We use up to bytes from the pointer below, so we
// unsoundly claim that the span has 6 bytes here.
auto span = UNSAFE_TODO(base::as_bytes(base::span(ptr, 6u)));
uint16_t first_size = base::numerics::U16FromBigEndian(span.first<2u>());
span = span.subspan(sizeof(uint16_t));
data_offset = sizeof(uint16_t);
if (first_size < kMax2ByteSize) {
message_size = first_size;
} else {
if (bytes_read < sizeof(uint16_t) + sizeof(uint32_t)) {
return false;
}
uint32_t real_size = base::numerics::U32FromBigEndian(span.first<4u>());
span = span.subspan(sizeof(uint32_t));
data_offset += sizeof(uint32_t);
message_size = real_size;
}
return true;
}
bool SmallMessageSocket::HandleCompletedMessages() {
DCHECK(!buffer_pool_);
bool keep_reading = true;
base::span<uint8_t> bytes_read = read_storage_->span_before_offset();
while (keep_reading) {
size_t data_offset;
size_t message_size;
if (!ReadSize(base::as_writable_chars(bytes_read).data(), bytes_read.size(),
data_offset, message_size)) {
break;
}
size_t total_size = data_offset + message_size;
if (static_cast<size_t>(read_storage_->capacity()) < total_size) {
if (bytes_read != read_storage_->span_before_offset()) {
read_storage_->everything().copy_prefix_from(bytes_read);
read_storage_->set_offset(bytes_read.size());
}
read_storage_->SetCapacity(total_size);
return true;
}
if (bytes_read.size() < total_size) {
break; // Haven't received the full message yet.
}
// Take a weak pointer in case OnMessage() causes this to be deleted.
auto self = weak_factory_.GetWeakPtr();
in_message_ = true;
auto data =
base::as_writable_chars(bytes_read.subspan(data_offset, message_size));
keep_reading = delegate_->OnMessage(data.data(), data.size());
if (!self) {
return false;
}
in_message_ = false;
bytes_read = bytes_read.subspan(total_size);
if (buffer_pool_) {
// A buffer pool was added within OnMessage().
ActivateBufferPool(bytes_read);
return (keep_reading ? HandleCompletedMessageBuffers() : false);
}
}
if (bytes_read != read_storage_->span_before_offset()) {
read_storage_->everything().copy_prefix_from(bytes_read);
read_storage_->set_offset(bytes_read.size());
}
return keep_reading;
}
bool SmallMessageSocket::HandleCompletedMessageBuffers() {
DCHECK(buffer_pool_);
size_t bytes_read;
while ((bytes_read = read_buffer_->used())) {
size_t data_offset;
size_t message_size;
if (!ReadSize(read_buffer_->StartOfBuffer(), bytes_read, data_offset,
message_size)) {
break;
}
size_t total_size = data_offset + message_size;
if (read_buffer_->capacity() < total_size) {
// Current buffer is not big enough.
auto new_buffer =
base::MakeRefCounted<::net::IOBufferWithSize>(total_size);
memcpy(new_buffer->data(), read_buffer_->StartOfBuffer(), bytes_read);
read_buffer_->SetUnderlyingBuffer(std::move(new_buffer), total_size);
read_buffer_->DidConsume(bytes_read);
return true;
}
if (bytes_read < total_size) {
break; // Haven't received the full message yet.
}
auto old_buffer = read_buffer_->TakeUnderlyingBuffer();
auto new_buffer = buffer_pool_->GetBuffer();
CHECK(new_buffer);
size_t new_buffer_size = buffer_pool_->buffer_size();
size_t extra_size = bytes_read - total_size;
if (extra_size > 0) {
// Copy extra data to new buffer.
if (extra_size > buffer_pool_->buffer_size()) {
new_buffer = base::MakeRefCounted<::net::IOBufferWithSize>(extra_size);
new_buffer_size = extra_size;
}
memcpy(new_buffer->data(), old_buffer->data() + total_size, extra_size);
}
read_buffer_->SetUnderlyingBuffer(std::move(new_buffer), new_buffer_size);
read_buffer_->DidConsume(extra_size);
// Take a weak pointer in case OnMessageBuffer() causes this to be deleted.
auto self = weak_factory_.GetWeakPtr();
bool keep_reading =
delegate_->OnMessageBuffer(std::move(old_buffer), total_size);
if (!self || !keep_reading) {
return false;
}
if (!buffer_pool_) {
// The buffer pool was removed within OnMessageBuffer().
return HandleCompletedMessages();
}
}
return true;
}
bool SmallMessageSocket::Delegate::OnMessageBuffer(
scoped_refptr<net::IOBuffer> buffer,
size_t size) {
size_t offset = SizeDataBytes(size);
return OnMessage(buffer->data() + offset, size - offset);
}
} // namespace chromecast