chromium/services/webnn/dml/platform_functions.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/webnn/dml/platform_functions.h"

#include "base/files/file_path.h"
#include "base/logging.h"
#include "base/native_library.h"
#include "base/path_service.h"

namespace webnn::dml {

PlatformFunctions::PlatformFunctions() {
  // D3D12
  base::ScopedNativeLibrary d3d12_library(
      base::LoadSystemLibrary(L"D3D12.dll"));
  if (!d3d12_library.is_valid()) {
    LOG(ERROR) << "[WebNN] Failed to load D3D12.dll.";
    return;
  }
  D3d12CreateDeviceProc d3d12_create_device_proc =
      reinterpret_cast<D3d12CreateDeviceProc>(
          d3d12_library.GetFunctionPointer("D3D12CreateDevice"));
  if (!d3d12_create_device_proc) {
    LOG(ERROR) << "[WebNN] Failed to get D3D12CreateDevice function.";
    return;
  }

  D3d12GetDebugInterfaceProc d3d12_get_debug_interface_proc =
      reinterpret_cast<D3d12GetDebugInterfaceProc>(
          d3d12_library.GetFunctionPointer("D3D12GetDebugInterface"));
  if (!d3d12_get_debug_interface_proc) {
    LOG(ERROR) << "[WebNN] Failed to get D3D12GetDebugInterface function.";
    return;
  }

  // First try to Load DirectML.dll from the module folder. It would enable
  // running unit tests which require DirectML feature level 4.0+ on Windows 10.
  base::ScopedNativeLibrary dml_library;
  base::FilePath module_path;
  if (base::PathService::Get(base::DIR_MODULE, &module_path)) {
    dml_library = base::ScopedNativeLibrary(
        base::LoadNativeLibrary(module_path.Append(L"directml.dll"), nullptr));
  }
  // If it failed to load from module folder, try to load from system folder.
  if (!dml_library.is_valid()) {
    dml_library =
        base::ScopedNativeLibrary(base::LoadSystemLibrary(L"directml.dll"));
  }
  if (!dml_library.is_valid()) {
    LOG(ERROR) << "[WebNN] Failed to load directml.dll.";
    return;
  }
  // On older versions of Windows, DMLCreateDevice was not publicly documented
  // and took a different number of arguments than the publicly documented
  // version of the function supported by later versions of the DLL. We should
  // use DMLCreateDevice1 which has always been publicly documented and accepts
  // a well defined number of arguments."
  DmlCreateDevice1Proc dml_create_device1_proc =
      reinterpret_cast<DmlCreateDevice1Proc>(
          dml_library.GetFunctionPointer("DMLCreateDevice1"));
  if (!dml_create_device1_proc) {
    LOG(ERROR) << "[WebNN] Failed to get DMLCreateDevice1 function.";
    return;
  }

  // DXCore which is optional.
  base::ScopedNativeLibrary dxcore_library(
      base::LoadSystemLibrary(L"DXCore.dll"));
  PlatformFunctions::DXCoreCreateAdapterFactoryProc
      dxcore_create_adapter_factory_proc;
  if (!dxcore_library.is_valid()) {
    LOG(WARNING) << "[WebNN] Failed to load DXCore.dll.";
  } else {
    dxcore_create_adapter_factory_proc =
        reinterpret_cast<DXCoreCreateAdapterFactoryProc>(
            dxcore_library.GetFunctionPointer("DXCoreCreateAdapterFactory"));
    if (!dxcore_create_adapter_factory_proc) {
      LOG(WARNING)
          << "[WebNN] Failed to get DXCoreCreateAdapterFactory function.";
    }
  }

  // D3D12
  d3d12_library_ = std::move(d3d12_library);
  d3d12_create_device_proc_ = std::move(d3d12_create_device_proc);
  d3d12_get_debug_interface_proc_ = std::move(d3d12_get_debug_interface_proc);

  // DXCore
  if (dxcore_library.is_valid() && dxcore_create_adapter_factory_proc) {
    dxcore_library_ = std::move(dxcore_library);
    dxcore_create_adapter_factory_proc_ =
        std::move(dxcore_create_adapter_factory_proc);
  }

  // DirectML
  dml_library_ = std::move(dml_library);
  dml_create_device1_proc_ = std::move(dml_create_device1_proc);
}

PlatformFunctions::~PlatformFunctions() = default;

// static
PlatformFunctions* PlatformFunctions::GetInstance() {
  static base::NoDestructor<PlatformFunctions> instance;
  if (!instance->AllFunctionsLoaded()) {
    LOG(ERROR) << "[WebNN] Failed to load all platform functions.";
    return nullptr;
  }
  return instance.get();
}

bool PlatformFunctions::AllFunctionsLoaded() {
  return d3d12_create_device_proc_ && dml_create_device1_proc_ &&
         d3d12_get_debug_interface_proc_;
}

}  // namespace webnn::dml