chromium/services/webnn/public/cpp/graph_validation_utils.h

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef SERVICES_WEBNN_PUBLIC_CPP_GRAPH_VALIDATION_UTILS_H_
#define SERVICES_WEBNN_PUBLIC_CPP_GRAPH_VALIDATION_UTILS_H_

#include <optional>
#include <vector>

#include "base/component_export.h"
#include "base/containers/enum_set.h"
#include "base/containers/span.h"
#include "base/types/expected.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "third_party/abseil-cpp/absl/types/variant.h"

namespace webnn {

std::string COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
    DataTypeConstraintToString(const SupportedDataTypes& constraint_set);

// Represents the `MLConv2dFilterOperandLayout` that specifies the layout format
// of the filter tensor. O is output channels, I is input channels / groups, H
// is height and W is the width of filter.
enum class Conv2dFilterOperandLayout {};

// Represents the `MLConvTranspose2dFilterOperandLayout` that specifies the
// layout format of the filter tensor. I is input channels, O is output channels
// / groups, H is height and W is the width of filter.
enum class ConvTranspose2dFilterOperandLayout {};

enum class Pool2dKind {};

// Represents the `MLRoundingType` that is used to compute the output shape.
enum class RoundingType {};

// Represents the `MLRecurrentNetworkDirection` that specifies the processing
// direction of the input sequence.
enum class RecurrentNetworkDirection {};

enum class ReduceKind {};

// A size has height and width values.
template <typename T>
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) Size2d {};

// The additional rows and columns added to the beginning and ending of each
// spatial dimension of input.
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) Padding2d {};

// Contains the attributes of batchNormalization operator.
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) BatchNormalizationAttributes {};

// Contains the attributes of conv2d operator.
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) Conv2dAttributesBase {};

// Contains the attributes of conv2d operator.
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) Conv2dAttributes
    : Conv2dAttributesBase {};

// Contains the attributes of convTranspose2d operator.
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) ConvTranspose2dAttributes
    : Conv2dAttributesBase {};

// Contains the attributes of pool2d operator.
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) Pool2dAttributes {};

// Contains the attributes of gemm operator.
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) GemmAttributes {};

// Contains the attributes of gru operator.
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) GruAttributes {};

// Contains the attributes of gruCell operator.
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) GruCellAttributes {};

// Contains the attributes of instanceNormalization operator.
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) InstanceNormalizationAttributes {};

// Contains the attributes of layerNormalization operator.
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) LayerNormalizationAttributes {};

struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) LstmAttributes {};

struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) LstmCellAttributes {};

struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) SliceAttributes {};

// Validate argMin and argMax operators defined in WebIDL here:
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-argminmax.
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateArgMinMaxAndInferOutput(const ContextProperties& context_properties,
                                    const OperandDescriptor& input,
                                    std::string_view label,
                                    uint32_t axis,
                                    OperandDataType output_data_type,
                                    bool keep_dimensions = false);

// Validate softmax operator defined in WebIDL here:
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-softmax.
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateSoftmaxAndInferOutput(const ContextProperties& context_properties,
                                  const OperandDescriptor& input,
                                  uint32_t axis,
                                  std::string_view label);

// Contains the attributes of the split operator.
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) SplitAttribute {};

// Validate and infer the output tensors' ranks and sizes for split operator
// based on the WebNN WebIDL
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-split
base::expected<std::vector<OperandDescriptor>, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateSplitAndInferOutput(const ContextProperties& context_properties,
                                const OperandDescriptor& input,
                                const SplitAttribute& attributes);

// Validate and infer output information of batchNormalization operator defined
// in WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-batchnorm.
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateBatchNormalizationAndInferOutput(
        const OperandDescriptor& input,
        const OperandDescriptor& mean,
        const OperandDescriptor& variance,
        const BatchNormalizationAttributes& attributes);

// Validate and infer output information of 2-D convolution operator defined in
// WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-conv2d
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateConv2dAndInferOutput(const OperandDescriptor& input,
                                 const OperandDescriptor& filter,
                                 const Conv2dAttributes& attributes);

