chromium/services/webnn/webnn_graph_builder_impl.cc

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/webnn/webnn_graph_builder_impl.h"

#include "base/containers/fixed_flat_map.h"
#include "base/containers/flat_map.h"
#include "base/functional/callback_forward.h"
#include "base/functional/callback_helpers.h"
#include "base/types/pass_key.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/graph_validation_utils.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/supported_data_types.h"
#include "services/webnn/public/mojom/webnn_error.mojom.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_graph_impl.h"
#include "services/webnn/webnn_utils.h"

namespace webnn {

namespace {

// Maps the id to its `mojo::Operand`.
IdToOperandMap;

webnn::InputOperandLayout MojoInputOperandLayoutToComponent(
    webnn::mojom::InputOperandLayout layout) {}

webnn::Pool2dKind FromMojoPool2dType(mojom::Pool2d::Kind kind) {}

webnn::ReduceKind MojoReduceTypeToComponent(mojom::Reduce::Kind kind) {}

webnn::RecurrentNetworkDirection MojoRecurrentNetworkDirectionToComponent(
    mojom::RecurrentNetworkDirection direction) {}

bool ValidateClampAttributes(const mojom::Clamp& clamp) {}

bool ValidateEluAttributes(const mojom::Elu& elu) {}

bool ValidateHardSigmoidAttributes(const mojom::HardSigmoid& hard_sigmoid) {}

bool ValidateLeakyReluAttributes(const mojom::LeakyRelu& leaky_relu) {}

bool ValidateLinearAttributes(const mojom::Linear& linear) {}

const mojom::Operand* GetMojoOperand(const IdToOperandMap& id_to_operand_map,
                                     uint64_t operand_id) {}

webnn::BatchNormalizationAttributes ConvertToBatchNormalizationAttributes(
    const IdToOperandMap& id_to_operand_map,
    const mojom::BatchNormalization& batch_normalization) {}

template <typename Conv2dAttributesType>
Conv2dAttributesType ConvertToConv2dAttributes(
    const webnn::ContextProperties& context_properties,
    const IdToOperandMap& id_to_operand_map,
    const webnn::mojom::Conv2d& conv2d,
    std::optional<OperandDescriptor> bias_operand) {}

webnn::Conv2dAttributes ConvertToConv2dAttributes(
    const webnn::ContextProperties& context_properties,
    const IdToOperandMap& id_to_operand_map,
    const webnn::mojom::Conv2d& conv2d,
    std::optional<OperandDescriptor> bias_operand) {}

webnn::LstmAttributes ConvertToLstmAttributes(
    const IdToOperandMap& id_to_operand_map,
    const webnn::mojom::Lstm& lstm) {}

webnn::LstmCellAttributes ConvertToLstmCellAttributes(
    const IdToOperandMap& id_to_operand_map,
    const webnn::mojom::LstmCell& lstm_cell) {}

webnn::ConvTranspose2dAttributes ConvertToConvTranspose2dAttributes(
    const webnn::ContextProperties& context_properties,
    const IdToOperandMap& id_to_operand_map,
    const webnn::mojom::Conv2d& conv2d,
    std::optional<OperandDescriptor> bias_operand) {}

webnn::LayerNormalizationAttributes ConvertToLayerNormalizationAttributes(
    const IdToOperandMap& id_to_operand_map,
    const mojom::LayerNormalization& layer_normalization) {}

webnn::Pool2dAttributes ConvertToPool2dAttributes(
    const webnn::ContextProperties& context_properties,
    const webnn::mojom::Pool2d& pool2d,
    const mojom::Operand* output) {}

webnn::GemmAttributes ConvertToGemmAttributes(
    const IdToOperandMap& id_to_operand_map,
    const mojom::Gemm& gemm) {}

webnn::GruAttributes ConvertToGruAttributes(
    const IdToOperandMap& id_to_operand_map,
    const webnn::mojom::Gru& gru) {}

webnn::GruCellAttributes ConvertToGruCellAttributes(
    const IdToOperandMap& id_to_operand_map,
    const webnn::mojom::GruCell& gru_cell) {}

webnn::InstanceNormalizationAttributes ConvertToInstanceNormalizationAttributes(
    const IdToOperandMap& id_to_operand_map,
    const mojom::InstanceNormalization& instance_normalization) {}

webnn::SliceAttributes ConvertToSliceAttributes(
    const webnn::mojom::Slice& slice) {}

template <typename Operation>
bool ValidateUnaryOperation(const IdToOperandMap& id_to_operand_map,
                            const Operation& operation,
                            const webnn::SupportedDataTypes& input_constraint,
                            base::flat_set<uint64_t>& processed_operands) {}

bool ValidateCastOperation(const ContextProperties& context_properties,
                           const IdToOperandMap& id_to_operand_map,
                           const mojom::ElementWiseUnary& operation,
                           base::flat_set<uint64_t>& processed_operands) {}

bool ValidateBatchNormalization(
    const IdToOperandMap& id_to_operand_map,
    const mojom::BatchNormalization& batch_normalization,
    base::flat_set<uint64_t>& processed_operands) {}

bool ValidateArgMinMax(const ContextProperties& context_properties,
                       const IdToOperandMap& id_to_operand_map,
                       const mojom::ArgMinMax& arg_min_max,
                       base::flat_set<uint64_t>& processed_operands) {}

bool ValidateClamp(const ContextProperties& context_properties,
                   const IdToOperandMap& id_to_operand_map,
                   const mojom::Clamp& clamp,
                   base::flat_set<uint64_t>& processed_operands) {}

bool ValidateConcat(const ContextProperties& context_properties,
                    const IdToOperandMap& id_to_operand_map,
                    const mojom::Concat& concat,
                    base::flat_set<uint64_t>& processed_operands) {}

bool ValidateConv2d(const ContextProperties& context_properties,
                    const IdToOperandMap& id_to_operand_map,
                    const mojom::Conv2d& conv2d,
                    base::flat_set<uint64_t>& processed_operands) {}

bool ValidateElementWiseBinaryDataTypes(
    const ContextProperties& context_properties,
    const mojom::Operand* lhs,
    const mojom::Operand* rhs,
    const mojom::Operand* output,
    const mojom::ElementWiseBinary& operation) {}

bool ValidateElementWiseBinary(const ContextProperties& context_properties,
                               const IdToOperandMap& id_to_operand_map,
                               const mojom::ElementWiseBinary& operation,
                               base::flat_set<uint64_t>& processed_operands) {}

bool ValidateElu(const ContextProperties& context_properties,
                 const IdToOperandMap& id_to_operand_map,
                 const mojom::Elu& elu,
                 base::flat_set<uint64_t>& processed_operands) {}

bool ValidateElementWiseUnary(const ContextProperties& context_properties,
                              const IdToOperandMap& id_to_operand_map,
                              const mojom::ElementWiseUnary& operation,
                              base::flat_set<uint64_t>& processed_operands) {}

bool ValidateExpand(const ContextProperties& context_properties,
                    const IdToOperandMap& id_to_operand_map,
                    const mojom::Expand& expand,
                    base::flat_set<uint64_t>& processed_operands) {}

bool ValidateGather(const ContextProperties& context_properties,
                    const IdToOperandMap& id_to_operand_map,
                    const mojom::Gather& gather,
                    base::flat_set<uint64_t>& processed_operands) {}

bool ValidateGatherElements(const ContextProperties& context_properties,
                            const IdToOperandMap& id_to_operand_map,
                            const mojom::GatherElements& gather_elements,
                            base::flat_set<uint64_t>& processed_operands) {}

bool ValidateGemm(const ContextProperties& context_properties,
                  const IdToOperandMap& id_to_operand_map,
                  const mojom::Gemm& gemm,
                  base::flat_set<uint64_t>& processed_operands) {}

bool ValidateGru(const IdToOperandMap& id_to_operand_map,
                 const mojom::Gru& gru,
                 base::flat_set<uint64_t>& processed_operands) {}

bool ValidateGruCell(const IdToOperandMap& id_to_operand_map,
                     const mojom::GruCell& gru_cell,
                     base::flat_set<uint64_t>& processed_operands) {}

bool ValidateHardSigmoid(const IdToOperandMap& id_to_operand_map,
                         const mojom::HardSigmoid& hard_sigmoid,
                         base::flat_set<uint64_t>& processed_operands) {}

bool ValidateLayerNormalization(
    const IdToOperandMap& id_to_operand_map,
    const mojom::LayerNormalization& layer_normalization,
    base::flat_set<uint64_t>& processed_operands) {}

bool ValidateLeakyRelu(const ContextProperties& context_properties,
                       const IdToOperandMap& id_to_operand_map,
                       const mojom::LeakyRelu& leaky_relu,
                       base::flat_set<uint64_t>& processed_operands) {}

bool ValidateLinear(const ContextProperties& context_properties,
                    const IdToOperandMap& id_to_operand_map,
                    const mojom::Linear& linear,
                    base::flat_set<uint64_t>& processed_operands) {}

bool ValidateLstm(const IdToOperandMap& id_to_operand_map,
                  const mojom::Lstm& lstm,
                  base::flat_set<uint64_t>& processed_operands) {}

bool ValidateLstmCell(const IdToOperandMap& id_to_operand_map,
                      const mojom::LstmCell& lstm_cell,
                      base::flat_set<uint64_t>& processed_operands) {}

bool ValidateInstanceNormalization(
    const IdToOperandMap& id_to_operand_map,
    const mojom::InstanceNormalization& instance_normalization,
    base::flat_set<uint64_t>& processed_operands) {}

bool ValidateMatmul(const ContextProperties& context_properties,
                    const IdToOperandMap& id_to_operand_map,
                    const mojom::Matmul& matmul,
                    base::flat_set<uint64_t>& processed_operands) {}

bool ValidatePad(const ContextProperties& context_properties,
                 const IdToOperandMap& id_to_operand_map,
                 const mojom::Pad& pad,
                 base::flat_set<uint64_t>& processed_operands) {}

bool ValidatePool2d(const ContextProperties& context_properties,
                    const IdToOperandMap& id_to_operand_map,
                    const mojom::Pool2d& pool2d,
                    base::flat_set<uint64_t>& processed_operands) {}

bool ValidatePrelu(const ContextProperties& context_properties,
                   const IdToOperandMap& id_to_operand_map,
                   const mojom::Prelu& prelu,
                   base::flat_set<uint64_t>& processed_operands) {}

bool ValidateResample2d(const ContextProperties& context_properties,
                        const IdToOperandMap& id_to_operand_map,
                        const mojom::Resample2d& resample2d,
                        base::flat_set<uint64_t>& processed_operands) {}

bool ValidateReshape(const ContextProperties& context_properties,
                     const IdToOperandMap& id_to_operand_map,
                     const mojom::Reshape& reshape,
                     base::flat_set<uint64_t>& processed_operands) {}

bool ValidateSlice(const ContextProperties& context_properties,
                   const IdToOperandMap& id_to_operand_map,
                   const mojom::Slice& slice,
                   base::flat_set<uint64_t>& processed_operands) {}

bool ValidateSoftmax(const ContextProperties& context_properties,
                     const IdToOperandMap& id_to_operand_map,
                     const mojom::Softmax& softmax,
                     base::flat_set<uint64_t>& processed_operands) {}

bool ValidateSplit(const ContextProperties& context_properties,
                   const IdToOperandMap& id_to_operand_map,
                   const mojom::Split& split,
                   base::flat_set<uint64_t>& processed_operands) {}

bool ValidateTranspose(const ContextProperties& context_properties,
                       const IdToOperandMap& id_to_operand_map,
                       const mojom::Transpose& transpose,
                       base::flat_set<uint64_t>& processed_operands) {}

bool ValidateTriangular(const ContextProperties& context_properties,
                        const IdToOperandMap& id_to_operand_map,
                        const mojom::Triangular& triangular,
                        base::flat_set<uint64_t>& processed_operands) {}

bool ValidateWhere(const ContextProperties& context_properties,
                   const IdToOperandMap& id_to_operand_map,
                   const mojom::Where& where,
                   base::flat_set<uint64_t>& processed_operands) {}

bool ValidateReduce(const ContextProperties& context_properties,
                    const IdToOperandMap& id_to_operand_map,
                    const mojom::Reduce& reduce,
                    base::flat_set<uint64_t>& processed_operands) {}

bool ValidateOperation(const ContextProperties& context_properties,
                       const IdToOperandMap& id_to_operand_map,
                       const mojom::Operation& operation,
                       base::flat_set<uint64_t>& processed_operands) {}

}  // namespace

WebNNGraphBuilderImpl::WebNNGraphBuilderImpl(WebNNContextImpl& context)
    :{}

WebNNGraphBuilderImpl::~WebNNGraphBuilderImpl() = default;

void WebNNGraphBuilderImpl::CreateGraph(mojom::GraphInfoPtr graph_info,
                                        CreateGraphCallback callback) {}

void WebNNGraphBuilderImpl::SetId(
    mojo::ReceiverId id,
    base::PassKey<WebNNContextImpl> /*pass_key*/) {}

void WebNNGraphBuilderImpl::DidCreateGraph(
    CreateGraphCallback callback,
    base::expected<std::unique_ptr<WebNNGraphImpl>, mojom::ErrorPtr> result) {}

// static
std::optional<WebNNGraphImpl::ComputeResourceInfo>
WebNNGraphBuilderImpl::ValidateGraph(
    const ContextProperties& context_properties,
    const mojom::GraphInfo& graph_info) {}

// static
bool WebNNGraphBuilderImpl::IsValidForTesting(
    const ContextProperties& context_properties,
    const mojom::GraphInfo& graph_info) {}

void WebNNGraphBuilderImpl::DestroySelf() {}

}  // namespace webnn