chromium/services/device/hid/hid_connection_win.cc

// Copyright 2014 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/device/hid/hid_connection_win.h"

#include <cstring>
#include <utility>

#include "base/containers/contains.h"
#include "base/feature_list.h"
#include "base/files/file.h"
#include "base/functional/bind.h"
#include "base/memory/ref_counted_memory.h"
#include "base/not_fatal_until.h"
#include "base/numerics/safe_conversions.h"
#include "base/ranges/algorithm.h"
#include "base/win/object_watcher.h"
#include "components/device_event_log/device_event_log.h"
#include "services/device/public/cpp/device_features.h"
#include "services/device/public/cpp/hid/hid_report_type.h"

#define INITGUID

#include <windows.h>

#include <hidclass.h>

extern "C" {
#include <hidsdi.h>
}

#include <setupapi.h>
#include <winioctl.h>

namespace device {

namespace {

bool IsValidHandle(HANDLE handle) {
  return handle != nullptr && handle != INVALID_HANDLE_VALUE;
}

}  // namespace

HidConnectionWin::HidDeviceEntry::HidDeviceEntry(
    base::flat_set<uint8_t> report_ids,
    base::win::ScopedHandle file_handle)
    : report_ids(std::move(report_ids)), file_handle(std::move(file_handle)) {}

HidConnectionWin::HidDeviceEntry::~HidDeviceEntry() = default;

class PendingHidTransfer : public base::win::ObjectWatcher::Delegate {
 public:
  typedef base::OnceCallback<void(PendingHidTransfer*, bool)> Callback;

  PendingHidTransfer(scoped_refptr<base::RefCountedBytes> buffer,
                     Callback callback);
  PendingHidTransfer(PendingHidTransfer&) = delete;
  PendingHidTransfer& operator=(PendingHidTransfer&) = delete;
  ~PendingHidTransfer() override;

  void TakeResultFromWindowsAPI(BOOL result);

  OVERLAPPED* GetOverlapped() { return &overlapped_; }

  // Implements base::win::ObjectWatcher::Delegate.
  void OnObjectSignaled(HANDLE object) override;

