chromium/services/webnn/dml/graph_impl_dml.h

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef SERVICES_WEBNN_DML_GRAPH_IMPL_DML_H_
#define SERVICES_WEBNN_DML_GRAPH_IMPL_DML_H_

#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "base/containers/flat_map.h"
#include "base/memory/ref_counted.h"
#include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
#include "base/types/expected.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_graph_impl.h"
#include "third_party/microsoft_dxheaders/include/directml.h"

// Windows SDK headers should be included after DirectX headers.
#include <wrl.h>

namespace webnn::dml {

class Adapter;
class CommandRecorder;
class ContextImplDml;
class GraphBuilderDml;

// Record the total byte length of buffers and the D3D12_RANGE for each
// buffer, all with the required alignment.
template <typename Key>
struct AlignedByteLength {
  size_t total_byte_length = 0;
  std::map<Key, D3D12_RANGE> key_to_d3d12_range_map;
};

// GraphImplDml inherits WebNNGraphImpl to represent a DML graph implementation.
// It is mainly responsible for building and compiling a DML graph from
// mojom::GraphInfo via GraphBuilderDml, then initializing and executing the
// graph represented by an IDMLCompiledOperator.
class GraphImplDml final : public WebNNGraphImpl {
 public:
  // It records the graph's buffer binding info to create the buffer binding
  // (DML_BUFFER_BINDING) for the graph execution.
  struct GraphBufferBindingInfo {
    GraphBufferBindingInfo();
    ~GraphBufferBindingInfo();

    GraphBufferBindingInfo(const GraphBufferBindingInfo&);
    GraphBufferBindingInfo& operator=(const GraphBufferBindingInfo&);

    GraphBufferBindingInfo(GraphBufferBindingInfo&&);
    GraphBufferBindingInfo& operator=(GraphBufferBindingInfo&&);

    // The count of input buffer bindings for the graph execution should equal
    // to the the number of both constants and inputs.
    size_t input_buffer_binding_count = 0;
    // The map is used to bind input buffers for the graph execution in
    // order.
    // The index is the DML_INPUT_GRAPH_EDGE_DESC::GraphInputIndex when
    // creating the DML_GRAPH_DESC.
    std::unordered_map<std::string, uint32_t> graph_input_name_to_index_map;
    // The map is used to bind output buffers for the graph execution in
    // order.
    // The index is the DML_OUTPUT_GRAPH_EDGE_DESC::GraphOutputIndex when
    // creating the DML_GRAPH_DESC.
    std::unordered_map<std::string, uint32_t> graph_output_name_to_index_map;
  };
  static base::expected<void, mojom::ErrorPtr> CreateAndBuildInternal(
      const ContextProperties& context_properties,
      scoped_refptr<Adapter> adapter,
      mojom::GraphInfoPtr& graph_info,
      GraphBuilderDml& graph_builder,
      std::unordered_map<uint64_t, uint32_t>& constant_id_to_input_index_map,
      GraphBufferBindingInfo& graph_buffer_binding_info);

  // This method builds and compiles a DML graph from mojom::GraphInfo via
  // GraphBuilderDml, and then calls the CommandRecorder::InitializeOperator
  // method to initialize the DML graph. Next, it calls CommandQueue::WaitAsync
  // method to wait for the initialization work to be completed on GPU. The
  // GraphImplDml instance will only be created and bound to the mojom receiver
  // in GraphImplDml::OnInitializationComplete method.
  static void CreateAndBuild(scoped_refptr<Adapter> adapter,
                             base::WeakPtr<ContextImplDml> context,
                             mojom::GraphInfoPtr graph_info,
                             ComputeResourceInfo compute_resource_info,
                             WebNNContextImpl::CreateGraphImplCallback callback,
                             bool pass_dml_execution_disable_meta_commands);

  GraphImplDml(const GraphImplDml&) = delete;
  GraphImplDml& operator=(const GraphImplDml&) = delete;
  ~GraphImplDml() override;