// Validate and infer output information of 2-D transposed convolution operator
// defined in WebIDL here
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-convtranspose2d
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateConvTranspose2dAndInferOutput(
        const OperandDescriptor& input,
        const OperandDescriptor& filter,
        const ConvTranspose2dAttributes& attributes);

// Validate and infer output information of pad operator defined in
// WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-pad
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidatePadAndInferOutput(const ContextProperties& context_properties,
                              const OperandDescriptor& input,
                              base::span<const uint32_t> beginning_padding,
                              base::span<const uint32_t> ending_padding,
                              std::string_view label);

// Validate and infer output information of matmul operator defined in
// WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-matmul
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateMatmulAndInferOutput(const ContextProperties& context_properties,
                                 const OperandDescriptor& a,
                                 const OperandDescriptor& b,
                                 std::string_view label);

// Validate and infer output information of 2-D pooling operator defined in
// WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-pool2d
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidatePool2dAndInferOutput(const ContextProperties& context_properties,
                                 const OperandDescriptor& input,
                                 const Pool2dAttributes& attributes,
                                 Pool2dKind kind);

// Validate and infer output information of 2-D resample operator defined in
// WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-resample2d
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateResample2dAndInferOutput(
        const ContextProperties& context_properties,
        const OperandDescriptor& input,
        const absl::variant<base::span<const float>,
                            base::span<const uint32_t>>& scales_or_sizes,
        base::span<const uint32_t> axes,
        std::string_view label);

// Validate and infer output information of gather operator defined in
// WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-gather
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateGatherAndInferOutput(const ContextProperties& context_properties,
                                 const OperandDescriptor& input,
                                 const OperandDescriptor& indices,
                                 const uint32_t axis,
                                 std::string_view label);

// Validate and infer output information of gatherElements operator defined in
// WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-gatherElements
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateGatherElementsAndInferOutput(
        const ContextProperties& context_properties,
        const OperandDescriptor& input,
        const OperandDescriptor& indices,
        const uint32_t axis,
        std::string_view label);

// Validate gemm operator defined in WebIDL here
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-gemm
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateGemmAndInferOutput(const ContextProperties& context_properties,
                               const OperandDescriptor& a,
                               const OperandDescriptor& b,
                               const GemmAttributes& attributes);

// Validate and infer output information of gru operator defined in WebIDL here
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-gru.
base::expected<std::vector<OperandDescriptor>, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateGruAndInferOutput(const OperandDescriptor& input,
                              const OperandDescriptor& weight,
                              const OperandDescriptor& recurrent_weight,
                              uint32_t steps,
                              uint32_t hidden_size,
                              const GruAttributes& attributes);

// Validate and infer output information of gruCell operator defined in WebIDL
// here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-grucell.
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateGruCellAndInferOutput(const OperandDescriptor& input,
                                  const OperandDescriptor& weight,
                                  const OperandDescriptor& recurrent_weight,
                                  const OperandDescriptor& hidden_state,
                                  uint32_t hidden_size,
                                  const GruCellAttributes& attributes);

// Validate and infer output information of instanceNormalization operator
// defined in WebIDL here
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-instancenorm.
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateInstanceNormalizationAndInferOutput(
        const OperandDescriptor& input,
        const InstanceNormalizationAttributes& attributes);

// Validate and infer output information of layerNormalization operator defined
// in WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-layernorm.
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateLayerNormalizationAndInferOutput(
        const OperandDescriptor& input,
        base::span<const uint32_t> axes,
        const LayerNormalizationAttributes& attributes);

// Validate and infer output information of lstm operator defined
// in WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-lstm.
base::expected<std::vector<OperandDescriptor>, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateLstmAndInferOutput(const OperandDescriptor& input,
                               const OperandDescriptor& weight,
                               const OperandDescriptor& recurrent_weight,
                               const uint32_t steps,
                               const uint32_t hidden_size,
                               const LstmAttributes& attributes);

