chromium/chrome/test/chromedriver/net/pipe_connection_win.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/40285824): Remove this and convert code to safer constructs.
#pragma allow_unsafe_buffers
#endif

#include "chrome/test/chromedriver/net/pipe_connection_win.h"

#include <windows.h>

#include <io.h>
#include <stdlib.h>

#include <list>
#include <memory>
#include <string>

#include "base/containers/span.h"
#include "base/json/json_reader.h"
#include "base/logging.h"
#include "base/threading/thread.h"
#include "base/values.h"
#include "chrome/test/chromedriver/net/command_id.h"
#include "chrome/test/chromedriver/net/sync_websocket.h"
#include "chrome/test/chromedriver/net/timeout.h"
#include "net/base/io_buffer.h"

namespace {

const size_t kWritePacketSize = 1 << 16;
const int kMinReadBufferCapacity = 4096;
const int kMaxReadBufferCapacity = 100 * 1024 * 1024;  // 100Mb

void DetermineRecipient(const std::string& message,
                        bool* send_to_chromedriver) {
  std::optional<base::Value> message_value =
      base::JSONReader::Read(message, base::JSON_REPLACE_INVALID_CHARACTERS);
  base::Value::Dict* message_dict =
      message_value ? message_value->GetIfDict() : nullptr;
  if (!message_dict) {
    *send_to_chromedriver = true;
    return;
  }
  base::Value* id = message_dict->Find("id");
  *send_to_chromedriver =
      id == nullptr ||
      (id->is_int() && CommandId::IsChromeDriverCommandId(id->GetInt()));
}

}  // namespace

class PipeReader {
 public:
  explicit PipeReader(base::WeakPtr<PipeConnectionWin> pipe_connection)
      : pipe_connection_(std::move(pipe_connection)),
        owning_sequence_(base::SequencedTaskRunner::GetCurrentDefault()),
        read_buffer_(base::MakeRefCounted<net::GrowableIOBuffer>()),
        thread_(std::make_unique<base::Thread>("PipeConnectionWinReadThread")) {
    DETACH_FROM_THREAD(io_thread_checker_);
    read_buffer_->SetCapacity(kMinReadBufferCapacity);
  }

  ~PipeReader() = default;

  bool Start(base::ScopedPlatformFile read_file) {
    DCHECK_CALLED_ON_VALID_THREAD(session_thread_checker_);
    base::Thread::Options options;
    options.message_pump_type = base::MessagePumpType::IO;
    is_connected_ = true;
    read_file_ = std::move(read_file);
    if (!thread_->StartWithOptions(std::move(options))) {
      is_connected_ = false;
      return false;
    }
    thread_->task_runner()->PostTask(
        FROM_HERE, base::BindOnce(&PipeReader::ReadLoopOnIOThread,
                                  base::Unretained(this)));
    return true;
  }

  bool IsConnected() const {
    base::AutoLock lock(lock_);
    return is_connected_;
  }

  void SetNotificationCallback(base::RepeatingClosure callback) {
    DCHECK_CALLED_ON_VALID_THREAD(session_thread_checker_);
    base::AutoLock lock(lock_);
    notify_ = std::move(callback);
  }

  bool HasNextMessage() const {
    DCHECK_CALLED_ON_VALID_THREAD(session_thread_checker_);
    base::AutoLock lock(lock_);
    return !received_queue_.empty();
  }

  SyncWebSocket::StatusCode ReceiveNextMessage(std::string* message,
                                               const Timeout& timeout) {
    DCHECK_CALLED_ON_VALID_THREAD(session_thread_checker_);
    base::AutoLock lock(lock_);
    while (received_queue_.empty() && is_connected_) {
      base::TimeDelta next_wait = timeout.GetRemainingTime();
      if (next_wait <= base::TimeDelta()) {
        return SyncWebSocket::StatusCode::kTimeout;
      }
      on_update_event_.TimedWait(next_wait);
    }
    if (!received_queue_.empty()) {
      *message = received_queue_.front();
      received_queue_.pop_front();
      return SyncWebSocket::StatusCode::kOk;
    }
    DCHECK(!is_connected_);
    return SyncWebSocket::StatusCode::kDisconnected;
  }