 private:
  // Contains the persistent resource for the graph initialization and execution
  // if the graph needs it. The resource should be kept alive until the GPU has
  // completed the execution.
  class PersistentResource final
      : public base::RefCountedThreadSafe<PersistentResource> {
   public:
    static scoped_refptr<PersistentResource> Create(
        uint64_t persistent_buffer_byte_length,
        Microsoft::WRL::ComPtr<ID3D12Resource> persistent_buffer);

    PersistentResource(const PersistentResource&) = delete;
    PersistentResource& operator=(const PersistentResource&) = delete;

    DML_BINDING_DESC persistent_buffer_binding_desc() const {
      return persistent_buffer_binding_desc_;
    }

   private:
    friend class base::RefCountedThreadSafe<PersistentResource>;
    PersistentResource(
        uint64_t persistent_buffer_byte_length,
        Microsoft::WRL::ComPtr<ID3D12Resource> persistent_buffer);
    ~PersistentResource();

    Microsoft::WRL::ComPtr<ID3D12Resource> persistent_buffer_;
    DML_BUFFER_BINDING persistent_buffer_binding_;
    DML_BINDING_DESC persistent_buffer_binding_desc_;
  };

  // Contains the GPU descriptor heap and temporary buffer for graph
  // execution. These resources should be kept alive until the GPU has completed
  // the execution. After that, the resources could be reused for next graph
  // execution or be released.
  struct GraphResources {
    GraphResources(Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> descriptor_heap,
                   uint64_t temporary_buffer_byte_length,
                   Microsoft::WRL::ComPtr<ID3D12Resource> temporary_resource);
    ~GraphResources();
    GraphResources(const GraphResources&) = delete;
    GraphResources& operator=(const GraphResources&) = delete;
    GraphResources(GraphResources&&) = delete;
    GraphResources& operator=(GraphResources&&) = delete;

    Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> descriptor_heap;

    // Temporary buffers can be reused between DML dispatches. However,
    // they cannot be used between multiple queues at a time.
    // https://learn.microsoft.com/en-us/windows/ai/directml/dml-binding
    Microsoft::WRL::ComPtr<ID3D12Resource> temporary_buffer;
    std::optional<DML_BUFFER_BINDING> temporary_buffer_binding;
    std::optional<DML_BINDING_DESC> temporary_buffer_binding_desc;
  };

  static base::expected<std::unique_ptr<GraphResources>, HRESULT>
  AllocateGraphResources(Adapter* adapter,
                         IDMLCompiledOperator* compiled_operator);

  // Contains the GPU resources for a graph execution, including the descriptor
  // heap, upload buffer, input buffer, output buffer, read-back buffer and
  // temporary buffer if the graph needs. These resources should be kept alive
  // until the GPU has completed the execution. After that, the resources could
  // be reused for next graph execution or be released.
  struct ComputeResources {
    ComputeResources(
        Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> descriptor_heap,
        AlignedByteLength<std::string> input_aligned_byte_length,
        Microsoft::WRL::ComPtr<ID3D12Resource> upload_buffer,
        Microsoft::WRL::ComPtr<ID3D12Resource> input_buffer,
        AlignedByteLength<std::string> output_aligned_byte_length,
        Microsoft::WRL::ComPtr<ID3D12Resource> output_buffer,
        Microsoft::WRL::ComPtr<ID3D12Resource> readback_buffer,
        uint64_t temporary_buffer_byte_length,
        Microsoft::WRL::ComPtr<ID3D12Resource> temporary_buffer,
        std::unique_ptr<CommandRecorder> command_recorder);
    ~ComputeResources();
    ComputeResources(const ComputeResources&) = delete;
    ComputeResources& operator=(const ComputeResources&) = delete;
    ComputeResources(ComputeResources&&) = delete;
    ComputeResources& operator=(ComputeResources&&) = delete;

    AlignedByteLength<std::string> input_aligned_byte_length;
    Microsoft::WRL::ComPtr<ID3D12Resource> upload_buffer;
    Microsoft::WRL::ComPtr<ID3D12Resource> input_buffer;

