#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.h"
#include <algorithm>
#include "base/containers/span.h"
#include "base/notimplemented.h"
#include "base/numerics/checked_math.h"
#include "base/ranges/algorithm.h"
#include "base/types/expected.h"
#include "base/types/expected_macros.h"
#include "base/types/pass_key.h"
#include "mojo/public/cpp/bindings/pending_associated_remote.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/webnn_errors.h"
#include "services/webnn/public/mojom/features.mojom-blink.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom-blink.h"
#include "services/webnn/public/mojom/webnn_graph.mojom-blink.h"
#include "third_party/abseil-cpp/absl/types/variant.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_arg_min_max_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_batch_normalization_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_clamp_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_transpose_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_elu_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gather_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gemm_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_cell_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_hard_sigmoid_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_instance_normalization_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_layer_normalization_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_leaky_relu_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_linear_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_cell_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_descriptor.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operator_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pad_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pool_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_recurrent_network_activation.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_reduce_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_resample_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_split_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_transpose_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_triangular_options.h"
#include "third_party/blink/renderer/core/execution_context/execution_context.h"
#include "third_party/blink/renderer/core/inspector/console_message.h"
#include "third_party/blink/renderer/core/typed_arrays/dom_array_buffer_view.h"
#include "third_party/blink/renderer/modules/ml/ml_context.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_constant_operand.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_error.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_type_converter.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operand.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operator.h"
#include "third_party/blink/renderer/platform/bindings/exception_code.h"
#include "third_party/blink/renderer/platform/bindings/exception_state.h"
#include "third_party/blink/renderer/platform/bindings/script_state.h"
#include "third_party/blink/renderer/platform/heap/collection_support/heap_deque.h"
#include "third_party/blink/renderer/platform/heap/persistent.h"
#include "third_party/blink/renderer/platform/wtf/functional.h"
namespace blink {
namespace {
#define THROW_AND_RETURN_TYPE_IF_ERROR(func, return_value) …
#define THROW_AND_RETURN_IF_ERROR(func, return_value) …
#define ASSIGN_OR_THROW_AND_RETURN_IF_ERROR(lhs, rexpr) …
constexpr char kGraphAlreadyBuiltError[] = …;
void LogConsoleWarning(ScriptState* script_state, const String& message) { … }
webnn::InputOperandLayout BlinkInputOperandLayoutToComponent(
blink::V8MLInputOperandLayout::Enum type) { … }
webnn::Conv2dFilterOperandLayout BlinkConv2dFilterLayoutToComponent(
blink::V8MLConv2dFilterOperandLayout::Enum type) { … }
webnn::ConvTranspose2dFilterOperandLayout
BlinkConvTranspose2dFilterLayoutToComponent(
blink::V8MLConvTranspose2dFilterOperandLayout::Enum type) { … }
webnn::RoundingType BlinkRoundingTypeToComponent(
blink::V8MLRoundingType::Enum type) { … }
webnn::Pool2dKind FromMojoPool2dKind(webnn::mojom::blink::Pool2d::Kind kind) { … }
webnn::ReduceKind MojoReduceKindToComponent(
webnn::mojom::blink::Reduce::Kind kind) { … }
webnn::RecurrentNetworkDirection BlinkRecurrentNetworkDirectionToComponent(
blink::V8MLRecurrentNetworkDirection::Enum direction) { … }
webnn::BatchNormalizationAttributes ConvertToBatchNormalizationAttributes(
const blink::MLBatchNormalizationOptions* options) { … }
template <typename MLConv2dOptionsType, typename Conv2dAttributesType>
base::expected<Conv2dAttributesType, String> ConvertToConv2dAttributesBase(
const MLConv2dOptionsType* options) { … }
base::expected<webnn::Conv2dAttributes, String> ConvertToConv2dAttributes(
const blink::MLConv2dOptions* options) { … }
base::expected<webnn::ConvTranspose2dAttributes, String>
ConvertToConvTranspose2dAttributes(
const blink::MLConvTranspose2dOptions* options) { … }
base::expected<webnn::Pool2dAttributes, std::string> ConvertToPool2dAttributes(
const blink::MLPool2dOptions* options) { … }
webnn::GemmAttributes ConvertToGemmAttributes(
const blink::MLGemmOptions* options) { … }
webnn::GruAttributes ConvertToGruAttributes(MLGraphBuilder* builder,
blink::MLGruOptions* options) { … }
webnn::GruCellAttributes ConvertToGruCellAttributes(
MLGraphBuilder* builder,
blink::MLGruCellOptions* options) { … }
webnn::InstanceNormalizationAttributes ConvertToInstanceNormalizationAttributes(
const blink::MLInstanceNormalizationOptions* options) { … }
webnn::LayerNormalizationAttributes ConvertToLayerNormalizationAttributes(
const blink::MLLayerNormalizationOptions* options) { … }
webnn::LstmAttributes ConvertToLstmAttributes(
const blink::MLLstmOptions* options) { … }
webnn::LstmCellAttributes ConvertToLstmCellAttributes(
const blink::MLLstmCellOptions* options) { … }
bool ValidateClampOptions(const MLClampOptions* options,
ExceptionState& exception_state) { … }
MLOperand* BuildArgMinMax(MLGraphBuilder* builder,
webnn::mojom::blink::ArgMinMax::Kind sub_kind,
const MLOperand* input,
const uint32_t axis,
const MLArgMinMaxOptions* options,
ExceptionState& exception_state) { … }
MLOperand* BuildElementWiseBinary(
MLGraphBuilder* builder,
webnn::mojom::blink::ElementWiseBinary::Kind kind,
const webnn::SupportedDataTypes& data_type_constraint,
const MLOperand* a,
const MLOperand* b,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* BuildUnaryOperator(
MLGraphBuilder* builder,
ExceptionState& exception_state,
webnn::mojom::blink::Operation::Tag kind,
const webnn::SupportedDataTypes& data_type_constraint,
const MLOperand* input,
const MLOperatorOptions* options) { … }
MLOperand* BuildElementWiseUnaryOperator(
MLGraphBuilder* builder,
ExceptionState& exception_state,
webnn::mojom::blink::ElementWiseUnary::Kind kind,
const webnn::SupportedDataTypes& data_type_constraint,
const MLOperand* input,
const MLOperatorOptions* options) { … }
MLOperand* BuildReduce(MLGraphBuilder* builder,
webnn::mojom::blink::Reduce::Kind kind,
const webnn::ContextProperties& context_properties,
const MLOperand* input,
const MLReduceOptions* options,
ExceptionState& exception_state) { … }
MLOperand* BuildPool2d(MLGraphBuilder* builder,
webnn::mojom::blink::Pool2d::Kind kind,
const webnn::ContextProperties& context_properties,
const MLOperand* input,
const MLPool2dOptions* options,
ExceptionState& exception_state) { … }
base::expected<std::pair<MLGraph::NamedOperandDescriptors,
MLGraph::NamedOperandDescriptors>,
String>
DetermineGraphConstraintsFromOutputs(const MLNamedOperands& named_outputs) { … }
base::expected<webnn::mojom::blink::GraphInfoPtr, String> BuildWebNNGraphInfo(
const MLNamedOperands& named_outputs,
const webnn::ContextProperties& context_properties) { … }
}
MLGraphBuilder* MLGraphBuilder::Create(ScriptState* script_state,
MLContext* context,
ExceptionState& exception_state) { … }
MLGraphBuilder::MLGraphBuilder(
ExecutionContext* execution_context,
MLContext* context,
mojo::PendingAssociatedRemote<webnn::mojom::blink::WebNNGraphBuilder>
pending_remote)
: … { … }
MLGraphBuilder::~MLGraphBuilder() = default;
void MLGraphBuilder::Trace(Visitor* visitor) const { … }
MLContext* MLGraphBuilder::GetContext() const { … }
MLOperand* MLGraphBuilder::input(String name,
const MLOperandDescriptor* desc,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::constant(const MLOperandDescriptor* desc,
NotShared<DOMArrayBufferView> buffer_view,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::argMin(const MLOperand* input,
const uint32_t axis,
const MLArgMinMaxOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::argMax(const MLOperand* input,
const uint32_t axis,
const MLArgMinMaxOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::batchNormalization(
const MLOperand* input,
const MLOperand* mean,
const MLOperand* variance,
const MLBatchNormalizationOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::concat(const HeapVector<Member<MLOperand>>& inputs,
const uint32_t axis,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::clamp(const MLOperand* input,
const MLClampOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::conv2d(const MLOperand* input,
const MLOperand* filter,
const MLConv2dOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::convTranspose2d(
const MLOperand* input,
const MLOperand* filter,
const MLConvTranspose2dOptions* options,
ExceptionState& exception_state) { … }
#define BUILD_ELEMENTWISE_BINARY_OP(op, op_kind) …
BUILD_ELEMENTWISE_BINARY_OP(…) …
BUILD_ELEMENTWISE_BINARY_OP(…) …
BUILD_ELEMENTWISE_BINARY_OP(…) …
BUILD_ELEMENTWISE_BINARY_OP(…) …
BUILD_ELEMENTWISE_BINARY_OP(…) …
BUILD_ELEMENTWISE_BINARY_OP(…) …
BUILD_ELEMENTWISE_BINARY_OP(…) …
BUILD_ELEMENTWISE_BINARY_OP(…) …
BUILD_ELEMENTWISE_BINARY_OP(…) …
BUILD_ELEMENTWISE_BINARY_OP(…) …
MLOperand* MLGraphBuilder::greaterOrEqual(const MLOperand* a,
const MLOperand* b,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::lesserOrEqual(const MLOperand* a,
const MLOperand* b,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
#define BUILD_ELEMENTWISE_UNARY_OP(op, op_kind) …
BUILD_ELEMENTWISE_UNARY_OP(…) …
BUILD_ELEMENTWISE_UNARY_OP(…) …
BUILD_ELEMENTWISE_UNARY_OP(…) …
BUILD_ELEMENTWISE_UNARY_OP(…) …
BUILD_ELEMENTWISE_UNARY_OP(…) …
BUILD_ELEMENTWISE_UNARY_OP(…) …
BUILD_ELEMENTWISE_UNARY_OP(…) …
BUILD_ELEMENTWISE_UNARY_OP(…) …
BUILD_ELEMENTWISE_UNARY_OP(…) …
BUILD_ELEMENTWISE_UNARY_OP(…) …
BUILD_ELEMENTWISE_UNARY_OP(…) …
BUILD_ELEMENTWISE_UNARY_OP(…) …
BUILD_ELEMENTWISE_UNARY_OP(…) …
BUILD_ELEMENTWISE_UNARY_OP(…) …
MLOperand* MLGraphBuilder::logicalNot(const MLOperand* input,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::cast(const MLOperand* input,
const V8MLOperandDataType output_data_type,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
#define BUILD_REDUCE_OP(op, op_kind) …
BUILD_REDUCE_OP(…) …
BUILD_REDUCE_OP(…) …
BUILD_REDUCE_OP(…) …
BUILD_REDUCE_OP(…) …
BUILD_REDUCE_OP(…) …
BUILD_REDUCE_OP(…) …
BUILD_REDUCE_OP(…) …
BUILD_REDUCE_OP(…) …
BUILD_REDUCE_OP(…) …
BUILD_REDUCE_OP(…) …
MLOperand* MLGraphBuilder::elu(const MLOperand* input,
const MLEluOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::expand(const MLOperand* input,
const Vector<uint32_t>& new_shape,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::gather(const MLOperand* input,
const MLOperand* indices,
const MLGatherOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::gatherElements(const MLOperand* input,
const MLOperand* indices,
const MLGatherOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::gelu(const MLOperand* input,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::gemm(const MLOperand* a,
const MLOperand* b,
const MLGemmOptions* options,
ExceptionState& exception_state) { … }
HeapVector<Member<const MLOperand>> MLGraphBuilder::gru(
const MLOperand* input,
const MLOperand* weight,
const MLOperand* recurrent_weight,
const uint32_t steps,
const uint32_t hidden_size,
MLGruOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::gruCell(const MLOperand* input,
const MLOperand* weight,
const MLOperand* recurrent_weight,
const MLOperand* hidden_state,
const uint32_t hidden_size,
MLGruCellOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::hardSigmoid(const MLOperand* input,
const MLHardSigmoidOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::hardSwish(const MLOperand* input,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::instanceNormalization(
const MLOperand* input,
const MLInstanceNormalizationOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::layerNormalization(
const MLOperand* input,
const MLLayerNormalizationOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::leakyRelu(const MLOperand* input,
const MLLeakyReluOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::linear(const MLOperand* input,
const MLLinearOptions* options,
ExceptionState& exception_state) { … }
HeapVector<Member<const MLOperand>> MLGraphBuilder::lstm(
const MLOperand* input,
const MLOperand* weight,
const MLOperand* recurrent_weight,
const uint32_t steps,
const uint32_t hidden_size,
MLLstmOptions* options,
ExceptionState& exception_state) { … }
HeapVector<Member<const MLOperand>> MLGraphBuilder::lstmCell(
const MLOperand* input,
const MLOperand* weight,
const MLOperand* recurrent_weight,
const MLOperand* hidden_state,
const MLOperand* cell_state,
const uint32_t hidden_size,
MLLstmCellOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::matmul(const MLOperand* a,
const MLOperand* b,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::pad(ScriptState* script_state,
const MLOperand* input,
const Vector<uint32_t>& beginning_padding,
const Vector<uint32_t>& ending_padding,
const MLPadOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::averagePool2d(const MLOperand* input,
const MLPool2dOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::l2Pool2d(const MLOperand* input,
const MLPool2dOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::maxPool2d(const MLOperand* input,
const MLPool2dOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::prelu(const MLOperand* input,
const MLOperand* slope,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::relu(const MLOperand* input,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::reshape(const MLOperand* input,
const Vector<uint32_t>& new_shape,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::resample2d(ScriptState* script_state,
const MLOperand* input,
const MLResample2dOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::sigmoid(const MLOperand* input,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::slice(const MLOperand* input,
const Vector<uint32_t>& starts,
const Vector<uint32_t>& sizes,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::softmax(const MLOperand* input,
uint32_t axis,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::softmax(const MLOperand* input,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::softplus(const MLOperand* input,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::softsign(const MLOperand* input,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
HeapVector<Member<const MLOperand>> MLGraphBuilder::split(
const MLOperand* input,
const uint32_t splits,
const MLSplitOptions* options,
ExceptionState& exception_state) { … }
HeapVector<Member<const MLOperand>> MLGraphBuilder::split(
const MLOperand* input,
const Vector<uint32_t>& splits,
const MLSplitOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::tanh(const MLOperand* input,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::transpose(const MLOperand* input,
const MLTransposeOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::triangular(const MLOperand* input,
const MLTriangularOptions* options,
ExceptionState& exception_state) { … }
MLOperand* MLGraphBuilder::where(const MLOperand* condition,
const MLOperand* true_value,
const MLOperand* false_value,
const MLOperatorOptions* options,
ExceptionState& exception_state) { … }
ScriptPromise<MLGraph> MLGraphBuilder::build(
ScriptState* script_state,
const MLNamedOperands& named_outputs,
ExceptionState& exception_state) { … }
void MLGraphBuilder::DidCreateWebNNGraph(
ScriptPromiseResolver<blink::MLGraph>* resolver,
std::pair<MLGraph::NamedOperandDescriptors,
MLGraph::NamedOperandDescriptors> input_and_output_constraints,
webnn::mojom::blink::CreateGraphResultPtr result) { … }
void MLGraphBuilder::OnConnectionError() { … }
base::expected<void, String> MLGraphBuilder::ValidateGraphBuilderState() const { … }
base::expected<void, String> MLGraphBuilder::ValidateInput(
const MLOperand* input) { … }
base::expected<void, String> MLGraphBuilder::ValidateInputs(
const HeapVector<Member<const MLOperand>>& inputs) { … }
}