chromium/third_party/blink/renderer/modules/ml/webnn/ml_graph_type_converter.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_type_converter.h"

#include <array>
#include <optional>

#include "base/notreached.h"
#include "base/numerics/safe_conversions.h"
#include "base/ranges/algorithm.h"
#include "base/types/expected_macros.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/graph_validation_utils.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/mojom/webnn_graph.mojom-blink-forward.h"
#include "services/webnn/public/mojom/webnn_graph.mojom-blink.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_arg_min_max_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_batch_normalization_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_clamp_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_transpose_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_elu_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gather_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gemm_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_cell_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_hard_sigmoid_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_input_operand_layout.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_instance_normalization_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_layer_normalization_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_leaky_relu_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_linear_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_cell_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operator_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pad_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pool_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_recurrent_network_activation.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_reduce_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_resample_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_split_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_transpose_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_triangular_options.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operand.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operator.h"
#include "third_party/blink/renderer/platform/wtf/wtf_size_t.h"

blink_mojom;

namespace mojo {

webnn::OperandDataType ToOperandDataType(
    blink::V8MLOperandDataType::Enum data_type) {}

webnn::mojom::blink::RecurrentNetworkActivation
BlinkRecurrentNetworkActivationToMojo(
    blink::V8MLRecurrentNetworkActivation activation) {}

blink_mojom::RecurrentNetworkDirection BlinkRecurrentNetworkDirectionToMojo(
    blink::V8MLRecurrentNetworkDirection::Enum direction) {}

blink_mojom::LstmWeightLayout BlinkLstmWeightLayoutToMojo(
    blink::V8MLLstmWeightLayout::Enum layout) {}

blink_mojom::GruWeightLayout BlinkGruWeightLayoutToMojo(
    blink::V8MLGruWeightLayout::Enum layout) {}

// Converters from IDL to Mojo.
blink_mojom::OperandPtr
TypeConverter<blink_mojom::OperandPtr, blink::MLOperand*>::Convert(
    const blink::MLOperand* ml_operand) {}

// Get height and width of input operand.
webnn::Size2d<uint32_t> GetInputOperandSize2d(
    const blink::MLOperand* input,
    blink::V8MLInputOperandLayout::Enum type) {}

}  // namespace mojo