    AlignedByteLength<std::string> output_aligned_byte_length;
    Microsoft::WRL::ComPtr<ID3D12Resource> output_buffer;
    Microsoft::WRL::ComPtr<ID3D12Resource> readback_buffer;

    GraphResources graph_resources;
    std::unique_ptr<CommandRecorder> command_recorder;
  };

  static base::expected<std::unique_ptr<ComputeResources>, HRESULT>
  AllocateComputeResources(Adapter* adapter,
                           IDMLCompiledOperator* compiled_operator,
                           const ComputeResourceInfo& compute_resource_info);

  // `ExecuteAndWaitSyncOnBackgroundThread` accepts a `CommandRecorder` which
  // keeps a reference to the `init_command_queue_for_npu_` in `Adapter`. The
  // method submits the command list for execution and synchronously wait for
  // initialization to complete. Since `ID3D12CommandQueue::ExecuteCommandLists`
  // called in this method may take long time on some adapters e.g. NPU, this
  // method should run on non-gpuMain threads to avoid blocking the compositor.
  //
  // CommandQueue is not a thread-safe object and should only be used by one
  // task runner at a time to avoid race conditions with its member variables.
  static HRESULT ExecuteAndWaitSyncOnBackgroundThread(
      std::unique_ptr<CommandRecorder> init_command_recorder_for_npu);

  // This method mainly records the graph execution onto the command list, binds
  // all required resources and closes the command list.
  //
  // This method is called firstly after the graph initialization has been
  // completed to prepare for the first graph execution. For following graph
  // executions, the method only needs to be called if we need to record
  // commands and bind resources again. Thus, it avoids re-calling the
  // `IDMLCommandRecorder::RecordDispatch` and
  // `ID3D12GraphicsCommandList::Close` methods which may be time-consuming for
  // some devices during the first execution and following executions of a graph
  // if not needed.
  static HRESULT RecordGraphExecution(
      Adapter* adapter,
      IDMLCompiledOperator* compiled_operator,
      const ComputeResources* compute_resources,
      const PersistentResource* persistent_resource,
      const GraphBufferBindingInfo& graph_buffer_binding_info);

  // `RecordGraphExecutionOnBackgroundThread` calls the `RecordGraphExecution`
  // method above, but runs on a background thread. The `compute_resources` is
  // passed to this method and will be returned to the caller after the graph
  // execution is recorded. Since `IDMLCommandRecorder::RecordDispatch` and
  // `ID3D12GraphicsCommandList::Close` called in this method may take long time
  // on some adapters e.g. NPU, this method should run on non-gpuMain threads to
  // avoid blocking the compositor.
  static base::expected<std::unique_ptr<GraphImplDml::ComputeResources>,
                        HRESULT>
  RecordGraphExecutionOnBackgroundThread(
      scoped_refptr<Adapter> adapter,
      scoped_refptr<PersistentResource> persistent_resource,
      Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiled_operator,
      std::unique_ptr<ComputeResources> compute_resources,
      GraphBufferBindingInfo graph_buffer_binding_info);

  // After the `RecordGraphExecutionOnBackgroundThread` task or
  // `RecordGraphExecution` task is completed, the `CreateWebNNGraphImpl`
  // method runs back on the gpuMain thread to create the `GraphImplDml`
  // instance.
  static void CreateWebNNGraphImpl(
      scoped_refptr<Adapter> adapter,
      base::WeakPtr<ContextImplDml> context,
      scoped_refptr<PersistentResource> persistent_resource,
      Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiled_operator,
      ComputeResourceInfo compute_resource_info,
      GraphBufferBindingInfo graph_buffer_binding_info,
      WebNNContextImpl::CreateGraphImplCallback callback,
      base::expected<std::unique_ptr<ComputeResources>, HRESULT>
          recording_result);

