chromium/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc

// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.h"

#include <algorithm>

#include "base/containers/span.h"
#include "base/notimplemented.h"
#include "base/numerics/checked_math.h"
#include "base/ranges/algorithm.h"
#include "base/types/expected.h"
#include "base/types/expected_macros.h"
#include "base/types/pass_key.h"
#include "mojo/public/cpp/bindings/pending_associated_remote.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/webnn_errors.h"
#include "services/webnn/public/mojom/features.mojom-blink.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom-blink.h"
#include "services/webnn/public/mojom/webnn_graph.mojom-blink.h"
#include "third_party/abseil-cpp/absl/types/variant.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_arg_min_max_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_batch_normalization_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_clamp_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_transpose_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_elu_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gather_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gemm_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_cell_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_hard_sigmoid_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_instance_normalization_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_layer_normalization_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_leaky_relu_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_linear_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_cell_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_descriptor.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operator_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pad_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pool_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_recurrent_network_activation.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_reduce_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_resample_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_split_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_transpose_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_triangular_options.h"
#include "third_party/blink/renderer/core/execution_context/execution_context.h"
#include "third_party/blink/renderer/core/inspector/console_message.h"
#include "third_party/blink/renderer/core/typed_arrays/dom_array_buffer_view.h"
#include "third_party/blink/renderer/modules/ml/ml_context.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_constant_operand.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_error.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_type_converter.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operand.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operator.h"
#include "third_party/blink/renderer/platform/bindings/exception_code.h"
#include "third_party/blink/renderer/platform/bindings/exception_state.h"
#include "third_party/blink/renderer/platform/bindings/script_state.h"
#include "third_party/blink/renderer/platform/heap/collection_support/heap_deque.h"
#include "third_party/blink/renderer/platform/heap/persistent.h"
#include "third_party/blink/renderer/platform/wtf/functional.h"

