// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/349653202): Remove this and spanify to fix the errors.
#pragma allow_unsafe_buffers
#endif
#include "services/webnn/dml/context_impl_dml.h"
#include <limits>
#include "base/bits.h"
#include "base/check.h"
#include "base/compiler_specific.h"
#include "base/containers/span.h"
#include "base/strings/strcat.h"
#include "base/types/expected_macros.h"
#include "gpu/config/gpu_driver_bug_workaround_type.h"
#include "services/webnn/dml/adapter.h"
#include "services/webnn/dml/buffer_impl_dml.h"
#include "services/webnn/dml/command_queue.h"
#include "services/webnn/dml/graph_impl_dml.h"
#include "services/webnn/dml/utils.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/supported_data_types.h"
#include "services/webnn/public/mojom/webnn_buffer.mojom.h"
#include "services/webnn/webnn_context_impl.h"
namespace webnn::dml {
using Microsoft::WRL::ComPtr;
namespace {
void HandleBufferCreationFailure(
const std::string& error_message,
WebNNContextImpl::CreateBufferImplCallback callback) {
std::move(callback).Run(base::unexpected(
CreateError(mojom::Error::Code::kUnknownError, error_message)));
}
} // namespace
// The context properties follow the supported feature level on the platform.
// https://learn.microsoft.com/en-us/windows/ai/directml/dml-feature-level-history
//
// TODO(crbug.com/345271830): update the context properties based on a certain
// feature level once there is a bundled DirectML.dll.
// static
ContextProperties ContextImplDml::GetProperties(
DML_FEATURE_LEVEL feature_level) {
CHECK_GE(feature_level, DML_FEATURE_LEVEL_4_0);
static constexpr SupportedDataTypes kFloat16To32Ints8{
OperandDataType::kFloat16, OperandDataType::kFloat32,
OperandDataType::kInt8, OperandDataType::kUint8};
static constexpr SupportedDataTypes kFloat16To32Ints32{
OperandDataType::kFloat16, OperandDataType::kFloat32,
OperandDataType::kInt32, OperandDataType::kUint32};
static constexpr SupportedDataTypes kFloat16To32Ints8To32{
OperandDataType::kFloat16, OperandDataType::kFloat32,
OperandDataType::kInt8, OperandDataType::kUint8,
OperandDataType::kInt32, OperandDataType::kUint32};
static constexpr SupportedDataTypes kFloat16To32Int8To64{
OperandDataType::kFloat16, OperandDataType::kFloat32,
OperandDataType::kInt8, OperandDataType::kInt32, OperandDataType::kInt64};
static constexpr SupportedDataTypes kUint8To32{OperandDataType::kUint8,
OperandDataType::kUint32};
static constexpr SupportedDataTypes kGatherIndicesSupportedDataTypes{
OperandDataType::kInt32, OperandDataType::kUint32,
OperandDataType::kInt64, OperandDataType::kUint64};
// TODO: crbug.com/345271830 - specify data types for all parameters.
ContextProperties properties(
/*input_operand_layout=*/InputOperandLayout::kNchw, Resample2DAxes::kAny,
{/*input=*/SupportedDataTypes::All(),
/*constant=*/SupportedDataTypes::All(),
/*arg_min_max_input=*/SupportedDataTypes::All(),
/*arg_min_max_output=*/DataTypeConstraint::kInt32To64,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_cast_operator_desc#tensor-support
/*cast_input=*/SupportedDataTypes::All(),
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_clip_operator_desc#tensor-support
/*clamp_input=*/kFloat16To32Ints8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_join_operator_desc#tensor-support
/*concat_inputs=*/kFloat16To32Ints8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_add_operator_desc#tensor-support
/*add_input=*/kFloat16To32Ints32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_subtract_operator_desc#tensor-support
/*sub_input=*/kFloat16To32Ints32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_multiply_operator_desc#tensor-support
/*mul_input=*/kFloat16To32Ints32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_divide_operator_desc#tensor-support
/*div_input=*/kFloat16To32Ints32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_max_operator_desc#tensor-support
/*max_input=*/kFloat16To32Ints8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_min_operator_desc#tensor-support
/*min_input=*/kFloat16To32Ints8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_pow_operator_desc#tensor-support
/*pow_input=*/kFloat16To32Ints8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_logical_equals_operator_desc#tensor-support
/*equal_input=*/kFloat16To32Ints8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_logical_greater_than_operator_desc#tensor-support
/*greater_input=*/kFloat16To32Ints8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_logical_greater_than_or_equal_operator_desc#tensor-support
/*greater_or_equal_input=*/kFloat16To32Ints8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_logical_less_than_operator_desc#tensor-support
/*lesser_input=*/kFloat16To32Ints8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_logical_less_than_or_equal_operator_desc#tensor-support
/*lesser_or_equal_input=*/kFloat16To32Ints8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_logical_not_operator_desc#tensor-support
/*logical_not_input=*/kUint8To32,
/*logical_output=*/kUint8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_abs_operator_desc#tensor-support
/*abs_input=*/DataTypeConstraint::kFloat16To32Int8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_ceil_operator_desc#tensor-support
/*ceil_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_cos_operator_desc#tensor-support
/*cos_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_erf_operator_desc#tensor-support
/*erf_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_exp_operator_desc#tensor-support
/*exp_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_floor_operator_desc#tensor-support
/*floor_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_identity_operator_desc#tensor-support
/*identity_input=*/kFloat16To32Ints8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_log_operator_desc#tensor-support
/*log_input=*/DataTypeConstraint::kFloat16To32,
// Neg is emulated by DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC, so the
// data type limits is set based on the spec.
// DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC introduced in feature level 5.0
// also supports int64.
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_negate_operator_desc#tensor-support
/*neg_input=*/DataTypeConstraint::kFloat16To32Int8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_recip_operator_desc#tensor-support
/*reciprocal_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_sign_operator_desc#tensor-support
/*sign_input=*/DataTypeConstraint::kFloat16To32Int8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_sin_operator_desc#tensor-support
/*sin_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_sqrt_operator_desc#tensor-support
/*sqrt_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_tan_operator_desc#tensor-support
/*tan_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_activation_elu_operator_desc
/*elu_input=*/DataTypeConstraint::kFloat16To32,
// Expand is emulated by identity.
/*expand_input=*/kFloat16To32Ints8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gather_operator_desc#tensor-support
/*gather_input=*/kFloat16To32Ints8To32,
/*gather_indices=*/kGatherIndicesSupportedDataTypes,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gather_elements_operator_desc#tensor-support
/*gather_elements_input=*/kFloat16To32Ints8To32,
/*gather_elements_indices=*/kGatherIndicesSupportedDataTypes,
// Gelu is emulated when the feature level is less than 5.1.
// https://learn.microsoft.com/en-us/windows/ai/directml/api/ns-directml-dml_activation_gelu_operator_desc
/*gelu_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gemm_operator_desc#tensor-support
/*gemm_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_activation_hard_sigmoid_operator_desc
/*hard_sigmoid_input=*/DataTypeConstraint::kFloat16To32,
// HardSwish is emulated when the feature level is less than 6.2.
// https://learn.microsoft.com/en-us/windows/ai/directml/api/ns-directml-dml_activation_hard_swish_operator_desc
/*hard_swish_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_activation_leaky_relu_operator_desc
/*leaky_relu_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_activation_linear_operator_desc#tensor-support
/*linear_input=*/DataTypeConstraint::kFloat16To32,
// Matmul is emulated by gemm.
/*matmul_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_padding_operator_desc#tensor-support
/*pad_input=*/kFloat16To32Ints8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_average_pooling_operator_desc
/*average_pool2d_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_lp_pooling_operator_desc
/*l2_pool2d_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_max_pooling_operator_desc
/*max_pool2d_input=*/kFloat16To32Ints8,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_activation_parameterized_relu_operator_desc#tensor-support
/*prelu_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_reduce_operator_desc#tensor-support-according-to-function
/*reduce_l1_input=*/DataTypeConstraint::kFloat16To32,
/*reduce_l2_input=*/DataTypeConstraint::kFloat16To32,
/*reduce_log_sum_input=*/DataTypeConstraint::kFloat16To32,
/*reduce_log_sum_exp_input=*/DataTypeConstraint::kFloat16To32,
/*reduce_max_input=*/kFloat16To32Ints8To32,
/*reduce_mean_input=*/DataTypeConstraint::kFloat16To32,
/*reduce_min_input=*/kFloat16To32Ints8To32,
/*reduce_product_input=*/DataTypeConstraint::kFloat16To32,
/*reduce_sum_input=*/kFloat16To32Ints32,
/*reduce_sum_square_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_activation_relu_operator_desc
/*relu_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_resample_operator_desc#tensor-support
/*resample2d_input=*/DataTypeConstraint::kFloat16To32,
// Reshape is emulated by identity.
/*reshape_input=*/kFloat16To32Ints8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_activation_sigmoid_operator_desc#tensor-support
/*sigmoid_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_slice_operator_desc#tensor-support
/*slice_input=*/kFloat16To32Ints8To32,
// Softmax is emulated when the feature level is less than 5.1.
// https://learn.microsoft.com/en-us/windows/ai/directml/api/ns-directml-dml_activation_softmax1_operator_desc
/*softmax_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_activation_relu_operator_desc#tensor-support
/*softplus_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_activation_softsign_operator_desc#tensor-support
/*softsign_input=*/DataTypeConstraint::kFloat16To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_split_operator_desc#tensor-support
/*split_input=*/kFloat16To32Ints8To32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_activation_tanh_operator_desc#tensor-support
/*tanh_input=*/DataTypeConstraint::kFloat16To32,
// Transpose is emulated by identity.
/*transpose_input=*/kFloat16To32Ints8To32,
// Triangular is emulated by DML_FILL_VALUE_CONSTANT_OPERATOR_DESC,
// DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC,
// DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC,
// DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC and DML_SLICE_OPERATOR_DESC
// when the feature level is less than 5.1, so the data type limit is set
// based on these ops.
// https://learn.microsoft.com/en-us/windows/ai/directml/api/ns-directml-dml_diagonal_matrix1_operator_desc#tensor-support
/*triangular_input=*/kFloat16To32Ints32,
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_if_operator_desc
/*where_condition=*/DataTypeConstraint::kUint8,
/*where_value=*/kFloat16To32Ints8To32});
if (feature_level >= DML_FEATURE_LEVEL_4_1) {
properties.data_type_limits.concat_inputs = SupportedDataTypes::All();
properties.data_type_limits.add_input =
DataTypeConstraint::kFloat16To32Ints32To64;
properties.data_type_limits.sub_input =
DataTypeConstraint::kFloat16To32Ints32To64;
properties.data_type_limits.mul_input =
DataTypeConstraint::kFloat16To32Ints32To64;
properties.data_type_limits.equal_input = SupportedDataTypes::All();
properties.data_type_limits.greater_input = SupportedDataTypes::All();
properties.data_type_limits.greater_or_equal_input =
SupportedDataTypes::All();
properties.data_type_limits.lesser_input = SupportedDataTypes::All();
properties.data_type_limits.lesser_or_equal_input =
SupportedDataTypes::All();
properties.data_type_limits.abs_input = kFloat16To32Int8To64;
properties.data_type_limits.identity_input = SupportedDataTypes::All();
properties.data_type_limits.expand_input = SupportedDataTypes::All();
properties.data_type_limits.gather_input = SupportedDataTypes::All();
properties.data_type_limits.gather_elements_input =
SupportedDataTypes::All();
properties.data_type_limits.reshape_input = SupportedDataTypes::All();
properties.data_type_limits.sign_input =
DataTypeConstraint::kFloat16To32Int8To64;
properties.data_type_limits.slice_input = SupportedDataTypes::All();
properties.data_type_limits.split_input = SupportedDataTypes::All();
properties.data_type_limits.transpose_input = SupportedDataTypes::All();
properties.data_type_limits.triangular_input =
DataTypeConstraint::kFloat16To32Ints32To64;
}
if (feature_level >= DML_FEATURE_LEVEL_5_0) {
properties.data_type_limits.clamp_input = SupportedDataTypes::All();
properties.data_type_limits.max_input = SupportedDataTypes::All();
properties.data_type_limits.min_input = SupportedDataTypes::All();
properties.data_type_limits.pad_input = SupportedDataTypes::All(),
properties.data_type_limits.reduce_l1_input =
DataTypeConstraint::kFloat16To32Ints32To64;
properties.data_type_limits.reduce_max_input = SupportedDataTypes::All();
properties.data_type_limits.reduce_min_input = SupportedDataTypes::All();
properties.data_type_limits.reduce_sum_input =
DataTypeConstraint::kFloat16To32Ints32To64;
properties.data_type_limits.reduce_sum_square_input =
DataTypeConstraint::kFloat16To32Ints32To64;
properties.data_type_limits.where_value = SupportedDataTypes::All();
properties.data_type_limits.max_pool2d_input = SupportedDataTypes::All();
}
if (feature_level >= DML_FEATURE_LEVEL_5_1) {
properties.data_type_limits.add_input = SupportedDataTypes::All();
properties.data_type_limits.sub_input = SupportedDataTypes::All();
properties.data_type_limits.mul_input = SupportedDataTypes::All();
properties.data_type_limits.div_input = kFloat16To32Ints8To32;
properties.data_type_limits.prelu_input =
DataTypeConstraint::kFloat16To32Int8To32;
properties.data_type_limits.relu_input =
DataTypeConstraint::kFloat16To32Int8To32;
properties.data_type_limits.triangular_input = SupportedDataTypes::All();
}
if (feature_level >= DML_FEATURE_LEVEL_6_0) {
properties.data_type_limits.div_input = SupportedDataTypes::All();
}
if (feature_level >= DML_FEATURE_LEVEL_6_2) {
properties.data_type_limits.resample2d_input = kFloat16To32Ints8;
}
return properties;
}
ContextImplDml::ContextImplDml(
scoped_refptr<Adapter> adapter,
mojo::PendingReceiver<mojom::WebNNContext> receiver,
WebNNContextProviderImpl* context_provider,
mojom::CreateContextOptionsPtr options,
std::unique_ptr<CommandRecorder> command_recorder,
const gpu::GpuFeatureInfo& gpu_feature_info)
: WebNNContextImpl(std::move(receiver),
context_provider,
GetProperties(adapter->max_supported_feature_level()),
std::move(options)),
adapter_(std::move(adapter)),
command_recorder_(std::move(command_recorder)),
gpu_feature_info_(gpu_feature_info) {
CHECK(command_recorder_);
}
ContextImplDml::~ContextImplDml() = default;
base::WeakPtr<WebNNContextImpl> ContextImplDml::AsWeakPtr() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return weak_factory_.GetWeakPtr();
}
void ContextImplDml::CreateGraphImpl(
mojom::GraphInfoPtr graph_info,
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
WebNNContextImpl::CreateGraphImplCallback callback) {
GraphImplDml::CreateAndBuild(
adapter_, weak_factory_.GetWeakPtr(), std::move(graph_info),
std::move(compute_resource_info), std::move(callback),
gpu_feature_info_->IsWorkaroundEnabled(
gpu::DML_EXECUTION_DISABLE_META_COMMANDS));
}
void ContextImplDml::CreateBufferImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNBuffer> receiver,
mojom::BufferInfoPtr buffer_info,
CreateBufferImplCallback callback) {
// DML requires resources to be in multiple of 4 bytes.
// https://learn.microsoft.com/en-us/windows/ai/directml/dml-helper-functions#dmlcalcbuffertensorsize
constexpr uint64_t kDMLBufferAlignment = 4ull;
if (std::numeric_limits<uint64_t>::max() - kDMLBufferAlignment <
static_cast<uint64_t>(buffer_info->descriptor.PackedByteLength())) {
LOG(ERROR) << "[WebNN] Buffer is too large to create.";
HandleBufferCreationFailure("Failed to create buffer.",
std::move(callback));
return;
}
const uint64_t aligned_buffer_byte_size = base::bits::AlignUp(
static_cast<uint64_t>(buffer_info->descriptor.PackedByteLength()),
kDMLBufferAlignment);
HRESULT hr = S_OK;
// If adapter supports UMA, create the custom heap with CPU memory pool. The
// CPU will directly read/write to this heap if the GPU isn't using it.
ComPtr<ID3D12Resource> buffer;
if (adapter_->IsUMA()) {
// Create a buffer configured with memory properties based on
// usage.
if (buffer_info->usage.Has(MLBufferUsageFlags::kWriteTo)) {
// Upload buffer is used when the buffer mostly CPU writes but
// could also CPU read. A upload buffer provides less bandwidth for CPU
// reads in favor of GPU writes being optimal.
hr = CreateCustomUploadBuffer(
adapter_->d3d12_device(), aligned_buffer_byte_size,
L"WebNN_Custom_Upload_Buffer_External", buffer);
} else if (buffer_info->usage.Has(MLBufferUsageFlags::kReadFrom)) {
// Readback buffer is used when the buffer only requires CPU reads.
hr = CreateCustomReadbackBuffer(
adapter_->d3d12_device(), aligned_buffer_byte_size,
L"WebNN_Custom_Readback_Buffer_External", buffer);
} else {
// Default buffer is used when the buffer has no need for CPU access
// in favor of any GPU access being optimal.
hr = CreateDefaultBuffer(adapter_->d3d12_device(),
aligned_buffer_byte_size,
L"WebNN_Default_Buffer_External", buffer);
}
} else {
// Create a default buffer that can be accessed only by GPU.
// The CPU must use a staging buffer to read/write to this buffer.
hr = CreateDefaultBuffer(adapter_->d3d12_device(), aligned_buffer_byte_size,
L"WebNN_Default_Buffer_External", buffer);
}
if (FAILED(hr)) {
HandleBufferCreationFailure("Failed to create buffer.",
std::move(callback));
HandleContextLostOrCrash("Failed to create the external buffer.", hr);
return;
}
// The receiver bound to WebNNBufferImpl.
//
// Safe to use ContextImplDml* because this context owns the buffer
// being connected and that context cannot destruct before the buffer.
std::move(callback).Run(std::make_unique<BufferImplDml>(
std::move(receiver), std::move(buffer), this, std::move(buffer_info)));
}
void ContextImplDml::ReadBuffer(
BufferImplDml* src_buffer,
mojom::WebNNBuffer::ReadBufferCallback callback) {
const size_t src_buffer_size = src_buffer->PackedByteLength();
HRESULT hr = S_OK;
// Map entire buffer to readback the output data.
if (adapter_->IsUMA() && adapter_->command_queue()->GetCompletedValue() >=
src_buffer->last_submission_fence_value()) {
ContextImplDml::OnReadbackComplete(src_buffer->buffer(), src_buffer_size,
std::move(callback), hr);
return;
}
// Copy the buffer into a staging buffer to readback the output data.
ComPtr<ID3D12Resource> download_buffer;
hr = CreateReadbackBuffer(adapter_->d3d12_device(),
static_cast<uint64_t>(src_buffer_size),
L"WebNN_Readback_Buffer", download_buffer);
if (FAILED(hr)) {
std::move(callback).Run(ToError<mojom::ReadBufferResult>(
mojom::Error::Code::kUnknownError, "Failed to read buffer."));
HandleContextLostOrCrash("Failed to create the download buffer.", hr);
return;
}
hr = StartRecordingIfNecessary();
if (FAILED(hr)) {
std::move(callback).Run(ToError<mojom::ReadBufferResult>(
mojom::Error::Code::kUnknownError, "Failed to read buffer."));
HandleRecordingError("Failed to start recording.", hr);
return;
}
command_recorder_->ReadbackBufferWithBarrier(download_buffer, src_buffer,
src_buffer_size);
// Submit copy and schedule GPU wait.
hr = command_recorder_->CloseAndExecute();
if (FAILED(hr)) {
std::move(callback).Run(ToError<mojom::ReadBufferResult>(
mojom::Error::Code::kUnknownError, "Failed to read buffer."));
HandleRecordingError("Failed to close and execute the command list.", hr);
return;
}
// The source and readback buffer is held alive during execution by the
// recorder by calling `ReadbackBufferWithBarrier()` then
// CommandRecorder::CloseAndExecute().
adapter_->command_queue()->WaitAsync(base::BindOnce(
&ContextImplDml::OnReadbackComplete, weak_factory_.GetWeakPtr(),
std::move(download_buffer), src_buffer_size, std::move(callback)));
}
void ContextImplDml::OnReadbackComplete(
ComPtr<ID3D12Resource> download_buffer,
size_t read_byte_size,
mojom::WebNNBuffer::ReadBufferCallback callback,
HRESULT hr) {
if (FAILED(hr)) {
std::move(callback).Run(ToError<mojom::ReadBufferResult>(
mojom::Error::Code::kUnknownError, "Failed to read buffer."));
HandleRecordingError("Failed to download the buffer.", hr);
return;
}
CHECK(download_buffer);
// Copy over data from the download buffer to the destination buffer.
void* mapped_download_data = nullptr;
hr = download_buffer->Map(0, nullptr, &mapped_download_data);
if (FAILED(hr)) {
std::move(callback).Run(ToError<mojom::ReadBufferResult>(
mojom::Error::Code::kUnknownError, "Failed to read buffer."));
HandleContextLostOrCrash("Failed to map the download buffer.", hr);
return;
}
mojo_base::BigBuffer dst_buffer(base::make_span(
static_cast<const uint8_t*>(mapped_download_data), read_byte_size));
download_buffer->Unmap(0, nullptr);
std::move(callback).Run(
mojom::ReadBufferResult::NewBuffer(std::move(dst_buffer)));
}
void ContextImplDml::WriteBuffer(BufferImplDml* dst_buffer,
mojo_base::BigBuffer src_buffer) {
HRESULT hr = S_OK;
ComPtr<ID3D12Resource> buffer_to_map = dst_buffer->buffer();
// Create a staging buffer to upload data into when the existing buffer
// cannot be updated by the CPU.
if (!adapter_->IsUMA() || adapter_->command_queue()->GetCompletedValue() <
dst_buffer->last_submission_fence_value()) {
hr = CreateUploadBuffer(adapter_->d3d12_device(), src_buffer.size(),
L"WebNN_Upload_Buffer", buffer_to_map);
if (FAILED(hr)) {
HandleContextLostOrCrash("Failed to create the upload buffer.", hr);
return;
}
}
CHECK(buffer_to_map);
// Copy over data from the source buffer to the mapped buffer.
void* mapped_buffer_data = nullptr;
hr = buffer_to_map->Map(0, nullptr, &mapped_buffer_data);
if (FAILED(hr)) {
HandleContextLostOrCrash("Failed to map the buffer.", hr);
return;
}
CHECK(mapped_buffer_data);
// SAFETY: `buffer_to_map` was constructed with size `src_buffer.size()`.
UNSAFE_BUFFERS(
base::span(static_cast<uint8_t*>(mapped_buffer_data), src_buffer.size()))
.copy_from(src_buffer);
buffer_to_map->Unmap(0, nullptr);
// Uploads are only required when the mapped buffer was a staging buffer.
if (dst_buffer->buffer() != buffer_to_map.Get()) {
hr = StartRecordingIfNecessary();
if (FAILED(hr)) {
HandleRecordingError("Failed to start recording.", hr);
return;
}
command_recorder_->UploadBufferWithBarrier(
dst_buffer, std::move(buffer_to_map), src_buffer.size());
// TODO(crbug.com/40278771): consider not submitting after every write.
// CloseAndExecute() only needs to be called once, when the buffer is read
// by another context operation (ex. input into dispatch). Submitting
// immediately prevents memory usage from increasing; however, it also
// incurs more overhead due to a near empty command-list getting executed
// every time.
hr = command_recorder_->CloseAndExecute();
if (FAILED(hr)) {
HandleRecordingError("Failed to close and execute the command list.", hr);
return;
}
// Since the queue owns the upload buffer, it does not need to be provided
// to OnUploadComplete() and will be finally released once the wait is
// satisfied.
adapter_->command_queue()->WaitAsync(base::BindOnce(
&ContextImplDml::OnUploadComplete, weak_factory_.GetWeakPtr()));
}
}
void ContextImplDml::OnUploadComplete(HRESULT hr) {
if (FAILED(hr)) {
HandleRecordingError("Failed to upload the buffer.", hr);
return;
}
}
HRESULT ContextImplDml::StartRecordingIfNecessary() {
// Recreate the recorder on error since resources recorded but
// not executed would remain alive until this context gets destroyed and
// this context would be prevented from recording new commands.
if (!command_recorder_) {
ASSIGN_OR_RETURN(command_recorder_,
CommandRecorder::Create(adapter_->command_queue(),
adapter_->dml_device()));
}
CHECK(command_recorder_);
// If the recorder is already recording, no need to re-open.
if (command_recorder_->IsOpen()) {
return S_OK;
}
// Open the command recorder for recording the context execution commands.
RETURN_IF_FAILED(command_recorder_->Open());
CHECK(command_recorder_->IsOpen());
return S_OK;
}
void ContextImplDml::HandleRecordingError(std::string_view error_message,
HRESULT hr) {
command_recorder_.reset();
HandleContextLostOrCrash(error_message, hr);
}
void ContextImplDml::HandleContextLostOrCrash(std::string_view message_for_log,
HRESULT hr) {
LOG(ERROR) << "[WebNN] " << message_for_log << " "
<< logging::SystemErrorCodeToString(hr);
HRESULT device_removed_reason =
adapter_->d3d12_device()->GetDeviceRemovedReason();
if (FAILED(device_removed_reason)) {
LOG(ERROR) << "[WebNN] Device Removed Reason: "
<< logging::SystemErrorCodeToString(device_removed_reason);
}
std::string_view message_for_promise;
switch (hr) {
case E_OUTOFMEMORY:
message_for_promise = "out of memory.";
break;
case DXGI_ERROR_DEVICE_REMOVED:
message_for_promise = "device removed.";
break;
case DXGI_ERROR_DEVICE_RESET:
message_for_promise = "device reset.";
break;
default:
message_for_promise = "internal error.";
}
OnLost(base::StrCat({"WebNN context is lost due to ", message_for_promise}));
CHECK(hr == E_OUTOFMEMORY || hr == DXGI_ERROR_DEVICE_REMOVED ||
hr == DXGI_ERROR_DEVICE_RESET);
}
} // namespace webnn::dml