  // After the `RecordGraphExecutionOnBackgroundThread` task or
  // `RecordGraphExecution` task is completed, the `ExecuteAndWaitAsync`
  // method runs back on the gpuMain thread to copy the input data and submit
  // the command list for execution.
  void ExecuteAndWaitAsync(
      scoped_refptr<Adapter> adapter,
      base::flat_map<std::string, mojo_base::BigBuffer> named_inputs,
      mojom::WebNNGraph::ComputeCallback callback,
      base::expected<std::unique_ptr<ComputeResources>, HRESULT>
          recording_result);

  GraphImplDml(scoped_refptr<Adapter> adapter,
               ContextImplDml* context,
               std::unique_ptr<CommandRecorder> command_recorder,
               scoped_refptr<PersistentResource> persistent_resource,
               Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiled_operator,
               ComputeResourceInfo compute_resource_info,
               GraphBufferBindingInfo graph_buffer_binding_info,
               std::unique_ptr<ComputeResources> compute_resources);

  // The method compiles all DML operators into an IDMLCompiledOperator
  // which can be dispatched to GPU. Since IDMLDevice1::CompileGraph called in
  // this method may take long time to compile shaders (if not cached before),
  // this method should run on a background thread rather than the current GPU
  // main thread to avoid blocking.
  static base::expected<Microsoft::WRL::ComPtr<IDMLCompiledOperator>, HRESULT>
  CompileOnBackgroundThread(GraphBuilderDml graph_builder,
                            bool pass_dml_execution_disable_meta_commands);

  // After the CompileOnBackgroundThread task is completed on a background
  // thread, the OnCompilationComplete method should run back on the GPU main
  // thread since graph initialization commands are submitted to GPU. Notice
  // that the compiled_operator might be nullptr if the graph compilation fails.
  //
  // The `constant_id_to_input_index_map` is used to bind constant buffers
  // for the graph initialization in order. The constant id is the key for
  // `id_to_operand_map` of `mojom::GraphInfo` interface, the input index is the
  // DML_INPUT_GRAPH_EDGE_DESC::GraphInputIndex when creating the
  // DML_GRAPH_DESC. DirectML graph treats both input tensors and constant
  // tensors to be graph inputs. The difference is the data of the constant
  // tensor is owned by DirectML and should be uploaded during the graph
  // initialization, while the data of the input tensor is uploaded for every
  // graph execution.
  static void OnCompilationComplete(
      scoped_refptr<Adapter> adapter,
      base::WeakPtr<ContextImplDml> context,
      WebNNContextImpl::CreateGraphImplCallback callback,
      base::flat_map<uint64_t, mojo_base::BigBuffer> constant_id_to_buffer_map,
      std::unordered_map<uint64_t, uint32_t> constant_id_to_input_index_map,
      GraphBufferBindingInfo graph_buffer_binding_info,
      ComputeResourceInfo compute_resource_info,
      base::expected<Microsoft::WRL::ComPtr<IDMLCompiledOperator>, HRESULT>
          compilation_result);

  // This method calls `RecordGraphExecution` to record the graph execution,
  // create the GraphImplDml instance and bind it to the mojom::WebNNGraph
  // receiver, then run callback to send the pending remote to the renderer
  // process.
  // Notice that the `persistent_resource` could be nullptr which means
  // it isn't required by the graph.
  static void OnInitializationComplete(
      scoped_refptr<Adapter> adapter,
      base::WeakPtr<ContextImplDml> context,
      scoped_refptr<PersistentResource> persistent_resource,
      Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiled_operator,
      ComputeResourceInfo compute_resource_info,
      GraphBufferBindingInfo graph_buffer_binding_info,
      WebNNContextImpl::CreateGraphImplCallback callback,
      HRESULT hr);

  // After the computation is completed, copy the output data from GPU readback
  // buffer and then run the callback to send it to the renderer process.
  //
  // The ranges in the value of the `graph_output_name_to_d3d12_range_map` are
  // the ranges in the readback output buffer and the default output buffer,
  // which indicate the aligned offset for each output of the graph.
  void OnComputationComplete(
      mojom::WebNNGraph::ComputeCallback callback,
      std::unique_ptr<ComputeResources> compute_resources,
      HRESULT hr);