namespace blink {

namespace {

ElementWiseBinary;
ElementWiseUnary;
Operation;
OperationPtr;
Size2d;

// Maps MLOperand to its id which is used to identify the `mojo::Operand` across
// processes.
OperandToIdMap;

uint64_t GetOperatorInputId(const MLOperator* op,
                            const OperandToIdMap& operand_to_id_map,
                            wtf_size_t index = 0) {}

uint64_t GetOperatorOutputId(const MLOperator* op,
                             const OperandToIdMap& operand_to_id_map,
                             wtf_size_t index = 0) {}

uint64_t InsertTemporaryOperand(const OperandToIdMap& operand_to_id_map,
                                webnn::OperandDescriptor descriptor,
                                blink_mojom::GraphInfo* graph_info) {}

Vector<uint32_t> PermuteShape(base::span<const uint32_t> shape,
                              base::span<const uint32_t> permutation) {}

// Insert a transpose operation after the given operand. Returns the ID of the
// operand holding the transposed result.
uint64_t InsertInputTranspose(const OperandToIdMap& operand_to_id_map,
                              const MLOperand* operand,
                              base::span<const uint32_t> permutation,
                              blink_mojom::GraphInfo* graph_info,
                              const String& label) {}

blink_mojom::ClampPtr CreateClamp(const OperandToIdMap& operand_to_id_map,
                                  const MLOperator* clamp) {}

blink_mojom::EluPtr CreateElu(const OperandToIdMap& operand_to_id_map,
                              const MLOperator* elu) {}

blink_mojom::HardSigmoidPtr CreateHardSigmoid(
    const OperandToIdMap& operand_to_id_map,
    const MLOperator* hard_sigmoid) {}

OperationPtr CreateExpandOperation(const OperandToIdMap& operand_to_id_map,
                                   const MLOperator* expand) {}

blink_mojom::LeakyReluPtr CreateLeakyRelu(
    const OperandToIdMap& operand_to_id_map,
    const MLOperator* leaky_relu) {}

blink_mojom::LinearPtr CreateLinear(const OperandToIdMap& operand_to_id_map,
                                    const MLOperator* linear) {}

OperationPtr CreateSoftmaxOperation(const OperandToIdMap& operand_to_id_map,
                                    const MLOperator* softmax) {}

OperationPtr CreateSoftplus(const OperandToIdMap& operand_to_id_map,
                            const MLOperator* softplus) {}

webnn::mojom::InputOperandLayout BlinkInputOperandLayoutToMojo(
    blink::V8MLInputOperandLayout::Enum type) {}

webnn::InputOperandLayout BlinkInputOperandLayoutToNative(
    blink::V8MLInputOperandLayout::Enum type) {}

constexpr std::array<uint32_t, 4> kNchwToNhwcPermutation =;
constexpr std::array<uint32_t, 4> kNhwcToNchwPermutation =;

std::optional<base::span<const uint32_t>> GetInputOperandPermutation(
    blink::V8MLInputOperandLayout::Enum input_layout,
    const webnn::ContextProperties& context_properties) {}

std::optional<base::span<const uint32_t>> GetOutputOperandPermutation(
    blink::V8MLInputOperandLayout::Enum input_layout,
    const webnn::ContextProperties& context_properties) {}

std::optional<base::span<const uint32_t>> GetConv2DFilterPermutation(
    webnn::InputOperandLayout input_layout,
    bool depthwise,
    blink::V8MLConv2dFilterOperandLayout filter_layout) {}

std::optional<base::span<const uint32_t>> GetConvTranspose2DFilterPermutation(
    webnn::InputOperandLayout input_layout,
    blink::V8MLConvTranspose2dFilterOperandLayout filter_layout) {}

constexpr std::array<uint32_t, 2> kResample2dChannelFirstAxes{};
constexpr std::array<uint32_t, 2> kResample2dChannelLastAxes{};
std::optional<std::vector<uint32_t>> GetResample2DPermutation(
    const Vector<uint32_t>& from_axes,
    const webnn::ContextProperties& context_properties) {}

std::vector<uint32_t> GetInversePermutation(
    base::span<const uint32_t> permutation) {}

OperationPtr CreateArgMinMaxOperation(const OperandToIdMap& operand_to_id_map,
                                      const MLOperator* op,
                                      blink_mojom::ArgMinMax::Kind kind) {}

OperationPtr CreateBatchNormalizationOperation(
    const OperandToIdMap& operand_to_id_map,
    const MLOperator* batch_normalization) {}

OperationPtr CreateConcatOperation(const OperandToIdMap& operand_to_id_map,
                                   const MLOperator* concat) {}

bool IsDepthwiseConv2d(const MLOperator* conv2d) {}

template <typename MLConv2dOptionsType>
std::optional<String> SerializeConv2dOperation(
    const OperandToIdMap& operand_to_id_map,
    const webnn::ContextProperties& context_properties,
    const MLOperator* conv2d,
    blink_mojom::GraphInfo* graph_info) {}

OperationPtr CreateElementWiseBinaryOperator(
    const OperandToIdMap& operand_to_id_map,
    const MLOperator* binary,
    const blink_mojom::ElementWiseBinary::Kind& kind) {}

OperationPtr CreateElementWiseUnaryOperator(
    const OperandToIdMap& operand_to_id_map,
    const MLOperator* unary,
    const blink_mojom::ElementWiseUnary::Kind& kind) {}

OperationPtr CreateGatherOperation(const OperandToIdMap& operand_to_id_map,
                                   const MLOperator* gather) {}

OperationPtr CreateGatherElementsOperation(
    const OperandToIdMap& operand_to_id_map,
    const MLOperator* gather_elements) {}

OperationPtr CreateGeluOperation(const OperandToIdMap& operand_to_id_map,
                                 const MLOperator* gelu) {}

OperationPtr CreateGemmOperation(const OperandToIdMap& operand_to_id_map,
                                 const MLOperator* gemm) {}

OperationPtr CreateGruOperation(const OperandToIdMap& operand_to_id_map,
                                const MLOperator* gru) {}

base::expected<OperationPtr, String> CreateGruCellOperation(
    const OperandToIdMap& operand_to_id_map,
    const MLOperator* gru_cell) {}

OperationPtr CreateHardSwishOperation(const OperandToIdMap& operand_to_id_map,
                                      const MLOperator* hard_swish) {}

OperationPtr CreateLayerNormalizationOperation(
    const OperandToIdMap& operand_to_id_map,
    const MLOperator* layer_normalization) {}

OperationPtr CreateInstanceNormalizationOperation(
    const OperandToIdMap& operand_to_id_map,
    const MLOperator* instance_normalization) {}

OperationPtr CreateLstmOperation(const OperandToIdMap& operand_to_id_map,
                                 const MLOperator* lstm) {}

base::expected<OperationPtr, String> CreateLstmCellOperation(
    const OperandToIdMap& operand_to_id_map,
    const MLOperator* lstm_cell) {}

OperationPtr CreateMatmulOperation(const OperandToIdMap& operand_to_id_map,
                                   const MLOperator* matmul) {}

OperationPtr CreatePadOperation(const OperandToIdMap& operand_to_id_map,
                                const MLOperator* op) {}

void SerializePool2dOperation(
    const OperandToIdMap& operand_to_id_map,
    const webnn::ContextProperties& context_properties,
    const MLOperator* pool2d,
    const blink_mojom::Pool2d::Kind& kind,
    blink_mojom::GraphInfo* graph_info) {}

OperationPtr CreatePreluOperation(const OperandToIdMap& operand_to_id_map,
                                  const MLOperator* prelu) {}

OperationPtr CreateReduceOperator(const OperandToIdMap& operand_to_id_map,
                                  const MLOperator* reduce,
                                  const blink_mojom::Reduce::Kind kind) {}

void SerializeResample2dOperation(
    const OperandToIdMap& operand_to_id_map,
    const webnn::ContextProperties& context_properties,
    const MLOperator* resample2d,
    blink_mojom::GraphInfo* graph_info) {}

OperationPtr CreateReluOperation(const OperandToIdMap& operand_to_id_map,
                                 const MLOperator* relu) {}

OperationPtr CreateReshapeOperation(const OperandToIdMap& operand_to_id_map,
                                    const MLOperator* reshape) {}

OperationPtr CreateSigmoidOperation(const OperandToIdMap& operand_to_id_map,
                                    const MLOperator* sigmoid) {}

OperationPtr CreateSliceOperation(const OperandToIdMap& operand_to_id_map,
                                  const MLOperator* slice) {}

OperationPtr CreateSoftsignOperation(const OperandToIdMap& operand_to_id_map,
                                     const MLOperator* softsign) {}

OperationPtr CreateSplitOperation(const OperandToIdMap& operand_to_id_map,
                                  const MLOperator* split) {}

OperationPtr CreateTanhOperation(const OperandToIdMap& operand_to_id_map,
                                 const MLOperator* tanh) {}

OperationPtr CreateTransposeOperation(const OperandToIdMap& operand_to_id_map,
                                      const MLOperator* transpose) {}

OperationPtr CreateTriangularOperation(const OperandToIdMap& operand_to_id_map,
                                       const MLOperator* triangular) {}

OperationPtr CreateWhereOperation(const OperandToIdMap& operand_to_id_map,
                                  const MLOperator* where) {}

}  // namespace

uint64_t NextOperandId(const webnn::mojom::blink::GraphInfo& graph_info) {}

// TODO(crbug.com/1504405): Use a lookup table to simplifie the switch logic.
std::optional<String> SerializeMojoOperation(
    const HeapHashMap<Member<const MLOperand>, uint64_t>& operand_to_id_map,
    const webnn::ContextProperties& context_properties,
    const MLOperator* op,
    webnn::mojom::blink::GraphInfo* graph_info) {}

}  // namespace blink