  void ReadLoopOnIOThread() {
    DCHECK_CALLED_ON_VALID_THREAD(io_thread_checker_);
    while (true) {
      if (read_buffer_->RemainingCapacity() == 0) {
        if (read_buffer_->capacity() >= kMaxReadBufferCapacity) {
          VLOG(logging::LOGGING_ERROR)
              << "Connection closed, not enough capacity";
          break;
        }
        read_buffer_->SetCapacity(2 * read_buffer_->capacity());
      }

      size_t bytes_read = ReadBytes(read_buffer_->data(),
                                    read_buffer_->RemainingCapacity(), false);
      if (!bytes_read) {
        break;
      }
      read_buffer_->set_offset(read_buffer_->offset() + bytes_read);

      // Go over the last read chunk, look for \0, extract messages.
      int offset = 0;
      for (int i = read_buffer_->offset() - bytes_read;
           i < read_buffer_->offset(); ++i) {
        if (read_buffer_->everything()[i] == '\0') {
          OnMessageReceivedOnIOThread(
              std::string(base::as_string_view(read_buffer_->everything())
                              .substr(offset, i - offset)));
          offset = i + 1;
        }
      }
      if (offset) {
        base::span<const uint8_t> subspan =
            read_buffer_->span_before_offset().subspan(offset);
        read_buffer_->everything().copy_prefix_from(subspan);
        read_buffer_->set_offset(subspan.size());
        int new_capacity = std::max(
            kMinReadBufferCapacity,
            std::min(read_buffer_->offset() * 2, read_buffer_->capacity()));
        if (new_capacity != read_buffer_->capacity()) {
          read_buffer_->SetCapacity(new_capacity);
        }
      }
    }
    owning_sequence_->PostTask(
        FROM_HERE,
        base::BindOnce(&PipeConnectionWin::Shutdown, pipe_connection_));
  }

  size_t ReadBytes(char* buffer, size_t size, bool exact_size) {
    DCHECK_CALLED_ON_VALID_THREAD(io_thread_checker_);
    size_t bytes_read = 0;
    base::PlatformFile file = base::kInvalidPlatformFile;
    {
      base::AutoLock lock(lock_);
      file = read_file_.get();
    }
    while (bytes_read < size) {
      DWORD size_read = 0;
      bool had_error = !ReadFile(file, buffer + bytes_read, size - bytes_read,
                                 &size_read, nullptr);
      if (had_error) {
        if (!shutting_down_.IsSet()) {
          VLOG(logging::LOGGING_ERROR)
              << "Connection terminated while reading from pipe";
          base::AutoLock lock(lock_);
          is_connected_ = false;
          on_update_event_.Signal();
        }
        return 0;
      }
      bytes_read += size_read;
      if (!exact_size) {
        break;
      }
    }
    return bytes_read;
  }

  void OnMessageReceivedOnIOThread(std::string message) {
    DCHECK_CALLED_ON_VALID_THREAD(io_thread_checker_);

    base::AutoLock lock(lock_);

    bool notification_is_needed = false;
    bool send_to_chromedriver;

    DetermineRecipient(message, &send_to_chromedriver);
    if (send_to_chromedriver) {
      notification_is_needed = received_queue_.empty();
      received_queue_.push_back(message);
    }
    on_update_event_.Signal();

    // The notification can be emitted sporadically but we explicitly allow
    // this.
    if (notification_is_needed && notify_) {
      owning_sequence_->PostTask(FROM_HERE, notify_);
    }
  }

  static void Shutdown(std::unique_ptr<PipeReader> pipe_io) {
    if (!pipe_io) {
      return;
    }
    auto thread = std::move(pipe_io->thread_);
    pipe_io->shutting_down_.Set();
    pipe_io->ClosePipe();
    // Post self destruction on the custom thread if it's running.
    if (thread->task_runner()) {
      thread->task_runner()->DeleteSoon(FROM_HERE, std::move(pipe_io));
    } else {
      pipe_io.reset();
    }
  }

 protected:
  // Concurrently discard the pipe handles to successfully join threads.
  void ClosePipe() {
    DCHECK_CALLED_ON_VALID_THREAD(session_thread_checker_);
    base::AutoLock lock(lock_);
    // Cancel pending synchronous read.
    CancelIoEx(read_file_.get(), nullptr);
    read_file_ = base::ScopedPlatformFile();
  }