  // After the dispatch is completed, recycle the graph resources for another
  // dispatch.
  void OnDispatchComplete(std::unique_ptr<GraphResources> graph_resources,
                          HRESULT hr);

  // If GraphImplDml::ComputeImpl fails, release the `compute_resources_`,
  // report the error message via `callback` and let `context_` handle the
  // error.
  void HandleComputationFailure(const std::string& error_message,
                                HRESULT hr,
                                mojom::WebNNGraph::ComputeCallback callback);

  // If GraphImplDml::DispatchImpl fails, report and log an error message and
  // release the command recorder since it may haven't been closed normally by
  // CommandRecorder::CloseAndExecute.
  void HandleDispatchFailure(std::string_view error_message, HRESULT hr);

  // Execute the compiled platform graph asynchronously. The `named_inputs` was
  // validated in base class so we can use them to compute directly, the result
  // of execution will be returned to renderer process with the `callback`.
  void ComputeImpl(
      base::flat_map<std::string, mojo_base::BigBuffer> named_inputs,
      mojom::WebNNGraph::ComputeCallback callback) override;

  void DispatchImpl(
      const base::flat_map<std::string_view, WebNNBufferImpl*>& named_inputs,
      const base::flat_map<std::string_view, WebNNBufferImpl*>& named_outputs)
      override;

  // The persistent resource is allocated after the compilation work is
  // completed for the graph initialization and will be used for the following
  // graph executions. It could be nullptr which means it isn't required by the
  // graph and won't need to be bound for graph executions.
  scoped_refptr<PersistentResource> persistent_resource_;

  // Adapter used to create the built graph.
  scoped_refptr<Adapter> adapter_;

  // ContextImplDml owns this object.
  raw_ptr<ContextImplDml> context_;

  // The command_recorder is created for the graph execution and recycled
  // after graph execution has completed. It avoids the resource allocation
  // overhead for the first execution and following executions when it is
  // available. A graph execution takes its ownership during the execution and
  // returns the ownership once the GPU has completed the execution. If it is
  // unavailable, e.g., being taken by previous uncompleted execution, a graph
  // execution will create a new one and release it after the execution is
  // done.
  std::unique_ptr<CommandRecorder> command_recorder_;
  // IDMLCompiledOperator represents a compiled and initialized DML graph to be
  // executed on GPU.
  Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiled_operator_;
  GraphBufferBindingInfo graph_buffer_binding_info_;

  // Compute resources are allocated upon graph execution and
  // recycled after graph execution has completed. It avoids the resource
  // allocation overhead for the following executions when
  // it is available. A graph execution takes its ownership during the execution
  // and returns the ownership once the GPU has completed the execution. If it
  // is unavailable, e.g., being taken by previous uncompleted execution, a
  // graph execution will allocate a new one and release it after the execution
  // is done.
  std::unique_ptr<ComputeResources> compute_resources_;

  // Graph resources are allocated after graph initialization and
  // recycled after graph execution has completed. It avoids the resource
  // allocation overhead for the first execution and following executions when
  // it is available. A graph execution takes its ownership during the execution
  // and returns the ownership once the GPU has completed the execution. If it
  // is unavailable, e.g., being taken by previous uncompleted execution, a
  // graph execution will allocate a new one and release it after the execution
  // is done.
  std::unique_ptr<GraphResources> graph_resources_;

  base::flat_map<std::string, base::WeakPtr<const WebNNBufferImpl>>
      previous_input_buffers_;
  base::flat_map<std::string, base::WeakPtr<const WebNNBufferImpl>>
      previous_output_buffers_;

  base::WeakPtrFactory<GraphImplDml> weak_factory_{this};
};

}  // namespace webnn::dml

#endif  // SERVICES_WEBNN_DML_GRAPH_IMPL_DML_H_