// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/349653202): Remove this and spanify to fix the errors.
#pragma allow_unsafe_buffers
#endif
#include "services/webnn/dml/graph_impl_dml.h"
#include <winerror.h>
#include <algorithm>
#include <array>
#include <limits>
#include <numeric>
#include "base/bits.h"
#include "base/check.h"
#include "base/containers/fixed_flat_set.h"
#include "base/feature_list.h"
#include "base/memory/ptr_util.h"
#include "base/notreached.h"
#include "base/numerics/safe_conversions.h"
#include "base/ranges/algorithm.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "base/task/thread_pool.h"
#include "base/trace_event/trace_event.h"
#include "base/types/expected.h"
#include "base/types/expected_macros.h"
#include "base/types/optional_ref.h"
#include "mojo/public/cpp/bindings/self_owned_associated_receiver.h"
#include "services/webnn/dml/adapter.h"
#include "services/webnn/dml/buffer_impl_dml.h"
#include "services/webnn/dml/command_queue.h"
#include "services/webnn/dml/command_recorder.h"
#include "services/webnn/dml/context_impl_dml.h"
#include "services/webnn/dml/error.h"
#include "services/webnn/dml/graph_builder_dml.h"
#include "services/webnn/dml/tensor_desc.h"
#include "services/webnn/dml/utils.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/graph_validation_utils.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/mojom/webnn_error.mojom.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_utils.h"
#include "third_party/abseil-cpp/absl/types/variant.h"
#include "third_party/fp16/src/include/fp16.h"
namespace webnn::dml {
namespace {
// The feature flag allows us to disable the graph fusion if it causes
// something wrong.
BASE_FEATURE(kApplyGraphFusion,
"ApplyGraphFusion",
base::FEATURE_ENABLED_BY_DEFAULT);
using Microsoft::WRL::ComPtr;
using mojom::ComputeResult;
using mojom::CreateGraphResult;
using mojom::Operand;
using mojom::OperandPtr;
using mojom::Operation;
// A map of all mojom operands in `mojom::GraphInfo` using the mojom operand id
// as key.
using IdToOperandMap = base::flat_map<uint64_t, OperandPtr>;
// A map of all node outputs in `dml::GraphBuilderDml` using the mojom operand
// id as key.
using IdToNodeOutputMap = std::map<uint64_t, const NodeOutput*>;
static constexpr auto kDmlFloatDataTypes =
base::MakeFixedFlatSet<DML_TENSOR_DATA_TYPE>(
{DML_TENSOR_DATA_TYPE_FLOAT32, DML_TENSOR_DATA_TYPE_FLOAT16});
DML_TENSOR_DATA_TYPE GetTensorDataType(OperandDataType type) {
switch (type) {
case OperandDataType::kFloat32:
return DML_TENSOR_DATA_TYPE_FLOAT32;
case OperandDataType::kFloat16:
return DML_TENSOR_DATA_TYPE_FLOAT16;
case OperandDataType::kInt8:
return DML_TENSOR_DATA_TYPE_INT8;
case OperandDataType::kUint8:
return DML_TENSOR_DATA_TYPE_UINT8;
case OperandDataType::kInt64:
return DML_TENSOR_DATA_TYPE_INT64;
case OperandDataType::kUint64:
return DML_TENSOR_DATA_TYPE_UINT64;
case OperandDataType::kInt32:
return DML_TENSOR_DATA_TYPE_INT32;
case OperandDataType::kUint32:
return DML_TENSOR_DATA_TYPE_UINT32;
}
}
OperandDataType DmlDataTypeToOperand(DML_TENSOR_DATA_TYPE type) {
switch (type) {
case DML_TENSOR_DATA_TYPE_FLOAT32:
return OperandDataType::kFloat32;
case DML_TENSOR_DATA_TYPE_FLOAT16:
return OperandDataType::kFloat16;
case DML_TENSOR_DATA_TYPE_INT8:
return OperandDataType::kInt8;
case DML_TENSOR_DATA_TYPE_UINT8:
return OperandDataType::kUint8;
case DML_TENSOR_DATA_TYPE_INT64:
return OperandDataType::kInt64;
case DML_TENSOR_DATA_TYPE_UINT64:
return OperandDataType::kUint64;
case DML_TENSOR_DATA_TYPE_INT32:
return OperandDataType::kInt32;
case DML_TENSOR_DATA_TYPE_UINT32:
return OperandDataType::kUint32;
default:
NOTREACHED() << "[WebNN] This data type is not supported.";
}
}
DML_REDUCE_FUNCTION MapReduceKindToReduceFuntion(mojom::Reduce::Kind kind) {
switch (kind) {
case mojom::Reduce::Kind::kL1:
return DML_REDUCE_FUNCTION_L1;
case mojom::Reduce::Kind::kL2:
return DML_REDUCE_FUNCTION_L2;
case mojom::Reduce::Kind::kLogSum:
return DML_REDUCE_FUNCTION_LOG_SUM;
case mojom::Reduce::Kind::kLogSumExp:
return DML_REDUCE_FUNCTION_LOG_SUM_EXP;
case mojom::Reduce::Kind::kMax:
return DML_REDUCE_FUNCTION_MAX;
case mojom::Reduce::Kind::kMean:
return DML_REDUCE_FUNCTION_AVERAGE;
case mojom::Reduce::Kind::kMin:
return DML_REDUCE_FUNCTION_MIN;
case mojom::Reduce::Kind::kProduct:
return DML_REDUCE_FUNCTION_MULTIPLY;
case mojom::Reduce::Kind::kSum:
return DML_REDUCE_FUNCTION_SUM;
case mojom::Reduce::Kind::kSumSquare:
return DML_REDUCE_FUNCTION_SUM_SQUARE;
}
}
void CheckInputDataTypeForReduce(const DataTypeLimits& data_type_limits,
mojom::Reduce::Kind kind,
OperandDataType data_type) {
switch (kind) {
case mojom::Reduce::Kind::kL1:
CHECK(data_type_limits.reduce_l1_input.Has(data_type));
break;
case mojom::Reduce::Kind::kL2:
CHECK(data_type_limits.reduce_l2_input.Has(data_type));
break;
case mojom::Reduce::Kind::kLogSum:
CHECK(data_type_limits.reduce_log_sum_input.Has(data_type));
break;
case mojom::Reduce::Kind::kLogSumExp:
CHECK(data_type_limits.reduce_log_sum_exp_input.Has(data_type));
break;
case mojom::Reduce::Kind::kMax:
CHECK(data_type_limits.reduce_max_input.Has(data_type));
break;
case mojom::Reduce::Kind::kMean:
CHECK(data_type_limits.reduce_mean_input.Has(data_type));
break;
case mojom::Reduce::Kind::kMin:
CHECK(data_type_limits.reduce_min_input.Has(data_type));
break;
case mojom::Reduce::Kind::kProduct:
CHECK(data_type_limits.reduce_product_input.Has(data_type));
break;
case mojom::Reduce::Kind::kSum:
CHECK(data_type_limits.reduce_sum_input.Has(data_type));
break;
case mojom::Reduce::Kind::kSumSquare:
CHECK(data_type_limits.reduce_sum_square_input.Has(data_type));
break;
}
}
DML_RECURRENT_NETWORK_DIRECTION MojoRecurrentNetworkDirectionToDml(
mojom::RecurrentNetworkDirection direction) {
switch (direction) {
case mojom::RecurrentNetworkDirection::kForward:
return DML_RECURRENT_NETWORK_DIRECTION_FORWARD;
case mojom::RecurrentNetworkDirection::kBackward:
return DML_RECURRENT_NETWORK_DIRECTION_BACKWARD;
case mojom::RecurrentNetworkDirection::kBoth:
return DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL;
}
}
// TODO(crbug.com/354543926): All calls to CreateError can be replaced by
// CreateUnexpectedError.
base::expected<void, mojom::ErrorPtr> CreateUnexpectedError(
mojom::Error::Code error_code,
const std::string& error_message,
std::string_view label) {
return base::unexpected(CreateError(error_code, error_message, label));
}
// Calculate the total byte length of buffers and the D3D12_RANGE for each
// buffer, all with the required alignment.
std::optional<AlignedByteLength<uint64_t>> CalculateAlignedByteLength(
const base::flat_map<uint64_t, mojo_base::BigBuffer>& ids_to_buffers) {
base::CheckedNumeric<size_t> total_byte_length(0);
std::map<uint64_t, D3D12_RANGE> key_to_d3d12_range_map;
for (const auto& [buffer_id, buffer] : ids_to_buffers) {
auto& d3d12_range = key_to_d3d12_range_map[buffer_id];
d3d12_range.Begin = total_byte_length.ValueOrDie();
// The buffer has a minimum base address alignment requirement of 16 bytes
// in the macro `DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT`:
// https://learn.microsoft.com/en-us/windows/win32/direct3d12/direct3d-directml-constants
total_byte_length += base::bits::AlignUp<size_t>(
buffer.size(), DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT);
if (!total_byte_length.IsValid()) {
LOG(ERROR) << "[WebNN] Failed to calculate the total byte length.";
return std::nullopt;
}
// The aligned byte length calculated with `End` sub `Begin` attribute is
// used to set the `SizeInBytes` field of `DML_BUFFER_BINDING`.
d3d12_range.End = total_byte_length.ValueOrDie();
}
return AlignedByteLength<uint64_t>{
.total_byte_length = total_byte_length.ValueOrDie(),
.key_to_d3d12_range_map = std::move(key_to_d3d12_range_map)};
}
// Same as above, but given a map of names to descriptors.
std::optional<AlignedByteLength<std::string>>
CalculateAlignedByteLengthFromDescriptors(
const base::flat_map<std::string, OperandDescriptor>&
names_to_descriptors) {
base::CheckedNumeric<size_t> total_byte_length(0);
std::map<std::string, D3D12_RANGE> key_to_d3d12_range_map;
for (auto& [name, descriptor] : names_to_descriptors) {
auto& d3d12_range = key_to_d3d12_range_map[name];
d3d12_range.Begin = total_byte_length.ValueOrDie();
// The buffer has a minimum base address alignment requirement of 16 bytes
// in the macro `DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT`:
// https://learn.microsoft.com/en-us/windows/win32/direct3d12/direct3d-directml-constants
total_byte_length += base::bits::AlignUp<size_t>(
descriptor.PackedByteLength(), DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT);
if (!total_byte_length.IsValid()) {
LOG(ERROR) << "[WebNN] Failed to calculate the total byte length.";
return std::nullopt;
}
// The aligned byte length calculated with `End` sub `Begin` attribute is
// used to set the `SizeInBytes` field of `DML_BUFFER_BINDING`.
d3d12_range.End = total_byte_length.ValueOrDie();
}
return AlignedByteLength<std::string>{
.total_byte_length = total_byte_length.ValueOrDie(),
.key_to_d3d12_range_map = std::move(key_to_d3d12_range_map)};
}
struct UploadAndDefaultBuffers {
ComPtr<ID3D12Resource> upload_buffer;
ComPtr<ID3D12Resource> default_buffer;
};
// Upload constants buffers in one Direct3D 12 committed resource, the
// DML_BUFFER_BINDING specifies a resource binding described by a range of bytes
// in the single buffer. For GPU supports UMA, pass a custom upload buffer via
// `buffer_variant` for both constants uploading and binding. For GPU doesn't
// support UMA, pass a upload buffer and a default buffer via `buffer_variant`
// for uploading and binding separately.
base::expected<std::map<uint64_t, DML_BUFFER_BINDING>, HRESULT>
UploadAndCreateConstantBufferBinding(
CommandRecorder* command_recorder,
const base::flat_map<uint64_t, mojo_base::BigBuffer>& key_to_buffer_map,
const AlignedByteLength<uint64_t>& aligned_byte_length,
absl::variant<UploadAndDefaultBuffers, ComPtr<ID3D12Resource>>
buffer_variant) {
// Map entire resource to copy the array buffer of constant/input one by one
// with byte offset.
void* mapped_buffer = nullptr;
ID3D12Resource* buffer_to_map = nullptr;
ID3D12Resource* buffer_to_bind = nullptr;
ComPtr<ID3D12Resource> cpu_buffer;
ComPtr<ID3D12Resource> upload_buffer;
ComPtr<ID3D12Resource> default_buffer;
if (absl::holds_alternative<ComPtr<ID3D12Resource>>(buffer_variant)) {
cpu_buffer = std::move(absl::get<ComPtr<ID3D12Resource>>(buffer_variant));
buffer_to_map = cpu_buffer.Get();
buffer_to_bind = buffer_to_map;
} else {
upload_buffer = std::move(
absl::get<UploadAndDefaultBuffers>(buffer_variant).upload_buffer);
default_buffer = std::move(
absl::get<UploadAndDefaultBuffers>(buffer_variant).default_buffer);
buffer_to_map = upload_buffer.Get();
buffer_to_bind = default_buffer.Get();
}
CHECK(buffer_to_map);
CHECK(buffer_to_bind);
RETURN_UNEXPECTED_IF_FAILED(buffer_to_map->Map(0, nullptr, &mapped_buffer));
std::map<uint64_t, DML_BUFFER_BINDING> key_to_buffer_binding_map;
for (auto& [key, buffer] : key_to_buffer_map) {
// Copy the input data to the upload heap with byte offset
const auto& d3d12_range =
aligned_byte_length.key_to_d3d12_range_map.at(key);
memcpy(static_cast<uint8_t*>(mapped_buffer) + d3d12_range.Begin,
buffer.data(), buffer.size());
// Create the buffer binding for each constant/input and push back into the
// DML_BUFFER_BINDING array.
auto size_in_bytes = d3d12_range.End - d3d12_range.Begin;
key_to_buffer_binding_map[key] =
DML_BUFFER_BINDING{.Buffer = buffer_to_bind,
.Offset = d3d12_range.Begin,
.SizeInBytes = size_in_bytes};
}
buffer_to_map->Unmap(0, nullptr);
if (absl::holds_alternative<ComPtr<ID3D12Resource>>(buffer_variant)) {
CHECK(cpu_buffer);
command_recorder->ReferenceCommandResources(std::move(cpu_buffer));
} else {
CHECK(default_buffer);
CHECK(upload_buffer);
UploadBufferWithBarrier(command_recorder, std::move(default_buffer),
std::move(upload_buffer),
aligned_byte_length.total_byte_length);
}
return key_to_buffer_binding_map;
}
HRESULT MapAndCopyInputDataToBuffer(
const base::flat_map<std::string, mojo_base::BigBuffer>& named_inputs,
const std::map<std::string, D3D12_RANGE>& input_name_to_d3d12_range_map,
ID3D12Resource* buffer) {
// Map entire resource to copy the array buffer of input one by one
// with byte offset.
void* mapped_buffer = nullptr;
CHECK(buffer);
RETURN_IF_FAILED(buffer->Map(0, nullptr, &mapped_buffer));
for (auto& [name, input] : named_inputs) {
// Copy the input data to the upload heap with byte offset
const auto& d3d12_range = input_name_to_d3d12_range_map.at(name);
memcpy(static_cast<uint8_t*>(mapped_buffer) + d3d12_range.Begin,
input.data(), input.size());
}
buffer->Unmap(0, nullptr);
return S_OK;
}
// Define some methods like CreateInputNode and CreateOperatorNodeForRelu here
// to focus on converting the mojo graph struct to corresponding DML graph node
// by using dml::GraphBuilderDml as a helper. dml::GraphBuilderDml should be
// decoupled from mojo graph structs and focus on manipulating DML graph
// structs.
//
// The return value is the GraphInputIndex assigned by graph builder.
uint32_t CreateInputNode(const IdToOperandMap& id_to_operand_map,
uint64_t input_id,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const OperandPtr& operand = id_to_operand_map.at(input_id);
// If the operand is constant, the tensor is identified by
// DML_TENSOR_FLAG_OWNED_BY_DML which must be bound to the binding table
// during the graph initialization, and not during execution.
DML_TENSOR_FLAGS flags = operand->kind == Operand::Kind::kConstant
? DML_TENSOR_FLAG_OWNED_BY_DML
: DML_TENSOR_FLAG_NONE;
TensorDesc input_tensor_desc(
GetTensorDataType(operand->descriptor.data_type()), flags,
operand->descriptor.shape());
const InputNode* input_node = graph_builder.CreateInputNode();
CHECK(input_node);
const NodeOutput* node_output =
graph_builder.CreateNodeOutput(input_node, std::move(input_tensor_desc));
CHECK(node_output);
id_to_node_output_map[input_id] = std::move(node_output);
return input_node->GetGraphInputIndex();
}
const NodeOutput* GetNodeOutputForOperand(
const IdToNodeOutputMap& id_to_node_output_map,
uint64_t operand_id) {
const auto input_iterator = id_to_node_output_map.find(operand_id);
CHECK(input_iterator != id_to_node_output_map.end());
CHECK(input_iterator->second);
return input_iterator->second;
}
const NodeOutput* GetOptionalNodeOutputForOperand(
const IdToNodeOutputMap& id_to_node_output_map,
std::optional<uint64_t> operand_id) {
return operand_id.has_value() ? GetNodeOutputForOperand(id_to_node_output_map,
operand_id.value())
: nullptr;
}
const DML_TENSOR_DESC* GetOptionalDmlTensorDescPtr(
base::optional_ref<const TensorDesc> tensor_desc) {
return tensor_desc.has_value() ? &tensor_desc->GetDMLTensorDesc() : nullptr;
}
// Build a one-element constant operand with specified rank for float value and
// add it into the graph info. For example, if the rank is 3, the operand
// dimensions would be {1, 1, 1}.
uint64_t BuildConstantOperandForFloatValue(mojom::GraphInfoPtr& graph_info,
uint64_t& next_operand_id,
OperandDataType data_type,
size_t rank,
float value) {
OperandPtr constant_operand = Operand::New();
constant_operand->kind = Operand::Kind::kConstant;
constant_operand->descriptor =
*OperandDescriptor::Create(data_type, std::vector<uint32_t>(rank, 1));
uint64_t constant_id = next_operand_id++;
CHECK(graph_info->id_to_operand_map
.try_emplace(constant_id, std::move(constant_operand))
.second);
mojo_base::BigBuffer buffer;
switch (data_type) {
case OperandDataType::kFloat32: {
buffer = mojo_base::BigBuffer(base::make_span(
reinterpret_cast<const uint8_t*>(&value), sizeof(value)));
break;
}
case OperandDataType::kFloat16: {
uint16_t fp16_value = fp16_ieee_from_fp32_value(value);
buffer = mojo_base::BigBuffer(base::make_span(
reinterpret_cast<const uint8_t*>(&fp16_value), sizeof(fp16_value)));
break;
}
default:
LOG(ERROR) << "[WebNN] The data type must be one of the floating point "
"data types.";
NOTREACHED();
}
CHECK(graph_info->constant_id_to_buffer_map
.try_emplace(constant_id, std::move(buffer))
.second);
return constant_id;
}
const TensorDesc CreateOutputTensorDesc(const IdToOperandMap& id_to_operand_map,
uint64_t output_id) {
const OperandPtr& output_operand = id_to_operand_map.at(output_id);
return TensorDesc(GetTensorDataType(output_operand->descriptor.data_type()),
output_operand->descriptor.shape());
}
void CreateOperatorNodeForArgMinMax(const IdToOperandMap& id_to_operand_map,
const mojom::ArgMinMaxPtr& arg_min_max,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, arg_min_max->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
const uint64_t output_id = arg_min_max->output_operand_id;
const auto& output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
const uint32_t axis = arg_min_max->axis;
// Determine output sizes. Ignore output_desc->dimensions for the dimensions,
// since DirectML expects the output dimensions to have the same rank as the
// input, and output_desc->dimensions may have removed dimensions if
// keepDimensions was false.
std::vector<uint32_t> output_dimensions = input_tensor_desc.GetDimensions();
CHECK_LT(axis, output_dimensions.size());
output_dimensions[axis] = 1u;
TensorDesc new_output_tensor_desc(output_tensor_desc.GetDataType(),
std::move(output_dimensions));
DML_OPERATOR_TYPE operator_type;
switch (arg_min_max->kind) {
case mojom::ArgMinMax_Kind::kMin: {
operator_type = DML_OPERATOR_ARGMIN;
break;
}
case mojom::ArgMinMax_Kind::kMax: {
operator_type = DML_OPERATOR_ARGMAX;
break;
}
}
const std::array<const uint32_t, 1> axes = {axis};
DML_ARGMAX_OPERATOR_DESC operator_desc = {};
operator_desc.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
operator_desc.OutputTensor = &new_output_tensor_desc.GetDMLTensorDesc(),
operator_desc.AxisCount = axes.size();
operator_desc.Axes = axes.data();
operator_desc.AxisDirection =
DML_AXIS_DIRECTION::DML_AXIS_DIRECTION_INCREASING;
std::array<const NodeOutput*, 1> inputs = {input};
const OperatorNode* arg_min_max_node = graph_builder.CreateOperatorNode(
operator_type, &operator_desc, inputs, arg_min_max->label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(arg_min_max_node, output_tensor_desc);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
struct ActivationOperatorDesc {
absl::variant<DML_ACTIVATION_ELU_OPERATOR_DESC,
DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC,
DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC,
DML_ACTIVATION_LINEAR_OPERATOR_DESC,
DML_ACTIVATION_RELU_OPERATOR_DESC,
DML_ACTIVATION_SIGMOID_OPERATOR_DESC,
DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC,
DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC,
DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC,
DML_ACTIVATION_TANH_OPERATOR_DESC>
desc;
DML_OPERATOR_DESC GetActivationDmlDesc() const {
if (absl::holds_alternative<DML_ACTIVATION_ELU_OPERATOR_DESC>(desc)) {
return {DML_OPERATOR_ACTIVATION_ELU,
&absl::get<DML_ACTIVATION_ELU_OPERATOR_DESC>(desc)};
} else if (absl::holds_alternative<
DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC>(desc)) {
return {DML_OPERATOR_ACTIVATION_HARD_SIGMOID,
&absl::get<DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC>(desc)};
} else if (absl::holds_alternative<DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_LEAKY_RELU,
&absl::get<DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC>(desc)};
} else if (absl::holds_alternative<DML_ACTIVATION_LINEAR_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_LINEAR,
&absl::get<DML_ACTIVATION_LINEAR_OPERATOR_DESC>(desc)};
} else if (absl::holds_alternative<DML_ACTIVATION_RELU_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_RELU,
&absl::get<DML_ACTIVATION_RELU_OPERATOR_DESC>(desc)};
} else if (absl::holds_alternative<DML_ACTIVATION_SIGMOID_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_SIGMOID,
&absl::get<DML_ACTIVATION_SIGMOID_OPERATOR_DESC>(desc)};
} else if (absl::holds_alternative<DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_SOFTMAX1,
&absl::get<DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC>(desc)};
} else if (absl::holds_alternative<DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_SOFTPLUS,
&absl::get<DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC>(desc)};
} else if (absl::holds_alternative<DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_SOFTSIGN,
&absl::get<DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC>(desc)};
} else if (absl::holds_alternative<DML_ACTIVATION_TANH_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_TANH,
&absl::get<DML_ACTIVATION_TANH_OPERATOR_DESC>(desc)};
} else {
NOTREACHED() << "The activation type is not supported.";
}
}
};
ActivationOperatorDesc CreateOperatorDescForActivation(
mojom::RecurrentNetworkActivation activation) {
switch (activation) {
case mojom::RecurrentNetworkActivation::kRelu:
return ActivationOperatorDesc{.desc =
DML_ACTIVATION_RELU_OPERATOR_DESC{}};
case mojom::RecurrentNetworkActivation::kSigmoid:
return ActivationOperatorDesc{.desc =
DML_ACTIVATION_SIGMOID_OPERATOR_DESC{}};
case mojom::RecurrentNetworkActivation::kTanh:
return ActivationOperatorDesc{.desc =
DML_ACTIVATION_TANH_OPERATOR_DESC{}};
}
}
std::optional<const Operation*> GetFusibleActivationFromOperation(
const std::map<const Operation*, const Operation*>&
operation_to_fusible_standalone_activation_map,
const Operation* operation) {
const auto activation_iterator =
operation_to_fusible_standalone_activation_map.find(operation);
if (activation_iterator !=
operation_to_fusible_standalone_activation_map.end()) {
return activation_iterator->second;
}
return std::optional<const Operation*>();
}
std::optional<uint64_t> GetFusibleTransposeInputId(
const std::map<uint64_t, const Operation*>&
output_id_to_fusible_transpose_map,
uint64_t input_id) {
const auto transpose_iterator =
output_id_to_fusible_transpose_map.find(input_id);
if (transpose_iterator != output_id_to_fusible_transpose_map.end()) {
return transpose_iterator->second->get_transpose()->input_operand_id;
}
return std::optional<uint64_t>();
}
// According to the DirectML documentations:
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_add1_operator_desc,
// and
// https://learn.microsoft.com/en-us/windows/ai/directml/dml-fused-activations,
// for the element wise binary operation, only `DML_OPERATOR_ELEMENT_WISE_ADD1`
// supports fused activation when the output data type is FLOAT16 or FLOAT32.
bool CanElementWiseBinarySupportFusion(
const mojom::ElementWiseBinaryPtr& binary,
const IdToOperandMap& id_to_operand_map) {
const OperandPtr& output_operand =
id_to_operand_map.at(binary->output_operand_id);
OperandDataType output_data_type = output_operand->descriptor.data_type();
return binary->kind == mojom::ElementWiseBinary::Kind::kAdd &&
(output_data_type == OperandDataType::kFloat32 ||
output_data_type == OperandDataType::kFloat16);
}
// Return true if the operation can be fused with any of the following
// standalone activations operators according to
// https://learn.microsoft.com/en-us/windows/ai/directml/dml-fused-activations:
// DML_OPERATOR_BATCH_NORMALIZATION
// DML_OPERATOR_BATCH_NORMALIZATION_TRAINING
// DML_OPERATOR_CONVOLUTION
// DML_OPERATOR_ELEMENT_WISE_ADD1
// DML_OPERATOR_GEMM
// DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION
// DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1
bool CanFuseStandaloneActivation(const Operation* operation,
const IdToOperandMap& id_to_operand_map) {
switch (operation->which()) {
case Operation::Tag::kElementWiseBinary:
return CanElementWiseBinarySupportFusion(
operation->get_element_wise_binary(), id_to_operand_map);
case Operation::Tag::kConv2d:
case Operation::Tag::kBatchNormalization:
case Operation::Tag::kGemm:
case Operation::Tag::kInstanceNormalization:
case Operation::Tag::kLayerNormalization:
case Operation::Tag::kMatmul:
return true;
default:
return false;
}
}
// Return a valid output id if the operation is a fusible activation according
// to
// https://learn.microsoft.com/en-us/windows/ai/directml/dml-fused-activations.
// DML_OPERATOR_ELEMENT_WISE_CLIP will be supported after the DirectML version
// upper than DML_FEATURE_LEVEL_6_0 according to
// https://learn.microsoft.com/en-us/windows/ai/directml/dml-feature-level-history#dml_feature_level_6_0.
//
// TODO(crbug.com/345640552): Fuse clip and other operators when possible.
std::optional<uint64_t> GetFusibleActivationOutputId(
const mojom::Operation& operation) {
switch (operation.which()) {
case mojom::Operation::Tag::kElu:
return operation.get_elu()->output_operand_id;
case mojom::Operation::Tag::kHardSigmoid:
return operation.get_hard_sigmoid()->output_operand_id;
case mojom::Operation::Tag::kLeakyRelu:
return operation.get_leaky_relu()->output_operand_id;
case mojom::Operation::Tag::kLinear:
return operation.get_linear()->output_operand_id;
case mojom::Operation::Tag::kRelu:
return operation.get_relu()->output_operand_id;
case mojom::Operation::Tag::kSigmoid:
return operation.get_sigmoid()->output_operand_id;
case mojom::Operation::Tag::kSoftplus:
return operation.get_softplus()->output_operand_id;
case mojom::Operation::Tag::kSoftsign:
return operation.get_softsign()->output_operand_id;
case mojom::Operation::Tag::kTanh:
return operation.get_tanh()->output_operand_id;
default:
return std::optional<uint64_t>();
}
}
ActivationOperatorDesc CreateOperatorDescForFusibleActivation(
const mojom::Operation& activation) {
CHECK(GetFusibleActivationOutputId(activation));
switch (activation.which()) {
case mojom::Operation::Tag::kElu:
return ActivationOperatorDesc{.desc = DML_ACTIVATION_ELU_OPERATOR_DESC{
.Alpha = activation.get_elu()->alpha}};
case mojom::Operation::Tag::kHardSigmoid:
return ActivationOperatorDesc{
.desc = DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC{
.Alpha = activation.get_hard_sigmoid()->alpha,
.Beta = activation.get_hard_sigmoid()->beta}};
case mojom::Operation::Tag::kLeakyRelu:
return ActivationOperatorDesc{
.desc = DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC{
.Alpha = activation.get_leaky_relu()->alpha}};
case mojom::Operation::Tag::kLinear:
return ActivationOperatorDesc{.desc = DML_ACTIVATION_LINEAR_OPERATOR_DESC{
.Alpha = activation.get_linear()->alpha,
.Beta = activation.get_linear()->beta}};
case mojom::Operation::Tag::kRelu:
return ActivationOperatorDesc{.desc =
DML_ACTIVATION_RELU_OPERATOR_DESC{}};
case mojom::Operation::Tag::kSigmoid:
return ActivationOperatorDesc{.desc =
DML_ACTIVATION_SIGMOID_OPERATOR_DESC{}};
case mojom::Operation::Tag::kSoftplus:
return ActivationOperatorDesc{
.desc = DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC{.Steepness = 1.0}};
case mojom::Operation::Tag::kSoftsign:
return ActivationOperatorDesc{
.desc = DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC{}};
case mojom::Operation::Tag::kTanh:
return ActivationOperatorDesc{.desc =
DML_ACTIVATION_TANH_OPERATOR_DESC{}};
default:
NOTREACHED() << "The operation is not a fusible activation.";
}
}
// The struct contains the connectivity information of an operation in
// `mojom::GraphInfo::operations`. It helps to generate and represent the
// topological information about how all operations are connected.
struct OperationConnectivity {
// The operation's input ids which are used to identity the input operands in
// `mojom::GraphInfo::id_to_operand_map`.
std::vector<uint64_t> input_ids;
// The operation's output ids which are used to identity the output operands
// in `mojom::GraphInfo::id_to_operand_map`.
std::vector<uint64_t> output_ids;
};
void RetrieveOperationConnectivity(
const Operation* operation,
OperationConnectivity& out_operation_connectivity) {
std::vector<uint64_t>& input_ids = out_operation_connectivity.input_ids;
std::vector<uint64_t>& output_ids = out_operation_connectivity.output_ids;
input_ids.clear();
output_ids.clear();
switch (operation->which()) {
case Operation::Tag::kArgMinMax: {
const auto& arg_min_max = operation->get_arg_min_max();
input_ids = {arg_min_max->input_operand_id};
output_ids = {arg_min_max->output_operand_id};
break;
}
case Operation::Tag::kBatchNormalization: {
const auto& batch_norm = operation->get_batch_normalization();
input_ids = {batch_norm->input_operand_id, batch_norm->mean_operand_id,
batch_norm->variance_operand_id};
auto& scale_operand_id = batch_norm->scale_operand_id;
if (scale_operand_id) {
input_ids.push_back(scale_operand_id.value());
}
auto& bias_operand_id = batch_norm->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
output_ids = {batch_norm->output_operand_id};
break;
}
case Operation::Tag::kClamp: {
const auto& clamp = operation->get_clamp();
input_ids = {clamp->input_operand_id};
output_ids = {clamp->output_operand_id};
break;
}
case Operation::Tag::kConcat: {
const auto& concat = operation->get_concat();
input_ids = {concat->input_operand_ids};
output_ids = {concat->output_operand_id};
break;
}
case Operation::Tag::kConv2d: {
const auto& conv2d = operation->get_conv2d();
input_ids = {conv2d->input_operand_id, conv2d->filter_operand_id};
auto& bias_operand_id = conv2d->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
output_ids = {conv2d->output_operand_id};
break;
}
case Operation::Tag::kElementWiseBinary: {
const auto& binary = operation->get_element_wise_binary();
input_ids = {binary->lhs_operand_id, binary->rhs_operand_id};
output_ids = {binary->output_operand_id};
break;
}
case Operation::Tag::kElu: {
const auto& elu = operation->get_elu();
input_ids = {elu->input_operand_id};
output_ids = {elu->output_operand_id};
break;
}
case Operation::Tag::kElementWiseUnary: {
const auto& unary = operation->get_element_wise_unary();
input_ids = {unary->input_operand_id};
output_ids = {unary->output_operand_id};
break;
}
case Operation::Tag::kExpand: {
const auto& expand = operation->get_expand();
input_ids = {expand->input_operand_id};
output_ids = {expand->output_operand_id};
break;
}
case Operation::Tag::kGather: {
const auto& gather = operation->get_gather();
input_ids = {gather->input_operand_id, gather->indices_operand_id};
output_ids = {gather->output_operand_id};
break;
}
case Operation::Tag::kGatherElements: {
const auto& gather_elements = operation->get_gather_elements();
input_ids = {gather_elements->input_operand_id,
gather_elements->indices_operand_id};
output_ids = {gather_elements->output_operand_id};
break;
}
case Operation::Tag::kGelu: {
const auto& gelu = operation->get_gelu();
input_ids = {gelu->input_operand_id};
output_ids = {gelu->output_operand_id};
break;
}
case Operation::Tag::kGemm: {
const auto& gemm = operation->get_gemm();
input_ids = {gemm->a_operand_id, gemm->b_operand_id};
auto& c_operand_id = gemm->c_operand_id;
if (c_operand_id) {
input_ids.push_back(c_operand_id.value());
}
output_ids = {gemm->output_operand_id};
break;
}
case Operation::Tag::kGru: {
const auto& gru = operation->get_gru();
input_ids = {gru->input_operand_id, gru->weight_operand_id,
gru->recurrent_weight_operand_id};
auto& bias_operand_id = gru->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
auto& recurrent_bias_operand_id = gru->recurrent_bias_operand_id;
if (recurrent_bias_operand_id) {
input_ids.push_back(recurrent_bias_operand_id.value());
}
auto& initial_hidden_state_operand_id =
gru->initial_hidden_state_operand_id;
if (initial_hidden_state_operand_id) {
input_ids.push_back(initial_hidden_state_operand_id.value());
}
output_ids = {gru->output_operand_ids};
break;
}
case Operation::Tag::kGruCell: {
const auto& gru_cell = operation->get_gru_cell();
input_ids = {gru_cell->input_operand_id, gru_cell->weight_operand_id,
gru_cell->recurrent_weight_operand_id,
gru_cell->hidden_state_operand_id};
auto& bias_operand_id = gru_cell->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
auto& recurrent_bias_operand_id = gru_cell->recurrent_bias_operand_id;
if (recurrent_bias_operand_id) {
input_ids.push_back(recurrent_bias_operand_id.value());
}
output_ids = {gru_cell->output_operand_id};
break;
}
case Operation::Tag::kHardSigmoid: {
const auto& hard_sgmoid = operation->get_hard_sigmoid();
input_ids = {hard_sgmoid->input_operand_id};
output_ids = {hard_sgmoid->output_operand_id};
break;
}
case Operation::Tag::kHardSwish: {
const auto& hard_swish = operation->get_hard_swish();
input_ids = {hard_swish->input_operand_id};
output_ids = {hard_swish->output_operand_id};
break;
}
case Operation::Tag::kInstanceNormalization: {
const auto& instance_norm = operation->get_instance_normalization();
input_ids = {instance_norm->input_operand_id};
auto& scale_operand_id = instance_norm->scale_operand_id;
if (scale_operand_id) {
input_ids.push_back(scale_operand_id.value());
}
auto& bias_operand_id = instance_norm->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
output_ids = {instance_norm->output_operand_id};
break;
}
case Operation::Tag::kLayerNormalization: {
const auto& layer_norm = operation->get_layer_normalization();
input_ids = {layer_norm->input_operand_id};
auto& scale_operand_id = layer_norm->scale_operand_id;
if (scale_operand_id) {
input_ids.push_back(scale_operand_id.value());
}
auto& bias_operand_id = layer_norm->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
output_ids = {layer_norm->output_operand_id};
break;
}
case Operation::Tag::kLeakyRelu: {
const auto& leaky_relu = operation->get_leaky_relu();
input_ids = {leaky_relu->input_operand_id};
output_ids = {leaky_relu->output_operand_id};
break;
}
case Operation::Tag::kLinear: {
const auto& linear = operation->get_linear();
input_ids = {linear->input_operand_id};
output_ids = {linear->output_operand_id};
break;
}
case Operation::Tag::kLstm: {
const auto& lstm = operation->get_lstm();
input_ids = {lstm->input_operand_id, lstm->weight_operand_id,
lstm->recurrent_weight_operand_id};
auto& bias_operand_id = lstm->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
auto& recurrent_bias_operand_id = lstm->recurrent_bias_operand_id;
if (recurrent_bias_operand_id) {
input_ids.push_back(recurrent_bias_operand_id.value());
}
auto& peephole_weight_operand_id = lstm->peephole_weight_operand_id;
if (peephole_weight_operand_id) {
input_ids.push_back(peephole_weight_operand_id.value());
}
auto& initial_hidden_state_operand_id =
lstm->initial_hidden_state_operand_id;
if (initial_hidden_state_operand_id) {
input_ids.push_back(initial_hidden_state_operand_id.value());
}
auto& initial_cell_state_operand_id = lstm->initial_cell_state_operand_id;
if (initial_cell_state_operand_id) {
input_ids.push_back(initial_cell_state_operand_id.value());
}
output_ids = {lstm->output_operand_ids};
break;
}
case Operation::Tag::kLstmCell: {
const auto& lstm_cell = operation->get_lstm_cell();
input_ids = {lstm_cell->input_operand_id, lstm_cell->weight_operand_id,
lstm_cell->recurrent_weight_operand_id,
lstm_cell->hidden_state_operand_id,
lstm_cell->cell_state_operand_id};
auto& bias_operand_id = lstm_cell->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
auto& recurrent_bias_operand_id = lstm_cell->recurrent_bias_operand_id;
if (recurrent_bias_operand_id) {
input_ids.push_back(recurrent_bias_operand_id.value());
}
auto& peephole_weight_operand_id = lstm_cell->peephole_weight_operand_id;
if (peephole_weight_operand_id) {
input_ids.push_back(peephole_weight_operand_id.value());
}
output_ids = {lstm_cell->output_operand_ids};
break;
}
case Operation::Tag::kMatmul: {
const auto& matmul = operation->get_matmul();
input_ids = {matmul->a_operand_id, matmul->b_operand_id};
output_ids = {matmul->output_operand_id};
break;
}
case Operation::Tag::kPad: {
const auto& pad = operation->get_pad();
input_ids = {pad->input_operand_id};
output_ids = {pad->output_operand_id};
break;
}
case Operation::Tag::kPool2d: {
const auto& pool2d = operation->get_pool2d();
input_ids = {pool2d->input_operand_id};
output_ids = {pool2d->output_operand_id};
break;
}
case Operation::Tag::kPrelu: {
const auto& prelu = operation->get_prelu();
input_ids = {prelu->input_operand_id, prelu->slope_operand_id};
output_ids = {prelu->output_operand_id};
break;
}
case Operation::Tag::kReduce: {
const auto& reduce = operation->get_reduce();
input_ids = {reduce->input_operand_id};
output_ids = {reduce->output_operand_id};
break;
}
case Operation::Tag::kRelu: {
const auto& relu = operation->get_relu();
input_ids = {relu->input_operand_id};
output_ids = {relu->output_operand_id};
break;
}
case Operation::Tag::kResample2d: {
const auto& resample2d = operation->get_resample2d();
input_ids = {resample2d->input_operand_id};
output_ids = {resample2d->output_operand_id};
break;
}
case Operation::Tag::kReshape: {
const auto& reshape = operation->get_reshape();
input_ids = {reshape->input_operand_id};
output_ids = {reshape->output_operand_id};
break;
}
case Operation::Tag::kSigmoid: {
const auto& sigmoid = operation->get_sigmoid();
input_ids = {sigmoid->input_operand_id};
output_ids = {sigmoid->output_operand_id};
break;
}
case Operation::Tag::kSlice: {
const auto& slice = operation->get_slice();
input_ids = {slice->input_operand_id};
output_ids = {slice->output_operand_id};
break;
}
case Operation::Tag::kSoftmax: {
const auto& softmax = operation->get_softmax();
input_ids = {softmax->input_operand_id};
output_ids = {softmax->output_operand_id};
break;
}
case Operation::Tag::kSoftplus: {
const auto& softplus = operation->get_softplus();
input_ids = {softplus->input_operand_id};
output_ids = {softplus->output_operand_id};
break;
}
case Operation::Tag::kSoftsign: {
const auto& softsign = operation->get_softsign();
input_ids = {softsign->input_operand_id};
output_ids = {softsign->output_operand_id};
break;
}
case Operation::Tag::kSplit: {
const auto& split = operation->get_split();
input_ids = {split->input_operand_id};
output_ids = {split->output_operand_ids};
break;
}
case Operation::Tag::kTanh: {
const auto& tanh = operation->get_tanh();
input_ids = {tanh->input_operand_id};
output_ids = {tanh->output_operand_id};
break;
}
case Operation::Tag::kTranspose: {
const auto& transpose = operation->get_transpose();
input_ids = {transpose->input_operand_id};
output_ids = {transpose->output_operand_id};
break;
}
case Operation::Tag::kTriangular: {
const auto& triangular = operation->get_triangular();
input_ids = {triangular->input_operand_id};
output_ids = {triangular->output_operand_id};
break;
}
case Operation::Tag::kWhere: {
const auto& where = operation->get_where();
input_ids = {where->condition_operand_id, where->true_value_operand_id,
where->false_value_operand_id};
output_ids = {where->output_operand_id};
break;
}
}
}
// The struct contains the information of graph fusion. In `CreateAndBuild`
// method, when going through all operations to add each operation into the
// final graph, this struct will be used for graph fusion.
struct GraphFusionInfo {
// A map of all standalone activations in `mojom::GraphInfo` which can be
// fused into preceding operations.
// The key is the preceding operation which can support fusion. The value is
// the standalone activation which can be fused into the preceding operation.
std::map<const Operation*, const Operation*>
operation_to_fusible_standalone_activation_map;
// A map of all transposes that can be fused into the following matmul using
// transpose's output operand id as the key.
std::map<uint64_t, const Operation*> output_id_to_fusible_transpose_map;
// A set of all operations in `mojom::GraphInfo` which can be fused into
// another operation. No DirectML operator node will be created for operations
// in this set.
std::unordered_set<const Operation*> fusible_operations_set;
};
// The method gets the graph fusion information from `mojom::GraphInfo`, based
// on that the `operations` in `mojom::GraphInfo` have been in topological
// order which means if operation 'j' depends on 'i', 'i' must appear before
// 'j'.
// TODO(issues.chromium.org/41494177): Validate the topological order of
// operations in `mojom::GraphInfo` on services side.
GraphFusionInfo GetGraphFusionInfo(const mojom::GraphInfoPtr& graph_info) {
// If it's disabled, just return empty 'GraphFusionInfo' object which means no
// graph fusion will be applied.
if (!base::FeatureList::IsEnabled(kApplyGraphFusion)) {
return GraphFusionInfo();
}
// A map of all fusible activations in `mojom::GraphInfo` using activation's
// input operand id as the key.
std::map<uint64_t, const Operation*> input_id_to_activation_map;
// The case we're interested in includes a fusible base operation with exactly
// one output edge, followed by a fusible activation operation:
//
// [input]
// |
// conv2d (fusible base operation)
// |
// relu (fusible activation operation)
// |
// [output]
//
// If the base operation has more than one output edge, because the outputs go
// to any other operation or a graph output, then no fusion occurs. For
// example, if `relu` was fused into `conv2d`, `elu` would lose the input, so
// conv2d should be skipped, and similarly for graph `output2`:
//
// [input]
// |
// conv2d (unfusible base operation)
// / \
// relu elu
// | |
// [output1][output2]
//
// [input]
// |
// conv2d (unfusible base operation)
// / \
// relu \
// | \
// [output1] [output2]
//
// If the base operation is not followed by a fusible activation, skip
// it:
//
// [input]
// |
// conv2d (unfusible base operation)
// |
// pool2d
// |
// [output]
//
// A map of all matmul operations in `mojom::GraphInfo` using matmul's input
// operand id as the key.
std::map<uint64_t, const Operation*> input_id_to_matmul_map;
// This is a scenario where transpose can be fused into the following matmul.
// The transpose output solely feeds matmul. The transposed input can be
// either on the input a, input b or both. The transpose should only swap
// the last two axes (the row and column of the inner matrix), so it can be
// fused into `TransA()` or `TransB()` of the following calculation that
// DirectML `GEMM` operator performs:
//
// Output = FusedActivation(Alpha * TransA(A) x TransB(B) + Beta * C)
//
// See more details at:
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gemm_operator_desc
//
// [input a] [input b]
// | /
// transpose /
// \ /
// \ /
// matmul
//
// TODO(crbug.com/340729469): Remove the complex operator fusions when the
// underlying DirectML runtime can handle.
GraphFusionInfo graph_fusion_info;
// Based on that all the operand ids are contiguous, it's used to record how
// many times each operand id is used as an output edge from one operation.
// Notice that the operand id from renderer is increased from 1, so reserve
// `operand count + 1` size for the vector.
std::vector<uint32_t> node_output_edge_counts(
graph_info->id_to_operand_map.size() + 1, 0);
for (uint64_t graph_output_id : graph_info->output_operands) {
++node_output_edge_counts.at(graph_output_id);
}
// Iterate from the end of operations instead from the beginning, so we
// can easily get the total output edges count of a fusible base operation
// before visiting it.
OperationConnectivity operation_connectivity;
for (size_t operation_index = graph_info->operations.size();
operation_index-- > 0;) {
const auto& operation = graph_info->operations[operation_index];
RetrieveOperationConnectivity(
operation.get(),
/*out_operation_connectivity*/ operation_connectivity);
for (uint64_t input_id : operation_connectivity.input_ids) {
++node_output_edge_counts.at(input_id);
}
// Try to find standalone activations that can be fused into preceding
// operations.
if (GetFusibleActivationOutputId(*operation)) {
// We found a standalone activation operation that may need to be fused
// with a predecessor. So record its input edge to later check
// against any fusible base operation's corresponding output edge.
CHECK_EQ(operation_connectivity.input_ids.size(), 1U);
// We needn't check the result of `try_emplace` here, because if the key
// `output_id` is already in container, there must be more than 1 output
// edges from a predecessor in which case the fusion must be skipped.
input_id_to_activation_map.try_emplace(
operation_connectivity.input_ids[0], operation.get());
} else if (CanFuseStandaloneActivation(operation.get(),
graph_info->id_to_operand_map)) {
CHECK_EQ(operation_connectivity.output_ids.size(), 1U);
uint64_t output_id = operation_connectivity.output_ids[0];
// Add this operation to the fusion info if there's exactly one output
// edge to a fusible standalone activation.
const auto activation_iterator =
input_id_to_activation_map.find(output_id);
if (node_output_edge_counts[output_id] == 1 &&
activation_iterator != input_id_to_activation_map.end()) {
const auto* activation = activation_iterator->second;
graph_fusion_info.fusible_operations_set.insert(activation);
graph_fusion_info
.operation_to_fusible_standalone_activation_map[operation.get()] =
activation;
}
}
// Try to find transposes that can be fused into following matmul
// operations.
switch (operation->which()) {
case Operation::Tag::kMatmul: {
// Map matmul's inputs to operation, so the following algorithm can find
// a transpose whose output is consumed by a matmul.
CHECK_EQ(operation_connectivity.input_ids.size(), 2U);
// We needn't check the result of `try_emplace` here, because if the key
// `input_id` is already in container, there must be more than 1 output
// edges from a predecessor in which case the transpose fusion won't
// happen.
input_id_to_matmul_map.try_emplace(operation_connectivity.input_ids[0],
operation.get());
input_id_to_matmul_map.try_emplace(operation_connectivity.input_ids[1],
operation.get());
break;
}
case Operation::Tag::kTranspose: {
// If a transpose's output is solely used by a matmul and it only swaps
// the last two axes, it can be fused into DirectML GEMM operator by
// setting corresponding input tensor transformation attribute.
CHECK_EQ(operation_connectivity.output_ids.size(), 1U);
uint64_t output_id = operation_connectivity.output_ids[0];
if (!input_id_to_matmul_map.contains(output_id) ||
node_output_edge_counts[output_id] != 1) {
break;
}
const mojom::TransposePtr& transpose = operation->get_transpose();
const mojom::OperandPtr& input_operand =
graph_info->id_to_operand_map.at(transpose->input_operand_id);
uint32_t input_rank = input_operand->descriptor.shape().size();
if (input_rank < 2) {
break;
}
std::vector<uint32_t> swap_last_two_axes(input_rank);
std::iota(swap_last_two_axes.begin(), swap_last_two_axes.end(), 0);
std::swap(swap_last_two_axes[input_rank - 2],
swap_last_two_axes[input_rank - 1]);
if (swap_last_two_axes == transpose->permutation) {
graph_fusion_info.fusible_operations_set.insert(operation.get());
graph_fusion_info.output_id_to_fusible_transpose_map[output_id] =
operation.get();
}
break;
}
default: {
// Skip other operations.
break;
}
}
}
CHECK_EQ(
graph_fusion_info.operation_to_fusible_standalone_activation_map.size() +
graph_fusion_info.output_id_to_fusible_transpose_map.size(),
graph_fusion_info.fusible_operations_set.size());
return graph_fusion_info;
}
void CreateOperatorNodeForBatchNormalization(
const Operation* operation,
const std::map<const Operation*, const Operation*>&
operation_to_fusible_standalone_activation_map,
mojom::GraphInfoPtr& graph_info,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map,
std::unordered_map<uint64_t, uint32_t>& constant_id_to_input_index_map,
uint64_t& next_operand_id) {
const auto& batch_normalization = operation->get_batch_normalization();
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, batch_normalization->input_operand_id);
const TensorDesc& input_tensor_desc = input->GetTensorDesc();
const auto input_rank = input_tensor_desc.GetDimensions().size();
auto& id_to_operand_map = graph_info->id_to_operand_map;
uint64_t output_id = batch_normalization->output_operand_id;
const OperandPtr& output_operand = id_to_operand_map.at(output_id);
OperandDataType data_type = output_operand->descriptor.data_type();
const TensorDesc output_tensor_desc(GetTensorDataType(data_type),
output_operand->descriptor.shape());
const NodeOutput* mean = GetNodeOutputForOperand(
id_to_node_output_map, batch_normalization->mean_operand_id);
auto mean_tensor_desc = mean->GetTensorDesc();
auto mean_rank = mean_tensor_desc.GetDimensions().size();
CHECK_EQ(mean_rank, 1U);
auto axis = batch_normalization->axis;
uint32_t axes[1] = {axis};
// In WebNN spec, mean operand is specified as a 1-D tensor and its size equal
// to the size of the input dimension denoted by axis. But for DML,
// InputTensor and MeanTensor must have the same DimensionCount -
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc.
mean_tensor_desc.MakeBroadcastCompatible(input_rank, axes);
const NodeOutput* variance = GetNodeOutputForOperand(
id_to_node_output_map, batch_normalization->variance_operand_id);
auto variance_tensor_desc = variance->GetTensorDesc();
auto variance_rank = variance_tensor_desc.GetDimensions().size();
CHECK_EQ(variance_rank, 1U);
// In WebNN spec, variance operand is specified as a 1-D tensor and its size
// equal to the size of the input dimension denoted by axis. But for DML,
// InputTensor and VarianceTensor must have the same DimensionCount -
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc.
variance_tensor_desc.MakeBroadcastCompatible(input_rank, axes);
uint64_t scale_operand_id;
if (batch_normalization->scale_operand_id.has_value()) {
scale_operand_id = batch_normalization->scale_operand_id.value();
} else {
// If the scale is not present, create a constant operand for scale and
// insert the operand into the graph.
scale_operand_id = BuildConstantOperandForFloatValue(
graph_info, next_operand_id, data_type,
/*rank*/ 1, /*default scale*/ 1.0);
// Create an input node for the scale operand and store the assigned input
// index in `constant_id_to_input_index_map`, which will be used for
// constant buffer binding.
uint32_t scale_input_index =
CreateInputNode(id_to_operand_map, scale_operand_id, graph_builder,
id_to_node_output_map);
CHECK(constant_id_to_input_index_map
.try_emplace(scale_operand_id, scale_input_index)
.second);
}
const NodeOutput* scale =
GetNodeOutputForOperand(id_to_node_output_map, scale_operand_id);
auto scale_tensor_desc = scale->GetTensorDesc();
auto scale_rank = scale_tensor_desc.GetDimensions().size();
CHECK_EQ(scale_rank, 1U);
// In WebNN spec, scale operand is specified as a 1-D tensor and its size
// equal to the size of the input dimension denoted by axis. But for DML,
// InputTensor and ScaleTensor must have the same DimensionCount -
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc.
scale_tensor_desc.MakeBroadcastCompatible(input_rank, axes);
uint64_t bias_operand_id;
if (batch_normalization->bias_operand_id.has_value()) {
bias_operand_id = batch_normalization->bias_operand_id.value();
} else {
// If the bias is not present, create a constant operand for bias and insert
// the operand into the graph.
bias_operand_id = BuildConstantOperandForFloatValue(
graph_info, next_operand_id, data_type,
/*rank*/ 1, /*default bias*/ 0);
// Create an input node for the bias operand and store the assigned input
// index in `constant_id_to_input_index_map`, which will be used for
// constant buffer binding.
uint32_t bias_input_index =
CreateInputNode(id_to_operand_map, bias_operand_id, graph_builder,
id_to_node_output_map);
CHECK(constant_id_to_input_index_map
.try_emplace(bias_operand_id, bias_input_index)
.second);
}
const NodeOutput* bias =
GetNodeOutputForOperand(id_to_node_output_map, bias_operand_id);
auto bias_tensor_desc = bias->GetTensorDesc();
auto bias_rank = bias_tensor_desc.GetDimensions().size();
CHECK_EQ(bias_rank, 1U);
// In WebNN spec, bias operand is specified as a 1-D tensor and its size
// equal to the size of the input dimension denoted by axis. But for DML,
// InputTensor and BiasTensor must have the same DimensionCount -
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc.
bias_tensor_desc.MakeBroadcastCompatible(input_rank, axes);
std::array<const NodeOutput*, 5> inputs = {input, mean, variance, scale,
bias};
std::optional<const Operation*> fusible_activation =
GetFusibleActivationFromOperation(
operation_to_fusible_standalone_activation_map, operation);
std::optional<ActivationOperatorDesc> activation_operator_desc;
std::optional<DML_OPERATOR_DESC> activation_dml_desc;
if (fusible_activation) {
activation_operator_desc =
CreateOperatorDescForFusibleActivation(*fusible_activation.value());
output_id =
GetFusibleActivationOutputId(*fusible_activation.value()).value();
activation_dml_desc = activation_operator_desc->GetActivationDmlDesc();
}
DML_BATCH_NORMALIZATION_OPERATOR_DESC batch_normalization_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.MeanTensor = &mean_tensor_desc.GetDMLTensorDesc(),
.VarianceTensor = &variance_tensor_desc.GetDMLTensorDesc(),
.ScaleTensor = &scale_tensor_desc.GetDMLTensorDesc(),
.BiasTensor = &bias_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
// Spatial is used to specify whether locations are spatial.
// This parameter was deprecated in DML_FEATURE_LEVEL_4_0, and has no
// effect.
.Spatial = true,
.Epsilon = batch_normalization->epsilon,
.FusedActivation =
activation_dml_desc ? &activation_dml_desc.value() : nullptr,
};
const std::string& label = batch_normalization->label;
const OperatorNode* batch_normalization_node =
graph_builder.CreateOperatorNode(DML_OPERATOR_BATCH_NORMALIZATION,
&batch_normalization_operator_desc,
inputs, label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
batch_normalization_node, std::move(output_tensor_desc), 0);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForClamp(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::ClampPtr& clamp,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, clamp->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.clamp_input.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
uint64_t output_id = clamp->output_operand_id;
auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
DML_ELEMENT_WISE_CLIP_OPERATOR_DESC clamp_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
// No scale or bias applies to the input.
.ScaleBias = nullptr,
.Min = clamp->min_value,
.Max = clamp->max_value};
std::array<const NodeOutput*, 1> inputs = {input};
const OperatorNode* clamp_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_CLIP, &clamp_operator_desc, inputs,
clamp->label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
clamp_node, std::move(output_tensor_desc), 0);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForConcat(const IdToOperandMap& id_to_operand_map,
const mojom::ConcatPtr& concat,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const auto& input_operand_ids = concat->input_operand_ids;
size_t input_num = input_operand_ids.size();
std::vector<const NodeOutput*> inputs;
std::vector<DML_TENSOR_DESC> input_dml_tensor_descs;
inputs.reserve(input_num);
input_dml_tensor_descs.reserve(input_num);
for (const auto& input_operand_id : input_operand_ids) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, input_operand_id);
inputs.push_back(input);
input_dml_tensor_descs.push_back(input->GetTensorDesc().GetDMLTensorDesc());
}
uint64_t output_id = concat->output_operand_id;
auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
DML_JOIN_OPERATOR_DESC concat_operator_desc{
.InputCount = base::checked_cast<uint32_t>(input_dml_tensor_descs.size()),
.InputTensors = input_dml_tensor_descs.data(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Axis = concat->axis};
const OperatorNode* concat_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_JOIN, &concat_operator_desc, inputs, concat->label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
concat_node, std::move(output_tensor_desc), 0);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForConv2d(
const IdToOperandMap& id_to_operand_map,
const Operation* operation,
const std::map<const Operation*, const Operation*>&
operation_to_fusible_standalone_activation_map,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const auto& conv2d = operation->get_conv2d();
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, conv2d->input_operand_id);
// The input tensor description may be transposed.
auto input_tensor_desc = input->GetTensorDesc();
CHECK_EQ(input_tensor_desc.GetDimensions().size(), 4u);
CHECK(kDmlFloatDataTypes.contains(input_tensor_desc.GetDataType()));
const NodeOutput* filter =
GetNodeOutputForOperand(id_to_node_output_map, conv2d->filter_operand_id);
auto filter_tensor_desc = filter->GetTensorDesc();
uint64_t output_id = conv2d->output_operand_id;
// The output tensor description may be transposed.
auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
CHECK_EQ(output_tensor_desc.GetDimensions().size(), 4u);
std::vector<const NodeOutput*> inputs = {input, filter};
std::optional<TensorDesc> reshaped_bias_tensor_desc;
auto& bias_operand_id = conv2d->bias_operand_id;
if (bias_operand_id) {
const auto bias_node_output_iterator =
id_to_node_output_map.find(bias_operand_id.value());
CHECK(bias_node_output_iterator != id_to_node_output_map.end());
const NodeOutput* bias_node_output = bias_node_output_iterator->second;
CHECK(bias_node_output);
const auto& bias_tensor_desc = bias_node_output->GetTensorDesc();
const auto& bias_dims = bias_tensor_desc.GetDimensions();
CHECK_EQ(bias_dims.size(), 1u);
// In WebNN spec bias specifies the additional 1-D tensor with the shape of
// {outputChannels}. But for DML the expected dimensions of the BiasTensor
// are { 1, OutputChannelCount, 1, 1 } for 4D. So reshape the bias:
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_convolution_operator_desc
std::vector<uint32_t> reshaped_bias_dims = {1, bias_dims[0], 1, 1};
reshaped_bias_tensor_desc =
TensorDesc(bias_tensor_desc.GetDataType(), bias_tensor_desc.GetFlags(),
std::move(reshaped_bias_dims));
const NodeOutput* reshaped_bias_node_output =
graph_builder.CreateNodeOutput(&bias_node_output->GetNode(),
reshaped_bias_tensor_desc.value());
inputs.push_back(reshaped_bias_node_output);
}
std::array<uint32_t, 2> strides = {conv2d->strides->height,
conv2d->strides->width};
std::array<uint32_t, 2> dilations = {conv2d->dilations->height,
conv2d->dilations->width};
std::array<uint32_t, 2> start_padding = {conv2d->padding->beginning->height,
conv2d->padding->beginning->width};
std::array<uint32_t, 2> end_padding = {conv2d->padding->ending->height,
conv2d->padding->ending->width};
// The outputSizes of WebNN convTranspose2d specifies the sizes of the last
// two dimensions of the output tensor but the outputPadding of DirectML
// convolution applies a zero padding to the result of the operator. Since
// graph builder will explicitly pass in the output tensor shape anyway. So,
// there is no ambiguity of the output shape and we set the output_padding to
// {0, 0}:
// https://www.w3.org/TR/webnn/#dom-mlconvtranspose2doptions-outputpadding
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_convolution_operator_desc
std::array<uint32_t, 2> default_out_padding = {0, 0};
std::optional<const Operation*> fusible_activation =
GetFusibleActivationFromOperation(
operation_to_fusible_standalone_activation_map, operation);
std::optional<ActivationOperatorDesc> activation_operator_desc;
std::optional<DML_OPERATOR_DESC> activation_dml_desc;
if (fusible_activation) {
activation_operator_desc =
CreateOperatorDescForFusibleActivation(*fusible_activation.value());
output_id =
GetFusibleActivationOutputId(*fusible_activation.value()).value();
activation_dml_desc = activation_operator_desc->GetActivationDmlDesc();
}
DML_CONVOLUTION_DIRECTION conv2d_direction;
switch (conv2d->kind) {
case mojom::Conv2d::Kind::kDirect:
conv2d_direction =
DML_CONVOLUTION_DIRECTION::DML_CONVOLUTION_DIRECTION_FORWARD;
break;
case mojom::Conv2d::Kind::kTransposed:
conv2d_direction =
DML_CONVOLUTION_DIRECTION::DML_CONVOLUTION_DIRECTION_BACKWARD;
break;
}
DML_CONVOLUTION_OPERATOR_DESC conv2d_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.FilterTensor = &filter_tensor_desc.GetDMLTensorDesc(),
.BiasTensor = GetOptionalDmlTensorDescPtr(reshaped_bias_tensor_desc),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Mode = DML_CONVOLUTION_MODE_CROSS_CORRELATION,
.Direction = conv2d_direction,
.DimensionCount =
2u, /*Determines the size of the Strides, Dilations, StartPadding,
EndPadding, and OutputPadding arrays.*/
.Strides = strides.data(),
.Dilations = dilations.data(),
.StartPadding = start_padding.data(),
.EndPadding = end_padding.data(),
.OutputPadding = default_out_padding.data(),
.GroupCount = conv2d->groups,
.FusedActivation =
activation_dml_desc ? &activation_dml_desc.value() : nullptr,
};
const std::string& label = conv2d->label;
const OperatorNode* conv2d_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_CONVOLUTION, &conv2d_operator_desc, inputs, label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
conv2d_node, std::move(output_tensor_desc), 0);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
template <typename DML_OPERATOR_DESC>
const OperatorNode* CreateBinaryOperator(const TensorDesc& a_tensor,
const TensorDesc& b_tensor,
const TensorDesc& output_tensor,
GraphBuilderDml& graph_builder,
DML_OPERATOR_TYPE operator_type,
base::span<const NodeOutput*> inputs,
std::string_view label) {
DML_OPERATOR_DESC binary_operator_desc{
.ATensor = &a_tensor.GetDMLTensorDesc(),
.BTensor = &b_tensor.GetDMLTensorDesc(),
.OutputTensor = &output_tensor.GetDMLTensorDesc()};
return graph_builder.CreateOperatorNode(operator_type, &binary_operator_desc,
inputs, label);
}
void CreateOperatorNodeForBinary(
const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const Operation* operation,
const std::map<const Operation*, const Operation*>&
operation_to_fusible_standalone_activation_map,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const auto& binary = operation->get_element_wise_binary();
// The input a and b tensor descriptions may be broadcasted.
const NodeOutput* input_a =
GetNodeOutputForOperand(id_to_node_output_map, binary->lhs_operand_id);
auto input_a_tensor_desc = input_a->GetTensorDesc();
const NodeOutput* input_b =
GetNodeOutputForOperand(id_to_node_output_map, binary->rhs_operand_id);
auto input_b_tensor_desc = input_b->GetTensorDesc();
uint64_t output_id = binary->output_operand_id;
const auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
auto output_dimensions = output_tensor_desc.GetDimensions();
if (input_a_tensor_desc.GetDimensions() != output_dimensions) {
input_a_tensor_desc.BroadcastTo(output_dimensions);
}
if (input_b_tensor_desc.GetDimensions() != output_dimensions) {
input_b_tensor_desc.BroadcastTo(output_dimensions);
}
CHECK_EQ(input_a_tensor_desc.GetDataType(),
input_b_tensor_desc.GetDataType());
const OperandDataType input_data_type =
DmlDataTypeToOperand(input_a_tensor_desc.GetDataType());
const std::string& label = binary->label;
const OperatorNode* binary_node = nullptr;
std::array<const NodeOutput*, 2> inputs = {input_a, input_b};
switch (binary->kind) {
case mojom::ElementWiseBinary::Kind::kAdd: {
CHECK(context_properties.data_type_limits.add_input.Has(input_data_type));
std::optional<const Operation*> fusible_activation =
GetFusibleActivationFromOperation(
operation_to_fusible_standalone_activation_map, operation);
if (fusible_activation) {
ActivationOperatorDesc activation_operator_desc =
CreateOperatorDescForFusibleActivation(*fusible_activation.value());
DML_OPERATOR_DESC activation_dml_desc =
activation_operator_desc.GetActivationDmlDesc();
DML_ELEMENT_WISE_ADD1_OPERATOR_DESC add1_operator_desc{
.ATensor = &input_a_tensor_desc.GetDMLTensorDesc(),
.BTensor = &input_b_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.FusedActivation = &activation_dml_desc,
};
binary_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_ADD1, &add1_operator_desc, inputs, label);
output_id =
GetFusibleActivationOutputId(*fusible_activation.value()).value();
}
// If no standalone activation need to be fused, prefer
// `DML_OPERATOR_ELEMENT_WISE_ADD` which supports more data types than
// `DML_OPERATOR_ELEMENT_WISE_ADD1`.
else {
binary_node = CreateBinaryOperator<DML_ELEMENT_WISE_ADD_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_ADD, inputs, label);
}
break;
}
case mojom::ElementWiseBinary::Kind::kDiv: {
CHECK(context_properties.data_type_limits.div_input.Has(input_data_type));
binary_node = CreateBinaryOperator<DML_ELEMENT_WISE_DIVIDE_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_DIVIDE, inputs, label);
break;
}
case mojom::ElementWiseBinary::Kind::kMax: {
CHECK(context_properties.data_type_limits.max_input.Has(input_data_type));
binary_node = CreateBinaryOperator<DML_ELEMENT_WISE_MAX_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_MAX, inputs, label);
break;
}
case mojom::ElementWiseBinary::Kind::kMin: {
CHECK(context_properties.data_type_limits.min_input.Has(input_data_type));
binary_node = CreateBinaryOperator<DML_ELEMENT_WISE_MIN_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_MIN, inputs, label);
break;
}
case mojom::ElementWiseBinary::Kind::kMul: {
CHECK(context_properties.data_type_limits.mul_input.Has(input_data_type));
binary_node =
CreateBinaryOperator<DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_MULTIPLY, inputs, label);
break;
}
case mojom::ElementWiseBinary::Kind::kSub: {
CHECK(context_properties.data_type_limits.sub_input.Has(input_data_type));
binary_node =
CreateBinaryOperator<DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_SUBTRACT, inputs, label);
break;
}
case mojom::ElementWiseBinary::Kind::kPow: {
CHECK(context_properties.data_type_limits.pow_input.Has(input_data_type));
DML_ELEMENT_WISE_POW_OPERATOR_DESC element_wise_operator_desc{
.InputTensor = &input_a_tensor_desc.GetDMLTensorDesc(),
.ExponentTensor = &input_b_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc()};
binary_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_POW, &element_wise_operator_desc, inputs,
label);
break;
}
case mojom::ElementWiseBinary::Kind::kEqual: {
CHECK(
context_properties.data_type_limits.equal_input.Has(input_data_type));
binary_node =
CreateBinaryOperator<DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS, inputs,
label);
break;
}
case mojom::ElementWiseBinary::Kind::kGreater: {
CHECK(context_properties.data_type_limits.greater_input.Has(
input_data_type));
binary_node = CreateBinaryOperator<
DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN, inputs,
label);
break;
}
case mojom::ElementWiseBinary::Kind::kGreaterOrEqual: {
CHECK(context_properties.data_type_limits.greater_or_equal_input.Has(
input_data_type));
binary_node = CreateBinaryOperator<
DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder,
DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL, inputs,
label);
break;
}
case mojom::ElementWiseBinary::Kind::kLesser: {
CHECK(context_properties.data_type_limits.lesser_input.Has(
input_data_type));
binary_node = CreateBinaryOperator<
DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN, inputs,
label);
break;
}
case mojom::ElementWiseBinary::Kind::kLesserOrEqual: {
CHECK(context_properties.data_type_limits.lesser_or_equal_input.Has(
input_data_type));
binary_node = CreateBinaryOperator<
DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL,
inputs, label);
break;
}
}
const NodeOutput* output = graph_builder.CreateNodeOutput(
binary_node, std::move(output_tensor_desc), 0);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForPad(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::PadPtr& pad,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, pad->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.pad_input.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
uint64_t output_id = pad->output_operand_id;
const auto& output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
DML_PADDING_MODE padding_mode;
// This value is ignored for other padding modes.
float padding_value = 0;
switch (pad->mode->which()) {
case mojom::PaddingMode::Tag::kConstant:
padding_mode = DML_PADDING_MODE::DML_PADDING_MODE_CONSTANT;
padding_value = pad->mode->get_constant()->value;
break;
case mojom::PaddingMode::Tag::kEdge:
padding_mode = DML_PADDING_MODE::DML_PADDING_MODE_EDGE;
break;
case mojom::PaddingMode::Tag::kReflection:
padding_mode = DML_PADDING_MODE::DML_PADDING_MODE_REFLECTION;
break;
case mojom::PaddingMode::Tag::kSymmetric:
padding_mode = DML_PADDING_MODE::DML_PADDING_MODE_SYMMETRIC;
break;
}
const auto& beginning_padding = pad->beginning_padding;
const auto& ending_padding = pad->ending_padding;
CHECK_EQ(beginning_padding.size(), ending_padding.size());
DML_PADDING_OPERATOR_DESC pad_operator_desc = {
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.PaddingMode = padding_mode,
.PaddingValue = padding_value,
.DimensionCount = static_cast<uint32_t>(beginning_padding.size()),
.StartPadding = beginning_padding.data(),
.EndPadding = ending_padding.data()};
std::array<const NodeOutput*, 1> inputs = {input};
const OperatorNode* pad_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_PADDING, &pad_operator_desc, {inputs}, pad->label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(pad_node, std::move(output_tensor_desc));
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForPool2d(
const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Pool2dPtr& pool2d,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, pool2d->input_operand_id);
// The input tensor description may be transposed.
auto input_tensor_desc = input->GetTensorDesc();
uint64_t output_id = pool2d->output_operand_id;
// The output tensor description may be transposed.
auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
std::array<uint32_t, 2> strides = {pool2d->strides->height,
pool2d->strides->width};
std::array<uint32_t, 2> dilations = {pool2d->dilations->height,
pool2d->dilations->width};
std::array<uint32_t, 2> window_dimensions = {
pool2d->window_dimensions->height, pool2d->window_dimensions->width};
std::array<uint32_t, 2> start_padding = {pool2d->padding->beginning->height,
pool2d->padding->beginning->width};
std::array<uint32_t, 2> end_padding = {pool2d->padding->ending->height,
pool2d->padding->ending->width};
std::array<const NodeOutput*, 1> inputs = {input};
const OperatorNode* pool2d_node = nullptr;
const std::string& label = pool2d->label;
switch (pool2d->kind) {
case mojom::Pool2d::Kind::kAveragePool2d: {
CHECK(context_properties.data_type_limits.average_pool2d_input.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
// TODO(crbug.com/40206287): Work around dilation support for L2 and
// average pooling. According to WebNN spec:
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-pool2d, dilations are
// supported by pooling operations, while for DirectML AVERAGE_POOLING and
// LP_POOLING don't support dilations.
// Spec issue tracked on
// https://github.com/webmachinelearning/webnn/issues/180.
if (dilations[0] != 1 || dilations[1] != 1) {
return base::unexpected(CreateError(
mojom::Error::Code::kNotSupportedError,
"Dilations are not supported for average pooling operator.",
label));
}
DML_AVERAGE_POOLING_OPERATOR_DESC average_pooling_desc = {
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.DimensionCount =
base::checked_cast<uint32_t>(window_dimensions.size()),
.Strides = strides.data(),
.WindowSize = window_dimensions.data(),
.StartPadding = start_padding.data(),
.EndPadding = end_padding.data(),
// The padding elements are not counted as part of the averaging
// calculation.
.IncludePadding = false};
pool2d_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_AVERAGE_POOLING, &average_pooling_desc, inputs, label);
break;
}
case mojom::Pool2d::Kind::kL2Pool2d: {
CHECK(context_properties.data_type_limits.l2_pool2d_input.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
DML_LP_POOLING_OPERATOR_DESC l2_pooling_desc = {
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.DimensionCount =
base::checked_cast<uint32_t>(window_dimensions.size()),
.Strides = strides.data(),
.WindowSize = window_dimensions.data(),
.StartPadding = start_padding.data(),
.EndPadding = end_padding.data(),
.P = 2};
pool2d_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_LP_POOLING, &l2_pooling_desc, inputs, label);
break;
}
case mojom::Pool2d::Kind::kMaxPool2d: {
CHECK(context_properties.data_type_limits.max_pool2d_input.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
// If the dilations are { 1, 1 } by default, prefer using
// `DML_MAX_POOLING_OPERATOR_DESC` without dilations supported for best
// compatibility.
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_max_pooling_operator_desc.
// TODO(issues.chromium.org/327244278): Remove the workaround of using
// `DML_MAX_POOLING_OPERATOR_DESC` without dilations.
if (dilations[0] == 1 && dilations[1] == 1) {
DML_MAX_POOLING_OPERATOR_DESC max_pooling_desc = {
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.DimensionCount =
base::checked_cast<uint32_t>(window_dimensions.size()),
.Strides = strides.data(),
.WindowSize = window_dimensions.data(),
.StartPadding = start_padding.data(),
.EndPadding = end_padding.data()};
pool2d_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_MAX_POOLING, &max_pooling_desc, inputs, label);
} else {
DML_MAX_POOLING2_OPERATOR_DESC max_pooling2_desc = {
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.OutputIndicesTensor = nullptr,
.DimensionCount =
base::checked_cast<uint32_t>(window_dimensions.size()),
.Strides = strides.data(),
.WindowSize = window_dimensions.data(),
.StartPadding = start_padding.data(),
.EndPadding = end_padding.data(),
.Dilations = dilations.data()};
pool2d_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_MAX_POOLING2, &max_pooling2_desc, inputs, label);
}
break;
}
default:
LOG(ERROR) << "[WebNN] Invalid Pool2d operator type";
NOTREACHED();
}
const NodeOutput* output = graph_builder.CreateNodeOutput(
pool2d_node, std::move(output_tensor_desc), 0);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
return base::ok();
}
void CreateOperatorNodeForPrelu(const ContextProperties context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::PreluPtr& prelu,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, prelu->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.prelu_input.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
const NodeOutput* slope =
GetNodeOutputForOperand(id_to_node_output_map, prelu->slope_operand_id);
auto slope_tensor_desc = slope->GetTensorDesc();
CHECK_EQ(input_tensor_desc.GetDataType(), slope_tensor_desc.GetDataType());
uint64_t output_id = prelu->output_operand_id;
const auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
const auto& output_dimensions = output_tensor_desc.GetDimensions();
if (slope_tensor_desc.GetDimensions() != output_dimensions) {
slope_tensor_desc.BroadcastTo(output_dimensions);
}
DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC prelu_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.SlopeTensor = &slope_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc()};
const std::string& label = prelu->label;
std::array<const NodeOutput*, 2> inputs = {input, slope};
const OperatorNode* prelu_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU, &prelu_desc, inputs, label);
const NodeOutput* node_output =
graph_builder.CreateNodeOutput(prelu_node, std::move(output_tensor_desc));
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
void CreateOperatorNodeForSlice(const IdToOperandMap& id_to_operand_map,
const mojom::SlicePtr& slice,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, slice->input_operand_id);
const TensorDesc& input_tensor_desc = input->GetTensorDesc();
const auto& input_dimensions = input_tensor_desc.GetDimensions();
// Start and size attributes must be unpacked from the mojo interface.
std::vector<uint32_t> starts;
std::vector<uint32_t> sizes;
starts.reserve(slice->starts_and_sizes.size());
sizes.reserve(slice->starts_and_sizes.size());
for (size_t i = 0; i < slice->starts_and_sizes.size(); ++i) {
starts.push_back(slice->starts_and_sizes[i]->start);
sizes.push_back(slice->starts_and_sizes[i]->size);
}
CHECK_EQ(input_dimensions.size(), slice->starts_and_sizes.size());
const TensorDesc& output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, slice->output_operand_id);
// WebNN doesn't support the strides parameter, but DML expects one. Create
// an appropriately sized array of 1s to produce the expected operation.
std::vector<uint32_t> strides(input_dimensions.size(), 1u);
DML_SLICE_OPERATOR_DESC slice_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.DimensionCount = static_cast<UINT>(input_dimensions.size()),
.Offsets = starts.data(),
.Sizes = sizes.data(),
.Strides = strides.data(),
};
std::array<const NodeOutput*, 1> input_node_output = {input};
const OperatorNode* slice_node =
graph_builder.CreateOperatorNode(DML_OPERATOR_SLICE, &slice_operator_desc,
input_node_output, slice->label);
const auto* slice_output =
graph_builder.CreateNodeOutput(slice_node, std::move(output_tensor_desc));
id_to_node_output_map[slice->output_operand_id] = std::move(slice_output);
}
void CreateOperatorNodeForSplit(const IdToOperandMap& id_to_operand_map,
const mojom::SplitPtr& split,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, split->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
// Since TensorDesc stores dimensions and strides vectors, we need to keep
// TensorDescs until create CreateOperatorNode is called.
std::vector<TensorDesc> output_tensor_desc;
output_tensor_desc.reserve(split->output_operand_ids.size());
std::vector<DML_TENSOR_DESC> output_tensor_desc_dml;
output_tensor_desc_dml.reserve(output_tensor_desc.size());
for (uint64_t output_id : split->output_operand_ids) {
output_tensor_desc.push_back(
CreateOutputTensorDesc(id_to_operand_map, output_id));
output_tensor_desc_dml.push_back(
output_tensor_desc.back().GetDMLTensorDesc());
}
auto output_count =
base::checked_cast<uint32_t>(output_tensor_desc_dml.size());
DML_SPLIT_OPERATOR_DESC split_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputCount = output_count,
.OutputTensors = output_tensor_desc_dml.data(),
.Axis = split->axis};
const std::string& label = split->label;
std::array<const NodeOutput*, 1> inputs = {input};
const OperatorNode* split_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_SPLIT, &split_desc, inputs, label);
for (uint32_t i = 0; i < output_count; ++i) {
uint64_t output_id = split->output_operand_ids[i];
const auto* output = graph_builder.CreateNodeOutput(
split_node, std::move(output_tensor_desc[i]), i);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
}
template <typename DML_OPERATOR_DESC, DML_OPERATOR_TYPE operator_type>
const OperatorNode* CreateUnaryOperator(const TensorDesc& input_tensor,
const TensorDesc& output_tensor,
const NodeOutput* input,
GraphBuilderDml& graph_builder,
std::string_view label = "") {
DML_OPERATOR_DESC unary_operator_desc{
.InputTensor = &input_tensor.GetDMLTensorDesc(),
.OutputTensor = &output_tensor.GetDMLTensorDesc()};
std::array<const NodeOutput*, 1> inputs = {input};
return graph_builder.CreateOperatorNode(operator_type, &unary_operator_desc,
inputs, label);
}
template <typename OperatorDesc,
DML_OPERATOR_TYPE operator_type,
typename Operation>
void CreateOperatorNodeForUnary(const IdToOperandMap& id_to_operand_map,
const Operation& operation,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, operation->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
uint64_t output_id = operation->output_operand_id;
const auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
const OperatorNode* unary_node =
CreateUnaryOperator<OperatorDesc, operator_type>(
input_tensor_desc, output_tensor_desc, input, graph_builder,
operation->label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
unary_node, std::move(output_tensor_desc), 0);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForNeg(const IdToOperandMap& id_to_operand_map,
const mojom::ElementWiseUnaryPtr& operation,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, operation->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
const uint64_t output_id = operation->output_operand_id;
const auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
// Set the values of scale and bias terms supplied to identity operator. Scale
// and bias have the effect of applying the function g(x) = x * Scale + Bias.
// When we set Scale to -1 and Bias to 0, we can simulate identity as negate
// operator.
DML_SCALE_BIAS scale_bias{.Scale = -1.f, .Bias = 0.f};
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC identity_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.ScaleBias = &scale_bias};
std::array<const NodeOutput*, 1> inputs = {input};
const OperatorNode* identity_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_IDENTITY, &identity_operator_desc, inputs,
operation->label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
identity_node, std::move(output_tensor_desc), 0);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForElementWiseUnary(
const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::ElementWiseUnaryPtr& operation,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const OperandDataType input_data_type =
DmlDataTypeToOperand(GetNodeOutputForOperand(id_to_node_output_map,
operation->input_operand_id)
->GetTensorDesc()
.GetDataType());
switch (operation->kind) {
case mojom::ElementWiseUnary::Kind::kAbs: {
CHECK(context_properties.data_type_limits.abs_input.Has(input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_ABS_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_ABS>(
id_to_operand_map, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kCast: {
CHECK(
context_properties.data_type_limits.cast_input.Has(input_data_type));
return CreateOperatorNodeForUnary<DML_CAST_OPERATOR_DESC,
DML_OPERATOR_CAST>(
id_to_operand_map, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kCeil: {
CHECK(
context_properties.data_type_limits.ceil_input.Has(input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_CEIL_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_CEIL>(
id_to_operand_map, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kCos: {
CHECK(context_properties.data_type_limits.cos_input.Has(input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_COS_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_COS>(
id_to_operand_map, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kErf: {
CHECK(context_properties.data_type_limits.erf_input.Has(input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_ERF_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_ERF>(
id_to_operand_map, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kExp: {
CHECK(context_properties.data_type_limits.exp_input.Has(input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_EXP_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_EXP>(
id_to_operand_map, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kFloor: {
CHECK(
context_properties.data_type_limits.floor_input.Has(input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_FLOOR_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_FLOOR>(
id_to_operand_map, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kIdentity: {
CHECK(context_properties.data_type_limits.identity_input.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_IDENTITY>(
id_to_operand_map, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kLog: {
CHECK(context_properties.data_type_limits.log_input.Has(input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_LOG_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_LOG>(
id_to_operand_map, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kLogicalNot: {
CHECK(context_properties.data_type_limits.logical_not_input.Has(
input_data_type));
return CreateOperatorNodeForUnary<
DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT>(
id_to_operand_map, operation, graph_builder, id_to_node_output_map);
}
// TODO(crbug.com/40943114): Implement the negate operator directly by
// DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC which is available in
// DML_FEATURE_LEVEL_5_0.
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_negate_operator_desc#availability
case mojom::ElementWiseUnary::Kind::kNeg: {
CHECK(context_properties.data_type_limits.neg_input.Has(input_data_type));
return CreateOperatorNodeForNeg(id_to_operand_map, operation,
graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kReciprocal: {
CHECK(context_properties.data_type_limits.reciprocal_input.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_RECIP_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_RECIP>(
id_to_operand_map, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kSign: {
CHECK(
context_properties.data_type_limits.sign_input.Has(input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_SIGN_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_SIGN>(
id_to_operand_map, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kSin: {
CHECK(context_properties.data_type_limits.sin_input.Has(input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_SIN_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_SIN>(
id_to_operand_map, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kSqrt: {
CHECK(
context_properties.data_type_limits.sqrt_input.Has(input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_SQRT_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_SQRT>(
id_to_operand_map, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kTan: {
CHECK(context_properties.data_type_limits.tan_input.Has(input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_TAN_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_TAN>(
id_to_operand_map, operation, graph_builder, id_to_node_output_map);
}
}
}
void CreateOperatorNodeForResample2d(
const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Resample2dPtr& resample2d,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, resample2d->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.resample2d_input.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
uint64_t output_id = resample2d->output_operand_id;
const auto& output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
const auto& input_dimensions = input_tensor_desc.GetDimensions();
const auto& output_dimensions = output_tensor_desc.GetDimensions();
size_t input_rank = input_dimensions.size();
CHECK_EQ(input_rank, output_dimensions.size());
// Use explicit scales if given, otherwise, compute scales from output
// dimensions / input dimensions. Then expand scales to full scales (same size
// as input rank using axes).
std::vector<float> full_scales(input_rank, 1);
const auto& scales = resample2d->scales;
const auto& axes = resample2d->axes;
if (scales) {
for (size_t i = 0; i < axes.size(); ++i) {
auto axis = axes[i];
CHECK_LT(axis, full_scales.size());
full_scales[axis] = scales.value()[i];
}
} else {
for (size_t i = 0; i < input_rank; ++i) {
full_scales[i] =
base::checked_cast<float>(output_dimensions[i]) / input_dimensions[i];
}
}
DML_INTERPOLATION_MODE mode;
switch (resample2d->mode) {
case mojom::Resample2d::InterpolationMode::kNearestNeighbor:
mode = DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR;
break;
case mojom::Resample2d::InterpolationMode::kLinear:
mode = DML_INTERPOLATION_MODE_LINEAR;
break;
}
DML_RESAMPLE_OPERATOR_DESC resample2d_operator_desc = {
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.InterpolationMode = mode,
.ScaleCount = static_cast<uint32_t>(full_scales.size()),
.Scales = full_scales.data()};
const std::string& label = resample2d->label;
std::array<const NodeOutput*, 1> inputs = {input};
const OperatorNode* resample2d_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_RESAMPLE, &resample2d_operator_desc, inputs, label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
resample2d_node, std::move(output_tensor_desc), 0);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForReduce(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::ReducePtr& reduce,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, reduce->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CheckInputDataTypeForReduce(
context_properties.data_type_limits, reduce->kind,
DmlDataTypeToOperand(input_tensor_desc.GetDataType()));
uint64_t output_id = reduce->output_operand_id;
const auto& output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
const auto& axes = reduce->axes;
// Determine output sizes. Ignore output_desc->dimensions for the dimensions,
// since DirectML expects the output dimensions to have the same rank as the
// input, and output_desc->dimensions may have removed dimensions if
// keepDimensions was false.
std::vector<uint32_t> output_dimensions = input_tensor_desc.GetDimensions();
for (uint32_t axis : axes) {
CHECK_LT(axis, output_dimensions.size());
output_dimensions[axis] = 1u;
}
TensorDesc new_output_tensor_desc(output_tensor_desc.GetDataType(),
output_dimensions);
std::array<const NodeOutput*, 1> inputs = {input};
DML_REDUCE_OPERATOR_DESC operator_desc = {};
operator_desc.Function = MapReduceKindToReduceFuntion(reduce->kind);
operator_desc.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
operator_desc.OutputTensor = &new_output_tensor_desc.GetDMLTensorDesc(),
operator_desc.AxisCount = static_cast<uint32_t>(axes.size());
operator_desc.Axes = axes.data();
const OperatorNode* reduce_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_REDUCE, &operator_desc, inputs, reduce->label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(reduce_node, output_tensor_desc);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
// Append an identity node to the input node output. Return the node output of
// the identity operator if it's successfully created, otherwise return a
// nullptr.
const NodeOutput* AppendIdentityNode(
GraphBuilderDml& graph_builder,
const NodeOutput* input,
const TensorDesc* input_tensor_desc = nullptr) {
CHECK(input);
if (!input_tensor_desc) {
input_tensor_desc = &input->GetTensorDesc();
}
TensorDesc identity_tensor_desc(input_tensor_desc->GetDataType(),
DML_TENSOR_FLAG_NONE,
input_tensor_desc->GetDimensions());
const OperatorNode* identity =
CreateUnaryOperator<DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_IDENTITY>(
*input_tensor_desc, identity_tensor_desc, input, graph_builder);
return graph_builder.CreateNodeOutput(identity,
std::move(identity_tensor_desc));
}
// Create a reshape node with the given new shape.
const NodeOutput* CreateReshapeNode(GraphBuilderDml& graph_builder,
const NodeOutput* input,
base::span<const uint32_t> new_shape) {
CHECK(input);
const auto& input_tensor_desc = input->GetTensorDesc();
const TensorDesc reshaped_input_tensor_desc(
input_tensor_desc.GetDataType(), input_tensor_desc.GetFlags(),
std::vector<uint32_t>(new_shape.begin(), new_shape.end()));
const NodeOutput* reshape_node =
AppendIdentityNode(graph_builder, input, &reshaped_input_tensor_desc);
return reshape_node;
}
// DirectML API does not have a real Reshape operator. The WebNN Reshape is
// implemented by a DirectML Identity operator. DirectML runtime is able to
// optimize the unnecessary IDENTITY operators when compiling the graph.
void CreateOperatorNodeForReshape(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::ReshapePtr& reshape,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, reshape->input_operand_id);
CHECK(context_properties.data_type_limits.reshape_input.Has(
DmlDataTypeToOperand(input->GetTensorDesc().GetDataType())));
uint64_t output_id = reshape->output_operand_id;
const OperandPtr& output_operand = id_to_operand_map.at(output_id);
base::span<const uint32_t> new_shape = output_operand->descriptor.shape();
const NodeOutput* output = CreateReshapeNode(graph_builder, input, new_shape);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForElu(const IdToOperandMap& id_to_operand_map,
const mojom::EluPtr& elu,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, elu->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
uint64_t output_id = elu->output_operand_id;
const auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
DML_ACTIVATION_ELU_OPERATOR_DESC elu_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Alpha = elu->alpha};
std::array<const NodeOutput*, 1> inputs = {input};
const OperatorNode* elu_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_ELU, &elu_desc, inputs, elu->label);
const NodeOutput* node_output =
graph_builder.CreateNodeOutput(elu_node, std::move(output_tensor_desc));
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
void CreateOperatorNodeForExpand(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::ExpandPtr& expand,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, expand->input_operand_id);
auto input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.expand_input.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
const uint64_t output_id = expand->output_operand_id;
const auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
// Use identity to implement the expand operation with broadcasting strides
// https://learn.microsoft.com/en-us/windows/ai/directml/dml-strides#broadcasting-with-strides.
const auto& output_dimensions = output_tensor_desc.GetDimensions();
if (input_tensor_desc.GetDimensions() != output_dimensions) {
input_tensor_desc.BroadcastTo(output_dimensions);
}
const OperatorNode* identity_node =
CreateUnaryOperator<DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_IDENTITY>(
input_tensor_desc, output_tensor_desc, input, graph_builder,
expand->label);
const NodeOutput* node_output = graph_builder.CreateNodeOutput(
identity_node, std::move(output_tensor_desc));
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForGather(
const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::GatherPtr& gather,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, gather->input_operand_id);
auto input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.gather_input.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
const NodeOutput* indices = GetNodeOutputForOperand(
id_to_node_output_map, gather->indices_operand_id);
auto indices_tensor_desc = indices->GetTensorDesc();
CHECK(context_properties.data_type_limits.gather_indices.Has(
DmlDataTypeToOperand(indices_tensor_desc.GetDataType())));
size_t indices_rank = indices_tensor_desc.GetDimensions().size();
if (!base::MakeCheckedNum(indices_rank).IsValid<uint32_t>()) {
return base::unexpected(
CreateError(mojom::Error::Code::kUnknownError,
"The indices rank of gather operator is too large."));
}
uint64_t output_id = gather->output_operand_id;
const auto original_output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
auto output_tensor_desc = original_output_tensor_desc;
size_t input_rank = input_tensor_desc.GetDimensions().size();
size_t output_rank = output_tensor_desc.GetDimensions().size();
size_t expanded_rank = std::max(input_rank, output_rank);
// According to the DirectML documentation
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gather_operator_desc,
// the parameters `InputTensor`, `OutputTensor` and `IndicesTensor` must have
// the same dimension count.
input_tensor_desc.EnsureMinimumRank(expanded_rank,
TensorDesc::Alignment::kTrailing);
indices_tensor_desc.EnsureMinimumRank(expanded_rank,
TensorDesc::Alignment::kTrailing);
uint32_t axis = gather->axis;
if (output_rank < input_rank) {
// There is only one case in which `output_rank` is less than `input_rank`,
// that is when indices is scalar. In this case, a one value should be
// inserted at the `axis` position of the output dimensions, because the
// indices dimensions is set to {1} since DirectML requires the tensor
// dimension count to be at least 1.
CHECK_EQ(indices_rank, 1u);
CHECK_EQ(output_rank, input_rank - 1);
auto output_dimensions = input_tensor_desc.GetDimensions();
CHECK_LT(axis, output_dimensions.size());
output_dimensions[axis] = 1;
output_tensor_desc = TensorDesc(output_tensor_desc.GetDataType(),
std::move(output_dimensions));
}
auto expanded_axis = base::MakeCheckedNum(expanded_rank) - input_rank +
base::checked_cast<size_t>(axis);
const std::string& label = gather->label;
if (!expanded_axis.AssignIfValid<uint32_t>(&axis)) {
return base::unexpected(
CreateError(mojom::Error::Code::kUnknownError,
"The axis of gather operator is too large.", label));
}
// TODO(crbug.com/40206287): Include a DirectML documentation link and a
// Chromium test that validates the out-of-bounds indices handling.
//
// DirectML implementation for gather operator has already handled the
// indices tensor by clamping it in the shader to prevent out-of-bounds
// access.
DML_GATHER_OPERATOR_DESC gather_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.IndicesTensor = &indices_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
// The axis dimension of InputTensor to gather on.
.Axis = axis,
// The number of actual index dimensions within the IndicesTensor.
.IndexDimensions = base::checked_cast<uint32_t>(indices_rank)};
std::array<const NodeOutput*, 2> inputs = {input, indices};
const OperatorNode* gather_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_GATHER, &gather_operator_desc, inputs, label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
gather_node, std::move(original_output_tensor_desc), 0);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
return base::ok();
}
void CreateOperatorNodeForGatherElements(
const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::GatherElementsPtr& gather_elements,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, gather_elements->input_operand_id);
const TensorDesc& input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.gather_elements_input.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
const NodeOutput* indices = GetNodeOutputForOperand(
id_to_node_output_map, gather_elements->indices_operand_id);
const TensorDesc& indices_tensor_desc = indices->GetTensorDesc();
CHECK(context_properties.data_type_limits.gather_elements_indices.Has(
DmlDataTypeToOperand(indices_tensor_desc.GetDataType())));
uint64_t output_id = gather_elements->output_operand_id;
const TensorDesc output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
// DirectML implementation for gatherElements operator has already handled the
// indices tensor by clamping it in the shader to prevent out-of-bounds
// access.
DML_GATHER_ELEMENTS_OPERATOR_DESC gather_elements_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.IndicesTensor = &indices_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
// The dimension of InputTensor to gather along.
.Axis = gather_elements->axis};
std::array<const NodeOutput*, 2> inputs = {input, indices};
const OperatorNode* node = graph_builder.CreateOperatorNode(
DML_OPERATOR_GATHER_ELEMENTS, &gather_elements_desc, inputs,
gather_elements->label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(node, std::move(output_tensor_desc), 0);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForGelu(
Adapter* adapter,
const IdToOperandMap& id_to_operand_map,
const mojom::GeluPtr& gelu,
mojom::GraphInfoPtr& graph_info,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map,
std::unordered_map<uint64_t, uint32_t>& constant_id_to_input_index_map,
uint64_t& next_operand_id) {
// Check feature level by referring to MSDN doc:
// https://learn.microsoft.com/en-us/windows/ai/directml/api/ns-directml-dml_activation_gelu_operator_desc
if (adapter->IsDMLFeatureLevelSupported(DML_FEATURE_LEVEL_5_1)) {
return CreateOperatorNodeForUnary<DML_ACTIVATION_GELU_OPERATOR_DESC,
DML_OPERATOR_ACTIVATION_GELU>(
id_to_operand_map, gelu, graph_builder, id_to_node_output_map);
}
// Emulate gelu (0.5 * x * (1 + erf(x / sqrt(2)))) with decomposed
// operations on platforms with low feature level according to
// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-gelu-method
//
// Build constant operand (2.0)
const OperandPtr& input_operand =
id_to_operand_map.at(gelu->input_operand_id);
const OperandDataType data_type = input_operand->descriptor.data_type();
uint64_t constant_for_sqrt_operand_id = BuildConstantOperandForFloatValue(
graph_info, next_operand_id, data_type, /*rank*/ 1,
/*default value*/ 2.0);
uint32_t constant_for_sqrt_input_index =
CreateInputNode(id_to_operand_map, constant_for_sqrt_operand_id,
graph_builder, id_to_node_output_map);
CHECK(constant_id_to_input_index_map
.try_emplace(constant_for_sqrt_operand_id,
constant_for_sqrt_input_index)
.second);
const NodeOutput* constant_for_sqrt_output = GetNodeOutputForOperand(
id_to_node_output_map, constant_for_sqrt_operand_id);
// Formula: sqrt(2)
const TensorDesc sqrt_output_tensor_desc =
TensorDesc(GetTensorDataType(data_type), /*dimensions*/ {1});
DML_ELEMENT_WISE_SQRT_OPERATOR_DESC sqrt_operator_desc{
.InputTensor =
&constant_for_sqrt_output->GetTensorDesc().GetDMLTensorDesc(),
.OutputTensor = &sqrt_output_tensor_desc.GetDMLTensorDesc(),
};
const std::string& label = gelu->label;
std::array<const NodeOutput*, 1> sqrt_inputs = {constant_for_sqrt_output};
const OperatorNode* sqrt_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_SQRT, &sqrt_operator_desc, sqrt_inputs, label);
const NodeOutput* sqrt_output =
graph_builder.CreateNodeOutput(sqrt_node, sqrt_output_tensor_desc);
// Formula: x / sqrt(2)
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, gelu->input_operand_id);
const TensorDesc& input_tensor_desc = input->GetTensorDesc();
const std::vector<uint32_t>& input_dimensions =
input_tensor_desc.GetDimensions();
TensorDesc div_divisor_tensor_desc = sqrt_output->GetTensorDesc();
div_divisor_tensor_desc.BroadcastTo(input_dimensions);
uint64_t output_id = gelu->output_operand_id;
const auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
const TensorDesc& div_output_tensor_desc = output_tensor_desc;
std::array<const NodeOutput*, 2> div_inputs = {input, sqrt_output};
const OperatorNode* div_node =
CreateBinaryOperator<DML_ELEMENT_WISE_DIVIDE_OPERATOR_DESC>(
input_tensor_desc, div_divisor_tensor_desc, div_output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_DIVIDE, div_inputs, label);
const NodeOutput* div_output =
graph_builder.CreateNodeOutput(div_node, div_output_tensor_desc);
// Formula: erf(x / sqrt(2))
const TensorDesc& erf_output_tensor_desc = output_tensor_desc;
DML_ELEMENT_WISE_ERF_OPERATOR_DESC erf_operator_desc{
.InputTensor = &div_output->GetTensorDesc().GetDMLTensorDesc(),
.OutputTensor = &erf_output_tensor_desc.GetDMLTensorDesc(),
};
std::array<const NodeOutput*, 1> erf_inputs = {div_output};
const OperatorNode* erf_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_ERF, &erf_operator_desc, erf_inputs, label);
const NodeOutput* erf_output =
graph_builder.CreateNodeOutput(erf_node, erf_output_tensor_desc);
// Build constant operand (1.0)
uint64_t constant_for_add_operand_id = BuildConstantOperandForFloatValue(
graph_info, next_operand_id, data_type, /*rank*/ 1,
/*default value*/ 1.0);
uint32_t constant_for_add_input_index =
CreateInputNode(id_to_operand_map, constant_for_add_operand_id,
graph_builder, id_to_node_output_map);
CHECK(constant_id_to_input_index_map
.try_emplace(constant_for_add_operand_id,
constant_for_add_input_index)
.second);
const NodeOutput* constant_for_add_output = GetNodeOutputForOperand(
id_to_node_output_map, constant_for_add_operand_id);
// Formula: 1 + erf(x / sqrt(2))
const TensorDesc& add_output_tensor_desc = output_tensor_desc;
TensorDesc constant_for_add_tensor_desc =
constant_for_add_output->GetTensorDesc();
constant_for_add_tensor_desc.BroadcastTo(input_dimensions);
std::array<const NodeOutput*, 2> add_inputs = {erf_output,
constant_for_add_output};
const OperatorNode* add_node =
CreateBinaryOperator<DML_ELEMENT_WISE_ADD_OPERATOR_DESC>(
erf_output_tensor_desc, constant_for_add_tensor_desc,
add_output_tensor_desc, graph_builder, DML_OPERATOR_ELEMENT_WISE_ADD,
add_inputs, label);
const NodeOutput* add_output =
graph_builder.CreateNodeOutput(add_node, add_output_tensor_desc);
// Formula: x * (1 + erf(x / sqrt(2)))
const TensorDesc& second_mul_output_tensor_desc = output_tensor_desc;
std::array<const NodeOutput*, 2> second_mul_inputs = {input, add_output};
const OperatorNode* second_mul_node =
CreateBinaryOperator<DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC>(
input_tensor_desc, add_output_tensor_desc,
second_mul_output_tensor_desc, graph_builder,
DML_OPERATOR_ELEMENT_WISE_MULTIPLY, second_mul_inputs, label);
const NodeOutput* second_mul_output = graph_builder.CreateNodeOutput(
second_mul_node, second_mul_output_tensor_desc);
// Build constant operand (0.5)
uint64_t constant_for_mul_operand_id = BuildConstantOperandForFloatValue(
graph_info, next_operand_id, data_type, /*rank*/ 1,
/*default value*/ 0.5);
uint32_t constant_for_mul_input_index =
CreateInputNode(id_to_operand_map, constant_for_mul_operand_id,
graph_builder, id_to_node_output_map);
CHECK(constant_id_to_input_index_map
.try_emplace(constant_for_mul_operand_id,
constant_for_mul_input_index)
.second);
const NodeOutput* constant_for_mul_output = GetNodeOutputForOperand(
id_to_node_output_map, constant_for_mul_operand_id);
// Formula: 0.5 * x * (1 + erf(x / sqrt(2)))
TensorDesc constant_for_mul_tensor_desc =
constant_for_mul_output->GetTensorDesc();
constant_for_mul_tensor_desc.BroadcastTo(input_dimensions);
std::array<const NodeOutput*, 2> mul_constant_inputs = {
second_mul_output, constant_for_mul_output};
const OperatorNode* mul_constant_node =
CreateBinaryOperator<DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC>(
second_mul_output_tensor_desc, constant_for_mul_tensor_desc,
output_tensor_desc, graph_builder, DML_OPERATOR_ELEMENT_WISE_MULTIPLY,
mul_constant_inputs, label);
const NodeOutput* node_output = graph_builder.CreateNodeOutput(
mul_constant_node, std::move(output_tensor_desc));
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
// Creates a DirectML operator for the WebNN general matrix multiplication
// (GEMM) of the expression alpha * A * B + beta * C.
void CreateOperatorNodeForGemm(
const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const Operation* operation,
const std::map<const Operation*, const Operation*>&
operation_to_fusible_standalone_activation_map,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const auto& gemm = operation->get_gemm();
const NodeOutput* input_a_node_output =
GetNodeOutputForOperand(id_to_node_output_map, gemm->a_operand_id);
auto input_a_tensor_desc = input_a_node_output->GetTensorDesc();
CHECK(context_properties.data_type_limits.gemm_input.Has(
DmlDataTypeToOperand(input_a_tensor_desc.GetDataType())));
const NodeOutput* input_b_node_output =
GetNodeOutputForOperand(id_to_node_output_map, gemm->b_operand_id);
auto input_b_tensor_desc = input_b_node_output->GetTensorDesc();
std::vector<const NodeOutput*> inputs{input_a_node_output,
input_b_node_output};
uint64_t output_id = gemm->output_operand_id;
const auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
// The input c tensor description may be broadcasted.
std::optional<TensorDesc> input_c_tensor_desc;
auto& c_operand_id = gemm->c_operand_id;
if (c_operand_id) {
uint64_t input_c_id = c_operand_id.value();
const auto input_c_node_output_iterator =
id_to_node_output_map.find(input_c_id);
CHECK(input_c_node_output_iterator != id_to_node_output_map.end());
const NodeOutput* input_c_node_output =
input_c_node_output_iterator->second;
CHECK(input_c_node_output);
input_c_tensor_desc = input_c_node_output->GetTensorDesc();
// Ensure the graph edge for c operand will be created.
inputs.push_back(input_c_node_output);
auto output_dimensions = output_tensor_desc.GetDimensions();
if (input_c_tensor_desc->GetDimensions() != output_dimensions) {
input_c_tensor_desc->BroadcastTo(output_dimensions);
}
}
// Use 4D GEMM which is available since feature level 1.0 for best
// compatibility. There is no performance difference in the shader between
// 2D/3D/4D, as 2D is just a variant of 4D with a batch/channel size of 1.
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gemm_operator_desc.
// TODO(issues.chromium.org/327244277): Remove the workaround of coercing
// GEMM's tensors to 4D.
input_a_tensor_desc.EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing);
input_b_tensor_desc.EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing);
if (input_c_tensor_desc) {
input_c_tensor_desc->EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing);
}
auto expanded_output_tensor_desc = output_tensor_desc;
expanded_output_tensor_desc.EnsureMinimumRank(
4, TensorDesc::Alignment::kTrailing);
std::optional<const Operation*> fusible_activation =
GetFusibleActivationFromOperation(
operation_to_fusible_standalone_activation_map, operation);
std::optional<ActivationOperatorDesc> activation_operator_desc;
std::optional<DML_OPERATOR_DESC> activation_dml_desc;
if (fusible_activation) {
activation_operator_desc =
CreateOperatorDescForFusibleActivation(*fusible_activation.value());
activation_dml_desc = activation_operator_desc->GetActivationDmlDesc();
output_id =
GetFusibleActivationOutputId(*fusible_activation.value()).value();
}
DML_GEMM_OPERATOR_DESC gemm_operator_desc{
.ATensor = &input_a_tensor_desc.GetDMLTensorDesc(),
.BTensor = &input_b_tensor_desc.GetDMLTensorDesc(),
.CTensor = GetOptionalDmlTensorDescPtr(input_c_tensor_desc),
.OutputTensor = &expanded_output_tensor_desc.GetDMLTensorDesc(),
.TransA = (gemm->a_transpose) ? DML_MATRIX_TRANSFORM_TRANSPOSE
: DML_MATRIX_TRANSFORM_NONE,
.TransB = (gemm->b_transpose) ? DML_MATRIX_TRANSFORM_TRANSPOSE
: DML_MATRIX_TRANSFORM_NONE,
.Alpha = gemm->alpha,
.Beta = gemm->beta,
.FusedActivation =
activation_dml_desc ? &activation_dml_desc.value() : nullptr,
};
const std::string& label = gemm->label;
const OperatorNode* gemm_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_GEMM, &gemm_operator_desc, inputs, label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
gemm_node, std::move(output_tensor_desc), 0);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
// This helper checks if the input node output is a constant operand, if so,
// append an identity node to the input node output by calling
// `AppendIdentityNode`, otherwise do nothing and return `input` directly.
const NodeOutput* AppendIdentityToConstantOperand(
GraphBuilderDml& graph_builder,
const NodeOutput* input) {
CHECK(input);
// Do nothing if the input is without the DML_TENSOR_FLAG_OWNED_BY_DML flag.
if (!(input->GetTensorDesc().GetFlags() & DML_TENSOR_FLAG_OWNED_BY_DML)) {
return input;
}
// Append an identity node if the input is with the
// DML_TENSOR_FLAG_OWNED_BY_DML flag. For certain operators like lstm and
// gru, their input tensors don't support this flag and an identity is needed
// to remove it.
return AppendIdentityNode(graph_builder, input);
}
// `GruType` must be `mojom::GruPtr` or `mojom::GruCellPtr`.
template <typename GruType>
base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForGru(
const IdToOperandMap& id_to_operand_map,
const GruType& gru,
mojom::GraphInfoPtr& graph_info,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map,
std::unordered_map<uint64_t, uint32_t>& constant_id_to_input_index_map,
uint64_t& next_operand_id) {
static_assert(std::is_same<GruType, mojom::GruPtr>::value ||
std::is_same<GruType, mojom::GruCellPtr>::value);
mojom::Operation::Tag op_tag;
std::optional<uint64_t> initial_hidden_state_operand_id;
bool return_sequence;
mojom::RecurrentNetworkDirection direction;
if constexpr (std::is_same<GruType, mojom::GruPtr>::value) {
op_tag = mojom::Operation::Tag::kGru;
initial_hidden_state_operand_id = gru->initial_hidden_state_operand_id;
return_sequence = gru->return_sequence;
direction = gru->direction;
} else /* GruType is mojom::GruCell */ {
op_tag = mojom::Operation::Tag::kGruCell;
initial_hidden_state_operand_id = gru->hidden_state_operand_id;
return_sequence = false;
direction = mojom::RecurrentNetworkDirection::kForward;
}
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, gru->input_operand_id);
// Since the InputTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML
// flag, add an identity operator to change the input type:
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gru_operator_desc
input = AppendIdentityToConstantOperand(graph_builder, input);
TensorDesc input_tensor_desc = input->GetTensorDesc();
// The input tensor is 4-D for gru and 3-D for gruCell, while DirectML expects
// a 4-D tensor.
input_tensor_desc.EnsureMinimumRank(/*rank=*/4,
TensorDesc::Alignment::kTrailing);
const NodeOutput* weight =
GetNodeOutputForOperand(id_to_node_output_map, gru->weight_operand_id);
// Since the WeightTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML
// flag, add an identity operator to change the input type:
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gru_operator_desc
weight = AppendIdentityToConstantOperand(graph_builder, weight);
TensorDesc weight_tensor_desc = weight->GetTensorDesc();
// The weight tensor is 3-D for gru and 2-D for gruCell, while DirectML
// expects a 4-D tensor.
weight_tensor_desc.EnsureMinimumRank(/*rank*/ 4,
TensorDesc::Alignment::kTrailing);
const NodeOutput* recurrent_weight = GetNodeOutputForOperand(
id_to_node_output_map, gru->recurrent_weight_operand_id);
// Since the RecurrenceTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML
// flag, add an identity operator to change the input type:
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gru_operator_desc
recurrent_weight =
AppendIdentityToConstantOperand(graph_builder, recurrent_weight);
TensorDesc recurrent_weight_tensor_desc = recurrent_weight->GetTensorDesc();
// The recurrent weight tensor is 3-D for gru and 2-D for gruCell, while
// DirectML expects a 4-D tensor.
recurrent_weight_tensor_desc.EnsureMinimumRank(
/*rank=*/4, TensorDesc::Alignment::kTrailing);
std::vector<const NodeOutput*> inputs{input, weight, recurrent_weight};
const OperandPtr& input_operand = id_to_operand_map.at(gru->input_operand_id);
const OperandDataType data_type = input_operand->descriptor.data_type();
const std::string& label = gru->label;
std::optional<TensorDesc> concatenated_bias_tensor_desc;
if (!gru->bias_operand_id.has_value() &&
!gru->recurrent_bias_operand_id.has_value()) {
// Use a nullptr to indicate there is no input edge for BiasTensor.
inputs.push_back(nullptr);
} else {
// The DirectML bias tensor is the concatenation of bias and recurrent bias
// (if bidirectional). Get or create the node output of bias and recurrent
// bias for the following concat operation.
std::optional<const NodeOutput*> zero_bias;
if (!gru->bias_operand_id.has_value() ||
!gru->recurrent_bias_operand_id.has_value()) {
uint64_t zero_bias_operand_id = BuildConstantOperandForFloatValue(
graph_info, next_operand_id, data_type, /*rank*/ 1,
/*default bias*/ 0);
uint32_t bias_input_index =
CreateInputNode(id_to_operand_map, zero_bias_operand_id,
graph_builder, id_to_node_output_map);
CHECK(constant_id_to_input_index_map
.try_emplace(zero_bias_operand_id, bias_input_index)
.second);
zero_bias =
GetNodeOutputForOperand(id_to_node_output_map, zero_bias_operand_id);
}
const NodeOutput* bias =
gru->bias_operand_id.has_value()
? GetOptionalNodeOutputForOperand(id_to_node_output_map,
gru->bias_operand_id)
: zero_bias.value();
const NodeOutput* recurrent_bias =
gru->recurrent_bias_operand_id.has_value()
? GetOptionalNodeOutputForOperand(id_to_node_output_map,
gru->recurrent_bias_operand_id)
: zero_bias.value();
const uint32_t num_directions =
direction == mojom::RecurrentNetworkDirection::kBoth ? 2 : 1;
uint32_t hidden_size = gru->hidden_size;
// 3 * hidden_size has been verified.
auto checked_three_times_hidden_size =
base::MakeCheckedNum(hidden_size) * 3;
CHECK(checked_three_times_hidden_size.IsValid());
// The half bias dimensions is [1, 1, num_directions, 3 * hidden_size] for
// gru and [1, 1, 1, 3 * hidden_size] for gruCell.
const std::array<uint32_t, 4> half_bias_dimensions = {
1, 1, num_directions, checked_three_times_hidden_size.ValueOrDie()};
TensorDesc bias_tensor_desc = bias->GetTensorDesc();
// The bias tensor shape is either [1] or [direction_count, 3 *
// hidden_size], which can be broadcasted to [1, 1, direction_count, 3 *
// hidden_size] as DirectML requires.
bias_tensor_desc.BroadcastTo(half_bias_dimensions);
TensorDesc recurrent_bias_tensor_desc = recurrent_bias->GetTensorDesc();
recurrent_bias_tensor_desc.BroadcastTo(half_bias_dimensions);
std::array<DML_TENSOR_DESC, 2> concat_input_tensor_descs = {
bias_tensor_desc.GetDMLTensorDesc(),
recurrent_bias_tensor_desc.GetDMLTensorDesc()};
// The DirectML bias dimensions is [1, 1, num_directions, 6 * hidden_size].
// Ideally, 6 * hidden_size validation should be part of the spec and
// validated for all backends. Spec issue tracked on
// https://github.com/webmachinelearning/webnn/issues/625.
auto checked_six_times_hidden_size = base::MakeCheckedNum(hidden_size) * 6;
if (!checked_six_times_hidden_size.IsValid()) {
return CreateUnexpectedError(
mojom::Error::Code::kUnknownError,
base::StringPrintf("The hidden size is too large for %s operator.",
OpTagToString(op_tag).c_str()),
label);
}
std::vector<uint32_t> concatenated_bias_dimensions = {
1, 1, num_directions, checked_six_times_hidden_size.ValueOrDie()};
concatenated_bias_tensor_desc = TensorDesc(
GetTensorDataType(data_type), std::move(concatenated_bias_dimensions));
DML_JOIN_OPERATOR_DESC concat_operator_desc{
.InputCount = concat_input_tensor_descs.size(),
.InputTensors = concat_input_tensor_descs.data(),
.OutputTensor = &concatenated_bias_tensor_desc->GetDMLTensorDesc(),
.Axis = 3};
std::array<const NodeOutput*, 2> bias_outputs = {bias, recurrent_bias};
const OperatorNode* concat_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_JOIN, &concat_operator_desc, bias_outputs, label);
const NodeOutput* concatenated_bias = graph_builder.CreateNodeOutput(
concat_node, concatenated_bias_tensor_desc.value(), 0);
inputs.push_back(concatenated_bias);
}
std::optional<TensorDesc> initial_hidden_state_tensor_desc;
if (initial_hidden_state_operand_id.has_value()) {
const NodeOutput* initial_hidden_state = GetNodeOutputForOperand(
id_to_node_output_map, initial_hidden_state_operand_id.value());
// Since the HiddenInitTensor doesn't support the
// DML_TENSOR_FLAG_OWNED_BY_DML flag, add an identity operator to change the
// input type:
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gru_operator_desc
initial_hidden_state =
AppendIdentityToConstantOperand(graph_builder, initial_hidden_state);
initial_hidden_state_tensor_desc = initial_hidden_state->GetTensorDesc();
// The initial hidden state tensor shape is `[num_directions, batch_size,
// hidden_size]`, while DirectML expects the shape to be `[1,
// num_directions, batch_size, hidden_size]`.
initial_hidden_state_tensor_desc->EnsureMinimumRank(
/*rank*/ 4, TensorDesc::Alignment::kTrailing);
inputs.push_back(initial_hidden_state);
} else {
// Use a nullptr to indicate there is no input edge for HiddenInitTensor.
inputs.push_back(nullptr);
}
// Use a nullptr to indicate all sequences in the batch have length
// seq_length:
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gru_operator_desc
inputs.push_back(nullptr);
std::vector<uint64_t> output_ids;
uint64_t output_hidden_state_id;
if constexpr (std::is_same<GruType, mojom::GruPtr>::value) {
output_ids = gru->output_operand_ids;
output_hidden_state_id = output_ids[0];
} else {
output_hidden_state_id = gru->output_operand_id;
}
TensorDesc output_hidden_state_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_hidden_state_id);
// The output hidden state tensor is 3-D for gru and 2-D for gruCell, while
// DirectML expects a 4-D tensor.
output_hidden_state_tensor_desc.EnsureMinimumRank(
/*rank*/ 4, TensorDesc::Alignment::kTrailing);
std::optional<uint64_t> output_sequence_id;
std::optional<TensorDesc> output_sequence_tensor_desc;
if (return_sequence) {
CHECK_EQ(output_ids.size(), 2u);
output_sequence_id = output_ids[1];
output_sequence_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_sequence_id.value());
}
if (gru->layout != mojom::GruWeightLayout::kZrn) {
return CreateUnexpectedError(
mojom::Error::Code::kNotSupportedError,
"The gru weight layout (rzn) is not supported.", label);
}
// When the recurrent network is bidirectional, dual activations must be
// provided for the forward and backward directions.
const size_t number_of_activations =
direction == mojom::RecurrentNetworkDirection::kBoth
? gru->activations.size() * 2
: gru->activations.size();
std::vector<ActivationOperatorDesc> activation_operator_descs;
activation_operator_descs.reserve(number_of_activations);
for (mojom::RecurrentNetworkActivation activation : gru->activations) {
activation_operator_descs.push_back(
CreateOperatorDescForActivation(activation));
}
// For bidirectional, activations must be provided f() and g() for forward
// followed by f() and g() for backwards.
if (direction == mojom::RecurrentNetworkDirection::kBoth) {
base::ranges::copy(activation_operator_descs,
std::back_inserter(activation_operator_descs));
}
std::vector<DML_OPERATOR_DESC> activation_dml_descs;
activation_dml_descs.reserve(activation_operator_descs.size());
base::ranges::transform(
activation_operator_descs, std::back_inserter(activation_dml_descs),
[](const auto& activation_operator_desc) {
return activation_operator_desc.GetActivationDmlDesc();
});
DML_GRU_OPERATOR_DESC gru_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.WeightTensor = &weight_tensor_desc.GetDMLTensorDesc(),
.RecurrenceTensor = &recurrent_weight_tensor_desc.GetDMLTensorDesc(),
.BiasTensor = GetOptionalDmlTensorDescPtr(concatenated_bias_tensor_desc),
.HiddenInitTensor =
GetOptionalDmlTensorDescPtr(initial_hidden_state_tensor_desc),
.SequenceLengthsTensor = nullptr,
.OutputSequenceTensor =
GetOptionalDmlTensorDescPtr(output_sequence_tensor_desc),
.OutputSingleTensor = &output_hidden_state_tensor_desc.GetDMLTensorDesc(),
.ActivationDescCount = static_cast<uint32_t>(activation_dml_descs.size()),
.ActivationDescs = activation_dml_descs.data(),
.Direction = MojoRecurrentNetworkDirectionToDml(direction),
.LinearBeforeReset = gru->reset_after};
const OperatorNode* gru_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_GRU, &gru_desc, inputs, label);
const NodeOutput* output_hidden_state = graph_builder.CreateNodeOutput(
gru_node, output_hidden_state_tensor_desc, /*output_index*/ 1);
CHECK(id_to_node_output_map
.try_emplace(output_hidden_state_id, output_hidden_state)
.second);
if (return_sequence) {
const NodeOutput* output_sequence = graph_builder.CreateNodeOutput(
gru_node, output_sequence_tensor_desc.value(), /*output_index*/ 0);
CHECK(id_to_node_output_map
.try_emplace(output_sequence_id.value(), output_sequence)
.second);
}
return base::ok();
}
void CreateOperatorNodeForHardSigmoid(
const IdToOperandMap& id_to_operand_map,
const mojom::HardSigmoidPtr& hard_sigmoid,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, hard_sigmoid->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
const uint64_t output_id = hard_sigmoid->output_operand_id;
auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC hard_sigmoid_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Alpha = hard_sigmoid->alpha,
.Beta = hard_sigmoid->beta};
std::array<const NodeOutput*, 1> inputs = {input};
const OperatorNode* hard_sigmoid_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_HARD_SIGMOID, &hard_sigmoid_desc, inputs,
hard_sigmoid->label);
const NodeOutput* node_output = graph_builder.CreateNodeOutput(
hard_sigmoid_node, std::move(output_tensor_desc));
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
void CreateOperatorNodeForHardSwish(Adapter* adapter,
const IdToOperandMap& id_to_operand_map,
const mojom::HardSwishPtr& hard_swish,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, hard_swish->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
const uint64_t output_id = hard_swish->output_operand_id;
auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
const float scale = 1.0 / 6.0;
const float bias = 0.5;
const std::string& label = hard_swish->label;
if (adapter->IsDMLFeatureLevelSupported(DML_FEATURE_LEVEL_6_2)) {
std::array<const NodeOutput*, 1> inputs = {input};
DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC hard_swish_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Alpha = scale,
.Beta = bias};
const OperatorNode* hard_swish_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_HARD_SWISH, &hard_swish_desc, inputs, label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(hard_swish_node, output_tensor_desc);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
return;
}
// If DirectML's feature level is before 6.2, we need to implement hardSwish
// by composing from smaller operators:
// Output = input * clamp((input / 6) + 0.5, 0, 1).
// First step: build `clamp((x / 6) + 0.5, 0, 1)`.
DML_SCALE_BIAS scale_bias = {.Scale = scale, .Bias = bias};
DML_ELEMENT_WISE_CLIP_OPERATOR_DESC clamp_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
// Applying the function `g(x) = x / 6 + 0.5` to each input element
// prior to clamp.
.ScaleBias = &scale_bias,
.Min = 0,
.Max = 1};
std::array<const NodeOutput*, 1> clamp_inputs = {input};
const OperatorNode* clamp_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_CLIP, &clamp_operator_desc, clamp_inputs,
label);
const NodeOutput* clamp_output =
graph_builder.CreateNodeOutput(clamp_node, output_tensor_desc, 0);
const auto& clamp_output_tensor_desc = clamp_output->GetTensorDesc();
// Second step: build `x * first_step`.
std::array<const NodeOutput*, 2> mul_inputs = {input, clamp_output};
DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC binary_mul_desc{
.ATensor = &input_tensor_desc.GetDMLTensorDesc(),
.BTensor = &clamp_output_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc()};
const OperatorNode* binary_mul_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &binary_mul_desc, mul_inputs, label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(binary_mul_node, output_tensor_desc);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
template <typename NormalizationPtr>
base::expected<void, mojom::ErrorPtr>
CreateOperatorNodeForMeanVarianceNormalization(
const NormalizationPtr& normalization,
const Operation* operation,
const std::map<const Operation*, const Operation*>&
operation_to_fusible_standalone_activation_map,
mojom::GraphInfoPtr& graph_info,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map,
std::unordered_map<uint64_t, uint32_t>& constant_id_to_input_index_map,
uint64_t& next_operand_id,
base::span<const uint32_t> mean_variance_axes,
base::span<const uint32_t> scale_bias_broadcast_axes,
mojom::Operation::Tag op) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, normalization->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
size_t input_rank = input_tensor_desc.GetDimensions().size();
auto& id_to_operand_map = graph_info->id_to_operand_map;
uint64_t output_id = normalization->output_operand_id;
const OperandPtr& output_operand = id_to_operand_map.at(output_id);
OperandDataType data_type = output_operand->descriptor.data_type();
const TensorDesc output_tensor_desc(GetTensorDataType(data_type),
output_operand->descriptor.shape());
const NodeOutput* scale = GetOptionalNodeOutputForOperand(
id_to_node_output_map, normalization->scale_operand_id);
const NodeOutput* bias = GetOptionalNodeOutputForOperand(
id_to_node_output_map, normalization->bias_operand_id);
// DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC requires `ScaleTensor` and
// `BiasTensor` to be both present or not present when DML_FEATURE_LEVEL is
// less than DML_FEATURE_LEVEL_5_2.
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_mean_variance_normalization1_operator_desc.
//
// If one of scale/bias is not present, create a constant operand for it and
// insert the operand into the graph.
if ((scale && !bias) || (!scale && bias)) {
if (!scale) {
uint64_t scale_operand_id = BuildConstantOperandForFloatValue(
graph_info, next_operand_id, data_type,
scale_bias_broadcast_axes.size(),
/*default scale*/ 1.0);
// Create an input node for the scale operand and store the assigned input
// index in `constant_id_to_input_index_map`, which will be used for
// constant buffer binding.
uint32_t scale_input_index =
CreateInputNode(id_to_operand_map, scale_operand_id, graph_builder,
id_to_node_output_map);
CHECK(constant_id_to_input_index_map
.try_emplace(scale_operand_id, scale_input_index)
.second);
scale = GetNodeOutputForOperand(id_to_node_output_map, scale_operand_id);
}
if (!bias) {
uint64_t bias_operand_id = BuildConstantOperandForFloatValue(
graph_info, next_operand_id, data_type,
scale_bias_broadcast_axes.size(),
/*default bias*/ 0);
// Create an input node for the bias operand and store the assigned input
// index in `constant_id_to_input_index_map`, which will be used for
// constant buffer binding.
uint32_t bias_input_index =
CreateInputNode(id_to_operand_map, bias_operand_id, graph_builder,
id_to_node_output_map);
CHECK(constant_id_to_input_index_map
.try_emplace(bias_operand_id, bias_input_index)
.second);
bias = GetNodeOutputForOperand(id_to_node_output_map, bias_operand_id);
}
}
const std::string& label = normalization->label;
if (!base::MakeCheckedNum(mean_variance_axes.size()).IsValid<uint32_t>()) {
return base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
OpTagToString(op) + ": The axes rank is too large.", label));
}
std::vector<const NodeOutput*> inputs = {input};
std::optional<TensorDesc> scale_tensor_desc;
std::optional<TensorDesc> bias_tensor_desc;
if (scale) {
inputs.push_back(scale);
scale_tensor_desc = scale->GetTensorDesc();
// The scale tensor should have the same rank as the input tensor required
// by DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC.
scale_tensor_desc->MakeBroadcastCompatible(input_rank,
scale_bias_broadcast_axes);
}
if (bias) {
inputs.push_back(bias);
bias_tensor_desc = bias->GetTensorDesc();
// The bias tensor should have the same rank as the input tensor required by
// DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC.
bias_tensor_desc->MakeBroadcastCompatible(input_rank,
scale_bias_broadcast_axes);
}
std::optional<const Operation*> fusible_activation =
GetFusibleActivationFromOperation(
operation_to_fusible_standalone_activation_map, operation);
std::optional<ActivationOperatorDesc> activation_operator_desc;
std::optional<DML_OPERATOR_DESC> activation_dml_desc;
if (fusible_activation) {
activation_operator_desc =
CreateOperatorDescForFusibleActivation(*fusible_activation.value());
activation_dml_desc = activation_operator_desc->GetActivationDmlDesc();
output_id =
GetFusibleActivationOutputId(*fusible_activation.value()).value();
}
DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC
normalization_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.ScaleTensor = GetOptionalDmlTensorDescPtr(scale_tensor_desc),
.BiasTensor = GetOptionalDmlTensorDescPtr(bias_tensor_desc),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.AxisCount = base::checked_cast<uint32_t>(mean_variance_axes.size()),
.Axes = mean_variance_axes.data(),
// The layer normalization and instance normalization includes variance.
.NormalizeVariance = true,
.Epsilon = normalization->epsilon,
.FusedActivation =
activation_dml_desc ? &activation_dml_desc.value() : nullptr,
};
const OperatorNode* normalization_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &normalization_operator_desc,
inputs, label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
normalization_node, std::move(output_tensor_desc));
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
return base::ok();
}
void CreateOperatorNodeForLeakyRelu(const IdToOperandMap& id_to_operand_map,
const mojom::LeakyReluPtr& leaky_relu,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, leaky_relu->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
uint64_t output_id = leaky_relu->output_operand_id;
const auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC leaky_relu_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Alpha = leaky_relu->alpha};
std::array<const NodeOutput*, 1> inputs = {input};
const OperatorNode* leaky_relu_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_LEAKY_RELU, &leaky_relu_desc, inputs,
leaky_relu->label);
const NodeOutput* node_output = graph_builder.CreateNodeOutput(
leaky_relu_node, std::move(output_tensor_desc));
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
void CreateOperatorNodeForLinear(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::LinearPtr& linear,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, linear->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.linear_input.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
uint64_t output_id = linear->output_operand_id;
auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
DML_ACTIVATION_LINEAR_OPERATOR_DESC linear_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Alpha = linear->alpha,
.Beta = linear->beta};
std::array<const NodeOutput*, 1> inputs = {input};
const OperatorNode* linear_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_LINEAR, &linear_desc, inputs, linear->label);
const NodeOutput* node_output = graph_builder.CreateNodeOutput(
linear_node, std::move(output_tensor_desc));
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
// `LstmType` must be `mojom::Lstm` or `mojom::LstmCell`.
template <typename LstmType>
base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForLstm(
const LstmType& lstm,
mojom::GraphInfoPtr& graph_info,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map,
std::unordered_map<uint64_t, uint32_t>& constant_id_to_input_index_map,
uint64_t& next_operand_id) {
static_assert(std::is_same<LstmType, mojom::Lstm>::value ||
std::is_same<LstmType, mojom::LstmCell>::value);
const std::string& label = lstm.label;
// TODO(crbug.com/329702350): Support the ifgo layout.
if (lstm.layout == mojom::LstmWeightLayout::kIfgo) {
return CreateUnexpectedError(
mojom::Error::Code::kNotSupportedError,
"The lstm weight layout (ifgo) is not supported.", label);
}
mojom::Operation::Tag op_tag;
std::optional<uint64_t> initial_hidden_state_operand_id;
std::optional<uint64_t> initial_cell_state_operand_id;
bool return_sequence;
mojom::RecurrentNetworkDirection direction;
if constexpr (std::is_same<LstmType, mojom::Lstm>::value) {
op_tag = mojom::Operation::Tag::kLstm;
initial_hidden_state_operand_id = lstm.initial_hidden_state_operand_id;
initial_cell_state_operand_id = lstm.initial_cell_state_operand_id;
return_sequence = lstm.return_sequence;
direction = lstm.direction;
} else /* `LstmType` is `mojom::LstmCell` */ {
op_tag = mojom::Operation::Tag::kLstmCell;
initial_hidden_state_operand_id = lstm.hidden_state_operand_id;
initial_cell_state_operand_id = lstm.cell_state_operand_id;
return_sequence = false;
direction = mojom::RecurrentNetworkDirection::kForward;
}
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, lstm.input_operand_id);
// Append an identity node if the input is a constant operand since
// InputTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML flag.
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_lstm_operator_desc
input = AppendIdentityToConstantOperand(graph_builder, input);
TensorDesc input_tensor_desc = input->GetTensorDesc();
// The input tensor is 2-D for lstmCell and 3-D for lstm, while DirectML
// expects a 4-D tensor.
input_tensor_desc.EnsureMinimumRank(/*rank=*/4,
TensorDesc::Alignment::kTrailing);
const NodeOutput* weight =
GetNodeOutputForOperand(id_to_node_output_map, lstm.weight_operand_id);
// Append an identity node if the weight is a constant operand since
// WeightTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML flag.
weight = AppendIdentityToConstantOperand(graph_builder, weight);
TensorDesc weight_tensor_desc = weight->GetTensorDesc();
// The weight tensor is 2-D for lstmCell and 3-D for lstm, while DirectML
// expects a 4-D tensor.
weight_tensor_desc.EnsureMinimumRank(/*rank=*/4,
TensorDesc::Alignment::kTrailing);
const NodeOutput* recurrent_weight = GetNodeOutputForOperand(
id_to_node_output_map, lstm.recurrent_weight_operand_id);
// Append an identity node if the recurrent weight is a constant operand since
// RecurrenceTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML flag.
recurrent_weight =
AppendIdentityToConstantOperand(graph_builder, recurrent_weight);
TensorDesc recurrent_weight_tensor_desc = recurrent_weight->GetTensorDesc();
// The recurrent weight tensor is 2-D for lstmCell and 3-D for lstm, while
// DirectML expects a 4-D tensor.
recurrent_weight_tensor_desc.EnsureMinimumRank(
/*rank=*/4, TensorDesc::Alignment::kTrailing);
IdToOperandMap& id_to_operand_map = graph_info->id_to_operand_map;
const std::vector<uint64_t>& output_ids = lstm.output_operand_ids;
const size_t output_count = output_ids.size();
CHECK_GE(output_count, 2u);
const uint64_t output_hidden_state_id = output_ids[0];
const OperandPtr& output_hidden_state_operand =
id_to_operand_map.at(output_hidden_state_id);
const OperandDataType output_data_type =
output_hidden_state_operand->descriptor.data_type();
TensorDesc output_hidden_state_tensor_desc(
GetTensorDataType(output_data_type),
output_hidden_state_operand->descriptor.shape());
// The output hidden state tensor is 2-D for lstmCell and 3-D for lstm, while
// DirectML expects a 4-D tensor.
output_hidden_state_tensor_desc.EnsureMinimumRank(
/*rank=*/4, TensorDesc::Alignment::kTrailing);
const uint64_t output_cell_state_id = output_ids[1];
TensorDesc output_cell_state_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_cell_state_id);
// The output cell state tensor is 2-D for lstmCell and 3-D for lstm, while
// DirectML expects a 4-D tensor.
output_cell_state_tensor_desc.EnsureMinimumRank(
/*rank=*/4, TensorDesc::Alignment::kTrailing);
std::optional<uint64_t> output_sequence_id;
std::optional<TensorDesc> output_sequence_tensor_desc;
if (return_sequence) {
CHECK_EQ(output_count, 3u);
output_sequence_id = output_ids[2];
output_sequence_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_sequence_id.value());
}
std::vector<const NodeOutput*> inputs{input, weight, recurrent_weight};
const NodeOutput* bias = GetOptionalNodeOutputForOperand(
id_to_node_output_map, lstm.bias_operand_id);
const NodeOutput* recurrent_bias = GetOptionalNodeOutputForOperand(
id_to_node_output_map, lstm.recurrent_bias_operand_id);
// DML_LSTM_OPERATOR_DESC only takes a concatenation of {bias, recurrent_bias}
// or none, so create a constant bias operand if one of the biases is not
// given.
if ((bias && !recurrent_bias) || (!bias && recurrent_bias)) {
uint64_t bias_operand_id = BuildConstantOperandForFloatValue(
graph_info, next_operand_id, output_data_type,
/*rank=*/1, /*default bias=*/0);
// Create an input node for the bias operand and store the assigned input
// index in `constant_id_to_input_index_map`, which will be used for
// constant buffer binding.
uint32_t bias_input_index =
CreateInputNode(id_to_operand_map, bias_operand_id, graph_builder,
id_to_node_output_map);
CHECK(constant_id_to_input_index_map
.try_emplace(bias_operand_id, bias_input_index)
.second);
if (!bias) {
bias = GetNodeOutputForOperand(id_to_node_output_map, bias_operand_id);
}
if (!recurrent_bias) {
recurrent_bias =
GetNodeOutputForOperand(id_to_node_output_map, bias_operand_id);
}
}
// Bias operands should be both present or not present.
CHECK((bias && recurrent_bias) || (!bias && !recurrent_bias));
// Concatenate the bias operands if they are both present.
std::optional<TensorDesc> concatenated_bias_tensor_desc;
if (bias && recurrent_bias) {
const uint32_t direction_count =
direction == mojom::RecurrentNetworkDirection::kBoth ? 2 : 1;
auto checked_four_times_hidden_size =
base::MakeCheckedNum(lstm.hidden_size) * 4;
// Four times hidden size should have already been validated.
CHECK(checked_four_times_hidden_size.IsValid());
const std::vector<uint32_t> bias_dimensions = {
1, 1, direction_count, checked_four_times_hidden_size.ValueOrDie()};
// The bias tensor shape is [1] or `[4 * hidden_size]` or [direction_count,
// 4 * hidden_size], which can be broadcasted to [1, 1, direction_count, 4 *
// hidden_size] as DirectML requires.
TensorDesc bias_tensor_desc = bias->GetTensorDesc();
bias_tensor_desc.BroadcastTo(bias_dimensions);
TensorDesc recurrent_bias_tensor_desc = recurrent_bias->GetTensorDesc();
recurrent_bias_tensor_desc.BroadcastTo(bias_dimensions);
std::array<DML_TENSOR_DESC, 2> bias_dml_tensor_descs = {
bias_tensor_desc.GetDMLTensorDesc(),
recurrent_bias_tensor_desc.GetDMLTensorDesc()};
auto checked_eight_times_hidden_size = checked_four_times_hidden_size * 2;
if (!checked_eight_times_hidden_size.IsValid()) {
return CreateUnexpectedError(
mojom::Error::Code::kUnknownError,
base::StringPrintf("The hidden size is too large for %s operator.",
OpTagToString(op_tag).c_str()),
label);
}
// The concatenated bias dimensions is [1, 1, direction_count, 8 *
// hidden_size].
std::vector<uint32_t> concatenated_dimensions = {
1, 1, direction_count, checked_eight_times_hidden_size.ValueOrDie()};
concatenated_bias_tensor_desc =
TensorDesc(GetTensorDataType(output_data_type),
std::move(concatenated_dimensions));
DML_JOIN_OPERATOR_DESC concat_operator_desc{
.InputCount = static_cast<uint32_t>(bias_dml_tensor_descs.size()),
.InputTensors = bias_dml_tensor_descs.data(),
.OutputTensor = &concatenated_bias_tensor_desc->GetDMLTensorDesc(),
.Axis = 3};
std::array<const NodeOutput*, 2> biases = {bias, recurrent_bias};
const OperatorNode* concat_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_JOIN, &concat_operator_desc, biases, label);
const NodeOutput* concatenated_bias = graph_builder.CreateNodeOutput(
concat_node, concatenated_bias_tensor_desc.value(), 0);
inputs.push_back(concatenated_bias);
} else {
// Use a nullptr to indicate there is no input edge for BiasTensor.
inputs.push_back(nullptr);
}
std::optional<TensorDesc> initial_hidden_state_tensor_desc;
if (initial_hidden_state_operand_id.has_value()) {
const NodeOutput* initial_hidden_state = GetNodeOutputForOperand(
id_to_node_output_map, initial_hidden_state_operand_id.value());
// Append an identity node if the initial hidden state is a constant operand
// since HiddenInitTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML
// flag.
initial_hidden_state =
AppendIdentityToConstantOperand(graph_builder, initial_hidden_state);
inputs.push_back(initial_hidden_state);
initial_hidden_state_tensor_desc = initial_hidden_state->GetTensorDesc();
// The initial hidden state tensor is 2-D for lstmCell and 3-D for lstm,
// while DirectML expects a 4-D tensor.
initial_hidden_state_tensor_desc->EnsureMinimumRank(
/*rank=*/4, TensorDesc::Alignment::kTrailing);
} else {
// Use a nullptr to indicate there is no input edge for HiddenInitTensor.
inputs.push_back(nullptr);
}
std::optional<TensorDesc> initial_cell_state_tensor_desc;
if (initial_cell_state_operand_id.has_value()) {
const NodeOutput* initial_cell_state = GetNodeOutputForOperand(
id_to_node_output_map, initial_cell_state_operand_id.value());
// Append an identity node if the initial cell state is a constant operand
// since CellMemInitTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML
// flag.
initial_cell_state =
AppendIdentityToConstantOperand(graph_builder, initial_cell_state);
inputs.push_back(initial_cell_state);
initial_cell_state_tensor_desc = initial_cell_state->GetTensorDesc();
// The initial cell state tensor is 2-D for lstmCell and 3-D for lstm, while
// DirectML expects a 4-D tensor.
initial_cell_state_tensor_desc->EnsureMinimumRank(
/*rank=*/4, TensorDesc::Alignment::kTrailing);
} else {
// Use a nullptr to indicate there is no input edge for CellMemInitTensor.
inputs.push_back(nullptr);
}
// Use a nullptr to indicate there is no input edge for SequenceLengthsTensor.
inputs.push_back(nullptr);
std::optional<TensorDesc> peephole_weight_tensor_desc;
if (lstm.peephole_weight_operand_id.has_value()) {
const NodeOutput* peephole_weight = GetNodeOutputForOperand(
id_to_node_output_map, lstm.peephole_weight_operand_id.value());
// Append an identity node if the peephole weight is a constant operand
// since PeepholeTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML
// flag.
peephole_weight =
AppendIdentityToConstantOperand(graph_builder, peephole_weight);
inputs.push_back(peephole_weight);
peephole_weight_tensor_desc = peephole_weight->GetTensorDesc();
// The peephole weight tensor is 1-D for lstmCell and 2-D for lstm, while
// DirectML expects a 4-D tensor.
peephole_weight_tensor_desc->EnsureMinimumRank(
/*rank=*/4, TensorDesc::Alignment::kTrailing);
}
// When the recurrent network is bidirectional, dual activations must be
// provided for the forward and backward directions.
const size_t number_of_activations =
direction == mojom::RecurrentNetworkDirection::kBoth
? lstm.activations.size() * 2
: lstm.activations.size();
std::vector<ActivationOperatorDesc> activation_operator_descs;
activation_operator_descs.reserve(number_of_activations);
for (mojom::RecurrentNetworkActivation activation : lstm.activations) {
activation_operator_descs.push_back(
CreateOperatorDescForActivation(activation));
}
// For bidirectional, activations must be provided f() and g() for forward
// followed by f() and g() for backwards.
if (direction == mojom::RecurrentNetworkDirection::kBoth) {
base::ranges::copy(activation_operator_descs,
std::back_inserter(activation_operator_descs));
}
std::vector<DML_OPERATOR_DESC> activation_dml_descs;
activation_dml_descs.reserve(activation_operator_descs.size());
base::ranges::transform(
activation_operator_descs, std::back_inserter(activation_dml_descs),
[](const auto& activation_operator_desc) {
return activation_operator_desc.GetActivationDmlDesc();
});
DML_LSTM_OPERATOR_DESC lstm_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.WeightTensor = &weight_tensor_desc.GetDMLTensorDesc(),
.RecurrenceTensor = &recurrent_weight_tensor_desc.GetDMLTensorDesc(),
.BiasTensor = GetOptionalDmlTensorDescPtr(concatenated_bias_tensor_desc),
.HiddenInitTensor =
GetOptionalDmlTensorDescPtr(initial_hidden_state_tensor_desc),
.CellMemInitTensor =
GetOptionalDmlTensorDescPtr(initial_cell_state_tensor_desc),
// All sequences in the batch have the same length.
.SequenceLengthsTensor = nullptr,
.PeepholeTensor =
GetOptionalDmlTensorDescPtr(peephole_weight_tensor_desc),
.OutputSequenceTensor =
GetOptionalDmlTensorDescPtr(output_sequence_tensor_desc),
.OutputSingleTensor = &output_hidden_state_tensor_desc.GetDMLTensorDesc(),
.OutputCellSingleTensor =
&output_cell_state_tensor_desc.GetDMLTensorDesc(),
.ActivationDescCount = static_cast<uint32_t>(activation_dml_descs.size()),
.ActivationDescs = activation_dml_descs.data(),
.Direction = MojoRecurrentNetworkDirectionToDml(direction),
// The cell clip threshold for the input of activations is not used.
.ClipThreshold = 0,
// The clip threshold is not used.
.UseClipThreshold = FALSE,
// The input and forget gates are not coupled.
.CoupleInputForget = FALSE};
const OperatorNode* lstm_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_LSTM, &lstm_desc, inputs, label);
if (return_sequence) {
const NodeOutput* output_sequence = graph_builder.CreateNodeOutput(
lstm_node, output_sequence_tensor_desc.value(), 0);
CHECK(id_to_node_output_map
.try_emplace(output_sequence_id.value(), output_sequence)
.second);
}
const NodeOutput* output_hidden_state = graph_builder.CreateNodeOutput(
lstm_node, output_hidden_state_tensor_desc, 1);
CHECK(id_to_node_output_map
.try_emplace(output_hidden_state_id, output_hidden_state)
.second);
const NodeOutput* output_cell_state = graph_builder.CreateNodeOutput(
lstm_node, output_cell_state_tensor_desc, 2);
CHECK(
id_to_node_output_map.try_emplace(output_cell_state_id, output_cell_state)
.second);
return base::ok();
}
// Using DML_GEMM_OPERATOR_DESC to implement WebNN matmul.
base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForMatmul(
const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const Operation* operation,
const std::map<const Operation*, const Operation*>&
operation_to_fusible_standalone_activation_map,
const std::map<uint64_t, const Operation*>&
output_id_to_fusible_transpose_map,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const auto& matmul = operation->get_matmul();
// If the transpose operation that produces input a (or b) is fusible, use the
// the input operand of that transpose operation instead and set the `TransA`
// (or `TransB`) of DirectML GEMM operator to
// `DML_MATRIX_TRANSFORM_TRANSPOSE`.
bool transpose_a = false;
uint64_t a_operand_id = matmul->a_operand_id;
std::optional<uint64_t> fusible_transpose_input_id =
GetFusibleTransposeInputId(output_id_to_fusible_transpose_map,
a_operand_id);
if (fusible_transpose_input_id) {
a_operand_id = fusible_transpose_input_id.value();
transpose_a = true;
}
const NodeOutput* input_a_node_output =
GetNodeOutputForOperand(id_to_node_output_map, a_operand_id);
auto input_a_tensor_desc = input_a_node_output->GetTensorDesc();
CHECK(kDmlFloatDataTypes.contains(input_a_tensor_desc.GetDataType()));
bool transpose_b = false;
uint64_t b_operand_id = matmul->b_operand_id;
fusible_transpose_input_id = GetFusibleTransposeInputId(
output_id_to_fusible_transpose_map, b_operand_id);
if (fusible_transpose_input_id) {
b_operand_id = fusible_transpose_input_id.value();
transpose_b = true;
}
const NodeOutput* input_b_node_output =
GetNodeOutputForOperand(id_to_node_output_map, b_operand_id);
auto input_b_tensor_desc = input_b_node_output->GetTensorDesc();
uint64_t output_id = matmul->output_operand_id;
const auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
const auto output_tensor_dims = output_tensor_desc.GetDimensions();
// Because DML_GEMM_OPERATOR_DESC restricts input_a_tensor and input_b_tensor,
// output_tensor must have the same DimensionCount and can't support
// broadcasting, input_a_tensor and input_b_tensor may need to be broadcasted.
if (output_tensor_dims.size() > 2) {
input_a_tensor_desc.BroadcastTo(output_tensor_dims, 2);
input_b_tensor_desc.BroadcastTo(output_tensor_dims, 2);
}
CHECK(context_properties.data_type_limits.matmul_input.Has(
DmlDataTypeToOperand(input_a_tensor_desc.GetDataType())));
CHECK_EQ(input_a_tensor_desc.GetDimensions().size(),
input_b_tensor_desc.GetDimensions().size());
CHECK_EQ(input_a_tensor_desc.GetDimensions().size(),
output_tensor_dims.size());
const std::string& label = matmul->label;
// TODO(issues.chromium.org/353856233): Flatten adjacent dimensions for GEMM >
// 4D because DML_GEMM_OPERATOR_DESC restricts tensor's rank <= 4.
if (input_a_tensor_desc.GetDimensions().size() > 4) {
return CreateUnexpectedError(
mojom::Error::Code::kNotSupportedError,
"The input tensor rank is larger than 4 for matmul operator.", label);
}
// Use 4D GEMM which is available since feature level 1.0 for best
// compatibility. There is no performance difference in the shader between
// 2D/3D/4D, as 2D is just a variant of 4D with a batch/channel size of 1.
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gemm_operator_desc.
// TODO(issues.chromium.org/327244277): Remove the workaround of coercing
// GEMM's tensors to 4D.
auto expanded_output_tensor_desc = output_tensor_desc;
if (output_tensor_dims.size() < 4) {
input_a_tensor_desc.EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing);
input_b_tensor_desc.EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing);
expanded_output_tensor_desc.EnsureMinimumRank(
4, TensorDesc::Alignment::kTrailing);
}
std::optional<const Operation*> fusible_activation =
GetFusibleActivationFromOperation(
operation_to_fusible_standalone_activation_map, operation);
std::optional<ActivationOperatorDesc> activation_operator_desc;
std::optional<DML_OPERATOR_DESC> activation_dml_desc;
if (fusible_activation) {
activation_operator_desc =
CreateOperatorDescForFusibleActivation(*fusible_activation.value());
activation_dml_desc = activation_operator_desc->GetActivationDmlDesc();
output_id =
GetFusibleActivationOutputId(*fusible_activation.value()).value();
}
DML_GEMM_OPERATOR_DESC matmul_operator_desc{
.ATensor = &input_a_tensor_desc.GetDMLTensorDesc(),
.BTensor = &input_b_tensor_desc.GetDMLTensorDesc(),
.CTensor = nullptr,
.OutputTensor = &expanded_output_tensor_desc.GetDMLTensorDesc(),
.TransA = transpose_a ? DML_MATRIX_TRANSFORM_TRANSPOSE
: DML_MATRIX_TRANSFORM_NONE,
.TransB = transpose_b ? DML_MATRIX_TRANSFORM_TRANSPOSE
: DML_MATRIX_TRANSFORM_NONE,
.Alpha = 1.0f,
.Beta = 0.0f,
.FusedActivation =
activation_dml_desc ? &activation_dml_desc.value() : nullptr,
};
std::array<const NodeOutput*, 2> inputs{input_a_node_output,
input_b_node_output};
const OperatorNode* matmul_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_GEMM, &matmul_operator_desc, inputs, label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
matmul_node, std::move(output_tensor_desc), 0);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
return base::ok();
}
// Create a transpose node with the given permutation.
const NodeOutput* CreateTransposeNode(GraphBuilderDml& graph_builder,
const NodeOutput* input,
base::span<const uint32_t> permutation) {
CHECK(input);
const TensorDesc& input_tensor_desc = input->GetTensorDesc();
TensorDesc transposed_input_tensor_desc = input_tensor_desc;
transposed_input_tensor_desc.Transpose(permutation);
// Append an identity node to consume the strides.
const NodeOutput* transpose_node =
AppendIdentityNode(graph_builder, input, &transposed_input_tensor_desc);
return transpose_node;
}
base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForSoftmax(
Adapter* adapter,
const IdToOperandMap& id_to_operand_map,
const mojom::SoftmaxPtr& softmax,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, softmax->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
uint64_t output_id = softmax->output_operand_id;
const auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
std::array<const NodeOutput*, 1> inputs = {input};
const uint32_t axis = softmax->axis;
const std::string& label = softmax->label;
if (adapter->IsDMLFeatureLevelSupported(DML_FEATURE_LEVEL_5_1)) {
std::array<uint32_t, 1> axes = {axis};
DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC softmax1_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.AxisCount = base::checked_cast<uint32_t>(axes.size()),
.Axes = axes.data()};
const OperatorNode* softmax_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_SOFTMAX1, &softmax1_operator_desc, inputs,
label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
softmax_node, std::move(output_tensor_desc));
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
} else {
// Emulate softmax with N-D input and axis parameter supported when feature
// level less than DML_FEATURE_LEVEL_5_1:
// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_activation_softmax_operator_desc.
//
// Transpose the input tensor to make the axis to be the last dimension if
// needed.
const NodeOutput* axis_transposed_to_last_output = nullptr;
const uint32_t input_rank = input_tensor_desc.GetDimensions().size();
std::vector<uint32_t> permutation(input_rank);
std::iota(permutation.begin(), permutation.end(), 0);
if (axis == (input_rank - 1)) {
axis_transposed_to_last_output = input;
} else {
std::vector<uint32_t> transpose_axis_to_last(permutation);
std::swap(transpose_axis_to_last[axis],
transpose_axis_to_last[input_rank - 1]);
axis_transposed_to_last_output =
CreateTransposeNode(graph_builder, input, transpose_axis_to_last);
}
// Reshape the input tensor to 2D if needed.
const NodeOutput* reshaped_2d_output = nullptr;
if (axis_transposed_to_last_output->GetTensorDesc()
.GetDimensions()
.size() <= 2) {
reshaped_2d_output = axis_transposed_to_last_output;
} else {
const std::vector<uint32_t>& axis_transposed_to_last_output_dims =
axis_transposed_to_last_output->GetTensorDesc().GetDimensions();
auto reshaped_2d_dim_0 = base::MakeCheckedNum<uint32_t>(1);
for (uint32_t i = 0; i < axis_transposed_to_last_output_dims.size() - 1;
i++) {
reshaped_2d_dim_0 *= axis_transposed_to_last_output_dims[i];
if (!reshaped_2d_dim_0.IsValid<uint32_t>()) {
return CreateUnexpectedError(
mojom::Error::Code::kNotSupportedError,
"For softmax impl: failed to reshape the input to 2-D tensor.",
label);
}
}
std::vector<uint32_t> reshaped_2d_dims = {
reshaped_2d_dim_0.ValueOrDie(),
axis_transposed_to_last_output_dims.back()};
reshaped_2d_output = CreateReshapeNode(
graph_builder, axis_transposed_to_last_output, reshaped_2d_dims);
}
// Perform 2-D softmax.
const TensorDesc softmax_2d_output_tensor_desc =
TensorDesc(reshaped_2d_output->GetTensorDesc().GetDataType(),
reshaped_2d_output->GetTensorDesc().GetDimensions());
DML_ACTIVATION_SOFTMAX_OPERATOR_DESC softmax_2d_operator_desc{
.InputTensor = &reshaped_2d_output->GetTensorDesc().GetDMLTensorDesc(),
.OutputTensor = &softmax_2d_output_tensor_desc.GetDMLTensorDesc()};
std::array<const NodeOutput*, 1> softmax_2d_inputs = {reshaped_2d_output};
const OperatorNode* softmax_2d_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_SOFTMAX, &softmax_2d_operator_desc,
softmax_2d_inputs, label);
const NodeOutput* softmax_2d_output = graph_builder.CreateNodeOutput(
softmax_2d_node, softmax_2d_output_tensor_desc);
// Reshape the 2-D tensor back to N-D.
const NodeOutput* reshaped_nd_output = nullptr;
if (axis_transposed_to_last_output->GetTensorDesc()
.GetDimensions()
.size() <= 2) {
reshaped_nd_output = softmax_2d_output;
} else {
reshaped_nd_output = CreateReshapeNode(
graph_builder, softmax_2d_output,
axis_transposed_to_last_output->GetTensorDesc().GetDimensions());
}
// Transpose the output tensor back to the original shape.
const NodeOutput* last_transposed_to_axis_output = nullptr;
if (axis == (input_rank - 1)) {
last_transposed_to_axis_output = reshaped_nd_output;
} else {
std::vector<uint32_t> transpose_axis_back(permutation);
std::swap(transpose_axis_back[axis], transpose_axis_back[input_rank - 1]);
last_transposed_to_axis_output = CreateTransposeNode(
graph_builder, reshaped_nd_output, transpose_axis_back);
}
// The output id must be unique in the map.
CHECK(id_to_node_output_map
.try_emplace(output_id, last_transposed_to_axis_output)
.second);
}
return base::ok();
}
void CreateOperatorNodeForSoftplus(const IdToOperandMap& id_to_operand_map,
const mojom::SoftplusPtr& softplus,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(id_to_node_output_map,
softplus->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
const uint64_t output_id = softplus->output_operand_id;
const auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC softplus_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Steepness = 1.0};
std::array<const NodeOutput*, 1> inputs = {input};
const OperatorNode* softplus_node =
graph_builder.CreateOperatorNode(DML_OPERATOR_ACTIVATION_SOFTPLUS,
&softplus_desc, inputs, softplus->label);
const NodeOutput* node_output = graph_builder.CreateNodeOutput(
softplus_node, std::move(output_tensor_desc));
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
// Transpose is not a real DirectML operator. As for implementation, the input
// tensor is remapped for reading elements following the strides after the
// permutation, and an identity operator is appended to consume the remapped
// strides.
void CreateOperatorNodeForTranspose(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::TransposePtr& transpose,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, transpose->input_operand_id);
CHECK(context_properties.data_type_limits.transpose_input.Has(
DmlDataTypeToOperand(input->GetTensorDesc().GetDataType())));
uint64_t output_id = transpose->output_operand_id;
const NodeOutput* output =
CreateTransposeNode(graph_builder, input, transpose->permutation);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
// For DirectML feature levels before 5.1, we need to compose triangular
// from smaller operators: identity, slice, bitwise and.
//
// 1. expand the basic mask into an expanded mask big enough for the input
// 2. shear the expanded mask
// 3. slice the sheared mask
// 4. mask the input via bitwise and
//
// A simple constant mask is created with two values, one to
// fully preserve input values and one to fully zero them. Then, expand the mask
// from [1, 2, 1] to [mask_height, 2, mask_width]. Note the mask_width is
// calculated according to the input width and the diagonal. Next, shear the
// mask to achieve a diagonal shape by reshaping the dimensions from
// [mask_height, 2, mask_width] to [mask_height, 2 * mask_width] and set strides
// = {2 * mask_width - 1, 1}. By changing the default strides, the shape of the
// mask looks like a rhomboid. Then, we can get a mask with bit values filled
// with 0 or 0xFFFF using DML_SLICE_OPERATOR_DESC.
// ----------------
// [ 0xFFFF, 0xFFFF, 0, 0 [0xFFFF, 0xFFFF, | 0, 0 |
// 0xFFFF, 0xFFFF, 0, 0 => 0xFFFF, | 0xFFFF, 0, | 0
// 0xFFFF, 0xFFFF, 0, 0] | 0xFFFF, 0xFFFF,| 0, 0]
// -----------------
// Finally, the mask is a matrix shown above which
// has the same shape and the same data type with the input and consists of 0 or
// 1 value in each bit. So the mask can be used to get either the upper or lower
// triangular part of the input tensor by doing bitwise and computation between
// the mask and the input. For example:
// [ 2, 3 [0, 0,] [0, 0,
// 4, 5, bit_and [0xFFFF, 0,] => 4, 0,
// 6, 7] [0xFFFF, 0xFFFF] 6, 7]
base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForTriangular(
const ContextProperties& context_properties,
Adapter* adapter,
const mojom::TriangularPtr& triangular,
mojom::GraphInfoPtr& graph_info,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map,
std::unordered_map<uint64_t, uint32_t>& constant_id_to_input_index_map,
uint64_t& next_operand_id) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, triangular->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.triangular_input.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
auto& id_to_operand_map = graph_info->id_to_operand_map;
uint64_t output_id = triangular->output_operand_id;
auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
CHECK_EQ(input_tensor_desc.GetDimensions().size(),
output_tensor_desc.GetDimensions().size());
const auto& input_dimensions = input_tensor_desc.GetDimensions();
const auto input_rank = input_dimensions.size();
CHECK_GE(input_rank, 2U);
bool upper = triangular->upper;
int32_t diagonal = triangular->diagonal;
const std::string& label = triangular->label;
// Initialize scale union with a zero value.
DML_SCALAR_UNION scalar_union = {};
// DML_DIAGONAL_MATRIX1_OPERATOR_DESC was introduced in DML_FEATURE_LEVEL_5_1
// and supported input dimension count is from 2 to 4.
if (adapter->IsDMLFeatureLevelSupported(DML_FEATURE_LEVEL_5_1) &&
input_rank <= 4) {
// DML_DIAGONAL_MATRIX1_OPERATOR_DESC will generate an identity-like matrix
// with zero between the given diagonal span, with other elements being
// filled with the input values. The diagonal values may be shifted anywhere
// between DiagonalFillBegin and DiagonalFillEnd, where a value greater than
// zero shifts all values to the right, and less than zero shifts them to
// the left.
DML_DIAGONAL_MATRIX1_OPERATOR_DESC diagonal_matrix1_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.ValueDataType = output_tensor_desc.GetDataType(),
.Value = scalar_union,
.DiagonalFillBegin =
upper ? std::numeric_limits<int32_t>::min() : diagonal + 1,
.DiagonalFillEnd =
upper ? diagonal : std::numeric_limits<int32_t>::max()};
std::array<const NodeOutput*, 1> inputs = {input};
const OperatorNode* diagonal_matrix1_node =
graph_builder.CreateOperatorNode(DML_OPERATOR_DIAGONAL_MATRIX1,
&diagonal_matrix1_desc, inputs, label);
const NodeOutput* node_output = graph_builder.CreateNodeOutput(
diagonal_matrix1_node, std::move(output_tensor_desc));
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
return base::ok();
}
// For DirectML feature levels before 5.1, we need to compose triangular
// from smaller operators: identity, slice, bitwise and.
const OperandPtr& output_operand = id_to_operand_map.at(output_id);
OperandDataType data_type = output_operand->descriptor.data_type();
const uint32_t height = input_dimensions[input_rank - 2];
const uint32_t width = input_dimensions[input_rank - 1];
uint32_t longest_dimension_length = std::max(height, width);
// Check the case where the diagonal shift value shifts all the values
// too far above when keeping the top triangle or too far below when keeping
// the bottom triangle, yielding all zeros.
// 1. Upper = true
// [ 1, 2, 3 \
// 4, 5, 6, \
// 7, 8, 9] \
// 2. Upper = false
// \ [ 1, 2, 3,
// \ 4, 5, 6,
// \ 7, 8, 9]
if ((diagonal > 0 &&
(base::checked_cast<uint32_t>(diagonal) >= longest_dimension_length) &&
upper) ||
(diagonal < 0 &&
(base::checked_cast<uint32_t>(-diagonal) >= longest_dimension_length) &&
!upper)) {
DML_FILL_VALUE_CONSTANT_OPERATOR_DESC fill_constant_operator_desc{
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.ValueDataType = output_tensor_desc.GetDataType(),
.Value = scalar_union,
};
const OperatorNode* fill_constant_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_FILL_VALUE_CONSTANT, &fill_constant_operator_desc, {},
label);
const NodeOutput* constant = graph_builder.CreateNodeOutput(
fill_constant_node, std::move(output_tensor_desc), 0);
auto constant_tensor_desc = constant->GetTensorDesc();
std::array<const NodeOutput*, 2> inputs = {input, constant};
const OperatorNode* mul_node =
CreateBinaryOperator<DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC>(
input_tensor_desc, constant_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_MULTIPLY, inputs, label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(mul_node, output_tensor_desc);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
return base::ok();
}
// Check the case where the diagonal shift value shifts all the values
// too far above when keeping the bottom triangle or too far below when
// keeping the top triangle, returning the input tensor.
// 1. Upper = false
// [ 1, 2, 3 \
// 4, 5, 6, \
// 7, 8, 9] \
// 2. Upper = true
// \ [ 1, 2, 3,
// \ 4, 5, 6,
// \ 7, 8, 9]
if ((diagonal > 0 &&
(base::checked_cast<uint32_t>(diagonal) >= longest_dimension_length) &&
!upper) ||
(diagonal < 0 &&
(base::checked_cast<uint32_t>(-diagonal) >= longest_dimension_length) &&
upper)) {
// Return input matrix.
const Node& input_node = input->GetNode();
// The output_index of this NodeOutput should be the same as the input
// NodeOutput for creating correct intermediate edges of the graph.
const NodeOutput* output = graph_builder.CreateNodeOutput(
&input_node, std::move(output_tensor_desc), input->GetOutputIndex());
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
return base::ok();
}
// First step: create a simple constant mask with two values, one to
// fully preserve input values and one to fully zero them.
uint64_t lower_mask = 0;
uint64_t upper_mask = std::numeric_limits<uint64_t>::max();
if (!upper) {
std::swap(lower_mask, upper_mask);
}
OperandDataType webnn_mask_data_type;
DML_TENSOR_DATA_TYPE dml_mask_data_type;
mojo_base::BigBuffer buffer;
switch (data_type) {
case OperandDataType::kInt8:
case OperandDataType::kUint8: {
webnn_mask_data_type = OperandDataType::kUint8;
dml_mask_data_type = DML_TENSOR_DATA_TYPE_UINT8;
std::array<uint8_t, 2> values = {static_cast<uint8_t>(lower_mask),
static_cast<uint8_t>(upper_mask)};
buffer = mojo_base::BigBuffer(base::as_bytes(base::make_span(values)));
break;
}
case OperandDataType::kFloat16: {
// Here we create a mask with float16 data type since WebNN doesn't define
// uint16.
webnn_mask_data_type = OperandDataType::kFloat16;
dml_mask_data_type = DML_TENSOR_DATA_TYPE_UINT16;
std::array<uint16_t, 2> values = {static_cast<uint16_t>(lower_mask),
static_cast<uint16_t>(upper_mask)};
buffer = mojo_base::BigBuffer(base::as_bytes(base::make_span(values)));
break;
}
case OperandDataType::kFloat32:
case OperandDataType::kInt32:
case OperandDataType::kUint32: {
webnn_mask_data_type = OperandDataType::kUint32;
dml_mask_data_type = DML_TENSOR_DATA_TYPE_UINT32;
std::array<uint32_t, 2> values = {static_cast<uint32_t>(lower_mask),
static_cast<uint32_t>(upper_mask)};
buffer = mojo_base::BigBuffer(base::as_bytes(base::make_span(values)));
break;
}
case OperandDataType::kInt64:
case OperandDataType::kUint64: {
webnn_mask_data_type = OperandDataType::kUint64;
dml_mask_data_type = DML_TENSOR_DATA_TYPE_UINT64;
std::array<uint64_t, 2> values = {static_cast<uint64_t>(lower_mask),
static_cast<uint64_t>(upper_mask)};
buffer = mojo_base::BigBuffer(base::as_bytes(base::make_span(values)));
break;
}
}
OperandPtr constant_operand = Operand::New();
constant_operand->kind = Operand::Kind::kConstant;
constant_operand->descriptor = *OperandDescriptor::Create(
webnn_mask_data_type, std::array<uint32_t, 3>{1, 2, 1});
uint64_t constant_operand_id = next_operand_id++;
CHECK(graph_info->id_to_operand_map
.try_emplace(constant_operand_id, std::move(constant_operand))
.second);
CHECK(graph_info->constant_id_to_buffer_map
.try_emplace(constant_operand_id, std::move(buffer))
.second);
uint32_t constant_input_index =
CreateInputNode(id_to_operand_map, constant_operand_id, graph_builder,
id_to_node_output_map);
CHECK(constant_id_to_input_index_map
.try_emplace(constant_operand_id, constant_input_index)
.second);
const NodeOutput* constant =
GetNodeOutputForOperand(id_to_node_output_map, constant_operand_id);
auto constant_tensor_desc = constant->GetTensorDesc();
const auto mask_height = height;
const auto checked_mask_width =
(base::MakeCheckedNum<uint32_t>(longest_dimension_length) +
std::min(base::checked_cast<uint32_t>(std::abs(diagonal)),
longest_dimension_length)) *
2;
// TODO(issues.chromium.org/335524385): All error handlings of checked_math
// values inside the implementation of triangular here should be removed and
// performing proper validation at graph creation time.
if (!checked_mask_width.IsValid<uint32_t>()) {
return base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"For triangular impl: the mask width is too large.", label));
}
const uint32_t mask_width = checked_mask_width.ValueOrDie();
// Second step: expand the mask from [1, 2, 1] to [mask_height, 2,
// mask_width].
std::vector<uint32_t> expand_constant_dims = {mask_height, 2, mask_width};
if (constant_tensor_desc.GetDimensions() != expand_constant_dims) {
constant_tensor_desc.BroadcastTo(expand_constant_dims);
}
const auto expand_constant_tensor_desc = TensorDesc(
constant_tensor_desc.GetDataType(), std::move(expand_constant_dims));
const OperatorNode* expand_constant_node =
CreateUnaryOperator<DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_IDENTITY>(
constant_tensor_desc, expand_constant_tensor_desc, constant,
graph_builder, label);
const auto* expand_constant_output = graph_builder.CreateNodeOutput(
expand_constant_node, std::move(expand_constant_tensor_desc));
auto expand_constant_output_tensor_desc =
expand_constant_output->GetTensorDesc();
// Third step: shear the mask to achieve a diagonal shape by reshaping
// the dimensions from [mask_height, 2, mask_width] to [mask_height,
// 2 * mask_width] and set strides = {2 * mask_width - 1, 1}. By changing
// the default strides, we can get the rhomboid to slice.
// For example:
// [ 1, 1, 0, 0 [1, 1, 0, 0
// 1, 1, 0, 0 => 1, 1, 0, 0
// 1, 1, 0, 0] 1, 1, 0, 0]
const auto checked_slice_input_width =
base::MakeCheckedNum<uint32_t>(mask_width) * 2;
if (!checked_slice_input_width.IsValid<uint32_t>()) {
return base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"For triangular impl: the input width for slice is too large.", label));
}
const uint32_t slice_input_width = checked_slice_input_width.ValueOrDie();
std::vector<uint32_t> slice_input_dims = {mask_height, slice_input_width};
const auto checked_slice_input_stride = checked_slice_input_width - 1;
if (!checked_slice_input_stride.IsValid<uint32_t>()) {
return base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"For triangular impl: the input stride for slice is invalid.", label));
}
const uint32_t slice_input_stride = checked_slice_input_stride.ValueOrDie();
std::vector<uint32_t> slice_input_strides = {slice_input_stride, 1};
auto slice_input_tensor_desc =
TensorDesc(expand_constant_output_tensor_desc.GetDataType(),
expand_constant_output_tensor_desc.GetFlags(),
std::move(slice_input_dims), std::move(slice_input_strides));
// Since we change both the output dims and strides of
// expand_constant_output to get the slice_input_tensor_desc, the
// total_tensor_size_in_bytes of expand_constant_tensor_desc and
// slice_input_tensor_desc are not the same.
slice_input_tensor_desc.SetTotalTensorSizeInBytes(
expand_constant_output_tensor_desc.GetTotalTensorSizeInBytes());
std::vector<uint32_t> slice_output_dims = {height, width};
auto slice_output_tensor_desc = TensorDesc(
expand_constant_tensor_desc.GetDataType(), std::move(slice_output_dims));
std::array<uint32_t, 2> sizes = {height, width};
std::array<uint32_t, 2> offset =
upper ? std::array<uint32_t, 2>{0, mask_width - diagonal}
: std::array<uint32_t, 2>{0, mask_width - diagonal - 1};
std::array<uint32_t, 2> strides = {1, 1};
// Fourth step: get the sliced mask with bit values filled with 0 or
// 0xFFFF...
DML_SLICE_OPERATOR_DESC slice_operator_desc{
.InputTensor = &slice_input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &slice_output_tensor_desc.GetDMLTensorDesc(),
.DimensionCount = 2,
.Offsets = offset.data(),
.Sizes = sizes.data(),
.Strides = strides.data(),
};
std::array<const NodeOutput*, 1> input_for_slice = {expand_constant_output};
const OperatorNode* slice_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_SLICE, &slice_operator_desc, input_for_slice, label);
const auto* slice_output = graph_builder.CreateNodeOutput(
slice_node, std::move(slice_output_tensor_desc));
slice_output_tensor_desc = slice_output->GetTensorDesc();
if (slice_output_tensor_desc.GetDimensions() != input_dimensions) {
slice_output_tensor_desc.BroadcastTo(input_dimensions);
}
// Fifth step: using bit_and_operator to do the bit computation between
// input and mask.
// Here we need to cast the input and mask tensor data type to the data type
// that DML elementwise-bit-and operator supports and has the same bit width.
// For example casting float16 to uint16, float32 to uint32.
TensorDesc bit_and_operator_input_tensor_desc =
TensorDesc(dml_mask_data_type, input_tensor_desc.GetFlags(),
input_tensor_desc.GetDimensions());
TensorDesc bit_and_operator_mask_tensor_desc =
TensorDesc(dml_mask_data_type, slice_output_tensor_desc.GetFlags(),
slice_output_tensor_desc.GetDimensions(),
slice_output_tensor_desc.GetStrides());
TensorDesc bit_and_operator_output_tensor_desc =
TensorDesc(dml_mask_data_type, output_tensor_desc.GetFlags(),
output_tensor_desc.GetDimensions());
DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC bit_and_operator_desc{
.ATensor = &bit_and_operator_input_tensor_desc.GetDMLTensorDesc(),
.BTensor = &bit_and_operator_mask_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &bit_and_operator_output_tensor_desc.GetDMLTensorDesc()};
std::array<const NodeOutput*, 2> inputs{input, slice_output};
const OperatorNode* bit_and_operator_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_BIT_AND, &bit_and_operator_desc, inputs, label);
const NodeOutput* bit_and_operator_output =
graph_builder.CreateNodeOutput(bit_and_operator_node, output_tensor_desc);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, bit_and_operator_output)
.second);
return base::ok();
}
void CreateOperatorNodeForWhere(const IdToOperandMap& id_to_operand_map,
const mojom::WherePtr& where,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* condition = GetNodeOutputForOperand(
id_to_node_output_map, where->condition_operand_id);
auto condition_tensor_desc = condition->GetTensorDesc();
const NodeOutput* true_value = GetNodeOutputForOperand(
id_to_node_output_map, where->true_value_operand_id);
auto true_value_tensor_desc = true_value->GetTensorDesc();
const NodeOutput* false_value = GetNodeOutputForOperand(
id_to_node_output_map, where->false_value_operand_id);
auto false_value_tensor_desc = false_value->GetTensorDesc();
uint64_t output_id = where->output_operand_id;
const auto output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);
const auto output_tensor_dims = output_tensor_desc.GetDimensions();
// Broadcast each of the inputs to the output.
if (condition_tensor_desc.GetDimensions() != output_tensor_dims) {
condition_tensor_desc.BroadcastTo(output_tensor_dims);
}
if (true_value_tensor_desc.GetDimensions() != output_tensor_dims) {
true_value_tensor_desc.BroadcastTo(output_tensor_dims);
}
if (false_value_tensor_desc.GetDimensions() != output_tensor_dims) {
false_value_tensor_desc.BroadcastTo(output_tensor_dims);
}
DML_ELEMENT_WISE_IF_OPERATOR_DESC where_operator_desc{
.ConditionTensor = &condition_tensor_desc.GetDMLTensorDesc(),
.ATensor = &true_value_tensor_desc.GetDMLTensorDesc(),
.BTensor = &false_value_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc()};
std::array<const NodeOutput*, 3> inputs{condition, true_value, false_value};
const OperatorNode* where_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_IF, &where_operator_desc, inputs, where->label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
where_node, std::move(output_tensor_desc), 0);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
// If graph creation fails, report the error message via `callback` and let
// `context` handle the error.
void HandleGraphCreationFailure(
const std::string& error_message,
WebNNContextImpl::CreateGraphImplCallback callback,
ContextImplDml* context,
HRESULT hr) {
std::move(callback).Run(base::unexpected(
CreateError(mojom::Error::Code::kUnknownError, error_message)));
context->HandleContextLostOrCrash(error_message, hr);
}
bool IsDispatchBindingValid(
const base::flat_map<std::string_view, WebNNBufferImpl*>& named_buffers,
const base::flat_map<std::string, base::WeakPtr<const WebNNBufferImpl>>&
prev_named_buffers) {
return base::ranges::equal(
named_buffers, prev_named_buffers,
[](const auto& pair, const auto& previous_pair) {
const auto& [name, buffer] = pair;
const auto& [prev_name, prev_buffer] = previous_pair;
return name == prev_name && buffer == prev_buffer.get();
});
}
} // namespace
GraphImplDml::GraphBufferBindingInfo::GraphBufferBindingInfo() = default;
GraphImplDml::GraphBufferBindingInfo::~GraphBufferBindingInfo() = default;
GraphImplDml::GraphBufferBindingInfo::GraphBufferBindingInfo(
const GraphBufferBindingInfo&) = default;
GraphImplDml::GraphBufferBindingInfo&
GraphImplDml::GraphBufferBindingInfo::operator=(const GraphBufferBindingInfo&) =
default;
GraphImplDml::GraphBufferBindingInfo::GraphBufferBindingInfo(
GraphBufferBindingInfo&&) = default;
GraphImplDml::GraphBufferBindingInfo&
GraphImplDml::GraphBufferBindingInfo::operator=(GraphBufferBindingInfo&&) =
default;
// static
scoped_refptr<GraphImplDml::PersistentResource>
GraphImplDml::PersistentResource::Create(
uint64_t persistent_buffer_byte_length,
ComPtr<ID3D12Resource> persistent_buffer) {
CHECK_GT(persistent_buffer_byte_length, 0u);
CHECK_NE(persistent_buffer.Get(), nullptr);
return base::WrapRefCounted(new PersistentResource(
persistent_buffer_byte_length, std::move(persistent_buffer)));
}
GraphImplDml::PersistentResource::PersistentResource(
uint64_t persistent_buffer_byte_length,
ComPtr<ID3D12Resource> persistent_buffer)
: persistent_buffer_(std::move(persistent_buffer)) {
persistent_buffer_binding_ =
DML_BUFFER_BINDING{.Buffer = persistent_buffer_.Get(),
.Offset = 0,
.SizeInBytes = persistent_buffer_byte_length};
persistent_buffer_binding_desc_ = DML_BINDING_DESC{
.Type = DML_BINDING_TYPE_BUFFER, .Desc = &persistent_buffer_binding_};
}
GraphImplDml::PersistentResource::~PersistentResource() = default;
GraphImplDml::GraphResources::GraphResources(
ComPtr<ID3D12DescriptorHeap> descriptor_heap,
uint64_t temporary_buffer_byte_length,
ComPtr<ID3D12Resource> temporary_resource)
: descriptor_heap(std::move(descriptor_heap)),
temporary_buffer(std::move(temporary_resource)) {
if (temporary_buffer_byte_length > 0) {
CHECK_NE(temporary_buffer.Get(), nullptr);
temporary_buffer_binding =
DML_BUFFER_BINDING{.Buffer = temporary_buffer.Get(),
.Offset = 0,
.SizeInBytes = temporary_buffer_byte_length};
temporary_buffer_binding_desc =
DML_BINDING_DESC{.Type = DML_BINDING_TYPE_BUFFER,
.Desc = &temporary_buffer_binding.value()};
}
}
GraphImplDml::GraphResources::~GraphResources() = default;
// static
base::expected<std::unique_ptr<GraphImplDml::GraphResources>, HRESULT>
GraphImplDml::AllocateGraphResources(Adapter* adapter,
IDMLCompiledOperator* compiled_operator) {
TRACE_EVENT0("gpu", "GraphImplDml::AllocateGraphResources");
// Create the descriptor heap.
DML_BINDING_PROPERTIES execution_binding_properties =
compiled_operator->GetBindingProperties();
ComPtr<ID3D12DescriptorHeap> descriptor_heap;
RETURN_UNEXPECTED_IF_FAILED(CreateDescriptorHeap(
adapter->d3d12_device(),
execution_binding_properties.RequiredDescriptorCount,
L"WebNN_Descriptor_Heap_For_Execution", descriptor_heap));
// Create and bind the temporary resource if the operator execution requires.
ComPtr<ID3D12Resource> temporary_buffer;
uint64_t temporary_buffer_byte_length =
execution_binding_properties.TemporaryResourceSize;
if (temporary_buffer_byte_length > 0) {
RETURN_UNEXPECTED_IF_FAILED(CreateDefaultBuffer(
adapter->d3d12_device(), temporary_buffer_byte_length,
L"WebNN_Temporary_Buffer_For_Execution", temporary_buffer));
}
return base::WrapUnique(new GraphResources(std::move(descriptor_heap),
temporary_buffer_byte_length,
std::move(temporary_buffer)));
}
GraphImplDml::ComputeResources::ComputeResources(
ComPtr<ID3D12DescriptorHeap> descriptor_heap,
AlignedByteLength<std::string> input_aligned_byte_length,
ComPtr<ID3D12Resource> upload_buffer,
ComPtr<ID3D12Resource> input_buffer,
AlignedByteLength<std::string> output_aligned_byte_length,
ComPtr<ID3D12Resource> output_buffer,
ComPtr<ID3D12Resource> readback_buffer,
uint64_t temporary_buffer_byte_length,
ComPtr<ID3D12Resource> temporary_resource,
std::unique_ptr<CommandRecorder> command_recorder)
: input_aligned_byte_length(std::move(input_aligned_byte_length)),
upload_buffer(std::move(upload_buffer)),
input_buffer(std::move(input_buffer)),
output_aligned_byte_length(std::move(output_aligned_byte_length)),
output_buffer(std::move(output_buffer)),
readback_buffer(std::move(readback_buffer)),
graph_resources(std::move(descriptor_heap),
temporary_buffer_byte_length,
std::move(temporary_resource)),
command_recorder(std::move(command_recorder)) {}
GraphImplDml::ComputeResources::~ComputeResources() = default;
// static
base::expected<std::unique_ptr<GraphImplDml::ComputeResources>, HRESULT>
GraphImplDml::AllocateComputeResources(
Adapter* adapter,
IDMLCompiledOperator* compiled_operator,
const ComputeResourceInfo& compute_resource_info) {
TRACE_EVENT0("gpu", "GraphImplDml::AllocateComputeResources");
// Create the descriptor heap.
DML_BINDING_PROPERTIES execution_binding_properties =
compiled_operator->GetBindingProperties();
ComPtr<ID3D12DescriptorHeap> descriptor_heap;
RETURN_UNEXPECTED_IF_FAILED(CreateDescriptorHeap(
adapter->d3d12_device(),
execution_binding_properties.RequiredDescriptorCount,
L"WebNN_Descriptor_Heap_For_Execution", descriptor_heap));
// Calculate the total byte length of input array buffers to create
// GPU input buffer and upload buffer, also records the aligned D3D12_RANGE
// for each input.
std::optional<AlignedByteLength<std::string>> aligned_byte_length_of_inputs =
CalculateAlignedByteLengthFromDescriptors(
compute_resource_info.input_names_to_descriptors);
if (!aligned_byte_length_of_inputs) {
LOG(ERROR)
<< "[WebNN] Failed to calculate the aligned byte length of inputs.";
return base::unexpected(E_INVALIDARG);
}
size_t total_byte_length_of_inputs =
aligned_byte_length_of_inputs.value().total_byte_length;
ComPtr<ID3D12Resource> upload_buffer;
ComPtr<ID3D12Resource> input_buffer;
// It is possible that a graph doesn't have any inputs. For example, a graph
// may only compute results given weights. For such graphs, there is no need
// to allocate upload and input buffers.
if (total_byte_length_of_inputs > 0) {
if (adapter->IsUMA()) {
// For GPU supports UMA, create the custom heap with CPU memory pool, and
// create a resource to map the heap. CPU writes the input data into this
// resource which could be bound as graph input for GPU reading during
// execution.
RETURN_UNEXPECTED_IF_FAILED(CreateCustomUploadBuffer(
adapter->d3d12_device(), total_byte_length_of_inputs,
L"WebNN_Custom_Upload_Buffer_Inputs", input_buffer));
} else {
// Create the upload heap that can be written by CPU and read from GPU,
// and create a resource to map the heap.
RETURN_UNEXPECTED_IF_FAILED(CreateUploadBuffer(
adapter->d3d12_device(), total_byte_length_of_inputs,
L"WebNN_Upload_Buffer_Inputs", upload_buffer));
// Create the default heap that only can be accessed by GPU not provide
// CPU access, and create a resource to map the heap.
RETURN_UNEXPECTED_IF_FAILED(CreateDefaultBuffer(
adapter->d3d12_device(), total_byte_length_of_inputs,
L"WebNN_Default_Buffer_Inputs", input_buffer));
}
}
// Calculate the total byte length of outputs array buffer to create
// an output buffer and readback buffer, also records the aligned D3D12_RANGE
// for each output.
std::optional<AlignedByteLength<std::string>> aligned_byte_length_of_outputs =
CalculateAlignedByteLengthFromDescriptors(
compute_resource_info.output_names_to_descriptors);
if (!aligned_byte_length_of_outputs) {
LOG(ERROR)
<< "[WebNN] Failed to calculate the aligned byte length of outputs.";
return base::unexpected(E_INVALIDARG);
}
// Create the output buffer which will be bound for the graph execution.
size_t total_byte_length_of_outputs =
aligned_byte_length_of_outputs.value().total_byte_length;
ComPtr<ID3D12Resource> readback_buffer;
ComPtr<ID3D12Resource> output_buffer;
if (adapter->IsUMA()) {
// For GPU supports UMA, create the custom heap with CPU memory pool, and
// create a resource to map the heap. This resource could be bound as graph
// execution output for GPU writing. And CPU could read the output data from
// this resource after GPU execution.
RETURN_UNEXPECTED_IF_FAILED(CreateCustomReadbackBuffer(
adapter->d3d12_device(), total_byte_length_of_outputs,
L"WebNN_Custom_Readback_Buffer_Outputs", output_buffer));
} else {
// Create the output buffer which will be written by GPU.
RETURN_UNEXPECTED_IF_FAILED(CreateDefaultBuffer(
adapter->d3d12_device(), total_byte_length_of_outputs,
L"WebNN_Default_Buffer_Outputs", output_buffer));
// Create the readback buffer which will be read by CPU.
RETURN_UNEXPECTED_IF_FAILED(CreateReadbackBuffer(
adapter->d3d12_device(), total_byte_length_of_outputs,
L"WebNN_ReadBack_Buffer_Outputs", readback_buffer));
}
// Create and bind the temporary resource if the operator execution requires.
ComPtr<ID3D12Resource> temporary_buffer;
uint64_t temporary_buffer_byte_length =
execution_binding_properties.TemporaryResourceSize;
if (temporary_buffer_byte_length > 0) {
RETURN_UNEXPECTED_IF_FAILED(CreateDefaultBuffer(
adapter->d3d12_device(), temporary_buffer_byte_length,
L"WebNN_Temporary_Buffer_For_Execution", temporary_buffer));
}
// Create a command recorder which may be re-used between compute() calls.
ASSIGN_OR_RETURN(
std::unique_ptr<CommandRecorder> command_recorder,
CommandRecorder::Create(adapter->command_queue(), adapter->dml_device()));
return base::WrapUnique(new ComputeResources(
std::move(descriptor_heap),
std::move(aligned_byte_length_of_inputs.value()),
std::move(upload_buffer), std::move(input_buffer),
std::move(aligned_byte_length_of_outputs.value()),
std::move(output_buffer), std::move(readback_buffer),
temporary_buffer_byte_length, std::move(temporary_buffer),
std::move(command_recorder)));
}
// static
HRESULT GraphImplDml::RecordGraphExecution(
Adapter* adapter,
IDMLCompiledOperator* compiled_operator,
const ComputeResources* compute_resources,
const PersistentResource* persistent_resource,
const GraphBufferBindingInfo& graph_buffer_binding_info) {
TRACE_EVENT0("gpu", "dml::GraphImpl::RecordGraphExecution");
// Open the command recorder for recording the graph execution commands.
RETURN_IF_FAILED(compute_resources->command_recorder->Open());
// Create the input buffer bindings for the graph execution.
std::map<std::string, DML_BUFFER_BINDING>
graph_input_name_to_buffer_binding_map;
for (auto& [name, d3d12_range] :
compute_resources->input_aligned_byte_length.key_to_d3d12_range_map) {
auto size_in_bytes = d3d12_range.End - d3d12_range.Begin;
graph_input_name_to_buffer_binding_map[name] =
DML_BUFFER_BINDING{.Buffer = compute_resources->input_buffer.Get(),
.Offset = d3d12_range.Begin,
.SizeInBytes = size_in_bytes};
}
std::vector<DML_BINDING_DESC> input_buffer_binding_desc(
graph_buffer_binding_info.input_buffer_binding_count,
DML_BINDING_DESC{.Type = DML_BINDING_TYPE_NONE, .Desc = nullptr});
// The graph input tensors must be bound to the binding table during the
// graph execution.
for (auto& [name, buffer_binding] : graph_input_name_to_buffer_binding_map) {
// Get the graph input index with the name.
const auto graph_input_index_iterator =
graph_buffer_binding_info.graph_input_name_to_index_map.find(name);
CHECK(graph_input_index_iterator !=
graph_buffer_binding_info.graph_input_name_to_index_map.end());
uint32_t graph_input_index = graph_input_index_iterator->second;
input_buffer_binding_desc[graph_input_index] = {DML_BINDING_TYPE_BUFFER,
&buffer_binding};
}
if (compute_resources->input_aligned_byte_length.total_byte_length > 0 &&
!adapter->IsUMA()) {
UploadBufferWithBarrier(
compute_resources->command_recorder.get(),
compute_resources->input_buffer, compute_resources->upload_buffer,
compute_resources->input_aligned_byte_length.total_byte_length);
}
// Create the output buffer bindings for the graph execution.
size_t output_buffer_binding_count =
graph_buffer_binding_info.graph_output_name_to_index_map.size();
std::vector<DML_BINDING_DESC> output_buffer_binding_desc(
output_buffer_binding_count,
DML_BINDING_DESC{.Type = DML_BINDING_TYPE_NONE, .Desc = nullptr});
std::vector<DML_BUFFER_BINDING> output_buffer_binding;
output_buffer_binding.reserve(output_buffer_binding_count);
for (auto& [name, graph_output_index] :
graph_buffer_binding_info.graph_output_name_to_index_map) {
const auto graph_output_range_iterator =
compute_resources->output_aligned_byte_length.key_to_d3d12_range_map
.find(name);
CHECK(graph_output_range_iterator !=
compute_resources->output_aligned_byte_length.key_to_d3d12_range_map
.end());
const auto& d3d12_range = graph_output_range_iterator->second;
output_buffer_binding.push_back(
DML_BUFFER_BINDING{.Buffer = compute_resources->output_buffer.Get(),
.Offset = d3d12_range.Begin,
.SizeInBytes = d3d12_range.End - d3d12_range.Begin});
output_buffer_binding_desc[graph_output_index] = {
DML_BINDING_TYPE_BUFFER, &output_buffer_binding.back()};
}
std::optional<DML_BINDING_DESC> persistent_buffer_binding_desc;
if (persistent_resource) {
persistent_buffer_binding_desc =
persistent_resource->persistent_buffer_binding_desc();
}
// Execute the graph with input, output and persistent buffer bindings.
RETURN_IF_FAILED(compute_resources->command_recorder->ExecuteOperator(
compiled_operator, compute_resources->graph_resources.descriptor_heap,
input_buffer_binding_desc, output_buffer_binding_desc,
persistent_buffer_binding_desc,
compute_resources->graph_resources.temporary_buffer_binding_desc));
if (!adapter->IsUMA()) {
ReadbackBufferWithBarrier(
compute_resources->command_recorder.get(),
compute_resources->readback_buffer, compute_resources->output_buffer,
compute_resources->output_aligned_byte_length.total_byte_length);
}
RETURN_IF_FAILED(compute_resources->command_recorder->Close());
return S_OK;
}
GraphImplDml::GraphImplDml(
scoped_refptr<Adapter> adapter,
ContextImplDml* context,
std::unique_ptr<CommandRecorder> command_recorder,
scoped_refptr<PersistentResource> persistent_resource,
ComPtr<IDMLCompiledOperator> compiled_operator,
ComputeResourceInfo compute_resource_info,
GraphBufferBindingInfo graph_buffer_binding_info,
std::unique_ptr<ComputeResources> compute_resources)
: WebNNGraphImpl(context, std::move(compute_resource_info)),
persistent_resource_(std::move(persistent_resource)),
adapter_(std::move(adapter)),
context_(context),
command_recorder_(std::move(command_recorder)),
compiled_operator_(std::move(compiled_operator)),
graph_buffer_binding_info_(std::move(graph_buffer_binding_info)),
compute_resources_(std::move(compute_resources)) {}
// Notice that it's the CommandQueue's responsibility to wait for all of the
// queued work to complete before destructing itself.
GraphImplDml::~GraphImplDml() = default;
base::expected<ComPtr<IDMLCompiledOperator>, HRESULT>
GraphImplDml::CompileOnBackgroundThread(
GraphBuilderDml graph_builder,
const bool pass_dml_execution_disable_meta_commands) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::CompileOnBackgroundThread");
DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_NONE;
if (pass_dml_execution_disable_meta_commands) {
flags |= DML_EXECUTION_FLAG_DISABLE_META_COMMANDS;
}
return graph_builder.Compile(flags);
}
// static
HRESULT GraphImplDml::ExecuteAndWaitSyncOnBackgroundThread(
std::unique_ptr<CommandRecorder> init_command_recorder_for_npu) {
TRACE_EVENT0("gpu",
"dml::GraphImplDml::ExecuteAndWaitSyncOnBackgroundThread");
RETURN_IF_FAILED(init_command_recorder_for_npu->Execute());
RETURN_IF_FAILED(init_command_recorder_for_npu->command_queue()->WaitSync());
return S_OK;
}
// static
void GraphImplDml::OnCompilationComplete(
scoped_refptr<Adapter> adapter,
base::WeakPtr<ContextImplDml> context,
WebNNContextImpl::CreateGraphImplCallback callback,
base::flat_map<uint64_t, mojo_base::BigBuffer> constant_id_to_buffer_map,
std::unordered_map<uint64_t, uint32_t> constant_id_to_input_index_map,
GraphBufferBindingInfo graph_buffer_binding_info,
ComputeResourceInfo compute_resource_info,
base::expected<ComPtr<IDMLCompiledOperator>, HRESULT> compilation_result) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::OnCompilationComplete");
if (!context) {
std::move(callback).Run(base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"Failed to create graph because the context was destroyed.")));
return;
}
if (!compilation_result.has_value()) {
// Handle the unsupported error on NPU gracefully since it's expected.
if (adapter->IsNPU() &&
compilation_result.error() == DXGI_ERROR_UNSUPPORTED) {
LOG(ERROR)
<< "[WebNN] Failed to compile graph on NPU. Model is not supported.";
std::move(callback).Run(base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"Failed to compile graph on NPU. Model is not supported.")));
} else {
HandleGraphCreationFailure("Failed to compile the graph.",
std::move(callback), context.get(),
compilation_result.error());
}
return;
}
ComPtr<IDMLCompiledOperator> compiled_operator =
std::move(compilation_result.value());
CommandQueue* command_queue = adapter->IsNPU()
? adapter->init_command_queue_for_npu()
: adapter->command_queue();
ASSIGN_OR_RETURN(
std::unique_ptr<CommandRecorder> initialization_command_recorder,
CommandRecorder::Create(command_queue, adapter->dml_device()),
&HandleGraphCreationFailure,
"Failed to create command recorder for graph initialization.",
std::move(callback), context.get());
HRESULT hr = initialization_command_recorder->Open();
if (FAILED(hr)) {
HandleGraphCreationFailure("Failed to open the command recorder.",
std::move(callback), context.get(), hr);
return;
}
// Create the input resource binding for graph initialization. The number of
// bindings must exactly match the number of inputs (including constants) of
// the graph, only the constant resource needs to be bound, the inputs for
// computation supply nullptr for `Buffer` member to indicate 'no binding'.
//
// The constant tensor specifying DML_TENSOR_FLAG_OWNED_BY_DML need to bind
// the resource in the buffer binding (DML_BUFFER_BINDING) array, the index
// of constant in the array is DML_INPUT_GRAPH_EDGE_DESC.GraphInputIndex which
// is got from `constant_id_to_input_index_map`.
//
// The inputs tensors without the DML_TENSOR_FLAG_OWNED_BY_DML flag is
// expected to be bound during execution, and not during initialization.
std::vector<DML_BUFFER_BINDING> input_buffer_binding(
graph_buffer_binding_info.input_buffer_binding_count,
DML_BUFFER_BINDING{.Buffer = nullptr, .Offset = 0, .SizeInBytes = 0});
if (!constant_id_to_buffer_map.empty()) {
std::optional<AlignedByteLength<uint64_t>>
aligned_byte_length_of_constants =
CalculateAlignedByteLength(constant_id_to_buffer_map);
if (!aligned_byte_length_of_constants) {
std::move(callback).Run(base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"Failed to calculate the aligned byte length of constants.")));
return;
}
size_t total_byte_length_of_constants =
aligned_byte_length_of_constants.value().total_byte_length;
absl::variant<UploadAndDefaultBuffers, ComPtr<ID3D12Resource>>
buffer_variant;
if (adapter->IsUMA()) {
// For GPU supports UMA, create the custom heap with CPU memory pool, and
// create a resource to map the heap. CPU writes constants into this
// resource which will be bound as graph input for GPU reading during
// initialization.
ComPtr<ID3D12Resource> cpu_buffer;
hr = CreateCustomUploadBuffer(
adapter->d3d12_device(), total_byte_length_of_constants,
L"WebNN_Custom_Upload_Buffer_Constants", cpu_buffer);
if (FAILED(hr)) {
HandleGraphCreationFailure(
"Failed to create custom upload buffer for constants.",
std::move(callback), context.get(), hr);
return;
}
buffer_variant = std::move(cpu_buffer);
} else {
// Create the upload heap that can be written by CPU and read from GPU,
// and create a resource to map the heap.
ComPtr<ID3D12Resource> upload_buffer;
hr = CreateUploadBuffer(adapter->d3d12_device(),
total_byte_length_of_constants,
L"WebNN_Upload_Buffer_Constants", upload_buffer);
if (FAILED(hr)) {
HandleGraphCreationFailure(
"Failed to create upload buffer for constants.",
std::move(callback), context.get(), hr);
return;
}
// Create the default heap that only can be accessed by GPU not provide
// CPU access, and create a resource to map the heap.
ComPtr<ID3D12Resource> default_buffer;
hr = CreateDefaultBuffer(
adapter->d3d12_device(), total_byte_length_of_constants,
L"WebNN_Default_Buffer_Constants", default_buffer);
if (FAILED(hr)) {
HandleGraphCreationFailure(
"Failed to create default input buffer for constants.",
std::move(callback), context.get(), hr);
return;
}
buffer_variant =
UploadAndDefaultBuffers{.upload_buffer = std::move(upload_buffer),
.default_buffer = std::move(default_buffer)};
}
ASSIGN_OR_RETURN(
(std::map<uint64_t, DML_BUFFER_BINDING> constant_buffer_binding),
UploadAndCreateConstantBufferBinding(
initialization_command_recorder.get(), constant_id_to_buffer_map,
aligned_byte_length_of_constants.value(),
std::move(buffer_variant)),
&HandleGraphCreationFailure, "Failed to upload constant weight data.",
std::move(callback), context.get());
// The constant tensor must be bound to the binding table during operator
// initialization, and not during execution.
for (auto& [constant_id, buffer_binding] : constant_buffer_binding) {
// Get the graph input index with the constant id.
const auto graph_input_index_iterator =
constant_id_to_input_index_map.find(constant_id);
CHECK(graph_input_index_iterator != constant_id_to_input_index_map.end());
input_buffer_binding[graph_input_index_iterator->second] =
std::move(buffer_binding);
}
}
DML_BUFFER_ARRAY_BINDING input_buffer_array_binding{
.BindingCount = base::checked_cast<uint32_t>(input_buffer_binding.size()),
.Bindings = input_buffer_binding.data()};
DML_BINDING_DESC input_buffer_binding_desc = {DML_BINDING_TYPE_BUFFER_ARRAY,
&input_buffer_array_binding};
// Create the persistent resource which is bound as output of operator
// initializer.
scoped_refptr<PersistentResource> persistent_resource;
std::optional<DML_BINDING_DESC> persistent_buffer_binding_desc;
DML_BINDING_PROPERTIES execution_binding_properties =
compiled_operator->GetBindingProperties();
uint64_t persistent_buffer_size =
execution_binding_properties.PersistentResourceSize;
if (persistent_buffer_size) {
ComPtr<ID3D12Resource> persistent_buffer;
hr = CreateDefaultBuffer(adapter->d3d12_device(), persistent_buffer_size,
L"WebNN_Default_Persistent_Buffer",
persistent_buffer);
if (FAILED(hr)) {
HandleGraphCreationFailure(
"Failed to create the default buffer for persistent resource.",
std::move(callback), context.get(), hr);
return;
}
persistent_resource = PersistentResource::Create(
persistent_buffer_size, std::move(persistent_buffer));
CHECK(persistent_resource);
persistent_buffer_binding_desc =
persistent_resource->persistent_buffer_binding_desc();
}
hr = initialization_command_recorder->InitializeOperator(
compiled_operator.Get(), input_buffer_binding_desc,
persistent_buffer_binding_desc);
if (FAILED(hr)) {
HandleGraphCreationFailure("Failed to initialize the operator.",
std::move(callback), context.get(), hr);
return;
}
hr = initialization_command_recorder->Close();
if (FAILED(hr)) {
HandleGraphCreationFailure("Failed to close the command list.",
std::move(callback), context.get(), hr);
return;
}
// TODO(crbug.com/344921705): Move other graph initialization tasks to the
// background thread: records the graph initialization onto the command list,
// binds all required resources and closes the command list.
if (adapter->IsNPU()) {
adapter->init_task_runner_for_npu()->PostTaskAndReplyWithResult(
FROM_HERE,
base::BindOnce(&GraphImplDml::ExecuteAndWaitSyncOnBackgroundThread,
std::move(initialization_command_recorder)),
base::BindOnce(
&GraphImplDml::OnInitializationComplete, std::move(adapter),
std::move(context), std::move(persistent_resource),
std::move(compiled_operator), std::move(compute_resource_info),
std::move(graph_buffer_binding_info), std::move(callback)));
return;
}
hr = initialization_command_recorder->Execute();
if (FAILED(hr)) {
HandleGraphCreationFailure("Failed to execute the command list.",
std::move(callback), context.get(), hr);
return;
}
// Since the initialization command recorder has given all of the resources
// needed for graph initialization to the command queue to hold onto until
// they're no longer needed, it won't need to be passed to
// `OnInitializationComplete()`.
initialization_command_recorder->command_queue()->WaitAsync(base::BindOnce(
&GraphImplDml::OnInitializationComplete, std::move(adapter),
std::move(context), std::move(persistent_resource),
std::move(compiled_operator), std::move(compute_resource_info),
std::move(graph_buffer_binding_info), std::move(callback)));
}
// static
base::expected<std::unique_ptr<GraphImplDml::ComputeResources>, HRESULT>
GraphImplDml::RecordGraphExecutionOnBackgroundThread(
scoped_refptr<Adapter> adapter,
scoped_refptr<PersistentResource> persistent_resource,
ComPtr<IDMLCompiledOperator> compiled_operator,
std::unique_ptr<ComputeResources> compute_resources,
GraphBufferBindingInfo graph_buffer_binding_info) {
TRACE_EVENT0("gpu",
"dml::GraphImplDml::RecordGraphExecutionOnBackgroundThread");
RETURN_UNEXPECTED_IF_FAILED(RecordGraphExecution(
adapter.get(), compiled_operator.Get(), compute_resources.get(),
persistent_resource.get(), graph_buffer_binding_info));
return compute_resources;
}
// static
void GraphImplDml::CreateWebNNGraphImpl(
scoped_refptr<Adapter> adapter,
base::WeakPtr<ContextImplDml> context,
scoped_refptr<PersistentResource> persistent_resource,
ComPtr<IDMLCompiledOperator> compiled_operator,
ComputeResourceInfo compute_resource_info,
GraphBufferBindingInfo graph_buffer_binding_info,
WebNNContextImpl::CreateGraphImplCallback callback,
base::expected<std::unique_ptr<ComputeResources>, HRESULT>
recording_result) {
if (!context) {
std::move(callback).Run(base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"Failed to create graph because the context was destroyed.")));
return;
}
if (!recording_result.has_value()) {
HandleGraphCreationFailure(
"Failed to record commands and bind resources for execution.",
std::move(callback), context.get(), recording_result.error());
return;
}
std::unique_ptr<ComputeResources> compute_resources =
std::move(recording_result.value());
// Create a new command recorder and pass it to `GraphImplDml` for
// `dispatch()`. For `compute()`, a separate command recorder is created by
// `AllocateComputeResources()` and stored in `compute_resources`.
ASSIGN_OR_RETURN(
std::unique_ptr<CommandRecorder> command_recorder_for_dispatch,
CommandRecorder::Create(adapter->command_queue(), adapter->dml_device()),
&HandleGraphCreationFailure,
"Failed to create the command recorder for dispatch.",
std::move(callback), context.get());
// The receiver bound to GraphImplDml.
std::move(callback).Run(base::WrapUnique(new GraphImplDml(
std::move(adapter), context.get(),
std::move(command_recorder_for_dispatch), std::move(persistent_resource),
std::move(compiled_operator), std::move(compute_resource_info),
std::move(graph_buffer_binding_info), std::move(compute_resources))));
}
// static
void GraphImplDml::OnInitializationComplete(
scoped_refptr<Adapter> adapter,
base::WeakPtr<ContextImplDml> context,
scoped_refptr<PersistentResource> persistent_resource,
ComPtr<IDMLCompiledOperator> compiled_operator,
ComputeResourceInfo compute_resource_info,
GraphBufferBindingInfo graph_buffer_binding_info,
WebNNContextImpl::CreateGraphImplCallback callback,
HRESULT hr) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::OnInitializationComplete");
if (!context) {
std::move(callback).Run(base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"Failed to create graph because the context was destroyed.")));
return;
}
if (FAILED(hr)) {
HandleGraphCreationFailure(
"Failed to wait for the initialization to complete.",
std::move(callback), context.get(), hr);
return;
}
base::expected<std::unique_ptr<ComputeResources>, HRESULT>
compute_resources_allocation_result = AllocateComputeResources(
adapter.get(), compiled_operator.Get(), compute_resource_info);
if (!compute_resources_allocation_result.has_value()) {
HandleGraphCreationFailure("Failed to allocate compute resource.",
std::move(callback), context.get(),
compute_resources_allocation_result.error());
return;
}
std::unique_ptr<ComputeResources> compute_resources =
std::move(compute_resources_allocation_result.value());
CHECK(compute_resources);
if (adapter->IsNPU()) {
base::ThreadPool::PostTaskAndReplyWithResult(
FROM_HERE,
{base::TaskPriority::USER_BLOCKING,
base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN},
base::BindOnce(&GraphImplDml::RecordGraphExecutionOnBackgroundThread,
adapter, persistent_resource, compiled_operator,
std::move(compute_resources), graph_buffer_binding_info),
base::BindOnce(&GraphImplDml::CreateWebNNGraphImpl, adapter,
std::move(context), persistent_resource,
compiled_operator, std::move(compute_resource_info),
graph_buffer_binding_info, std::move(callback)));
return;
}
hr = RecordGraphExecution(adapter.get(), compiled_operator.Get(),
compute_resources.get(), persistent_resource.get(),
graph_buffer_binding_info);
if (FAILED(hr)) {
HandleGraphCreationFailure(
"Failed to record commands and bind resources for execution.",
std::move(callback), context.get(), hr);
return;
}
CreateWebNNGraphImpl(
std::move(adapter), std::move(context), std::move(persistent_resource),
std::move(compiled_operator), std::move(compute_resource_info),
std::move(graph_buffer_binding_info), std::move(callback),
std::move(compute_resources));
}
// static
base::expected<void, mojom::ErrorPtr> GraphImplDml::CreateAndBuildInternal(
const ContextProperties& context_properties,
scoped_refptr<Adapter> adapter,
mojom::GraphInfoPtr& graph_info,
GraphBuilderDml& graph_builder,
std::unordered_map<uint64_t, uint32_t>& constant_id_to_input_index_map,
GraphBufferBindingInfo& graph_buffer_binding_info) {
IdToNodeOutputMap id_to_node_output_map;
const IdToOperandMap& id_to_operand_map = graph_info->id_to_operand_map;
// Add inputs.
for (auto& input_id : graph_info->input_operands) {
auto graph_input_index = CreateInputNode(
id_to_operand_map, input_id, graph_builder, id_to_node_output_map);
const OperandPtr& operand = id_to_operand_map.at(input_id);
CHECK(operand);
graph_buffer_binding_info
.graph_input_name_to_index_map[operand->name.value()] =
graph_input_index;
}
// The constant operand in WebNNGraph also is treated as input node in graph
// desc.
for (auto& [constant_id, _] : graph_info->constant_id_to_buffer_map) {
auto graph_input_index = CreateInputNode(
id_to_operand_map, constant_id, graph_builder, id_to_node_output_map);
constant_id_to_input_index_map[constant_id] = graph_input_index;
}
// Find out the next operand id that can be used as the key in
// `id_to_operand_map`. It might be used for inserting new operands into maps
// when adding operations.
uint64_t next_operand_id = 0;
base::ranges::for_each(
id_to_operand_map, [&next_operand_id](auto& key_value) {
next_operand_id = std::max(next_operand_id, key_value.first + 1);
});
// Fuse the operations in `mojom::GraphInfo` wherever possible to optimize the
// graph's compute performance.
//
// 1. Go through all operations from the last one to the first one, record the
// output edges count from each operation.
// 2. Find the fusible operations and record them in `GraphFusionInfo`. For
// example, activations (such as relu/sigmoid) that can be fused into
// preceding operations that can support activation fusion (such as
// conv2d/batch_norm), or transposes that can be fused into following matmul
// operation.
// 3. Go through all operations again, create corresponding DirectML operators
// and add them into the final DirectML graph. During the process, the
// `GraphFusionInfo` will be passed to DirectML operator creation methods to
// configure the operator fusion and re-wire the input/output edges. The fused
// operations will be skipped and no DirectML operators will be created for
// them.
GraphFusionInfo graph_fusion_info = GetGraphFusionInfo(graph_info);
// Add operations.
for (auto& operation : graph_info->operations) {
// Skip the operations which are fused into another operation.
if (graph_fusion_info.fusible_operations_set.contains(operation.get())) {
continue;
}
// For operators that deal with DML API, there is a chance that operator
// creation will fail. Use `mojom::ErrorPtr` to hold the given error
// message.
base::expected<void, mojom::ErrorPtr> create_operator_result;
switch (operation->which()) {
case Operation::Tag::kArgMinMax: {
CreateOperatorNodeForArgMinMax(id_to_operand_map,
operation->get_arg_min_max(),
graph_builder, id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kBatchNormalization: {
CreateOperatorNodeForBatchNormalization(
operation.get(),
graph_fusion_info.operation_to_fusible_standalone_activation_map,
graph_info, graph_builder, id_to_node_output_map,
constant_id_to_input_index_map, next_operand_id);
break;
}
case Operation::Tag::kClamp: {
CreateOperatorNodeForClamp(context_properties, id_to_operand_map,
operation->get_clamp(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kConcat: {
CreateOperatorNodeForConcat(id_to_operand_map, operation->get_concat(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kConv2d: {
CreateOperatorNodeForConv2d(
id_to_operand_map, operation.get(),
graph_fusion_info.operation_to_fusible_standalone_activation_map,
graph_builder, id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kElementWiseBinary: {
CreateOperatorNodeForBinary(
context_properties, id_to_operand_map, operation.get(),
graph_fusion_info.operation_to_fusible_standalone_activation_map,
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kElu: {
CreateOperatorNodeForElu(id_to_operand_map, operation->get_elu(),
graph_builder, id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kElementWiseUnary: {
CreateOperatorNodeForElementWiseUnary(
context_properties, id_to_operand_map,
operation->get_element_wise_unary(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kExpand: {
CreateOperatorNodeForExpand(context_properties, id_to_operand_map,
operation->get_expand(), graph_builder,
id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kGather: {
create_operator_result = CreateOperatorNodeForGather(
context_properties, id_to_operand_map, operation->get_gather(),
graph_builder, id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kGatherElements: {
CreateOperatorNodeForGatherElements(
context_properties, id_to_operand_map,
operation->get_gather_elements(), graph_builder,
id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kGelu: {
CreateOperatorNodeForGelu(
adapter.get(), id_to_operand_map, operation->get_gelu(), graph_info,
graph_builder, id_to_node_output_map,
constant_id_to_input_index_map, next_operand_id);
break;
}
case mojom::Operation::Tag::kGemm: {
CreateOperatorNodeForGemm(
context_properties, id_to_operand_map, operation.get(),
graph_fusion_info.operation_to_fusible_standalone_activation_map,
graph_builder, id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kGru: {
create_operator_result = CreateOperatorNodeForGru<mojom::GruPtr>(
id_to_operand_map, operation->get_gru(), graph_info, graph_builder,
id_to_node_output_map, constant_id_to_input_index_map,
next_operand_id);
break;
}
case mojom::Operation::Tag::kGruCell: {
create_operator_result = CreateOperatorNodeForGru<mojom::GruCellPtr>(
id_to_operand_map, operation->get_gru_cell(), graph_info,
graph_builder, id_to_node_output_map,
constant_id_to_input_index_map, next_operand_id);
break;
}
case mojom::Operation::Tag::kHardSigmoid: {
CreateOperatorNodeForHardSigmoid(id_to_operand_map,
operation->get_hard_sigmoid(),
graph_builder, id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kHardSwish: {
CreateOperatorNodeForHardSwish(adapter.get(), id_to_operand_map,
operation->get_hard_swish(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kInstanceNormalization: {
// The axes along which to calculate the Mean and Variance.
std::array<uint32_t, 2> mean_variance_axes;
std::array<uint32_t, 1> scale_bias_broadcast_axes;
const auto& instance_normalization =
operation->get_instance_normalization();
switch (instance_normalization->layout) {
case mojom::InputOperandLayout::kChannelsFirst: {
mean_variance_axes = {2, 3};
scale_bias_broadcast_axes = {1};
break;
}
case mojom::InputOperandLayout::kChannelsLast:
mean_variance_axes = {1, 2};
scale_bias_broadcast_axes = {3};
break;
}
create_operator_result = CreateOperatorNodeForMeanVarianceNormalization(
instance_normalization, operation.get(),
graph_fusion_info.operation_to_fusible_standalone_activation_map,
graph_info, graph_builder, id_to_node_output_map,
constant_id_to_input_index_map, next_operand_id, mean_variance_axes,
scale_bias_broadcast_axes, Operation::Tag::kInstanceNormalization);
break;
}
case Operation::Tag::kLayerNormalization: {
const auto& layer_normalization = operation->get_layer_normalization();
const auto axes = layer_normalization->axes;
create_operator_result = CreateOperatorNodeForMeanVarianceNormalization(
layer_normalization, operation.get(),
graph_fusion_info.operation_to_fusible_standalone_activation_map,
graph_info, graph_builder, id_to_node_output_map,
constant_id_to_input_index_map, next_operand_id, axes, axes,
Operation::Tag::kLayerNormalization);
break;
}
case Operation::Tag::kLeakyRelu: {
CreateOperatorNodeForLeakyRelu(id_to_operand_map,
operation->get_leaky_relu(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kLinear: {
CreateOperatorNodeForLinear(context_properties, id_to_operand_map,
operation->get_linear(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kLstm: {
create_operator_result = CreateOperatorNodeForLstm<mojom::Lstm>(
*operation->get_lstm(), graph_info, graph_builder,
id_to_node_output_map, constant_id_to_input_index_map,
next_operand_id);
break;
}
case Operation::Tag::kLstmCell: {
create_operator_result = CreateOperatorNodeForLstm<mojom::LstmCell>(
*operation->get_lstm_cell(), graph_info, graph_builder,
id_to_node_output_map, constant_id_to_input_index_map,
next_operand_id);
break;
}
case mojom::Operation::Tag::kMatmul: {
create_operator_result = CreateOperatorNodeForMatmul(
context_properties, id_to_operand_map, operation.get(),
graph_fusion_info.operation_to_fusible_standalone_activation_map,
graph_fusion_info.output_id_to_fusible_transpose_map, graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kPad: {
CreateOperatorNodeForPad(context_properties, id_to_operand_map,
operation->get_pad(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kPool2d: {
create_operator_result = CreateOperatorNodeForPool2d(
context_properties, id_to_operand_map, operation->get_pool2d(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kPrelu: {
CreateOperatorNodeForPrelu(context_properties, id_to_operand_map,
operation->get_prelu(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kReduce: {
CreateOperatorNodeForReduce(context_properties, id_to_operand_map,
operation->get_reduce(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kRelu: {
CreateOperatorNodeForUnary<DML_ACTIVATION_RELU_OPERATOR_DESC,
DML_OPERATOR_ACTIVATION_RELU>(
id_to_operand_map, operation->get_relu(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kResample2d: {
CreateOperatorNodeForResample2d(context_properties, id_to_operand_map,
operation->get_resample2d(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kReshape: {
CreateOperatorNodeForReshape(context_properties, id_to_operand_map,
operation->get_reshape(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kSigmoid: {
CreateOperatorNodeForUnary<DML_ACTIVATION_SIGMOID_OPERATOR_DESC,
DML_OPERATOR_ACTIVATION_SIGMOID>(
id_to_operand_map, operation->get_sigmoid(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kSlice: {
CreateOperatorNodeForSlice(id_to_operand_map, operation->get_slice(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kSoftmax: {
create_operator_result = CreateOperatorNodeForSoftmax(
adapter.get(), id_to_operand_map, operation->get_softmax(),
graph_builder, id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kSoftplus: {
CreateOperatorNodeForSoftplus(id_to_operand_map,
operation->get_softplus(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kSoftsign: {
CreateOperatorNodeForUnary<DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC,
DML_OPERATOR_ACTIVATION_SOFTSIGN>(
id_to_operand_map, operation->get_softsign(), graph_builder,
id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kSplit: {
CreateOperatorNodeForSplit(id_to_operand_map, operation->get_split(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kTanh: {
CreateOperatorNodeForUnary<DML_ACTIVATION_TANH_OPERATOR_DESC,
DML_OPERATOR_ACTIVATION_TANH>(
id_to_operand_map, operation->get_tanh(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kTranspose: {
CreateOperatorNodeForTranspose(context_properties, id_to_operand_map,
operation->get_transpose(),
graph_builder, id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kTriangular: {
create_operator_result = CreateOperatorNodeForTriangular(
context_properties, adapter.get(), operation->get_triangular(),
graph_info, graph_builder, id_to_node_output_map,
constant_id_to_input_index_map, next_operand_id);
break;
}
case Operation::Tag::kWhere: {
CreateOperatorNodeForWhere(id_to_operand_map, operation->get_where(),
graph_builder, id_to_node_output_map);
break;
}
default: {
std::string error_message = NotSupportedOperatorError(*operation);
create_operator_result = base::unexpected(CreateError(
mojom::Error::Code::kNotSupportedError, std::move(error_message)));
}
}
if (!create_operator_result.has_value()) {
return create_operator_result;
}
}
for (auto& output_id : graph_info->output_operands) {
const auto output_iterator = id_to_node_output_map.find(output_id);
CHECK(output_iterator != id_to_node_output_map.end());
const NodeOutput* output = output_iterator->second;
CHECK(output);
// TODO: A DML graph's output tensor may have adjusted strides rather than
// default strides which are calculated by its' dimensions. For example,
// dimensions [1,2,3,4] should have default strides [24,12,4,1] according to
// https://docs.microsoft.com/en-us/windows/win32/direct3d12/dml-helper-functions#calculatestrides,
// but the strides may be adjusted for supporting some ops such as
// transpose. Append an identity operator to consume the adjusted strides to
// ensure a correct output result.
// Appending an identity operator DML_OPERATOR_ELEMENT_WISE_IDENTITY which
// effectively copies input tensor to the output tensor to avoid directly
// using graph input as output.
if (output->GetNode().GetType() == Node::Type::kInput) {
output = AppendIdentityNode(graph_builder, output);
}
std::string name = id_to_operand_map.at(output_id)->name.value();
graph_buffer_binding_info.graph_output_name_to_index_map[std::move(name)] =
graph_builder.CreateOutputEdge(output);
}
graph_buffer_binding_info.input_buffer_binding_count =
constant_id_to_input_index_map.size() +
graph_buffer_binding_info.graph_input_name_to_index_map.size();
return base::ok();
}
// static
void GraphImplDml::CreateAndBuild(
scoped_refptr<Adapter> adapter,
base::WeakPtr<ContextImplDml> context,
mojom::GraphInfoPtr graph_info,
ComputeResourceInfo compute_resource_info,
WebNNContextImpl::CreateGraphImplCallback callback,
const bool pass_dml_execution_disable_meta_commands) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::CreateAndBuild");
GraphBuilderDml graph_builder(adapter->dml_device());
std::unordered_map<uint64_t, uint32_t> constant_id_to_input_index_map;
GraphBufferBindingInfo graph_buffer_binding_info;
base::expected<void, mojom::ErrorPtr> create_operator_result =
GraphImplDml::CreateAndBuildInternal(
context->properties(), adapter, graph_info, graph_builder,
constant_id_to_input_index_map, graph_buffer_binding_info);
// TODO(crbug.com/349649099): Handle context lost for operator creation
// failures.
if (!create_operator_result.has_value()) {
std::move(callback).Run(
base::unexpected(std::move(create_operator_result.error())));
return;
}
base::ThreadPool::PostTaskAndReplyWithResult(
FROM_HERE,
{base::TaskPriority::USER_BLOCKING,
base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN},
base::BindOnce(&GraphImplDml::CompileOnBackgroundThread,
std::move(graph_builder),
pass_dml_execution_disable_meta_commands),
base::BindOnce(&GraphImplDml::OnCompilationComplete, std::move(adapter),
std::move(context), std::move(callback),
std::move(graph_info->constant_id_to_buffer_map),
std::move(constant_id_to_input_index_map),
std::move(graph_buffer_binding_info),
std::move(compute_resource_info)));
}
void GraphImplDml::HandleComputationFailure(
const std::string& error_message,
HRESULT hr,
mojom::WebNNGraph::ComputeCallback callback) {
compute_resources_.reset();
std::move(callback).Run(ComputeResult::NewError(
CreateError(mojom::Error::Code::kUnknownError, error_message)));
context_->HandleContextLostOrCrash(error_message, hr);
}
void GraphImplDml::HandleDispatchFailure(std::string_view error_message,
HRESULT hr) {
command_recorder_.reset();
// Clear out previous buffers recorded for dispatch() so we don't mistakenly
// skip recording on failure.
previous_input_buffers_.clear();
previous_output_buffers_.clear();
context_->HandleContextLostOrCrash(error_message, hr);
}
void GraphImplDml::ExecuteAndWaitAsync(
scoped_refptr<Adapter> adapter,
base::flat_map<std::string, mojo_base::BigBuffer> named_inputs,
mojom::WebNNGraph::ComputeCallback callback,
base::expected<std::unique_ptr<ComputeResources>, HRESULT>
recording_result) {
if (!recording_result.has_value()) {
HandleComputationFailure(
"Failed to record commands and bind resources for execution.",
std::move(recording_result.error()), std::move(callback));
return;
}
std::unique_ptr<ComputeResources> compute_resources =
std::move(recording_result.value());
HRESULT hr = S_OK;
if (compute_resources->input_aligned_byte_length.total_byte_length > 0) {
// For GPU supports UMA, the `input_buffer` is allocated in the custom heap
// which can be mapped and written by CPU efficiently.
auto* buffer = adapter->IsUMA() ? compute_resources->input_buffer.Get()
: compute_resources->upload_buffer.Get();
hr = MapAndCopyInputDataToBuffer(
named_inputs,
compute_resources->input_aligned_byte_length.key_to_d3d12_range_map,
buffer);
if (FAILED(hr)) {
HandleComputationFailure(
"Failed to copy the data from named inputs to the buffer.", hr,
std::move(callback));
return;
}
}
// Submit the command list for execution.
hr = compute_resources->command_recorder->Execute();
if (FAILED(hr)) {
HandleComputationFailure("Failed to execute the command list.", hr,
std::move(callback));
return;
}
compute_resources->command_recorder->command_queue()->WaitAsync(
base::BindOnce(&GraphImplDml::OnComputationComplete,
weak_factory_.GetWeakPtr(), std::move(callback),
std::move(compute_resources)));
}
void GraphImplDml::ComputeImpl(
base::flat_map<std::string, mojo_base::BigBuffer> named_inputs,
mojom::WebNNGraph::ComputeCallback callback) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::ComputeImpl");
// It indicates whether we need to record commands and bind resources again
// for the graph execution by calling `RecordGraphExecution` method. If either
// the `compute_resources_` is not available during the graph execution, it
// must be set to true.
bool is_command_recording_needed = false;
// Use the existing compute resource if it is available, otherwise allocate
// a new one.
std::unique_ptr<ComputeResources> compute_resources =
std::move(compute_resources_);
if (!compute_resources) {
base::expected<std::unique_ptr<ComputeResources>, HRESULT>
compute_resources_allocation_result = AllocateComputeResources(
adapter_.get(), compiled_operator_.Get(), compute_resource_info());
if (!compute_resources_allocation_result.has_value()) {
HandleComputationFailure(
"Failed to allocate compute resource.",
std::move(compute_resources_allocation_result.error()),
std::move(callback));
return;
}
compute_resources = std::move(compute_resources_allocation_result.value());
is_command_recording_needed = true;
}
CHECK(compute_resources);
if (is_command_recording_needed) {
if (adapter_->IsNPU()) {
base::ThreadPool::PostTaskAndReplyWithResult(
FROM_HERE,
{base::TaskPriority::USER_BLOCKING,
base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN},
base::BindOnce(&GraphImplDml::RecordGraphExecutionOnBackgroundThread,
adapter_, persistent_resource_, compiled_operator_,
std::move(compute_resources),
graph_buffer_binding_info_),
base::BindOnce(&GraphImplDml::ExecuteAndWaitAsync,
weak_factory_.GetWeakPtr(), adapter_,
std::move(named_inputs), std::move(callback)));
return;
}
HRESULT hr = RecordGraphExecution(
adapter_.get(), compiled_operator_.Get(), compute_resources.get(),
persistent_resource_.get(), graph_buffer_binding_info_);
if (FAILED(hr)) {
HandleComputationFailure(
"Failed to record and bind resources for execution.", hr,
std::move(callback));
return;
}
}
ExecuteAndWaitAsync(adapter_, std::move(named_inputs), std::move(callback),
std::move(compute_resources));
}
void GraphImplDml::OnComputationComplete(
mojom::WebNNGraph::ComputeCallback callback,
std::unique_ptr<ComputeResources> compute_resources,
HRESULT hr) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::OnComputationComplete");
if (FAILED(hr)) {
HandleComputationFailure("Failed to wait for the computation to complete.",
hr, std::move(callback));
return;
}
// Map entire buffer to readback the output data one by one with byte
// offset. For GPU supports UMA, the `output_buffer` is allocated in the
// custom heap that can be mapped and read by CPU efficiently.
void* mapped_buffer = nullptr;
auto* buffer_to_map = adapter_->IsUMA()
? compute_resources->output_buffer.Get()
: compute_resources->readback_buffer.Get();
CHECK(buffer_to_map);
hr = buffer_to_map->Map(0, nullptr, &mapped_buffer);
if (FAILED(hr)) {
HandleComputationFailure("Failed to map the buffer for outputs.", hr,
std::move(callback));
return;
}
const std::map<std::string, D3D12_RANGE>&
graph_output_name_to_d3d12_range_map =
compute_resources->output_aligned_byte_length.key_to_d3d12_range_map;
base::flat_map<std::string, mojo_base::BigBuffer> named_outputs;
named_outputs.reserve(graph_output_name_to_d3d12_range_map.size());
for (auto& [name, d3d12_range] : graph_output_name_to_d3d12_range_map) {
named_outputs[name] = mojo_base::BigBuffer(base::make_span(
static_cast<const uint8_t*>(mapped_buffer) + d3d12_range.Begin,
compute_resource_info()
.output_names_to_descriptors.at(name)
.PackedByteLength()));
}
buffer_to_map->Unmap(0, nullptr);
// If there is an existing available compute resource, release this compute
// resource. Otherwise, recycle this compute resource for the next call.
if (!compute_resources_) {
compute_resources_ = std::move(compute_resources);
}
std::move(callback).Run(
ComputeResult::NewNamedOutputs(std::move(named_outputs)));
}
void GraphImplDml::DispatchImpl(
const base::flat_map<std::string_view, WebNNBufferImpl*>& named_inputs,
const base::flat_map<std::string_view, WebNNBufferImpl*>& named_outputs) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::DispatchImpl");
// It indicates whether we need to record commands and bind resources again.
// If either the I/O buffers change or `graph_resources_` is not available
// during the graph execution, it must be set to true.
bool is_command_recording_needed = false;
// TODO(crbug.com/40278771): avoid re-bindings for all buffers
if (!IsDispatchBindingValid(named_inputs, previous_input_buffers_)) {
is_command_recording_needed = true;
}
if (!IsDispatchBindingValid(named_outputs, previous_output_buffers_)) {
is_command_recording_needed = true;
}
if (!command_recorder_) {
ASSIGN_OR_RETURN(command_recorder_,
CommandRecorder::Create(adapter_->command_queue(),
adapter_->dml_device()),
&GraphImplDml::HandleDispatchFailure, this,
"Failed to create the command recorder.");
is_command_recording_needed = true;
}
// Use the existing graph resource if it is available, otherwise allocate
// a new one.
// TODO(crbug.com/40278771): pre-allocate graph resources in graph
// initialization.
std::unique_ptr<GraphResources> graph_resources = std::move(graph_resources_);
if (!graph_resources) {
base::expected<std::unique_ptr<GraphResources>, HRESULT> result =
AllocateGraphResources(adapter_.get(), compiled_operator_.Get());
if (!result.has_value()) {
HandleDispatchFailure("Failed to allocate graph resources.",
std::move(result.error()));
return;
}
graph_resources = std::move(result.value());
is_command_recording_needed = true;
}
CHECK(graph_resources);
HRESULT hr = S_OK;
if (is_command_recording_needed) {
hr = command_recorder_->Open();
if (FAILED(hr)) {
HandleDispatchFailure("Failed to open the command recorder.", hr);
return;
}
// Create the MLBuffer input bindings needed for graph execution.
std::vector<DML_BUFFER_BINDING> graph_input_buffer_bindings(
graph_buffer_binding_info_.input_buffer_binding_count,
DML_BUFFER_BINDING{.Buffer = nullptr, .Offset = 0, .SizeInBytes = 0});
previous_input_buffers_.reserve(named_inputs.size());
// The graph input tensors must be bound to the binding table during the
// graph execution.
std::vector<DML_BINDING_DESC> input_buffer_binding_desc(
graph_buffer_binding_info_.input_buffer_binding_count,
DML_BINDING_DESC{.Type = DML_BINDING_TYPE_NONE, .Desc = nullptr});
for (auto& [name, input_buffer] : named_inputs) {
BufferImplDml* input_buffer_impl =
static_cast<BufferImplDml*>(input_buffer);
// Get the graph input index for the name.
const size_t graph_input_index =
graph_buffer_binding_info_.graph_input_name_to_index_map.at(
std::string(name));
graph_input_buffer_bindings[graph_input_index] = DML_BUFFER_BINDING{
.Buffer = input_buffer_impl->buffer(),
.Offset = 0,
.SizeInBytes = input_buffer_impl->PackedByteLength()};
input_buffer_binding_desc[graph_input_index] = {
DML_BINDING_TYPE_BUFFER,
&graph_input_buffer_bindings[graph_input_index]};
previous_input_buffers_[std::string(name)] =
input_buffer_impl->GetWeakPtr();
command_recorder_->OnBufferAccessed(input_buffer_impl);
}
// TODO(crbug.com/40278771): consider pre-computing the output binding
// count.
const size_t output_buffer_binding_count =
graph_buffer_binding_info_.graph_output_name_to_index_map.size();
// Create the MLBuffer output bindings needed for graph execution.
std::vector<DML_BUFFER_BINDING> graph_output_buffer_bindings(
output_buffer_binding_count,
DML_BUFFER_BINDING{.Buffer = nullptr, .Offset = 0, .SizeInBytes = 0});
// The graph output tensors must be bound to the binding table during the
// graph execution.
std::vector<DML_BINDING_DESC> output_buffer_binding_desc(
output_buffer_binding_count,
DML_BINDING_DESC{.Type = DML_BINDING_TYPE_NONE, .Desc = nullptr});
previous_output_buffers_.reserve(named_outputs.size());
for (auto& [name, output_buffer] : named_outputs) {
BufferImplDml* output_buffer_impl =
static_cast<BufferImplDml*>(output_buffer);
// Get the graph output index with the name.
const size_t graph_output_index =
graph_buffer_binding_info_.graph_output_name_to_index_map.at(
std::string(name));
graph_output_buffer_bindings[graph_output_index] = DML_BUFFER_BINDING{
.Buffer = output_buffer_impl->buffer(),
.Offset = 0,
.SizeInBytes = output_buffer_impl->PackedByteLength()};
output_buffer_binding_desc[graph_output_index] = {
DML_BINDING_TYPE_BUFFER,
&graph_output_buffer_bindings[graph_output_index]};
previous_output_buffers_[std::string(name)] =
output_buffer_impl->GetWeakPtr();
// Only output buffers could get modified upon execution.
command_recorder_->OnBufferAccessed(output_buffer_impl);
}
std::optional<DML_BINDING_DESC> persistent_buffer_binding_desc;
if (persistent_resource_) {
persistent_buffer_binding_desc =
persistent_resource_->persistent_buffer_binding_desc();
}
// Execute the graph with input, output, temporary, and persistent bindings.
hr = command_recorder_->ExecuteOperator(
compiled_operator_.Get(), graph_resources->descriptor_heap,
input_buffer_binding_desc, output_buffer_binding_desc,
persistent_buffer_binding_desc,
graph_resources->temporary_buffer_binding_desc);
if (FAILED(hr)) {
HandleDispatchFailure("Failed to record execute operator.", hr);
return;
}
hr = command_recorder_->Close();
if (FAILED(hr)) {
HandleDispatchFailure("Failed to close the command recorder.", hr);
return;
}
}
// Submit the command list for execution.
hr = command_recorder_->Execute();
if (FAILED(hr)) {
HandleDispatchFailure("Failed to execute the command recorder.", hr);
return;
}
// Prepare for the next dispatch.
command_recorder_->command_queue()->WaitAsync(
base::BindOnce(&GraphImplDml::OnDispatchComplete,
weak_factory_.GetWeakPtr(), std::move(graph_resources)));
}
void GraphImplDml::OnDispatchComplete(
std::unique_ptr<GraphResources> graph_resources,
HRESULT hr) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::OnDispatchComplete");
if (FAILED(hr)) {
HandleDispatchFailure("Failed to wait for the dispatch to complete.", hr);
return;
}
// If there is an existing available graph resources, release the graph
// resources. Otherwise, recycle the graph resources for the next call.
if (!graph_resources_) {
graph_resources_ = std::move(graph_resources);
}
}
} // namespace webnn::dml