chromium/services/webnn/webnn_context_provider_impl.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/webnn/webnn_context_provider_impl.h"

#include <memory>
#include <utility>

#include "base/check_is_test.h"
#include "base/types/expected_macros.h"
#include "mojo/public/cpp/bindings/self_owned_receiver.h"
#include "services/webnn/buildflags.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom-forward.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_error.mojom.h"
#include "services/webnn/webnn_context_impl.h"

#if BUILDFLAG(IS_WIN)
#include <wrl.h>

#include "base/notreached.h"
#include "services/webnn/dml/adapter.h"
#include "services/webnn/dml/command_queue.h"
#include "services/webnn/dml/command_recorder.h"
#include "services/webnn/dml/context_impl_dml.h"
#include "services/webnn/dml/utils.h"
#endif

#if BUILDFLAG(IS_MAC)
#include "services/webnn/coreml/context_impl_coreml.h"
#endif

#if BUILDFLAG(WEBNN_USE_TFLITE)
#if BUILDFLAG(IS_CHROMEOS)
#include "services/webnn/tflite/context_impl_cros.h"
#else
#include "services/webnn/tflite/context_impl_tflite.h"
#endif
#endif

namespace webnn {

#if BUILDFLAG(IS_WIN)
using Microsoft::WRL::ComPtr;
#endif

namespace {

WebNNContextProviderImpl::BackendForTesting* g_backend_for_testing =;

CreateContextOptionsPtr;
WebNNContextProvider;

#if BUILDFLAG(IS_WIN)
base::expected<scoped_refptr<dml::Adapter>, mojom::ErrorPtr> GetDmlGpuAdapter(
    gpu::SharedContextState* shared_context_state,
    const gpu::GpuFeatureInfo& gpu_feature_info) {
  if (gpu_feature_info.IsWorkaroundEnabled(DISABLE_WEBNN_FOR_GPU)) {
    return base::unexpected(
        dml::CreateError(mojom::Error::Code::kNotSupportedError,
                         "WebNN is blocklisted for GPU."));
  }

  if (!shared_context_state) {
    // Unit tests do not pass in a SharedContextState, since a reference to
    // a GpuServiceImpl must be initialized to obtain a SharedContextState.
    // Instead, we just enumerate the first DXGI adapter.
    CHECK_IS_TEST();
    return dml::Adapter::GetGpuInstanceForTesting();
  }

  // At the current stage, all `ContextImplDml` share this instance.
  //
  // TODO(crbug.com/40277628): Support getting `Adapter` instance based on
  // `options`.
  ComPtr<ID3D11Device> d3d11_device = shared_context_state->GetD3D11Device();
  if (!d3d11_device) {
    return base::unexpected(dml::CreateError(
        mojom::Error::Code::kNotSupportedError,
        "Failed to get D3D11 Device from SharedContextState."));
  }

  ComPtr<IDXGIDevice> dxgi_device;
  // A QueryInterface() via As() from a ID3D11Device to IDXGIDevice should
  // always succeed.
  CHECK_EQ(d3d11_device.As(&dxgi_device), S_OK);
  ComPtr<IDXGIAdapter> dxgi_adapter;
  // Asking for an adapter from IDXGIDevice is always expected to succeed.
  CHECK_EQ(dxgi_device->GetAdapter(&dxgi_adapter), S_OK);
  return dml::Adapter::GetGpuInstance(std::move(dxgi_adapter));
}
#endif

#if BUILDFLAG(IS_WIN)
bool ShouldCreateDmlContext(const mojom::CreateContextOptions& options) {
  switch (options.device) {
    case mojom::CreateContextOptions::Device::kCpu:
      return false;
    case mojom::CreateContextOptions::Device::kGpu:
    case mojom::CreateContextOptions::Device::kNpu:
      return true;
  }
}
#endif  // BUILDFLAG(IS_WIN)

}  // namespace

#if BUILDFLAG(IS_CHROMEOS)
WebNNContextProviderImpl::WebNNContextProviderImpl() = default;
#else
WebNNContextProviderImpl::WebNNContextProviderImpl(
    scoped_refptr<gpu::SharedContextState> shared_context_state,
    gpu::GpuFeatureInfo gpu_feature_info,
    gpu::GPUInfo gpu_info)
    :{}
#endif  // BUILDFLAG(IS_CHROMEOS)

WebNNContextProviderImpl::~WebNNContextProviderImpl() = default;

#if BUILDFLAG(IS_CHROMEOS)
// static
void WebNNContextProviderImpl::Create(
    mojo::PendingReceiver<WebNNContextProvider> receiver
) {
  mojo::MakeSelfOwnedReceiver<WebNNContextProvider>(
      base::WrapUnique(new WebNNContextProviderImpl()), std::move(receiver));
}

#else
std::unique_ptr<WebNNContextProviderImpl> WebNNContextProviderImpl::Create(
    scoped_refptr<gpu::SharedContextState> shared_context_state,
    gpu::GpuFeatureInfo gpu_feature_info,
    gpu::GPUInfo gpu_info) {}

void WebNNContextProviderImpl::BindWebNNContextProvider(
    mojo::PendingReceiver<mojom::WebNNContextProvider> receiver) {}

#endif  // BUILDFLAG(IS_CHROMEOS)

// static
void WebNNContextProviderImpl::CreateForTesting(
    mojo::PendingReceiver<mojom::WebNNContextProvider> receiver,
    WebNNStatus status) {}

void WebNNContextProviderImpl::OnConnectionError(WebNNContextImpl* impl) {}

// static
void WebNNContextProviderImpl::SetBackendForTesting(
    BackendForTesting* backend_for_testing) {}

void WebNNContextProviderImpl::CreateWebNNContext(
    CreateContextOptionsPtr options,
    WebNNContextProvider::CreateWebNNContextCallback callback) {}

}  // namespace webnn