chromium/services/webnn/dml/command_queue.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/webnn/dml/command_queue.h"

#include "base/check_is_test.h"
#include "base/containers/span.h"
#include "base/logging.h"
#include "base/memory/ptr_util.h"
#include "base/numerics/safe_conversions.h"
#include "base/trace_event/trace_event.h"

namespace webnn::dml {

using Microsoft::WRL::ComPtr;

CommandQueue::PendingWorkDelegate::PendingWorkDelegate(
    std::deque<CommandQueue::QueuedObject> queued_objects,
    Microsoft::WRL::ComPtr<ID3D12CommandQueue> command_queue,
    uint64_t last_fence_value,
    Microsoft::WRL::ComPtr<ID3D12Fence> fence,
    base::win::ScopedHandle fence_event)
    : base::win::ObjectWatcher::Delegate(),
      queued_objects_(std::move(queued_objects)),
      command_queue_(std::move(command_queue)),
      last_fence_value_(last_fence_value),
      fence_(std::move(fence)),
      fence_event_(std::move(fence_event)) {
  CHECK(object_watcher_.StartWatchingOnce(fence_event_.get(), this));
}

CommandQueue::PendingWorkDelegate::~PendingWorkDelegate() = default;

//  Pending GPU work will be ensured done and signals the object watcher by
//  system eventually to delete PendingWorkDelegate, it's out of our control, we
//  have to trust that, if not, there will be memory leak.
void CommandQueue::PendingWorkDelegate::OnObjectSignaled(HANDLE object) {
  CHECK_EQ(object, fence_event_.get());

  CHECK_GE(fence_->GetCompletedValue(), last_fence_value_);
  delete this;
}

// static
void CommandQueue::ScheduleCleanupForPendingWork(
    std::deque<CommandQueue::QueuedObject> queued_objects,
    ComPtr<ID3D12CommandQueue> command_queue,
    uint64_t last_fence_value,
    ComPtr<ID3D12Fence> fence) {
  // Create a new fence_event that ensures it is only triggered by the
  // last_fence_value.
  base::win::ScopedHandle fence_event(
      CreateEvent(nullptr, /*bManualReset=*/FALSE,
                  /*bInitialState=*/FALSE, nullptr));
  CHECK(fence_event.is_valid());

  HRESULT hr = fence->SetEventOnCompletion(last_fence_value, fence_event.get());
  if (FAILED(hr)) {
    LOG(ERROR) << "[WebNN] Failed to set event on completion: "
               << logging::SystemErrorCodeToString(hr);
    return;
  }

  // It is by design we're not using std::unique_ptr because PendingWorkDelegate
  // deletes itself.
  new PendingWorkDelegate(std::move(queued_objects), std::move(command_queue),
                          last_fence_value, std::move(fence),
                          std::move(fence_event));
}

CommandQueue::CommandQueue(ComPtr<ID3D12CommandQueue> command_queue,
                           ComPtr<ID3D12Fence> fence)
    : base::win::ObjectWatcher::Delegate(),
      command_queue_(std::move(command_queue)),
      fence_(std::move(fence)) {
  fence_event_.Set(CreateEvent(nullptr, /*bManualReset=*/FALSE,
                               /*bInitialState=*/FALSE, nullptr));
  CHECK(fence_event_.is_valid());

  DETACH_FROM_SEQUENCE(sequence_checker_);
}

CommandQueue::~CommandQueue() {
  if (fence_->GetCompletedValue() >= last_fence_value_) {
    return;
  }

  ScheduleCleanupForPendingWork(std::move(queued_objects_),
                                std::move(command_queue_), last_fence_value_,
                                std::move(fence_));
}

// static
scoped_refptr<CommandQueue> CommandQueue::Create(ID3D12Device* d3d12_device) {
  ComPtr<ID3D12CommandQueue> command_queue;
  D3D12_COMMAND_QUEUE_DESC command_queue_desc = {};
  command_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE;
  // To avoid TDRs on account of long compute work, create WebNN command queues
  // with the "disable timeout" flag. If other compute work on the system
  // competes with WebNN, the scheduler will TDR us anyways if it cannot preempt
  // our queue.
  command_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT;
  HRESULT hr = d3d12_device->CreateCommandQueue(&command_queue_desc,
                                                IID_PPV_ARGS(&command_queue));
  if (FAILED(hr)) {
    LOG(ERROR) << "[WebNN] Failed to create ID3D12CommandQueue: "
               << logging::SystemErrorCodeToString(hr);
    return nullptr;
  }

  ComPtr<ID3D12Fence> fence;
  hr =
      d3d12_device->CreateFence(0, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&fence));
  if (FAILED(hr)) {
    LOG(ERROR) << "[WebNN] Failed to create ID3D12Fence: "
               << logging::SystemErrorCodeToString(hr);
    return nullptr;
  }

  return base::WrapRefCounted(
      new CommandQueue(std::move(command_queue), std::move(fence)));
}

