#include "services/webnn/webnn_graph_builder_impl.h"
#include "base/containers/fixed_flat_map.h"
#include "base/containers/flat_map.h"
#include "base/functional/callback_forward.h"
#include "base/functional/callback_helpers.h"
#include "base/types/pass_key.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/graph_validation_utils.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/supported_data_types.h"
#include "services/webnn/public/mojom/webnn_error.mojom.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_graph_impl.h"
#include "services/webnn/webnn_utils.h"
namespace webnn {
namespace {
IdToOperandMap;
webnn::InputOperandLayout MojoInputOperandLayoutToComponent(
webnn::mojom::InputOperandLayout layout) { … }
webnn::Pool2dKind FromMojoPool2dType(mojom::Pool2d::Kind kind) { … }
webnn::ReduceKind MojoReduceTypeToComponent(mojom::Reduce::Kind kind) { … }
webnn::RecurrentNetworkDirection MojoRecurrentNetworkDirectionToComponent(
mojom::RecurrentNetworkDirection direction) { … }
bool ValidateClampAttributes(const mojom::Clamp& clamp) { … }
bool ValidateEluAttributes(const mojom::Elu& elu) { … }
bool ValidateHardSigmoidAttributes(const mojom::HardSigmoid& hard_sigmoid) { … }
bool ValidateLeakyReluAttributes(const mojom::LeakyRelu& leaky_relu) { … }
bool ValidateLinearAttributes(const mojom::Linear& linear) { … }
const mojom::Operand* GetMojoOperand(const IdToOperandMap& id_to_operand_map,
uint64_t operand_id) { … }
webnn::BatchNormalizationAttributes ConvertToBatchNormalizationAttributes(
const IdToOperandMap& id_to_operand_map,
const mojom::BatchNormalization& batch_normalization) { … }
template <typename Conv2dAttributesType>
Conv2dAttributesType ConvertToConv2dAttributes(
const webnn::ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const webnn::mojom::Conv2d& conv2d,
std::optional<OperandDescriptor> bias_operand) { … }
webnn::Conv2dAttributes ConvertToConv2dAttributes(
const webnn::ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const webnn::mojom::Conv2d& conv2d,
std::optional<OperandDescriptor> bias_operand) { … }
webnn::LstmAttributes ConvertToLstmAttributes(
const IdToOperandMap& id_to_operand_map,
const webnn::mojom::Lstm& lstm) { … }
webnn::LstmCellAttributes ConvertToLstmCellAttributes(
const IdToOperandMap& id_to_operand_map,
const webnn::mojom::LstmCell& lstm_cell) { … }
webnn::ConvTranspose2dAttributes ConvertToConvTranspose2dAttributes(
const webnn::ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const webnn::mojom::Conv2d& conv2d,
std::optional<OperandDescriptor> bias_operand) { … }
webnn::LayerNormalizationAttributes ConvertToLayerNormalizationAttributes(
const IdToOperandMap& id_to_operand_map,
const mojom::LayerNormalization& layer_normalization) { … }
webnn::Pool2dAttributes ConvertToPool2dAttributes(
const webnn::ContextProperties& context_properties,
const webnn::mojom::Pool2d& pool2d,
const mojom::Operand* output) { … }
webnn::GemmAttributes ConvertToGemmAttributes(
const IdToOperandMap& id_to_operand_map,
const mojom::Gemm& gemm) { … }
webnn::GruAttributes ConvertToGruAttributes(
const IdToOperandMap& id_to_operand_map,
const webnn::mojom::Gru& gru) { … }
webnn::GruCellAttributes ConvertToGruCellAttributes(
const IdToOperandMap& id_to_operand_map,
const webnn::mojom::GruCell& gru_cell) { … }
webnn::InstanceNormalizationAttributes ConvertToInstanceNormalizationAttributes(
const IdToOperandMap& id_to_operand_map,
const mojom::InstanceNormalization& instance_normalization) { … }
webnn::SliceAttributes ConvertToSliceAttributes(
const webnn::mojom::Slice& slice) { … }
template <typename Operation>
bool ValidateUnaryOperation(const IdToOperandMap& id_to_operand_map,
const Operation& operation,
const webnn::SupportedDataTypes& input_constraint,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateCastOperation(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::ElementWiseUnary& operation,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateBatchNormalization(
const IdToOperandMap& id_to_operand_map,
const mojom::BatchNormalization& batch_normalization,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateArgMinMax(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::ArgMinMax& arg_min_max,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateClamp(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Clamp& clamp,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateConcat(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Concat& concat,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateConv2d(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Conv2d& conv2d,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateElementWiseBinaryDataTypes(
const ContextProperties& context_properties,
const mojom::Operand* lhs,
const mojom::Operand* rhs,
const mojom::Operand* output,
const mojom::ElementWiseBinary& operation) { … }
bool ValidateElementWiseBinary(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::ElementWiseBinary& operation,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateElu(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Elu& elu,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateElementWiseUnary(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::ElementWiseUnary& operation,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateExpand(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Expand& expand,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateGather(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Gather& gather,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateGatherElements(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::GatherElements& gather_elements,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateGemm(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Gemm& gemm,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateGru(const IdToOperandMap& id_to_operand_map,
const mojom::Gru& gru,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateGruCell(const IdToOperandMap& id_to_operand_map,
const mojom::GruCell& gru_cell,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateHardSigmoid(const IdToOperandMap& id_to_operand_map,
const mojom::HardSigmoid& hard_sigmoid,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateLayerNormalization(
const IdToOperandMap& id_to_operand_map,
const mojom::LayerNormalization& layer_normalization,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateLeakyRelu(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::LeakyRelu& leaky_relu,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateLinear(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Linear& linear,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateLstm(const IdToOperandMap& id_to_operand_map,
const mojom::Lstm& lstm,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateLstmCell(const IdToOperandMap& id_to_operand_map,
const mojom::LstmCell& lstm_cell,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateInstanceNormalization(
const IdToOperandMap& id_to_operand_map,
const mojom::InstanceNormalization& instance_normalization,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateMatmul(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Matmul& matmul,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidatePad(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Pad& pad,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidatePool2d(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Pool2d& pool2d,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidatePrelu(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Prelu& prelu,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateResample2d(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Resample2d& resample2d,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateReshape(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Reshape& reshape,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateSlice(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Slice& slice,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateSoftmax(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Softmax& softmax,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateSplit(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Split& split,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateTranspose(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Transpose& transpose,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateTriangular(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Triangular& triangular,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateWhere(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Where& where,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateReduce(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Reduce& reduce,
base::flat_set<uint64_t>& processed_operands) { … }
bool ValidateOperation(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::Operation& operation,
base::flat_set<uint64_t>& processed_operands) { … }
}
WebNNGraphBuilderImpl::WebNNGraphBuilderImpl(WebNNContextImpl& context)
: … { … }
WebNNGraphBuilderImpl::~WebNNGraphBuilderImpl() = default;
void WebNNGraphBuilderImpl::CreateGraph(mojom::GraphInfoPtr graph_info,
CreateGraphCallback callback) { … }
void WebNNGraphBuilderImpl::SetId(
mojo::ReceiverId id,
base::PassKey<WebNNContextImpl> ) { … }
void WebNNGraphBuilderImpl::DidCreateGraph(
CreateGraphCallback callback,
base::expected<std::unique_ptr<WebNNGraphImpl>, mojom::ErrorPtr> result) { … }
std::optional<WebNNGraphImpl::ComputeResourceInfo>
WebNNGraphBuilderImpl::ValidateGraph(
const ContextProperties& context_properties,
const mojom::GraphInfo& graph_info) { … }
bool WebNNGraphBuilderImpl::IsValidForTesting(
const ContextProperties& context_properties,
const mojom::GraphInfo& graph_info) { … }
void WebNNGraphBuilderImpl::DestroySelf() { … }
}