// Copyright 2012 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/base/file_stream_context.h"
#include <windows.h>
#include <utility>
#include "base/files/file_path.h"
#include "base/functional/bind.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/message_loop/message_pump_for_io.h"
#include "base/task/current_thread.h"
#include "base/task/single_thread_task_runner.h"
#include "base/task/task_runner.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
namespace net {
namespace {
void SetOffset(OVERLAPPED* overlapped, const LARGE_INTEGER& offset) {
overlapped->Offset = offset.LowPart;
overlapped->OffsetHigh = offset.HighPart;
}
void IncrementOffset(OVERLAPPED* overlapped, DWORD count) {
LARGE_INTEGER offset;
offset.LowPart = overlapped->Offset;
offset.HighPart = overlapped->OffsetHigh;
offset.QuadPart += static_cast<LONGLONG>(count);
SetOffset(overlapped, offset);
}
} // namespace
FileStream::Context::Context(scoped_refptr<base::TaskRunner> task_runner)
: Context(base::File(), std::move(task_runner)) {}
FileStream::Context::Context(base::File file,
scoped_refptr<base::TaskRunner> task_runner)
: base::MessagePumpForIO::IOHandler(FROM_HERE),
file_(std::move(file)),
task_runner_(std::move(task_runner)) {
if (file_.IsValid()) {
DCHECK(file_.async());
OnFileOpened();
}
}
FileStream::Context::~Context() = default;
int FileStream::Context::Read(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) {
DCHECK(!async_in_progress_);
DCHECK(!async_read_initiated_);
DCHECK(!async_read_completed_);
DCHECK(!io_complete_for_read_received_);
IOCompletionIsPending(std::move(callback), buf);
async_read_initiated_ = true;
result_ = 0;
task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&FileStream::Context::ReadAsync, base::Unretained(this),
file_.GetPlatformFile(), base::WrapRefCounted(buf),
buf_len, &io_context_.overlapped,
base::SingleThreadTaskRunner::GetCurrentDefault()));
return ERR_IO_PENDING;
}
int FileStream::Context::Write(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) {
DCHECK(!async_in_progress_);
result_ = 0;
DWORD bytes_written = 0;
if (!WriteFile(file_.GetPlatformFile(), buf->data(), buf_len,
&bytes_written, &io_context_.overlapped)) {
IOResult error = IOResult::FromOSError(GetLastError());
if (error.os_error == ERROR_IO_PENDING) {
IOCompletionIsPending(std::move(callback), buf);
} else {
LOG(WARNING) << "WriteFile failed: " << error.os_error;
}
return static_cast<int>(error.result);
}
IOCompletionIsPending(std::move(callback), buf);
return ERR_IO_PENDING;
}
int FileStream::Context::ConnectNamedPipe(CompletionOnceCallback callback) {
DCHECK(!async_in_progress_);
result_ = 0;
// Always returns zero when making an asynchronous call.
::ConnectNamedPipe(file_.GetPlatformFile(), &io_context_.overlapped);
const auto error = ::GetLastError();
if (error == ERROR_PIPE_CONNECTED) {
return OK; // The client has already connected; operation complete.
}
if (error == ERROR_IO_PENDING) {
IOCompletionIsPending(std::move(callback), /*buf=*/nullptr);
return ERR_IO_PENDING; // Wait for an I/O completion packet.
}
// ERROR_INVALID_FUNCTION means that `file_` isn't a handle to a named pipe,
// but to an actual file. This is a programming error.
CHECK_NE(error, static_cast<DWORD>(ERROR_INVALID_FUNCTION));
return static_cast<int>(MapSystemError(error));
}
FileStream::Context::IOResult FileStream::Context::SeekFileImpl(
int64_t offset) {
LARGE_INTEGER result;
result.QuadPart = offset;
SetOffset(&io_context_.overlapped, result);
return IOResult(result.QuadPart, 0);
}
void FileStream::Context::OnFileOpened() {
HRESULT hr = base::CurrentIOThread::Get()->RegisterIOHandler(
file_.GetPlatformFile(), this);
if (!SUCCEEDED(hr))
file_.Close();
}
void FileStream::Context::IOCompletionIsPending(CompletionOnceCallback callback,
IOBuffer* buf) {
DCHECK(callback_.is_null());
callback_ = std::move(callback);
in_flight_buf_ = buf; // Hold until the async operation ends.
async_in_progress_ = true;
}
void FileStream::Context::OnIOCompleted(
base::MessagePumpForIO::IOContext* context,
DWORD bytes_read,
DWORD error) {
DCHECK_EQ(&io_context_, context);
DCHECK(!callback_.is_null());
DCHECK(async_in_progress_);
if (!async_read_initiated_)
async_in_progress_ = false;
if (orphaned_) {
io_complete_for_read_received_ = true;
// If we are called due to a pending read and the asynchronous read task
// has not completed we have to keep the context around until it completes.
if (async_read_initiated_ && !async_read_completed_)
return;
DeleteOrphanedContext();
return;
}
if (error == ERROR_HANDLE_EOF) {
result_ = 0;
} else if (error) {
IOResult error_result = IOResult::FromOSError(error);
result_ = static_cast<int>(error_result.result);
} else {
if (result_)
DCHECK_EQ(result_, static_cast<int>(bytes_read));
result_ = bytes_read;
IncrementOffset(&io_context_.overlapped, bytes_read);
}
if (async_read_initiated_)
io_complete_for_read_received_ = true;
InvokeUserCallback();
}
void FileStream::Context::InvokeUserCallback() {
// For an asynchonous Read operation don't invoke the user callback until
// we receive the IO completion notification and the asynchronous Read
// completion notification.
if (async_read_initiated_) {
if (!io_complete_for_read_received_ || !async_read_completed_)
return;
async_read_initiated_ = false;
io_complete_for_read_received_ = false;
async_read_completed_ = false;
async_in_progress_ = false;
}
scoped_refptr<IOBuffer> temp_buf = in_flight_buf_;
in_flight_buf_ = nullptr;
std::move(callback_).Run(result_);
}
void FileStream::Context::DeleteOrphanedContext() {
async_in_progress_ = false;
callback_.Reset();
in_flight_buf_ = nullptr;
CloseAndDelete();
}
// static
void FileStream::Context::ReadAsync(
FileStream::Context* context,
HANDLE file,
scoped_refptr<IOBuffer> buf,
int buf_len,
OVERLAPPED* overlapped,
scoped_refptr<base::SingleThreadTaskRunner> origin_thread_task_runner) {
DWORD bytes_read = 0;
BOOL ret = ::ReadFile(file, buf->data(), buf_len, &bytes_read, overlapped);
origin_thread_task_runner->PostTask(
FROM_HERE, base::BindOnce(&FileStream::Context::ReadAsyncResult,
base::Unretained(context), ret, bytes_read,
::GetLastError()));
}
void FileStream::Context::ReadAsyncResult(BOOL read_file_ret,
DWORD bytes_read,
DWORD os_error) {
// If the context is orphaned and we already received the io completion
// notification then we should delete the context and get out.
if (orphaned_ && io_complete_for_read_received_) {
DeleteOrphanedContext();
return;
}
async_read_completed_ = true;
if (read_file_ret) {
result_ = bytes_read;
InvokeUserCallback();
return;
}
IOResult error = IOResult::FromOSError(os_error);
if (error.os_error == ERROR_IO_PENDING) {
InvokeUserCallback();
} else {
OnIOCompleted(&io_context_, 0, error.os_error);
}
}
} // namespace net