#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_type_converter.h"
#include <array>
#include <optional>
#include "base/notreached.h"
#include "base/numerics/safe_conversions.h"
#include "base/ranges/algorithm.h"
#include "base/types/expected_macros.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/graph_validation_utils.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/mojom/webnn_graph.mojom-blink-forward.h"
#include "services/webnn/public/mojom/webnn_graph.mojom-blink.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_arg_min_max_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_batch_normalization_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_clamp_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_transpose_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_elu_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gather_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gemm_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_cell_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_hard_sigmoid_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_input_operand_layout.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_instance_normalization_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_layer_normalization_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_leaky_relu_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_linear_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_cell_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operator_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pad_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pool_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_recurrent_network_activation.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_reduce_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_resample_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_split_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_transpose_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_triangular_options.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operand.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operator.h"
#include "third_party/blink/renderer/platform/wtf/wtf_size_t.h"
blink_mojom;
namespace mojo {
webnn::OperandDataType ToOperandDataType(
blink::V8MLOperandDataType::Enum data_type) { … }
webnn::mojom::blink::RecurrentNetworkActivation
BlinkRecurrentNetworkActivationToMojo(
blink::V8MLRecurrentNetworkActivation activation) { … }
blink_mojom::RecurrentNetworkDirection BlinkRecurrentNetworkDirectionToMojo(
blink::V8MLRecurrentNetworkDirection::Enum direction) { … }
blink_mojom::LstmWeightLayout BlinkLstmWeightLayoutToMojo(
blink::V8MLLstmWeightLayout::Enum layout) { … }
blink_mojom::GruWeightLayout BlinkGruWeightLayoutToMojo(
blink::V8MLGruWeightLayout::Enum layout) { … }
blink_mojom::OperandPtr
TypeConverter<blink_mojom::OperandPtr, blink::MLOperand*>::Convert(
const blink::MLOperand* ml_operand) { … }
webnn::Size2d<uint32_t> GetInputOperandSize2d(
const blink::MLOperand* input,
blink::V8MLInputOperandLayout::Enum type) { … }
}
namespace blink {
namespace {
ElementWiseBinary;
ElementWiseUnary;
Operation;
OperationPtr;
Size2d;
OperandToIdMap;
uint64_t GetOperatorInputId(const MLOperator* op,
const OperandToIdMap& operand_to_id_map,
wtf_size_t index = 0) { … }
uint64_t GetOperatorOutputId(const MLOperator* op,
const OperandToIdMap& operand_to_id_map,
wtf_size_t index = 0) { … }
uint64_t InsertTemporaryOperand(const OperandToIdMap& operand_to_id_map,
webnn::OperandDescriptor descriptor,
blink_mojom::GraphInfo* graph_info) { … }
Vector<uint32_t> PermuteShape(base::span<const uint32_t> shape,
base::span<const uint32_t> permutation) { … }
uint64_t InsertInputTranspose(const OperandToIdMap& operand_to_id_map,
const MLOperand* operand,
base::span<const uint32_t> permutation,
blink_mojom::GraphInfo* graph_info,
const String& label) { … }
blink_mojom::ClampPtr CreateClamp(const OperandToIdMap& operand_to_id_map,
const MLOperator* clamp) { … }
blink_mojom::EluPtr CreateElu(const OperandToIdMap& operand_to_id_map,
const MLOperator* elu) { … }
blink_mojom::HardSigmoidPtr CreateHardSigmoid(
const OperandToIdMap& operand_to_id_map,
const MLOperator* hard_sigmoid) { … }
OperationPtr CreateExpandOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* expand) { … }
blink_mojom::LeakyReluPtr CreateLeakyRelu(
const OperandToIdMap& operand_to_id_map,
const MLOperator* leaky_relu) { … }
blink_mojom::LinearPtr CreateLinear(const OperandToIdMap& operand_to_id_map,
const MLOperator* linear) { … }
OperationPtr CreateSoftmaxOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* softmax) { … }
OperationPtr CreateSoftplus(const OperandToIdMap& operand_to_id_map,
const MLOperator* softplus) { … }
webnn::mojom::InputOperandLayout BlinkInputOperandLayoutToMojo(
blink::V8MLInputOperandLayout::Enum type) { … }
webnn::InputOperandLayout BlinkInputOperandLayoutToNative(
blink::V8MLInputOperandLayout::Enum type) { … }
constexpr std::array<uint32_t, 4> kNchwToNhwcPermutation = …;
constexpr std::array<uint32_t, 4> kNhwcToNchwPermutation = …;
std::optional<base::span<const uint32_t>> GetInputOperandPermutation(
blink::V8MLInputOperandLayout::Enum input_layout,
const webnn::ContextProperties& context_properties) { … }
std::optional<base::span<const uint32_t>> GetOutputOperandPermutation(
blink::V8MLInputOperandLayout::Enum input_layout,
const webnn::ContextProperties& context_properties) { … }
std::optional<base::span<const uint32_t>> GetConv2DFilterPermutation(
webnn::InputOperandLayout input_layout,
bool depthwise,
blink::V8MLConv2dFilterOperandLayout filter_layout) { … }
std::optional<base::span<const uint32_t>> GetConvTranspose2DFilterPermutation(
webnn::InputOperandLayout input_layout,
blink::V8MLConvTranspose2dFilterOperandLayout filter_layout) { … }
constexpr std::array<uint32_t, 2> kResample2dChannelFirstAxes{ … };
constexpr std::array<uint32_t, 2> kResample2dChannelLastAxes{ … };
std::optional<std::vector<uint32_t>> GetResample2DPermutation(
const Vector<uint32_t>& from_axes,
const webnn::ContextProperties& context_properties) { … }
std::vector<uint32_t> GetInversePermutation(
base::span<const uint32_t> permutation) { … }
OperationPtr CreateArgMinMaxOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* op,
blink_mojom::ArgMinMax::Kind kind) { … }
OperationPtr CreateBatchNormalizationOperation(
const OperandToIdMap& operand_to_id_map,
const MLOperator* batch_normalization) { … }
OperationPtr CreateConcatOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* concat) { … }
bool IsDepthwiseConv2d(const MLOperator* conv2d) { … }
template <typename MLConv2dOptionsType>
std::optional<String> SerializeConv2dOperation(
const OperandToIdMap& operand_to_id_map,
const webnn::ContextProperties& context_properties,
const MLOperator* conv2d,
blink_mojom::GraphInfo* graph_info) { … }
OperationPtr CreateElementWiseBinaryOperator(
const OperandToIdMap& operand_to_id_map,
const MLOperator* binary,
const blink_mojom::ElementWiseBinary::Kind& kind) { … }
OperationPtr CreateElementWiseUnaryOperator(
const OperandToIdMap& operand_to_id_map,
const MLOperator* unary,
const blink_mojom::ElementWiseUnary::Kind& kind) { … }
OperationPtr CreateGatherOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* gather) { … }
OperationPtr CreateGatherElementsOperation(
const OperandToIdMap& operand_to_id_map,
const MLOperator* gather_elements) { … }
OperationPtr CreateGeluOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* gelu) { … }
OperationPtr CreateGemmOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* gemm) { … }
OperationPtr CreateGruOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* gru) { … }
base::expected<OperationPtr, String> CreateGruCellOperation(
const OperandToIdMap& operand_to_id_map,
const MLOperator* gru_cell) { … }
OperationPtr CreateHardSwishOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* hard_swish) { … }
OperationPtr CreateLayerNormalizationOperation(
const OperandToIdMap& operand_to_id_map,
const MLOperator* layer_normalization) { … }
OperationPtr CreateInstanceNormalizationOperation(
const OperandToIdMap& operand_to_id_map,
const MLOperator* instance_normalization) { … }
OperationPtr CreateLstmOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* lstm) { … }
base::expected<OperationPtr, String> CreateLstmCellOperation(
const OperandToIdMap& operand_to_id_map,
const MLOperator* lstm_cell) { … }
OperationPtr CreateMatmulOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* matmul) { … }
OperationPtr CreatePadOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* op) { … }
void SerializePool2dOperation(
const OperandToIdMap& operand_to_id_map,
const webnn::ContextProperties& context_properties,
const MLOperator* pool2d,
const blink_mojom::Pool2d::Kind& kind,
blink_mojom::GraphInfo* graph_info) { … }
OperationPtr CreatePreluOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* prelu) { … }
OperationPtr CreateReduceOperator(const OperandToIdMap& operand_to_id_map,
const MLOperator* reduce,
const blink_mojom::Reduce::Kind kind) { … }
void SerializeResample2dOperation(
const OperandToIdMap& operand_to_id_map,
const webnn::ContextProperties& context_properties,
const MLOperator* resample2d,
blink_mojom::GraphInfo* graph_info) { … }
OperationPtr CreateReluOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* relu) { … }
OperationPtr CreateReshapeOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* reshape) { … }
OperationPtr CreateSigmoidOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* sigmoid) { … }
OperationPtr CreateSliceOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* slice) { … }
OperationPtr CreateSoftsignOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* softsign) { … }
OperationPtr CreateSplitOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* split) { … }
OperationPtr CreateTanhOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* tanh) { … }
OperationPtr CreateTransposeOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* transpose) { … }
OperationPtr CreateTriangularOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* triangular) { … }
OperationPtr CreateWhereOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* where) { … }
}
uint64_t NextOperandId(const webnn::mojom::blink::GraphInfo& graph_info) { … }
std::optional<String> SerializeMojoOperation(
const HeapHashMap<Member<const MLOperand>, uint64_t>& operand_to_id_map,
const webnn::ContextProperties& context_properties,
const MLOperator* op,
webnn::mojom::blink::GraphInfo* graph_info) { … }
}