HRESULT CommandQueue::ExecuteCommandList(ID3D12CommandList* command_list) {
  return ExecuteCommandLists(base::span_from_ref(command_list));
}

HRESULT CommandQueue::ExecuteCommandLists(
    base::span<ID3D12CommandList*> command_lists) {
  TRACE_EVENT0("gpu", "dml::CommandQueue::ExecuteCommandLists");
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  command_queue_->ExecuteCommandLists(
      base::checked_cast<uint32_t>(command_lists.size()), command_lists.data());
  ++last_fence_value_;
  return command_queue_->Signal(fence_.Get(), last_fence_value_);
}

HRESULT CommandQueue::WaitSync() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  if (fence_->GetCompletedValue() >= last_fence_value_) {
    ReleaseCompletedResources();
    return S_OK;
  }

  HRESULT hr =
      fence_->SetEventOnCompletion(last_fence_value_, fence_event_.get());
  if (FAILED(hr)) {
    ReleaseCompletedResources();
    LOG(ERROR) << "Failed to set event on completion : "
               << logging::SystemErrorCodeToString(hr);
    return hr;
  }
  CHECK_EQ(WaitForSingleObject(fence_event_.get(), INFINITE), WAIT_OBJECT_0);
  ReleaseCompletedResources();
  return S_OK;
}

void CommandQueue::OnObjectSignaled(HANDLE object) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  TRACE_EVENT_NESTABLE_ASYNC_END0("gpu", "dml::CommandQueue::WaitAsync",
                                  TRACE_ID_LOCAL(this));
  CHECK_EQ(object, fence_event_.get());
  ReleaseCompletedResources();

  // If the owning document has shutdown by the time we get here, the only
  // remaining references to CommandQueue are inside of the queued_callbacks_
  // callbacks. Running them while continuing to dereference member variables
  // will free "this" and cause use-after-free  The caller of OnObjectSignaled,
  // ObjectWatcher::Signal, has a similar problem. To fix, we take a reference
  // to ourselves while we run the callbacks.
  scoped_refptr<CommandQueue> self = this;

  while (!queued_callbacks_.empty() &&
         queued_callbacks_.front().fence_value <= fence_->GetCompletedValue()) {
    std::move(queued_callbacks_.front().callback).Run();
    queued_callbacks_.pop_front();
  }
}

void CommandQueue::WaitAsync(base::OnceCallback<void(HRESULT hr)> callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  if (!object_watcher_.IsWatching()) {
    CHECK(object_watcher_.StartWatchingMultipleTimes(fence_event_.get(), this));
  }

  HRESULT hr =
      fence_->SetEventOnCompletion(last_fence_value_, fence_event_.get());
  if (FAILED(hr)) {
    ReleaseCompletedResources();
    LOG(ERROR) << "[WebNN] Failed to set event on completion: "
               << logging::SystemErrorCodeToString(hr);
    std::move(callback).Run(hr);
    return;
  }
  TRACE_EVENT_NESTABLE_ASYNC_BEGIN0("gpu", "dml::CommandQueue::WaitAsync",
                                    TRACE_ID_LOCAL(this));
  queued_callbacks_.push_back(
      {last_fence_value_, base::BindOnce(std::move(callback), S_OK)});
}

void CommandQueue::ReferenceUntilCompleted(ComPtr<IUnknown> object) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  queued_objects_.push_back({last_fence_value_, std::move(object)});
}

void CommandQueue::ReleaseCompletedResources() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  uint64_t completed_value = fence_->GetCompletedValue();
  while (!queued_objects_.empty() &&
         queued_objects_.front().fence_value <= completed_value) {
    queued_objects_.pop_front();
  }
}

uint64_t CommandQueue::GetCompletedValue() const {
  return fence_->GetCompletedValue();
}

uint64_t CommandQueue::GetLastFenceValue() const {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  return last_fence_value_;
}

const std::deque<CommandQueue::QueuedObject>&
CommandQueue::GetQueuedObjectsForTesting() const {
  CHECK_IS_TEST();
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  return queued_objects_;
}

CommandQueue::QueuedObject::QueuedObject(uint64_t fence_value,
                                         ComPtr<IUnknown> object)
    : fence_value(fence_value), object(std::move(object)) {}
CommandQueue::QueuedObject::QueuedObject(QueuedObject&& other) = default;
CommandQueue::QueuedObject& CommandQueue::QueuedObject::operator=(
    QueuedObject&& other) = default;
CommandQueue::QueuedObject::~QueuedObject() = default;

CommandQueue::QueuedCallback::QueuedCallback(uint64_t fence_value,
                                             base::OnceClosure callback)
    : fence_value(fence_value), callback(std::move(callback)) {}
CommandQueue::QueuedCallback::QueuedCallback(QueuedCallback&& other) = default;
CommandQueue::QueuedCallback& CommandQueue::QueuedCallback::operator=(
    QueuedCallback&& other) = default;
CommandQueue::QueuedCallback::~QueuedCallback() = default;

}  // namespace webnn::dml