// Validate and infer output information of lstmCell operator defined
// in WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-lstmcell.
base::expected<std::vector<OperandDescriptor>, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateLstmCellAndInferOutput(const OperandDescriptor& input,
                                   const OperandDescriptor& weight,
                                   const OperandDescriptor& recurrent_weight,
                                   const OperandDescriptor& hidden_state,
                                   const OperandDescriptor& cell_state,
                                   const uint32_t hidden_size,
                                   const LstmCellAttributes& attributes);

// Validate concat operator defined in WebIDL here
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-concat
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateConcatAndInferOutput(const ContextProperties& context_properties,
                                 const std::vector<OperandDescriptor>& input,
                                 const uint32_t axis,
                                 std::string_view label);

// Validate prelu operator defined in WebIDL here:
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-prelu
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidatePreluAndInferOutput(const ContextProperties& context_properties,
                                const OperandDescriptor& input,
                                const OperandDescriptor& slope,
                                std::string_view label);

// Validate transpose operator defined in WebIDL here
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-transpose
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateTransposeAndInferOutput(const ContextProperties& context_properties,
                                    const OperandDescriptor& input,
                                    base::span<const uint32_t> permutation,
                                    std::string_view label);

// Validate slice operator defined in WebIDL here:
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-slice
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateSliceAndInferOutput(const ContextProperties& context_properties,
                                const OperandDescriptor& input,
                                const SliceAttributes& attributes);

// Validate and infer output information of reduce operator defined in
// WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-reduce
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateReduceAndInferOutput(const ContextProperties& context_properties,
                                 ReduceKind kind,
                                 const OperandDescriptor& input,
                                 std::string_view label,
                                 base::span<const uint32_t> axes,
                                 bool keepDimensions = false);

// Validate triangular operator defined in WebIDL here:
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-triangular.
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateTriangularAndInferOutput(
        const ContextProperties& context_properties,
        const OperandDescriptor& input,
        std::string_view label);

// Validate where operator defined in WebIDL here:
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-where.
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
    WEBNN_PUBLIC_CPP)
    ValidateWhereAndInferOutput(const ContextProperties& context_properties,
                                const OperandDescriptor& condition,
                                const OperandDescriptor& true_value,
                                const OperandDescriptor& false_value,
                                std::string_view label);

// Validate the creation of an MLBuffer given `descriptor`.
base::expected<void, std::string> COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
    ValidateBuffer(const ContextProperties& context_properties,
                   OperandDescriptor descriptor);

// Validate that the axes are within the range of [0, rank - 1] without
// duplication.
base::expected<void, std::string> COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
    ValidateAxes(base::span<const uint32_t> axes,
                 uint32_t rank,
                 std::string_view label);

// Broadcast the input shapes and return the output shape.
// If bidirectional is true, its behavior follows the numpy-broadcasting-rule:
// https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules.
// Otherwise, it unidirectionally broadcasts the lhs to the rhs.
std::optional<std::vector<uint32_t>> COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
    BroadcastShapes(base::span<const uint32_t> dims_lhs,
                    base::span<const uint32_t> dims_rhs,
                    bool bidirectional = true);

// Calculate the output size for convTranspose2d based on WebNN spec:
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-convtranspose2d
// Return the calculated output size if no error.
base::expected<uint32_t, std::string> COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
    CalculateConvTranspose2dOutputSize(const uint32_t input_size,
                                       const uint32_t filter_size,
                                       const uint32_t beginning_padding,
                                       const uint32_t ending_padding,
                                       const uint32_t stride,
                                       const uint32_t dilation,
                                       const uint32_t output_padding);

bool COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
    IsFloatingPointType(OperandDataType data_type);

// A depthwise conv2d operation is a variant of grouped convolution where the
// options.groups == input_channels == output_channels according to WebNN conv2d
// spec: https://www.w3.org/TR/webnn/#api-mlgraphbuilder-conv2d.
bool COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
    IsDepthwiseConv2d(uint32_t input_channels,
                      uint32_t output_channels,
                      uint32_t groups);

}  // namespace webnn

#endif  // SERVICES_WEBNN_PUBLIC_CPP_GRAPH_VALIDATION_UTILS_H_