 private:
  // The buffer isn't used by this object but it's important that a reference
  // to it is held until the transfer completes.
  scoped_refptr<base::RefCountedBytes> buffer_;
  Callback callback_;
  OVERLAPPED overlapped_;
  base::win::ScopedHandle event_;
  base::win::ObjectWatcher watcher_;
};

PendingHidTransfer::PendingHidTransfer(
    scoped_refptr<base::RefCountedBytes> buffer,
    PendingHidTransfer::Callback callback)
    : buffer_(buffer),
      callback_(std::move(callback)),
      event_(CreateEvent(NULL, FALSE, FALSE, NULL)) {
  memset(&overlapped_, 0, sizeof(OVERLAPPED));
  overlapped_.hEvent = event_.Get();
}

PendingHidTransfer::~PendingHidTransfer() {
  if (callback_)
    std::move(callback_).Run(this, false);
}

void PendingHidTransfer::TakeResultFromWindowsAPI(BOOL result) {
  if (result) {
    std::move(callback_).Run(this, true);
  } else if (GetLastError() == ERROR_IO_PENDING) {
    watcher_.StartWatchingOnce(event_.Get(), this);
  } else {
    HID_PLOG(DEBUG) << "HID transfer failed";
    std::move(callback_).Run(this, false);
  }
}

void PendingHidTransfer::OnObjectSignaled(HANDLE event_handle) {
  std::move(callback_).Run(this, true);
}

// static
scoped_refptr<HidConnection> HidConnectionWin::Create(
    scoped_refptr<HidDeviceInfo> device_info,
    std::vector<std::unique_ptr<HidDeviceEntry>> file_handles,
    bool allow_protected_reports,
    bool allow_fido_reports) {
  scoped_refptr<HidConnectionWin> connection(
      new HidConnectionWin(std::move(device_info), std::move(file_handles),
                           allow_protected_reports, allow_fido_reports));
  connection->SetUpInitialReads();
  return std::move(connection);
}

HidConnectionWin::HidConnectionWin(
    scoped_refptr<HidDeviceInfo> device_info,
    std::vector<std::unique_ptr<HidDeviceEntry>> file_handles,
    bool allow_protected_reports,
    bool allow_fido_reports)
    : HidConnection(std::move(device_info),
                    allow_protected_reports,
                    allow_fido_reports),
      file_handles_(std::move(file_handles)) {}

HidConnectionWin::~HidConnectionWin() {
  DCHECK(file_handles_.empty());
  DCHECK(transfers_.empty());
}

void HidConnectionWin::PlatformClose() {
  for (auto& entry : file_handles_) {
    CancelIo(entry->file_handle.Get());
    entry->file_handle.Close();
  }
  file_handles_.clear();
  transfers_.clear();
}

void HidConnectionWin::PlatformWrite(
    scoped_refptr<base::RefCountedBytes> buffer,
    WriteCallback callback) {
  // The Windows API always wants either a report ID (if supported) or zero at
  // the front of every output report and requires that the buffer size be equal
  // to the maximum output report size supported by this collection.
  size_t expected_size = device_info()->max_output_report_size() + 1;
  DCHECK(buffer->size() <= expected_size);
  buffer->as_vector().resize(expected_size);

  uint8_t report_id = buffer->as_vector()[0];
  HANDLE file_handle = GetHandleForReportId(report_id);
  if (!IsValidHandle(file_handle)) {
    HID_LOG(DEBUG) << "HID write failed due to invalid handle.";
    std::move(callback).Run(false);
    return;
  }

  transfers_.push_back(std::make_unique<PendingHidTransfer>(
      buffer, base::BindOnce(&HidConnectionWin::OnWriteComplete, this,
                             file_handle, std::move(callback))));
  transfers_.back()->TakeResultFromWindowsAPI(
      WriteFile(file_handle, buffer->data(), static_cast<DWORD>(buffer->size()),
                NULL, transfers_.back()->GetOverlapped()));
}

void HidConnectionWin::PlatformGetFeatureReport(uint8_t report_id,
                                                ReadCallback callback) {
  // The first byte of the destination buffer is the report ID being requested.
  auto buffer = base::MakeRefCounted<base::RefCountedBytes>(
      device_info()->max_feature_report_size() + 1);
  buffer->as_vector()[0] = report_id;

  HANDLE file_handle = GetHandleForReportId(report_id);
  if (!IsValidHandle(file_handle)) {
    HID_LOG(DEBUG) << "HID read failed due to invalid handle.";
    std::move(callback).Run(false, nullptr, 0);
    return;
  }

  transfers_.push_back(std::make_unique<PendingHidTransfer>(
      buffer, base::BindOnce(&HidConnectionWin::OnReadFeatureComplete, this,
                             file_handle, buffer, std::move(callback))));
  transfers_.back()->TakeResultFromWindowsAPI(DeviceIoControl(
      file_handle, IOCTL_HID_GET_FEATURE, NULL, 0, buffer->as_vector().data(),
      static_cast<DWORD>(buffer->as_vector().size()), NULL,
      transfers_.back()->GetOverlapped()));
}

void HidConnectionWin::PlatformSendFeatureReport(
    scoped_refptr<base::RefCountedBytes> buffer,
    WriteCallback callback) {
  uint8_t report_id = buffer->as_vector()[0];
  HANDLE file_handle = GetHandleForReportId(report_id);
  if (!IsValidHandle(file_handle)) {
    HID_LOG(DEBUG) << "HID write failed due to invalid handle.";
    std::move(callback).Run(false);
    return;
  }

  // The Windows API always wants either a report ID (if supported) or
  // zero at the front of every output report.
  transfers_.push_back(std::make_unique<PendingHidTransfer>(
      buffer, base::BindOnce(&HidConnectionWin::OnWriteComplete, this,
                             file_handle, std::move(callback))));
  transfers_.back()->TakeResultFromWindowsAPI(DeviceIoControl(
      file_handle, IOCTL_HID_SET_FEATURE, buffer->as_vector().data(),
      static_cast<DWORD>(buffer->as_vector().size()), NULL, 0, NULL,
      transfers_.back()->GetOverlapped()));
}

void HidConnectionWin::SetUpInitialReads() {
  for (const auto& entry : file_handles_)
    ReadNextInputReportOnHandle(entry->file_handle.Get());
}

void HidConnectionWin::ReadNextInputReportOnHandle(HANDLE file_handle) {
  // Windows will always include the report ID (including zero if report IDs
  // are not in use) in the buffer.
  auto buffer = base::MakeRefCounted<base::RefCountedBytes>(
      device_info()->max_input_report_size() + 1);

  transfers_.push_back(std::make_unique<PendingHidTransfer>(
      buffer, base::BindOnce(&HidConnectionWin::OnReadInputReport, this,
                             file_handle, buffer)));
  transfers_.back()->TakeResultFromWindowsAPI(
      ReadFile(file_handle, buffer->as_vector().data(),
               static_cast<DWORD>(buffer->as_vector().size()), NULL,
               transfers_.back()->GetOverlapped()));
}

void HidConnectionWin::OnReadInputReport(
    HANDLE file_handle,
    scoped_refptr<base::RefCountedBytes> buffer,
    PendingHidTransfer* transfer_raw,
    bool signaled) {
  if (!signaled) {
    HID_LOG(DEBUG) << "HID read failed.";
    return;
  }

  auto transfer = UnlinkTransfer(transfer_raw);
  DWORD bytes_transferred;
  if (!GetOverlappedResult(file_handle, transfer->GetOverlapped(),
                           &bytes_transferred, FALSE)) {
    HID_PLOG(DEBUG) << "HID read failed";
    return;
  }

  if (bytes_transferred < 1) {
    HID_LOG(DEBUG) << "HID read too short.";
    return;
  }

  uint8_t report_id = buffer->as_vector()[0];
  if (!IsReportProtected(report_id, HidReportType::kInput)) {
    // Hold a reference to |this| to prevent a callback executed by
    // ProcessInputReport from freeing this object.
    scoped_refptr<HidConnection> self(this);
    ProcessInputReport(buffer, bytes_transferred);
  }

  ReadNextInputReportOnHandle(file_handle);
}

void HidConnectionWin::OnReadFeatureComplete(
    HANDLE file_handle,
    scoped_refptr<base::RefCountedBytes> buffer,
    ReadCallback callback,
    PendingHidTransfer* transfer_raw,
    bool signaled) {
  if (!signaled) {
    HID_LOG(DEBUG) << "HID read failed.";
    std::move(callback).Run(false, nullptr, 0);
    return;
  }

  auto transfer = UnlinkTransfer(transfer_raw);
  DWORD bytes_transferred;
  if (!GetOverlappedResult(file_handle, transfer->GetOverlapped(),
                           &bytes_transferred, FALSE)) {
    HID_PLOG(DEBUG) << "HID read failed";
    std::move(callback).Run(false, nullptr, 0);
    return;
  }

  if (base::FeatureList::IsEnabled(features::kHidGetFeatureReportFix) &&
      buffer->size() > 0 && buffer->data()[0] == 0) {
    // Devices that don't use numbered reports return a buffer containing a
    // zero byte as the first byte. The zero byte is not counted in
    // `bytes_transferred`. Remove the zero byte before returning the buffer.
    buffer = base::MakeRefCounted<base::RefCountedBytes>(
        base::span(*buffer).subspan(/*offset=*/1));
  }

  DCHECK_LE(bytes_transferred, buffer->size());
  std::move(callback).Run(true, buffer, bytes_transferred);
}

void HidConnectionWin::OnWriteComplete(HANDLE file_handle,
                                       WriteCallback callback,
                                       PendingHidTransfer* transfer_raw,
                                       bool signaled) {
  if (!signaled) {
    HID_LOG(DEBUG) << "HID write failed.";
    std::move(callback).Run(false);
    return;
  }

  auto transfer = UnlinkTransfer(transfer_raw);
  DWORD bytes_transferred;
  if (!GetOverlappedResult(file_handle, transfer->GetOverlapped(),
                           &bytes_transferred, FALSE)) {
    HID_PLOG(DEBUG) << "HID write failed";
    std::move(callback).Run(false);
    return;
  }

  std::move(callback).Run(true);
}

std::unique_ptr<PendingHidTransfer> HidConnectionWin::UnlinkTransfer(
    PendingHidTransfer* transfer) {
  auto it = base::ranges::find(transfers_, transfer,
                               &std::unique_ptr<PendingHidTransfer>::get);
  CHECK(it != transfers_.end(), base::NotFatalUntil::M130);
  std::unique_ptr<PendingHidTransfer> saved_transfer = std::move(*it);
  transfers_.erase(it);
  return saved_transfer;
}

HANDLE HidConnectionWin::GetHandleForReportId(uint8_t report_id) const {
  for (const auto& entry : file_handles_) {
    if (base::Contains(entry->report_ids, report_id))
      return entry->file_handle.Get();
  }
  return nullptr;
}

}  // namespace device