chromium/services/webnn/dml/platform_functions.h

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef SERVICES_WEBNN_DML_PLATFORM_FUNCTIONS_H_
#define SERVICES_WEBNN_DML_PLATFORM_FUNCTIONS_H_

#include <windows.h>

#include "base/component_export.h"
#include "base/no_destructor.h"
#include "base/scoped_native_library.h"
#include "third_party/microsoft_dxheaders/include/directml.h"
#include "third_party/microsoft_dxheaders/src/include/directx/d3d12.h"
#include "third_party/microsoft_dxheaders/src/include/directx/dxcore.h"

namespace webnn::dml {

class COMPONENT_EXPORT(WEBNN_SERVICE) PlatformFunctions {
 public:
  PlatformFunctions(const PlatformFunctions&) = delete;
  PlatformFunctions& operator=(const PlatformFunctions&) = delete;

  static PlatformFunctions* GetInstance();

  using D3d12CreateDeviceProc = PFN_D3D12_CREATE_DEVICE;
  D3d12CreateDeviceProc d3d12_create_device_proc() const {
    return d3d12_create_device_proc_;
  }

  using D3d12GetDebugInterfaceProc = PFN_D3D12_GET_DEBUG_INTERFACE;
  D3d12GetDebugInterfaceProc d3d12_get_debug_interface_proc() const {
    return d3d12_get_debug_interface_proc_;
  }

  using DXCoreCreateAdapterFactoryProc =
      decltype(static_cast<STDMETHODIMP (*)(REFIID, void**)>(
          DXCoreCreateAdapterFactory));
  DXCoreCreateAdapterFactoryProc dxcore_create_adapter_factory_proc() const {
    return dxcore_create_adapter_factory_proc_;
  }

  using DmlCreateDevice1Proc = decltype(DMLCreateDevice1)*;
  DmlCreateDevice1Proc dml_create_device1_proc() const {
    return dml_create_device1_proc_;
  }

  bool IsDXCoreSupported() const { return dxcore_library_.is_valid(); }

 private:
  friend class base::NoDestructor<PlatformFunctions>;
  PlatformFunctions();
  ~PlatformFunctions();

  bool AllFunctionsLoaded();

  // D3D12
  base::ScopedNativeLibrary d3d12_library_;
  D3d12CreateDeviceProc d3d12_create_device_proc_;
  D3d12GetDebugInterfaceProc d3d12_get_debug_interface_proc_;

  // DXCore library can be null as it was missing in older Windows 10 versions.
  // It's needed for Microsoft Compute Driver Model (MCDM) devices (NPUs) which
  // are not enumerable via DXGI.
  base::ScopedNativeLibrary dxcore_library_;
  DXCoreCreateAdapterFactoryProc dxcore_create_adapter_factory_proc_ = nullptr;

  // DirectML
  base::ScopedNativeLibrary dml_library_;
  DmlCreateDevice1Proc dml_create_device1_proc_;
};

}  // namespace webnn::dml

#endif  // SERVICES_WEBNN_DML_PLATFORM_FUNCTIONS_H_