  mutable base::Lock lock_;
  // Protected by |lock_|.
  bool is_connected_ = false;
  base::AtomicFlag shutting_down_;
  THREAD_CHECKER(session_thread_checker_);
  THREAD_CHECKER(io_thread_checker_);
  base::WeakPtr<PipeConnectionWin> pipe_connection_;
  // Sequence where the instance was created.
  // The notifications about new data are emitted in this sequence.
  scoped_refptr<base::SequencedTaskRunner> owning_sequence_;
  base::ScopedPlatformFile read_file_;
  // Protected by |lock_|.
  std::list<std::string> received_queue_;
  // Protected by |lock_|.
  // Signaled when the pipe closes or a message is received.
  base::ConditionVariable on_update_event_{&lock_};
  // Protected by |lock_|.
  // Notifies that the queue is not empty.
  base::RepeatingClosure notify_;
  scoped_refptr<net::GrowableIOBuffer> read_buffer_;
  // Thread is the last member, to be destroyed first.
  // This ensures that there will be no races in the destructor.
  std::unique_ptr<base::Thread> thread_;
};

class PipeWriter {
 public:
  explicit PipeWriter(base::WeakPtr<PipeConnectionWin> pipe_connection)
      : owning_sequence_(base::SequencedTaskRunner::GetCurrentDefault()),
        pipe_connection_(std::move(pipe_connection)),
        thread_(new base::Thread("PipeConnectionWinWriteThread")) {
    DETACH_FROM_THREAD(io_thread_checker_);
  }

  virtual ~PipeWriter() = default;

  bool IsConnected() {
    base::AutoLock lock(lock_);
    return is_connected_;
  }

  void WriteIntoPipeOnIOThread(std::string message,
                               bool* success,
                               base::WaitableEvent* event) {
    DCHECK_CALLED_ON_VALID_THREAD(io_thread_checker_);
    // Trying to guess if the connection is still there
    {
      base::AutoLock lock(lock_);
      *success = is_connected_;
    }
    event->Signal();
    // The rest is done without blocking the session thread
    bool ok = WriteBytesOnIOThread(message.data(), message.size());
    ok = ok && WriteBytesOnIOThread("\0", 1);

    if (!ok) {
      owning_sequence_->PostTask(
          FROM_HERE,
          base::BindOnce(&PipeConnectionWin::Shutdown, pipe_connection_));
    }
  }

  bool Start(base::ScopedPlatformFile write_file) {
    base::Thread::Options options;
    options.message_pump_type = base::MessagePumpType::IO;
    is_connected_ = true;
    write_file_ = std::move(write_file);
    if (!thread_->StartWithOptions(std::move(options))) {
      is_connected_ = false;
      return false;
    }
    return true;
  }

  bool Write(std::string message) {
    DCHECK_CALLED_ON_VALID_THREAD(session_thread_checker_);
    // This is mostly for the case when the thread is not yet / no longer
    // running. Otherwise PostTask would crash.
    if (!IsConnected()) {
      return false;
    }
    base::TaskRunner* task_runner = thread_->task_runner().get();
    base::WaitableEvent event{base::WaitableEvent::ResetPolicy::AUTOMATIC,
                              base::WaitableEvent::InitialState::NOT_SIGNALED};
    bool success = false;
    if (!task_runner->PostTask(
            FROM_HERE, base::BindOnce(&PipeWriter::WriteIntoPipeOnIOThread,
                                      base::Unretained(this),
                                      std::move(message), &success, &event))) {
      return false;
    }
    event.Wait();
    return success;
  }

  void ClosePipe() {
    base::AutoLock lock(lock_);
    DCHECK_CALLED_ON_VALID_THREAD(session_thread_checker_);
    write_file_ = base::ScopedPlatformFile();
  }

  bool WriteBytesOnIOThread(const char* bytes, size_t size) {
    DCHECK_CALLED_ON_VALID_THREAD(io_thread_checker_);
    size_t total_written = 0;
    base::PlatformFile file = base::kInvalidPlatformFile;
    {
      base::AutoLock lock(lock_);
      file = write_file_.get();
    }
    while (total_written < size) {
      size_t length = size - total_written;
      if (length > kWritePacketSize) {
        length = kWritePacketSize;
      }
      DWORD bytes_written = 0;
      bool had_error =
          !WriteFile(file, bytes + total_written, static_cast<DWORD>(length),
                     &bytes_written, nullptr);
      if (had_error) {
        if (!shutting_down_.IsSet()) {
          VLOG(logging::LOGGING_ERROR) << "Could not write into pipe";
        }
        base::AutoLock lock(lock_);
        is_connected_ = false;
        return false;
      }
      total_written += bytes_written;
    }
    return true;
  }

