chromium/services/webnn/dml/adapter.h

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef SERVICES_WEBNN_DML_ADAPTER_H_
#define SERVICES_WEBNN_DML_ADAPTER_H_

#include "base/component_export.h"
#include "base/gtest_prod_util.h"
#include "base/memory/ref_counted.h"
#include "base/task/thread_pool.h"
#include "base/types/expected.h"
#include "gpu/config/gpu_feature_info.h"
#include "gpu/config/gpu_info.h"
#include "services/webnn/dml/error.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "third_party/microsoft_dxheaders/include/directml.h"
#include "third_party/microsoft_dxheaders/src/include/directx/d3d12.h"

// Windows SDK headers should be included after DirectX headers.
#include <dxgi.h>
#include <wrl.h>

namespace webnn::dml {

class CommandQueue;

// Adapters represent physical devices and are responsible for device discovery.
// An `Adapter` instance creates and maintains corresponding `IDXGIAdapter` or
// `IDXCoreAdapter`, `ID3D12Device`, `IDMLDevice1` and
// `webnn::dml::CommandQueue` for a physical adapter. A single `Adapter`
// instance is shared and reference-counted by all `GraphImplDml` of the same
// adapter. The `Adapter` instance is created upon the first `GraphImplDml` call
// `Adapter::GetGpuInstance()` or `Adapter::GetNpuInstance()` and is released
// when the last `GraphImplDml` is destroyed.
class COMPONENT_EXPORT(WEBNN_SERVICE) Adapter final
    : public base::RefCountedThreadSafe<Adapter> {
 public:
  // Get the shared `Adapter` instance. If `Adapter` instance already exists,
  // that one is returned regardless of whether the `dxgi_adapter` matches.
  // TODO(crbug.com/40277628): Support `Adapter` instance for other adapters.
  //
  // This method is not thread-safe and should only be called on the GPU main
  // thread.
  //
  // The returned `Adapter` is guaranteed to support a feature level equal to
  // or greater than DML_FEATURE_LEVEL_4_0.
  static base::expected<scoped_refptr<Adapter>, mojom::ErrorPtr> GetGpuInstance(
      Microsoft::WRL::ComPtr<IDXGIAdapter> dxgi_adapter);

  // Similar to `GetGpuInstance()` but use the first enumerated DXGI adapter.
  // The returned `Adapter` is guaranteed to support a feature level equal to or
  // greater than DML_FEATURE_LEVEL_2_0 because that is where DMLCreateDevice1
  // was introduced.
  static base::expected<scoped_refptr<Adapter>, mojom::ErrorPtr>
  GetGpuInstanceForTesting();

  // Similar to the `GetGpuInstance` method above, get the shared `Adapter`
  // instance for NPU.  The returned `Adapter` is guaranteed to support a
  // feature level equal to or greater than DML_FEATURE_LEVEL_6_4.
  static base::expected<scoped_refptr<Adapter>, mojom::ErrorPtr> GetNpuInstance(
      const gpu::GpuFeatureInfo& gpu_feature_info,
      const gpu::GPUInfo& gpu_info);

  // Similar to the `GetNpuInstance` method above, get the shared NPU `Adapter`
  // instance for testing.
  static base::expected<scoped_refptr<Adapter>, mojom::ErrorPtr>
  GetNpuInstanceForTesting();

  Adapter(const Adapter&) = delete;
  Adapter& operator=(const Adapter&) = delete;

  ID3D12Device* d3d12_device() const { return d3d12_device_.Get(); }

  IDMLDevice1* dml_device() const { return dml_device_.Get(); }

  CommandQueue* command_queue() const { return command_queue_.get(); }

  DML_FEATURE_LEVEL max_supported_feature_level() const {
    return max_supported_dml_feature_level_;
  }

  CommandQueue* init_command_queue_for_npu() const {
    return init_command_queue_for_npu_.get();
  }

  base::SequencedTaskRunner* init_task_runner_for_npu() const {
    return init_task_runner_for_npu_.get();
  }

  bool IsNPU() const { return npu_instance_ == this; }

  // Enable the debug layer (requires the Graphics Tools "optional feature").
  // Must be called prior to Adapter::GetGpuInstance() since the D3D12 device
  // must be created after the debug layer is enabled.
  // TODO(crbug.com/40206287): move this once adapter enumeration is
  // implemented.
  static void EnableDebugLayerForTesting();

  bool IsDMLFeatureLevelSupported(DML_FEATURE_LEVEL feature_level) const;

  // Determines if IDMLDevice1::CompileGraph can be used.
  bool IsDMLDeviceCompileGraphSupportedForTesting() const;

  // Indicates whether the underlying D3D12 device supports UMA (Unified Memory
  // Architecture).
  bool IsUMA() const { return is_uma_; }

 private:
  FRIEND_TEST_ALL_PREFIXES(WebNNAdapterTest, GetGpuInstance);
  FRIEND_TEST_ALL_PREFIXES(WebNNAdapterTest, GetNpuInstance);

  friend class base::RefCountedThreadSafe<Adapter>;
  Adapter(Microsoft::WRL::ComPtr<IUnknown> dxgi_or_dxcore_adapter,
          Microsoft::WRL::ComPtr<ID3D12Device> d3d12_device,
          Microsoft::WRL::ComPtr<IDMLDevice1> dml_device,
          scoped_refptr<CommandQueue> command_queue,
          scoped_refptr<CommandQueue> init_command_queue_for_npu,
          DML_FEATURE_LEVEL max_supported_dml_feature_level,
          bool is_uma);
  ~Adapter();

  // `dxgi_or_dxcore_adapter` must be either `IDXGIAdapter` or `IDXCoreAdapter'.
  // DXGI is older and more broadly supported, capable of enumerating GPUs,
  // while NPUs are only able to be enumerated by DXCore.
  // The `min_required_dml_feature_level` parameter allows us to run a portion
  // of the test suite on machines that have feature levels less than the one
  // that will be required for full WebNN.
  // TODO(issues.chromium.org/331369802): Remove min/max DML feature level
  // parameters for `dml::Adapter` creation in production code.
  static base::expected<scoped_refptr<Adapter>, mojom::ErrorPtr> Create(
      Microsoft::WRL::ComPtr<IUnknown> dxgi_or_dxcore_adapter,
      DML_FEATURE_LEVEL min_required_dml_feature_level);

  Microsoft::WRL::ComPtr<IUnknown> dxgi_or_dxcore_adapter_;

  Microsoft::WRL::ComPtr<ID3D12Device> d3d12_device_;
  Microsoft::WRL::ComPtr<IDMLDevice1> dml_device_;
  scoped_refptr<CommandQueue> command_queue_;
  // It's dedicated to initialize graph on background thread for the NPU
  // adapter, it's nullptr for the GPU adapter.
  scoped_refptr<CommandQueue> init_command_queue_for_npu_;
  // It's used to post the graph initialization tasks to the background thread
  // and guarantee the tasks are executed in order for the NPU adapter, it's
  // nullptr for the GPU adapter.
  scoped_refptr<base::SequencedTaskRunner> init_task_runner_for_npu_;

  DML_FEATURE_LEVEL max_supported_dml_feature_level_ = DML_FEATURE_LEVEL_1_0;

  static bool enable_d3d12_debug_layer_for_testing_;

  // Store the info of D3D12_FEATURE_DATA_ARCHITECTURE.
  const bool is_uma_ = false;

  static Adapter* gpu_instance_;  // Static instance for dxgi adapter.
  static Adapter* npu_instance_;  // Static instance for dxcore adapter.
};

}  // namespace webnn::dml

#endif  // SERVICES_WEBNN_DML_ADAPTER_H_