#include "services/webnn/public/cpp/graph_validation_utils.h"
#include <algorithm>
#include <numeric>
#include <set>
#include <vector>
#include "base/check_op.h"
#include "base/containers/contains.h"
#include "base/notreached.h"
#include "base/numerics/checked_math.h"
#include "base/numerics/safe_conversions.h"
#include "base/ranges/algorithm.h"
#include "base/strings/strcat.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "base/types/expected.h"
#include "base/types/expected_macros.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/supported_data_types.h"
#include "services/webnn/public/cpp/webnn_errors.h"
namespace webnn {
namespace {
std::string ErrorWithLabel(std::string_view label,
std::string_view error_message) { … }
base::expected<double, std::string> CalculateConv2dOutputSize(
const uint32_t input_size,
const uint32_t filter_size,
const uint32_t beginning_padding,
const uint32_t ending_padding,
const uint32_t stride,
const uint32_t dilation,
std::string_view label) { … }
base::expected<Size2d<double>, std::string>
ValidateAndCalculateConv2dOutputSizes(const uint32_t input_height,
const uint32_t input_width,
const uint32_t filter_height,
const uint32_t filter_width,
const Padding2d& padding,
const Size2d<uint32_t>& strides,
const Size2d<uint32_t>& dilations,
std::string_view label) { … }
base::expected<Size2d<uint32_t>, std::string>
ValidateAndCalculateConvTranspose2dOutputSizes(
const uint32_t input_height,
const uint32_t input_width,
const uint32_t filter_height,
const uint32_t filter_width,
const Padding2d& padding,
const Size2d<uint32_t>& strides,
const Size2d<uint32_t>& dilations,
const Size2d<uint32_t>& output_padding,
std::string_view label) { … }
struct Conv2dInputOutputInfo { … };
base::expected<Conv2dInputOutputInfo, std::string>
ValidateAndGetConv2dInputInfo(const OperandDescriptor& input,
const Conv2dAttributesBase& attributes) { … }
base::expected<OperandDescriptor, std::string>
ValidateConv2dBiasAndCreateOutputOperand(
const OperandDescriptor& input,
const Conv2dAttributesBase& attributes,
const Conv2dInputOutputInfo& output_info) { … }
base::expected<std::vector<uint32_t>, std::string>
ValidateReduceAxesAndInferOutput(base::span<const uint32_t> input_dimensions,
base::span<const uint32_t> axes,
bool keep_dimensions,
std::string_view label) { … }
base::expected<void, std::string> ValidateRecurrentNetworkOperand(
const OperandDescriptor& operand,
const char* operand_name,
base::span<const uint32_t> expected_shape,
OperandDataType input_data_type,
std::string_view label) { … }
}
std::string DataTypeConstraintToString(
const SupportedDataTypes& constraint_set) { … }
base::expected<OperandDescriptor, std::string> ValidateSoftmaxAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
uint32_t axis,
std::string_view label) { … }
base::expected<OperandDescriptor, std::string> ValidateArgMinMaxAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
std::string_view label,
uint32_t axis,
OperandDataType output_data_type,
bool keep_dimensions) { … }
base::expected<std::vector<OperandDescriptor>, std::string>
ValidateSplitAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const SplitAttribute& attributes) { … }
base::expected<void, std::string>
ValidateNormalizationOperandIsCompatibleWithInput(
const OperandDescriptor& operand,
const OperandDataType input_data_type,
size_t input_size_on_axis,
std::string_view label) { … }
BatchNormalizationAttributes::BatchNormalizationAttributes() = default;
BatchNormalizationAttributes::~BatchNormalizationAttributes() = default;
BatchNormalizationAttributes::BatchNormalizationAttributes(
BatchNormalizationAttributes&& other) = default;
BatchNormalizationAttributes& BatchNormalizationAttributes::operator=(
BatchNormalizationAttributes&& other) = default;
base::expected<OperandDescriptor, std::string>
ValidateBatchNormalizationAndInferOutput(
const OperandDescriptor& input,
const OperandDescriptor& mean,
const OperandDescriptor& variance,
const BatchNormalizationAttributes& attributes) { … }
Conv2dAttributesBase::Conv2dAttributesBase() = default;
Conv2dAttributesBase::~Conv2dAttributesBase() = default;
Conv2dAttributesBase::Conv2dAttributesBase(Conv2dAttributesBase&& other) =
default;
Conv2dAttributesBase& Conv2dAttributesBase::operator=(
Conv2dAttributesBase&& other) = default;
Conv2dAttributes::Conv2dAttributes() = default;
Conv2dAttributes::~Conv2dAttributes() = default;
Conv2dAttributes::Conv2dAttributes(Conv2dAttributes&& other) = default;
Conv2dAttributes& Conv2dAttributes::operator=(Conv2dAttributes&& other) =
default;
base::expected<OperandDescriptor, std::string> ValidateConv2dAndInferOutput(
const OperandDescriptor& input,
const OperandDescriptor& filter,
const Conv2dAttributes& attributes) { … }
ConvTranspose2dAttributes::ConvTranspose2dAttributes() = default;
ConvTranspose2dAttributes::~ConvTranspose2dAttributes() = default;
ConvTranspose2dAttributes::ConvTranspose2dAttributes(
ConvTranspose2dAttributes&& other) = default;
ConvTranspose2dAttributes& ConvTranspose2dAttributes::operator=(
ConvTranspose2dAttributes&& other) = default;
base::expected<OperandDescriptor, std::string>
ValidateConvTranspose2dAndInferOutput(
const OperandDescriptor& input,
const OperandDescriptor& filter,
const ConvTranspose2dAttributes& attributes) { … }
base::expected<OperandDescriptor, std::string> ValidatePadAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
base::span<const uint32_t> beginning_padding,
base::span<const uint32_t> ending_padding,
std::string_view label) { … }
base::expected<OperandDescriptor, std::string> ValidateMatmulAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& a,
const OperandDescriptor& b,
std::string_view label) { … }
Pool2dAttributes::Pool2dAttributes() = default;
Pool2dAttributes::~Pool2dAttributes() = default;
Pool2dAttributes::Pool2dAttributes(Pool2dAttributes&& other) = default;
Pool2dAttributes& Pool2dAttributes::operator=(Pool2dAttributes&& other) =
default;
base::expected<OperandDescriptor, std::string> ValidatePool2dAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const Pool2dAttributes& attributes,
Pool2dKind kind) { … }
base::expected<uint32_t, std::string> CalculateResample2dOutputSize(
const uint32_t input_size,
const float scale,
std::string_view label) { … }
base::expected<OperandDescriptor, std::string> ValidateResample2dAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const absl::variant<base::span<const float>, base::span<const uint32_t>>&
scales_or_sizes,
base::span<const uint32_t> axes,
std::string_view label) { … }
base::expected<OperandDescriptor, std::string> ValidateGatherAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& indices,
const uint32_t axis,
std::string_view label) { … }
base::expected<OperandDescriptor, std::string>
ValidateGatherElementsAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& indices,
const uint32_t axis,
std::string_view label) { … }
GemmAttributes::GemmAttributes() = default;
GemmAttributes::~GemmAttributes() = default;
GemmAttributes::GemmAttributes(GemmAttributes&& other) = default;
GemmAttributes& GemmAttributes::operator=(GemmAttributes&& other) = default;
base::expected<OperandDescriptor, std::string> ValidateGemmAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& a,
const OperandDescriptor& b,
const GemmAttributes& attributes) { … }
GruAttributes::GruAttributes() = default;
GruAttributes::~GruAttributes() = default;
GruAttributes::GruAttributes(GruAttributes&& other) = default;
GruAttributes& GruAttributes::operator=(GruAttributes&& other) = default;
base::expected<std::vector<OperandDescriptor>, std::string>
ValidateGruAndInferOutput(const OperandDescriptor& input,
const OperandDescriptor& weight,
const OperandDescriptor& recurrent_weight,
uint32_t steps,
uint32_t hidden_size,
const GruAttributes& attributes) { … }
GruCellAttributes::GruCellAttributes() = default;
GruCellAttributes::~GruCellAttributes() = default;
GruCellAttributes::GruCellAttributes(GruCellAttributes&& other) = default;
GruCellAttributes& GruCellAttributes::operator=(GruCellAttributes&& other) =
default;
base::expected<OperandDescriptor, std::string> ValidateGruCellAndInferOutput(
const OperandDescriptor& input,
const OperandDescriptor& weight,
const OperandDescriptor& recurrent_weight,
const OperandDescriptor& hidden_state,
uint32_t hidden_size,
const GruCellAttributes& attributes) { … }
InstanceNormalizationAttributes::InstanceNormalizationAttributes() = default;
InstanceNormalizationAttributes::~InstanceNormalizationAttributes() = default;
InstanceNormalizationAttributes::InstanceNormalizationAttributes(
InstanceNormalizationAttributes&& other) = default;
InstanceNormalizationAttributes& InstanceNormalizationAttributes::operator=(
InstanceNormalizationAttributes&& other) = default;
base::expected<OperandDescriptor, std::string>
ValidateInstanceNormalizationAndInferOutput(
const OperandDescriptor& input,
const InstanceNormalizationAttributes& attributes) { … }
LayerNormalizationAttributes::LayerNormalizationAttributes() = default;
LayerNormalizationAttributes::~LayerNormalizationAttributes() = default;
LayerNormalizationAttributes::LayerNormalizationAttributes(
LayerNormalizationAttributes&& other) = default;
LayerNormalizationAttributes& LayerNormalizationAttributes::operator=(
LayerNormalizationAttributes&& other) = default;
base::expected<OperandDescriptor, std::string>
ValidateLayerNormalizationAndInferOutput(
const OperandDescriptor& input,
base::span<const uint32_t> axes,
const LayerNormalizationAttributes& attributes) { … }
LstmAttributes::LstmAttributes() = default;
LstmAttributes::~LstmAttributes() = default;
LstmAttributes::LstmAttributes(LstmAttributes&& other) = default;
LstmAttributes& LstmAttributes::operator=(LstmAttributes&& other) = default;
base::expected<std::vector<OperandDescriptor>, std::string>
ValidateLstmAndInferOutput(const OperandDescriptor& input,
const OperandDescriptor& weight,
const OperandDescriptor& recurrent_weight,
const uint32_t steps,
const uint32_t hidden_size,
const LstmAttributes& attributes) { … }
LstmCellAttributes::LstmCellAttributes() = default;
LstmCellAttributes::~LstmCellAttributes() = default;
LstmCellAttributes::LstmCellAttributes(LstmCellAttributes&& other) = default;
LstmCellAttributes& LstmCellAttributes::operator=(LstmCellAttributes&& other) =
default;
base::expected<std::vector<OperandDescriptor>, std::string>
ValidateLstmCellAndInferOutput(const OperandDescriptor& input,
const OperandDescriptor& weight,
const OperandDescriptor& recurrent_weight,
const OperandDescriptor& hidden_state,
const OperandDescriptor& cell_state,
const uint32_t hidden_size,
const LstmCellAttributes& attributes) { … }
base::expected<OperandDescriptor, std::string> ValidateConcatAndInferOutput(
const ContextProperties& context_properties,
const std::vector<OperandDescriptor>& inputs,
const uint32_t axis,
std::string_view label) { … }
base::expected<OperandDescriptor, std::string> ValidatePreluAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& slope,
std::string_view label) { … }
base::expected<OperandDescriptor, std::string> ValidateTransposeAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
base::span<const uint32_t> permutation,
std::string_view label) { … }
SliceAttributes::SliceAttributes() = default;
SliceAttributes::~SliceAttributes() = default;
SliceAttributes::SliceAttributes(SliceAttributes&& other) = default;
SliceAttributes& SliceAttributes::operator=(SliceAttributes&& other) = default;
base::expected<OperandDescriptor, std::string> ValidateSliceAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const SliceAttributes& attributes) { … }
base::expected<OperandDescriptor, std::string> ValidateReduceAndInferOutput(
const ContextProperties& context_properties,
ReduceKind kind,
const OperandDescriptor& input,
std::string_view label,
base::span<const uint32_t> axes,
bool keep_dimensions) { … }
base::expected<OperandDescriptor, std::string> ValidateTriangularAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
std::string_view label) { … }
base::expected<OperandDescriptor, std::string> ValidateWhereAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& condition,
const OperandDescriptor& true_value,
const OperandDescriptor& false_value,
std::string_view label) { … }
base::expected<void, std::string> ValidateAxes(base::span<const uint32_t> axes,
uint32_t rank,
std::string_view label) { … }
base::expected<void, std::string> ValidateBuffer(
const ContextProperties& context_properties,
OperandDescriptor descriptor) { … }
std::optional<std::vector<uint32_t>> BroadcastShapes(
base::span<const uint32_t> dims_lhs,
base::span<const uint32_t> dims_rhs,
bool bidirectional) { … }
base::expected<uint32_t, std::string> CalculateConvTranspose2dOutputSize(
const uint32_t input_size,
const uint32_t filter_size,
const uint32_t beginning_padding,
const uint32_t ending_padding,
const uint32_t stride,
const uint32_t dilation,
const uint32_t output_padding) { … }
bool IsFloatingPointType(OperandDataType data_type) { … }
bool IsDepthwiseConv2d(uint32_t input_channels,
uint32_t output_channels,
uint32_t groups) { … }
}