#ifndef SERVICES_WEBNN_PUBLIC_CPP_GRAPH_VALIDATION_UTILS_H_
#define SERVICES_WEBNN_PUBLIC_CPP_GRAPH_VALIDATION_UTILS_H_
#include <optional>
#include <vector>
#include "base/component_export.h"
#include "base/containers/enum_set.h"
#include "base/containers/span.h"
#include "base/types/expected.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "third_party/abseil-cpp/absl/types/variant.h"
namespace webnn {
std::string COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
DataTypeConstraintToString(const SupportedDataTypes& constraint_set);
enum class Conv2dFilterOperandLayout { … };
enum class ConvTranspose2dFilterOperandLayout { … };
enum class Pool2dKind { … };
enum class RoundingType { … };
enum class RecurrentNetworkDirection { … };
enum class ReduceKind { … };
template <typename T>
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) Size2d { … };
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) Padding2d { … };
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) BatchNormalizationAttributes { … };
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) Conv2dAttributesBase { … };
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) Conv2dAttributes
: Conv2dAttributesBase { … };
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) ConvTranspose2dAttributes
: Conv2dAttributesBase { … };
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) Pool2dAttributes { … };
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) GemmAttributes { … };
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) GruAttributes { … };
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) GruCellAttributes { … };
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) InstanceNormalizationAttributes { … };
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) LayerNormalizationAttributes { … };
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) LstmAttributes { … };
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) LstmCellAttributes { … };
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) SliceAttributes { … };
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateArgMinMaxAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
std::string_view label,
uint32_t axis,
OperandDataType output_data_type,
bool keep_dimensions = false);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateSoftmaxAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
uint32_t axis,
std::string_view label);
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) SplitAttribute { … };
base::expected<std::vector<OperandDescriptor>, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateSplitAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const SplitAttribute& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateBatchNormalizationAndInferOutput(
const OperandDescriptor& input,
const OperandDescriptor& mean,
const OperandDescriptor& variance,
const BatchNormalizationAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateConv2dAndInferOutput(const OperandDescriptor& input,
const OperandDescriptor& filter,
const Conv2dAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateConvTranspose2dAndInferOutput(
const OperandDescriptor& input,
const OperandDescriptor& filter,
const ConvTranspose2dAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidatePadAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
base::span<const uint32_t> beginning_padding,
base::span<const uint32_t> ending_padding,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateMatmulAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& a,
const OperandDescriptor& b,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidatePool2dAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const Pool2dAttributes& attributes,
Pool2dKind kind);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateResample2dAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const absl::variant<base::span<const float>,
base::span<const uint32_t>>& scales_or_sizes,
base::span<const uint32_t> axes,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateGatherAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& indices,
const uint32_t axis,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateGatherElementsAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& indices,
const uint32_t axis,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateGemmAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& a,
const OperandDescriptor& b,
const GemmAttributes& attributes);
base::expected<std::vector<OperandDescriptor>, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateGruAndInferOutput(const OperandDescriptor& input,
const OperandDescriptor& weight,
const OperandDescriptor& recurrent_weight,
uint32_t steps,
uint32_t hidden_size,
const GruAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateGruCellAndInferOutput(const OperandDescriptor& input,
const OperandDescriptor& weight,
const OperandDescriptor& recurrent_weight,
const OperandDescriptor& hidden_state,
uint32_t hidden_size,
const GruCellAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateInstanceNormalizationAndInferOutput(
const OperandDescriptor& input,
const InstanceNormalizationAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateLayerNormalizationAndInferOutput(
const OperandDescriptor& input,
base::span<const uint32_t> axes,
const LayerNormalizationAttributes& attributes);
base::expected<std::vector<OperandDescriptor>, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateLstmAndInferOutput(const OperandDescriptor& input,
const OperandDescriptor& weight,
const OperandDescriptor& recurrent_weight,
const uint32_t steps,
const uint32_t hidden_size,
const LstmAttributes& attributes);
base::expected<std::vector<OperandDescriptor>, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateLstmCellAndInferOutput(const OperandDescriptor& input,
const OperandDescriptor& weight,
const OperandDescriptor& recurrent_weight,
const OperandDescriptor& hidden_state,
const OperandDescriptor& cell_state,
const uint32_t hidden_size,
const LstmCellAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateConcatAndInferOutput(const ContextProperties& context_properties,
const std::vector<OperandDescriptor>& input,
const uint32_t axis,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidatePreluAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& slope,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateTransposeAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
base::span<const uint32_t> permutation,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateSliceAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const SliceAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateReduceAndInferOutput(const ContextProperties& context_properties,
ReduceKind kind,
const OperandDescriptor& input,
std::string_view label,
base::span<const uint32_t> axes,
bool keepDimensions = false);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateTriangularAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateWhereAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& condition,
const OperandDescriptor& true_value,
const OperandDescriptor& false_value,
std::string_view label);
base::expected<void, std::string> COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
ValidateBuffer(const ContextProperties& context_properties,
OperandDescriptor descriptor);
base::expected<void, std::string> COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
ValidateAxes(base::span<const uint32_t> axes,
uint32_t rank,
std::string_view label);
std::optional<std::vector<uint32_t>> COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
BroadcastShapes(base::span<const uint32_t> dims_lhs,
base::span<const uint32_t> dims_rhs,
bool bidirectional = true);
base::expected<uint32_t, std::string> COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
CalculateConvTranspose2dOutputSize(const uint32_t input_size,
const uint32_t filter_size,
const uint32_t beginning_padding,
const uint32_t ending_padding,
const uint32_t stride,
const uint32_t dilation,
const uint32_t output_padding);
bool COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
IsFloatingPointType(OperandDataType data_type);
bool COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
IsDepthwiseConv2d(uint32_t input_channels,
uint32_t output_channels,
uint32_t groups);
}
#endif