  static void Shutdown(std::unique_ptr<PipeWriter> pipe_io) {
    if (!pipe_io) {
      return;
    }
    auto thread = std::move(pipe_io->thread_);
    pipe_io->shutting_down_.Set();
    pipe_io->ClosePipe();
    // Post self destruction on the custom thread if it's running.
    if (thread->task_runner()) {
      thread->task_runner()->DeleteSoon(FROM_HERE, std::move(pipe_io));
    } else {
      pipe_io.reset();
    }
  }

 private:
  base::Lock lock_;
  // Protected by |lock_|.
  bool is_connected_ = false;
  // Sequence where the instance was created.
  // The notifications about new data are emitted in this sequence.
  scoped_refptr<base::SequencedTaskRunner> owning_sequence_;
  base::AtomicFlag shutting_down_;
  THREAD_CHECKER(session_thread_checker_);
  THREAD_CHECKER(io_thread_checker_);
  base::WeakPtr<PipeConnectionWin> pipe_connection_;
  base::ScopedPlatformFile write_file_;
  // Thread is the last member, to be destroyed first.
  // This ensures that there will be no races in the destructor.
  std::unique_ptr<base::Thread> thread_;
};

PipeConnectionWin::PipeConnectionWin(base::ScopedPlatformFile read_file,
                                     base::ScopedPlatformFile write_file)
    : read_file_(std::move(read_file)), write_file_(std::move(write_file)) {
  pipe_reader_ = std::make_unique<PipeReader>(weak_factory_.GetWeakPtr());
  pipe_writer_ = std::make_unique<PipeWriter>(weak_factory_.GetWeakPtr());
  pipe_reader_->SetNotificationCallback(base::BindRepeating(
      &PipeConnectionWin::SendNotification, weak_factory_.GetWeakPtr()));
}

PipeConnectionWin::~PipeConnectionWin() {
  Shutdown();
}

bool PipeConnectionWin::IsConnected() {
  return pipe_reader_ && pipe_reader_->IsConnected() && pipe_writer_ &&
         pipe_writer_->IsConnected();
}

bool PipeConnectionWin::Connect(const GURL& url) {
  if (connection_requested_) {
    return IsConnected();
  }
  connection_requested_ = true;
  if (!pipe_reader_ || !pipe_writer_) {
    return false;
  }
  bool reader_started = pipe_reader_->Start(std::move(read_file_));
  bool writer_started = pipe_writer_->Start(std::move(write_file_));
  if (!reader_started || !writer_started) {
    Shutdown();
    return false;
  }
  return true;
}

bool PipeConnectionWin::Send(const std::string& message) {
  // If the remote reading end is closed the local end should stop sending
  // messages.
  if (!pipe_writer_ || !pipe_writer_->IsConnected()) {
    return false;
  }
  // If the remote writing end is closed, this is a signal for the local end
  // for shutting down the communication.
  if (pipe_reader_ && !pipe_reader_->IsConnected()) {
    Shutdown();
    return false;
  }
  return pipe_writer_->Write(message);
}

SyncWebSocket::StatusCode PipeConnectionWin::ReceiveNextMessage(
    std::string* message,
    const Timeout& timeout) {
  if (!pipe_reader_) {
    return SyncWebSocket::StatusCode::kDisconnected;
  }
  return pipe_reader_->ReceiveNextMessage(message, timeout);
}

bool PipeConnectionWin::HasNextMessage() {
  if (!pipe_reader_) {
    return false;
  }
  return pipe_reader_->HasNextMessage();
}

void PipeConnectionWin::SetNotificationCallback(
    base::RepeatingClosure callback) {
  notify_ = std::move(callback);
}

void PipeConnectionWin::Shutdown() {
  if (shutting_down_) {
    return;
  }
  shutting_down_ = true;

  PipeWriter::Shutdown(std::move(pipe_writer_));
  PipeReader::Shutdown(std::move(pipe_reader_));
}

bool PipeConnectionWin::IsNull() const {
  return !pipe_reader_ && !pipe_writer_;
}

void PipeConnectionWin::SendNotification() {
  if (shutting_down_ || !notify_) {
    return;
  }
  notify_.Run();
}