#include "services/webnn/webnn_context_provider_impl.h"
#include <memory>
#include <utility>
#include "base/check_is_test.h"
#include "base/types/expected_macros.h"
#include "mojo/public/cpp/bindings/self_owned_receiver.h"
#include "services/webnn/buildflags.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom-forward.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_error.mojom.h"
#include "services/webnn/webnn_context_impl.h"
#if BUILDFLAG(IS_WIN)
#include <wrl.h>
#include "base/notreached.h"
#include "services/webnn/dml/adapter.h"
#include "services/webnn/dml/command_queue.h"
#include "services/webnn/dml/command_recorder.h"
#include "services/webnn/dml/context_impl_dml.h"
#include "services/webnn/dml/utils.h"
#endif
#if BUILDFLAG(IS_MAC)
#include "services/webnn/coreml/context_impl_coreml.h"
#endif
#if BUILDFLAG(WEBNN_USE_TFLITE)
#if BUILDFLAG(IS_CHROMEOS)
#include "services/webnn/tflite/context_impl_cros.h"
#else
#include "services/webnn/tflite/context_impl_tflite.h"
#endif
#endif
namespace webnn {
#if BUILDFLAG(IS_WIN)
using Microsoft::WRL::ComPtr;
#endif
namespace {
WebNNContextProviderImpl::BackendForTesting* g_backend_for_testing = …;
CreateContextOptionsPtr;
WebNNContextProvider;
#if BUILDFLAG(IS_WIN)
base::expected<scoped_refptr<dml::Adapter>, mojom::ErrorPtr> GetDmlGpuAdapter(
gpu::SharedContextState* shared_context_state,
const gpu::GpuFeatureInfo& gpu_feature_info) {
if (gpu_feature_info.IsWorkaroundEnabled(DISABLE_WEBNN_FOR_GPU)) {
return base::unexpected(
dml::CreateError(mojom::Error::Code::kNotSupportedError,
"WebNN is blocklisted for GPU."));
}
if (!shared_context_state) {
CHECK_IS_TEST();
return dml::Adapter::GetGpuInstanceForTesting();
}
ComPtr<ID3D11Device> d3d11_device = shared_context_state->GetD3D11Device();
if (!d3d11_device) {
return base::unexpected(dml::CreateError(
mojom::Error::Code::kNotSupportedError,
"Failed to get D3D11 Device from SharedContextState."));
}
ComPtr<IDXGIDevice> dxgi_device;
CHECK_EQ(d3d11_device.As(&dxgi_device), S_OK);
ComPtr<IDXGIAdapter> dxgi_adapter;
CHECK_EQ(dxgi_device->GetAdapter(&dxgi_adapter), S_OK);
return dml::Adapter::GetGpuInstance(std::move(dxgi_adapter));
}
#endif
#if BUILDFLAG(IS_WIN)
bool ShouldCreateDmlContext(const mojom::CreateContextOptions& options) {
switch (options.device) {
case mojom::CreateContextOptions::Device::kCpu:
return false;
case mojom::CreateContextOptions::Device::kGpu:
case mojom::CreateContextOptions::Device::kNpu:
return true;
}
}
#endif
}
#if BUILDFLAG(IS_CHROMEOS)
WebNNContextProviderImpl::WebNNContextProviderImpl() = default;
#else
WebNNContextProviderImpl::WebNNContextProviderImpl(
scoped_refptr<gpu::SharedContextState> shared_context_state,
gpu::GpuFeatureInfo gpu_feature_info,
gpu::GPUInfo gpu_info)
: … { … }
#endif
WebNNContextProviderImpl::~WebNNContextProviderImpl() = default;
#if BUILDFLAG(IS_CHROMEOS)
void WebNNContextProviderImpl::Create(
mojo::PendingReceiver<WebNNContextProvider> receiver
) {
mojo::MakeSelfOwnedReceiver<WebNNContextProvider>(
base::WrapUnique(new WebNNContextProviderImpl()), std::move(receiver));
}
#else
std::unique_ptr<WebNNContextProviderImpl> WebNNContextProviderImpl::Create(
scoped_refptr<gpu::SharedContextState> shared_context_state,
gpu::GpuFeatureInfo gpu_feature_info,
gpu::GPUInfo gpu_info) { … }
void WebNNContextProviderImpl::BindWebNNContextProvider(
mojo::PendingReceiver<mojom::WebNNContextProvider> receiver) { … }
#endif
void WebNNContextProviderImpl::CreateForTesting(
mojo::PendingReceiver<mojom::WebNNContextProvider> receiver,
WebNNStatus status) { … }
void WebNNContextProviderImpl::OnConnectionError(WebNNContextImpl* impl) { … }
void WebNNContextProviderImpl::SetBackendForTesting(
BackendForTesting* backend_for_testing) { … }
void WebNNContextProviderImpl::CreateWebNNContext(
CreateContextOptionsPtr options,
WebNNContextProvider::CreateWebNNContextCallback callback) { … }
}