namespace blink {

namespace {

#define THROW_AND_RETURN_TYPE_IF_ERROR(func, return_value)

#define THROW_AND_RETURN_IF_ERROR(func, return_value)

#define ASSIGN_OR_THROW_AND_RETURN_IF_ERROR(lhs, rexpr)

constexpr char kGraphAlreadyBuiltError[] =;

void LogConsoleWarning(ScriptState* script_state, const String& message) {}

webnn::InputOperandLayout BlinkInputOperandLayoutToComponent(
    blink::V8MLInputOperandLayout::Enum type) {}

webnn::Conv2dFilterOperandLayout BlinkConv2dFilterLayoutToComponent(
    blink::V8MLConv2dFilterOperandLayout::Enum type) {}

webnn::ConvTranspose2dFilterOperandLayout
BlinkConvTranspose2dFilterLayoutToComponent(
    blink::V8MLConvTranspose2dFilterOperandLayout::Enum type) {}

webnn::RoundingType BlinkRoundingTypeToComponent(
    blink::V8MLRoundingType::Enum type) {}

webnn::Pool2dKind FromMojoPool2dKind(webnn::mojom::blink::Pool2d::Kind kind) {}

webnn::ReduceKind MojoReduceKindToComponent(
    webnn::mojom::blink::Reduce::Kind kind) {}

webnn::RecurrentNetworkDirection BlinkRecurrentNetworkDirectionToComponent(
    blink::V8MLRecurrentNetworkDirection::Enum direction) {}

webnn::BatchNormalizationAttributes ConvertToBatchNormalizationAttributes(
    const blink::MLBatchNormalizationOptions* options) {}

template <typename MLConv2dOptionsType, typename Conv2dAttributesType>
base::expected<Conv2dAttributesType, String> ConvertToConv2dAttributesBase(
    const MLConv2dOptionsType* options) {}

base::expected<webnn::Conv2dAttributes, String> ConvertToConv2dAttributes(
    const blink::MLConv2dOptions* options) {}

base::expected<webnn::ConvTranspose2dAttributes, String>
ConvertToConvTranspose2dAttributes(
    const blink::MLConvTranspose2dOptions* options) {}

base::expected<webnn::Pool2dAttributes, std::string> ConvertToPool2dAttributes(
    const blink::MLPool2dOptions* options) {}

webnn::GemmAttributes ConvertToGemmAttributes(
    const blink::MLGemmOptions* options) {}

webnn::GruAttributes ConvertToGruAttributes(MLGraphBuilder* builder,
                                            blink::MLGruOptions* options) {}

webnn::GruCellAttributes ConvertToGruCellAttributes(
    MLGraphBuilder* builder,
    blink::MLGruCellOptions* options) {}

webnn::InstanceNormalizationAttributes ConvertToInstanceNormalizationAttributes(
    const blink::MLInstanceNormalizationOptions* options) {}

webnn::LayerNormalizationAttributes ConvertToLayerNormalizationAttributes(
    const blink::MLLayerNormalizationOptions* options) {}

webnn::LstmAttributes ConvertToLstmAttributes(
    const blink::MLLstmOptions* options) {}

webnn::LstmCellAttributes ConvertToLstmCellAttributes(
    const blink::MLLstmCellOptions* options) {}

bool ValidateClampOptions(const MLClampOptions* options,
                          ExceptionState& exception_state) {}

MLOperand* BuildArgMinMax(MLGraphBuilder* builder,
                          webnn::mojom::blink::ArgMinMax::Kind sub_kind,
                          const MLOperand* input,
                          const uint32_t axis,
                          const MLArgMinMaxOptions* options,
                          ExceptionState& exception_state) {}

MLOperand* BuildElementWiseBinary(
    MLGraphBuilder* builder,
    webnn::mojom::blink::ElementWiseBinary::Kind kind,
    const webnn::SupportedDataTypes& data_type_constraint,
    const MLOperand* a,
    const MLOperand* b,
    const MLOperatorOptions* options,
    ExceptionState& exception_state) {}

MLOperand* BuildUnaryOperator(
    MLGraphBuilder* builder,
    ExceptionState& exception_state,
    webnn::mojom::blink::Operation::Tag kind,
    const webnn::SupportedDataTypes& data_type_constraint,
    const MLOperand* input,
    const MLOperatorOptions* options) {}

MLOperand* BuildElementWiseUnaryOperator(
    MLGraphBuilder* builder,
    ExceptionState& exception_state,
    webnn::mojom::blink::ElementWiseUnary::Kind kind,
    const webnn::SupportedDataTypes& data_type_constraint,
    const MLOperand* input,
    const MLOperatorOptions* options) {}

MLOperand* BuildReduce(MLGraphBuilder* builder,
                       webnn::mojom::blink::Reduce::Kind kind,
                       const webnn::ContextProperties& context_properties,
                       const MLOperand* input,
                       const MLReduceOptions* options,
                       ExceptionState& exception_state) {}

MLOperand* BuildPool2d(MLGraphBuilder* builder,
                       webnn::mojom::blink::Pool2d::Kind kind,
                       const webnn::ContextProperties& context_properties,
                       const MLOperand* input,
                       const MLPool2dOptions* options,
                       ExceptionState& exception_state) {}

// Determines the input and output resources required for this computational
// graph by traversing the graph from `named_outputs` to its inputs.
// This may fail if the graph is not valid.
base::expected<std::pair<MLGraph::NamedOperandDescriptors,
                         MLGraph::NamedOperandDescriptors>,
               String>
DetermineGraphConstraintsFromOutputs(const MLNamedOperands& named_outputs) {}

base::expected<webnn::mojom::blink::GraphInfoPtr, String> BuildWebNNGraphInfo(
    const MLNamedOperands& named_outputs,
    const webnn::ContextProperties& context_properties) {}

}  // namespace

// static
MLGraphBuilder* MLGraphBuilder::Create(ScriptState* script_state,
                                       MLContext* context,
                                       ExceptionState& exception_state) {}

MLGraphBuilder::MLGraphBuilder(
    ExecutionContext* execution_context,
    MLContext* context,
    mojo::PendingAssociatedRemote<webnn::mojom::blink::WebNNGraphBuilder>
        pending_remote)
    :{}

MLGraphBuilder::~MLGraphBuilder() = default;

void MLGraphBuilder::Trace(Visitor* visitor) const {}

MLContext* MLGraphBuilder::GetContext() const {}

MLOperand* MLGraphBuilder::input(String name,
                                 const MLOperandDescriptor* desc,
                                 ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::constant(const MLOperandDescriptor* desc,
                                    NotShared<DOMArrayBufferView> buffer_view,
                                    ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::argMin(const MLOperand* input,
                                  const uint32_t axis,
                                  const MLArgMinMaxOptions* options,
                                  ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::argMax(const MLOperand* input,
                                  const uint32_t axis,
                                  const MLArgMinMaxOptions* options,
                                  ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::batchNormalization(
    const MLOperand* input,
    const MLOperand* mean,
    const MLOperand* variance,
    const MLBatchNormalizationOptions* options,
    ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::concat(const HeapVector<Member<MLOperand>>& inputs,
                                  const uint32_t axis,
                                  const MLOperatorOptions* options,
                                  ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::clamp(const MLOperand* input,
                                 const MLClampOptions* options,
                                 ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::conv2d(const MLOperand* input,
                                  const MLOperand* filter,
                                  const MLConv2dOptions* options,
                                  ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::convTranspose2d(
    const MLOperand* input,
    const MLOperand* filter,
    const MLConvTranspose2dOptions* options,
    ExceptionState& exception_state) {}

#define BUILD_ELEMENTWISE_BINARY_OP(op, op_kind)

BUILD_ELEMENTWISE_BINARY_OP()
BUILD_ELEMENTWISE_BINARY_OP()
BUILD_ELEMENTWISE_BINARY_OP()
BUILD_ELEMENTWISE_BINARY_OP()
BUILD_ELEMENTWISE_BINARY_OP()
BUILD_ELEMENTWISE_BINARY_OP()
BUILD_ELEMENTWISE_BINARY_OP()
BUILD_ELEMENTWISE_BINARY_OP()
BUILD_ELEMENTWISE_BINARY_OP()
BUILD_ELEMENTWISE_BINARY_OP()

MLOperand* MLGraphBuilder::greaterOrEqual(const MLOperand* a,
                                          const MLOperand* b,
                                          const MLOperatorOptions* options,
                                          ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::lesserOrEqual(const MLOperand* a,
                                         const MLOperand* b,
                                         const MLOperatorOptions* options,
                                         ExceptionState& exception_state) {}

#define BUILD_ELEMENTWISE_UNARY_OP(op, op_kind)

BUILD_ELEMENTWISE_UNARY_OP()
BUILD_ELEMENTWISE_UNARY_OP()
BUILD_ELEMENTWISE_UNARY_OP()
BUILD_ELEMENTWISE_UNARY_OP()
BUILD_ELEMENTWISE_UNARY_OP()
BUILD_ELEMENTWISE_UNARY_OP()
BUILD_ELEMENTWISE_UNARY_OP()
BUILD_ELEMENTWISE_UNARY_OP()
BUILD_ELEMENTWISE_UNARY_OP()
BUILD_ELEMENTWISE_UNARY_OP()
BUILD_ELEMENTWISE_UNARY_OP()
BUILD_ELEMENTWISE_UNARY_OP()
BUILD_ELEMENTWISE_UNARY_OP()
BUILD_ELEMENTWISE_UNARY_OP()

MLOperand* MLGraphBuilder::logicalNot(const MLOperand* input,
                                      const MLOperatorOptions* options,
                                      ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::cast(const MLOperand* input,
                                const V8MLOperandDataType output_data_type,
                                const MLOperatorOptions* options,
                                ExceptionState& exception_state) {}

#define BUILD_REDUCE_OP(op, op_kind)

BUILD_REDUCE_OP()
BUILD_REDUCE_OP()
BUILD_REDUCE_OP()
BUILD_REDUCE_OP()
BUILD_REDUCE_OP()
BUILD_REDUCE_OP()
BUILD_REDUCE_OP()
BUILD_REDUCE_OP()
BUILD_REDUCE_OP()
BUILD_REDUCE_OP()

MLOperand* MLGraphBuilder::elu(const MLOperand* input,
                               const MLEluOptions* options,
                               ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::expand(const MLOperand* input,
                                  const Vector<uint32_t>& new_shape,
                                  const MLOperatorOptions* options,
                                  ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::gather(const MLOperand* input,
                                  const MLOperand* indices,
                                  const MLGatherOptions* options,
                                  ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::gatherElements(const MLOperand* input,
                                          const MLOperand* indices,
                                          const MLGatherOptions* options,
                                          ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::gelu(const MLOperand* input,
                                const MLOperatorOptions* options,
                                ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::gemm(const MLOperand* a,
                                const MLOperand* b,
                                const MLGemmOptions* options,
                                ExceptionState& exception_state) {}

HeapVector<Member<const MLOperand>> MLGraphBuilder::gru(
    const MLOperand* input,
    const MLOperand* weight,
    const MLOperand* recurrent_weight,
    const uint32_t steps,
    const uint32_t hidden_size,
    MLGruOptions* options,
    ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::gruCell(const MLOperand* input,
                                   const MLOperand* weight,
                                   const MLOperand* recurrent_weight,
                                   const MLOperand* hidden_state,
                                   const uint32_t hidden_size,
                                   MLGruCellOptions* options,
                                   ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::hardSigmoid(const MLOperand* input,
                                       const MLHardSigmoidOptions* options,
                                       ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::hardSwish(const MLOperand* input,
                                     const MLOperatorOptions* options,
                                     ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::instanceNormalization(
    const MLOperand* input,
    const MLInstanceNormalizationOptions* options,
    ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::layerNormalization(
    const MLOperand* input,
    const MLLayerNormalizationOptions* options,
    ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::leakyRelu(const MLOperand* input,
                                     const MLLeakyReluOptions* options,
                                     ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::linear(const MLOperand* input,
                                  const MLLinearOptions* options,
                                  ExceptionState& exception_state) {}

HeapVector<Member<const MLOperand>> MLGraphBuilder::lstm(
    const MLOperand* input,
    const MLOperand* weight,
    const MLOperand* recurrent_weight,
    const uint32_t steps,
    const uint32_t hidden_size,
    MLLstmOptions* options,
    ExceptionState& exception_state) {}

HeapVector<Member<const MLOperand>> MLGraphBuilder::lstmCell(
    const MLOperand* input,
    const MLOperand* weight,
    const MLOperand* recurrent_weight,
    const MLOperand* hidden_state,
    const MLOperand* cell_state,
    const uint32_t hidden_size,
    MLLstmCellOptions* options,
    ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::matmul(const MLOperand* a,
                                  const MLOperand* b,
                                  const MLOperatorOptions* options,
                                  ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::pad(ScriptState* script_state,
                               const MLOperand* input,
                               const Vector<uint32_t>& beginning_padding,
                               const Vector<uint32_t>& ending_padding,
                               const MLPadOptions* options,
                               ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::averagePool2d(const MLOperand* input,
                                         const MLPool2dOptions* options,
                                         ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::l2Pool2d(const MLOperand* input,
                                    const MLPool2dOptions* options,
                                    ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::maxPool2d(const MLOperand* input,
                                     const MLPool2dOptions* options,
                                     ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::prelu(const MLOperand* input,
                                 const MLOperand* slope,
                                 const MLOperatorOptions* options,
                                 ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::relu(const MLOperand* input,
                                const MLOperatorOptions* options,
                                ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::reshape(const MLOperand* input,
                                   const Vector<uint32_t>& new_shape,
                                   const MLOperatorOptions* options,
                                   ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::resample2d(ScriptState* script_state,
                                      const MLOperand* input,
                                      const MLResample2dOptions* options,
                                      ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::sigmoid(const MLOperand* input,
                                   const MLOperatorOptions* options,
                                   ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::slice(const MLOperand* input,
                                 const Vector<uint32_t>& starts,
                                 const Vector<uint32_t>& sizes,
                                 const MLOperatorOptions* options,
                                 ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::softmax(const MLOperand* input,
                                   uint32_t axis,
                                   const MLOperatorOptions* options,
                                   ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::softmax(const MLOperand* input,
                                   const MLOperatorOptions* options,
                                   ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::softplus(const MLOperand* input,
                                    const MLOperatorOptions* options,
                                    ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::softsign(const MLOperand* input,
                                    const MLOperatorOptions* options,
                                    ExceptionState& exception_state) {}

HeapVector<Member<const MLOperand>> MLGraphBuilder::split(
    const MLOperand* input,
    const uint32_t splits,
    const MLSplitOptions* options,
    ExceptionState& exception_state) {}

HeapVector<Member<const MLOperand>> MLGraphBuilder::split(
    const MLOperand* input,
    const Vector<uint32_t>& splits,
    const MLSplitOptions* options,
    ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::tanh(const MLOperand* input,
                                const MLOperatorOptions* options,
                                ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::transpose(const MLOperand* input,
                                     const MLTransposeOptions* options,
                                     ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::triangular(const MLOperand* input,
                                      const MLTriangularOptions* options,
                                      ExceptionState& exception_state) {}

MLOperand* MLGraphBuilder::where(const MLOperand* condition,
                                 const MLOperand* true_value,
                                 const MLOperand* false_value,
                                 const MLOperatorOptions* options,
                                 ExceptionState& exception_state) {}

ScriptPromise<MLGraph> MLGraphBuilder::build(
    ScriptState* script_state,
    const MLNamedOperands& named_outputs,
    ExceptionState& exception_state) {}

void MLGraphBuilder::DidCreateWebNNGraph(
    ScriptPromiseResolver<blink::MLGraph>* resolver,
    std::pair<MLGraph::NamedOperandDescriptors,
              MLGraph::NamedOperandDescriptors> input_and_output_constraints,
    webnn::mojom::blink::CreateGraphResultPtr result) {}

void MLGraphBuilder::OnConnectionError() {}

base::expected<void, String> MLGraphBuilder::ValidateGraphBuilderState() const {}

// As specified in https://www.w3.org/TR/webnn/#mlgraphbuilder-validate-operand.
base::expected<void, String> MLGraphBuilder::ValidateInput(
    const MLOperand* input) {}

base::expected<void, String> MLGraphBuilder::ValidateInputs(
    const HeapVector<Member<const MLOperand>>& inputs) {}

}  // namespace blink