#include "services/webnn/tflite/graph_builder_tflite.h"
#include <cstdint>
#include <numeric>
#include <vector>
#include "base/containers/fixed_flat_set.h"
#include "base/containers/span.h"
#include "base/numerics/checked_math.h"
#include "base/numerics/safe_conversions.h"
#include "base/ranges/algorithm.h"
#include "base/strings/stringprintf.h"
#include "base/types/expected.h"
#include "base/types/expected_macros.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/graph_validation_utils.h"
#include "services/webnn/public/cpp/webnn_errors.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/webnn_utils.h"
#include "third_party/tflite/src/tensorflow/lite/schema/schema_generated.h"
namespace webnn::tflite {
namespace {
#define TFLITE_SCHEMA_VERSION …
template <typename DataType>
requires internal::IsSupportedTensorType<DataType>
struct TensorTypeMap;
template <>
struct TensorTypeMap<float> { … };
template <>
struct TensorTypeMap<int32_t> { … };
template <>
struct TensorTypeMap<uint32_t> { … };
template <>
struct TensorTypeMap<int64_t> { … };
base::expected<std::vector<int32_t>, std::string> ToSignedDimensions(
base::span<const uint32_t> input_dimensions) { … }
::tflite::TensorType OperandDataTypeToTFLite(OperandDataType data_type) { … }
enum class ClampRange { … };
base::expected<ClampRange, std::string> GetClampRange(
const mojom::Clamp& clamp) { … }
::tflite::BuiltinOperator GetRecurrentNetworkActivation(
mojom::RecurrentNetworkActivation activation) { … }
struct PaddingSizes { … };
std::optional<PaddingSizes> CalculateExplicitPaddingForSamePaddingMode(
uint32_t input_size,
uint32_t filter_size,
uint32_t stride,
uint32_t dilation,
bool is_transposed_conv2d) { … }
struct TfLitePadding { … };
base::expected<TfLitePadding, std::string> GetTfLitePaddingMode(
const mojom::Padding2d& padding2d,
const webnn::Size2d<uint32_t>& input,
const webnn::Size2d<uint32_t>& filter,
const mojom::Size2d& stride,
const mojom::Size2d& dilation,
bool is_transposed_conv2d) { … }
std::vector<uint32_t> GetIndexOfSortedValue(base::span<const uint32_t> axes) { … }
template <typename DataType>
std::vector<DataType> FillMaskTriangular(base::span<const int32_t> dimensions,
bool upper,
int32_t diagonal,
DataType mask) { … }
}
base::expected<flatbuffers::DetachedBuffer, std::string>
GraphBuilderTflite::CreateAndBuild(ContextProperties context_properties,
const mojom::GraphInfo& graph_info) { … }
ContextProperties GraphBuilderTflite::GetContextProperties() { … }
GraphBuilderTflite::GraphBuilderTflite(ContextProperties context_properties,
const mojom::GraphInfo& graph_info)
: … { … }
GraphBuilderTflite::~GraphBuilderTflite() = default;
base::expected<void, std::string> GraphBuilderTflite::SerializeOperand(
uint64_t operand_id,
const mojom::Operand& operand) { … }
base::expected<void, std::string> GraphBuilderTflite::SerializeOperation(
const mojom::Operation& op) { … }
flatbuffers::DetachedBuffer GraphBuilderTflite::FinishAndTakeFlatBuffer(
base::span<const uint64_t> input_operands,
base::span<const uint64_t> output_operands) { … }
uint32_t GraphBuilderTflite::SerializeBuffer(
const mojo_base::BigBuffer& constant) { … }
template <typename DataType>
requires internal::IsSupportedTensorType<DataType>
int32_t GraphBuilderTflite::SerializeTensorWithBuffer(
base::span<const DataType> buffer,
base::span<const int32_t> dimensions) { … }
int32_t GraphBuilderTflite::SerializeTemporaryTensor(
base::span<const int32_t> dimensions,
::tflite::TensorType tensor_type) { … }
uint32_t GraphBuilderTflite::GetOperatorCodeIndex(
::tflite::BuiltinOperator code,
int32_t version) { … }
const mojom::Operand& GraphBuilderTflite::GetOperand(
uint64_t operand_id) const { … }
auto GraphBuilderTflite::SerializeUnaryOperation(
::tflite::BuiltinOperator code,
int32_t input_tensor_index,
int32_t output_tensor_index,
::tflite::BuiltinOptions builtin_options_type,
flatbuffers::Offset<void> builtin_options) -> OperatorOffset { … }
auto GraphBuilderTflite::SerializeCastOperation(
int32_t input_tensor_index,
::tflite::TensorType input_tensor_type,
int32_t output_tensor_index,
::tflite::TensorType output_tensor_type) -> OperatorOffset { … }
auto GraphBuilderTflite::SerializeBinaryOperation(
::tflite::BuiltinOperator code,
int32_t lhs_tensor_index,
int32_t rhs_tensor_index,
int32_t output_tensor_index) -> OperatorOffset { … }
auto GraphBuilderTflite::SerializeConcatOperation(
base::span<const int32_t> input_tensor_indices,
int32_t output_tensor_index,
uint32_t axis) -> OperatorOffset { … }
auto GraphBuilderTflite::SerializeMatmulOperation(int32_t a_tensor_index,
int32_t b_tensor_index,
int32_t output_tensor_index)
-> OperatorOffset { … }
auto GraphBuilderTflite::SerializeLinearOperation(
base::span<const int32_t> input_dimensions,
::tflite::TensorType input_tensor_type,
int32_t input_tensor_index,
int32_t output_tensor_index,
float alpha,
float beta) -> OperatorOffset { … }
auto GraphBuilderTflite::SerializeNormalizationOperation(
base::span<const int32_t> input_dimensions,
::tflite::TensorType input_tensor_type,
int32_t input_tensor_index,
int32_t output_tensor_index,
int32_t mean_tensor_index,
int32_t variance_tensor_index,
float epsilon,
std::optional<int32_t> scale_tensor_index,
std::optional<int32_t> bias_tensor_index) -> OperatorOffset { … }
auto GraphBuilderTflite::SerializeReduceOperation(
::tflite::BuiltinOperator operator_code,
int32_t input_tensor_index,
int32_t output_tensor_index,
base::span<const int32_t> axes,
bool keep_dimensions) -> OperatorOffset { … }
auto GraphBuilderTflite::SerializeReshapeOperation(
int32_t input_tensor_index,
int32_t output_tensor_index,
base::span<const int32_t> new_shape) -> OperatorOffset { … }
auto GraphBuilderTflite::SerializeSliceOperation(
int32_t input_tensor_index,
int32_t output_tensor_index,
base::span<const int32_t> slice_starts,
base::span<const int32_t> slice_sizes)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeTransposeOperation(
int32_t input_tensor_index,
int32_t output_tensor_index,
base::span<const uint32_t> permutation) -> OperatorOffset { … }
auto GraphBuilderTflite::InsertPadOperation(const mojom::Operand& input_operand,
int32_t input_tensor_index,
base::span<const uint32_t> paddings)
-> base::expected<int32_t, std::string> { … }
int32_t GraphBuilderTflite::InsertTransposeOperation(
base::span<const int32_t> input_dimensions,
::tflite::TensorType input_tensor_type,
int32_t input_tensor_index,
base::span<const uint32_t> permutation) { … }
int32_t GraphBuilderTflite::SerializeSubGraphPowMul(
base::span<const int32_t> input_dimensions,
::tflite::TensorType input_tensor_type,
int32_t input_tensor_index,
float pow_exponent,
float mul_alpha) { … }
auto GraphBuilderTflite::SerializeArgMinMax(const mojom::ArgMinMax& arg_min_max)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeBatchNormalization(
const mojom::BatchNormalization& batch_normalization)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeClamp(const mojom::Clamp& clamp)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeConcat(const mojom::Concat& concat)
-> OperatorOffset { … }
auto GraphBuilderTflite::SerializeConv2d(const mojom::Conv2d& conv2d)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeElementWiseBinary(
const mojom::ElementWiseBinary& op) -> OperatorOffset { … }
auto GraphBuilderTflite::SerializeElementWiseUnary(
const mojom::ElementWiseUnary& op)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeElu(const mojom::Elu& elu)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeErf(const mojom::ElementWiseUnary& erf)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeExpand(const mojom::Expand& expand)
-> OperatorOffset { … }
auto GraphBuilderTflite::SerializeGather(const mojom::Gather& gather)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeGelu(const mojom::Gelu& gelu)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeGemm(const mojom::Gemm& gemm)
-> base::expected<OperatorOffset, std::string> { … }
int32_t GraphBuilderTflite::SerializeSubGraphMatmulAdd(
base::span<const int32_t> input_dimensions,
::tflite::TensorType input_tensor_type,
int32_t input_tensor_index,
int32_t weight_tensor_index,
std::optional<int32_t> bias_tensor_index) { … }
auto GraphBuilderTflite::SerializeSubGraphSliceTranspose(
::tflite::TensorType input_tensor_type,
int32_t input_tensor_index,
base::span<const int32_t> slice_starts,
base::span<const int32_t> slice_sizes)
-> base::expected<int32_t, std::string> { … }
auto GraphBuilderTflite::SerializeGruGate(
const GruCellOperation& gru_cell,
GruGateType type,
std::optional<int32_t> reset_gate_tensor_index)
-> base::expected<int32_t, std::string> { … }
GraphBuilderTflite::RecurrentNetworkBase::RecurrentNetworkBase(
base::span<const int32_t> input_dimensions,
::tflite::TensorType input_tensor_type,
int32_t input_tensor_index,
int32_t weight_tensor_index,
int32_t recurrent_weight_tensor_index,
std::optional<int32_t> bias_tensor_index,
std::optional<int32_t> recurrent_bias_tensor_index,
int32_t hidden_state_tensor_index,
int32_t hidden_size,
base::span<const mojom::RecurrentNetworkActivation> activations)
: … { … }
GraphBuilderTflite::RecurrentNetworkBase::~RecurrentNetworkBase() = default;
GraphBuilderTflite::GruCellOperation::GruCellOperation(
base::span<const int32_t> input_dimensions,
::tflite::TensorType input_tensor_type,
int32_t input_tensor_index,
int32_t output_tensor_index,
int32_t weight_tensor_index,
int32_t recurrent_weight_tensor_index,
std::optional<int32_t> bias_tensor_index,
std::optional<int32_t> recurrent_bias_tensor_index,
int32_t hidden_state_tensor_index,
int32_t hidden_size,
bool reset_after,
mojom::GruWeightLayout layout,
base::span<const mojom::RecurrentNetworkActivation> activations)
: … { … }
GraphBuilderTflite::GruCellOperation::~GruCellOperation() = default;
auto GraphBuilderTflite::SerializeGruCell(const mojom::GruCell& gru_cell)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeGruCellOperation(
const GruCellOperation& gru_cell)
-> base::expected<OperatorOffset, std::string> { … }
GraphBuilderTflite::LstmCellOperation::LstmCellOperation(
base::span<const int32_t> input_dimensions,
::tflite::TensorType input_tensor_type,
int32_t input_tensor_index,
base::span<const int32_t> output_tensor_indices,
int32_t weight_tensor_index,
int32_t recurrent_weight_tensor_index,
std::optional<int32_t> bias_tensor_index,
std::optional<int32_t> recurrent_bias_tensor_index,
int32_t hidden_state_tensor_index,
int32_t hidden_size,
int32_t cell_state_tensor_index,
std::optional<int32_t> peephole_weight_tensor_index,
mojom::LstmWeightLayout layout,
base::span<const mojom::RecurrentNetworkActivation> activations)
: … { … }
GraphBuilderTflite::LstmCellOperation::~LstmCellOperation() = default;
base::expected<int32_t, std::string> GraphBuilderTflite::SerializeLstmGate(
const LstmCellOperation& lstm_cell,
LstmGateType type) { … }
auto GraphBuilderTflite::SerializeLstmCellOperation(
const LstmCellOperation& lstm_cell)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeSubGraphSliceSqueeze(
::tflite::TensorType input_tensor_type,
int32_t input_tensor_index,
base::span<const int32_t> slice_starts,
base::span<const int32_t> slice_sizes,
int32_t squeeze_axis) -> base::expected<int32_t, std::string> { … }
template <typename RecurrentNetworkType>
auto GraphBuilderTflite::SerializeRecurrentNetwork(
const RecurrentNetworkType& recurrent_network)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeHardSigmoid(
const mojom::HardSigmoid& hard_sigmoid) -> OperatorOffset { … }
auto GraphBuilderTflite::SerializeHardSwish(const mojom::HardSwish& hard_swish)
-> OperatorOffset { … }
std::tuple<int32_t, int32_t>
GraphBuilderTflite::ComputeMeanAndVarianceForNormalization(
base::span<const int32_t> input_dimensions,
::tflite::TensorType input_tensor_type,
int32_t input_tensor_index,
base::span<const int32_t> spatial_dimensions) { … }
int32_t GraphBuilderTflite::TransposeAndReshapeLayerNormalizationScaleBias(
base::span<const int32_t> input_dimensions,
uint64_t scale_or_bias_operand_id,
base::span<const uint32_t> axes) { … }
auto GraphBuilderTflite::SerializeInstanceNormalization(
const mojom::InstanceNormalization& instance_normalization)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeLayerNormalization(
const mojom::LayerNormalization& layer_normalization)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeLeakyRelu(const mojom::LeakyRelu& leaky_relu)
-> OperatorOffset { … }
auto GraphBuilderTflite::SerializeLinear(const mojom::Linear& linear)
-> OperatorOffset { … }
auto GraphBuilderTflite::SerializeLogicalNot(
const mojom::ElementWiseUnary& logical_not) -> OperatorOffset { … }
auto GraphBuilderTflite::SerializeLstmCell(const mojom::LstmCell& lstm_cell)
-> base::expected<OperatorOffset, std::string> { … }
int32_t GraphBuilderTflite::GetInitialHiddenAndCellState(
std::optional<uint64_t> state_operand_id,
base::span<const int32_t> state_dimensions) { … }
int32_t GraphBuilderTflite::ReshapeHiddenAndCellState(
::tflite::TensorType input_tensor_type,
int32_t input_tensor_index,
base::span<const int32_t> new_shape,
std::optional<int32_t> concat_input_tensor_index,
base::span<const int32_t> concat_output_shape) { … }
auto GraphBuilderTflite::SerializeMatmul(const mojom::Matmul& matmul)
-> OperatorOffset { … }
auto GraphBuilderTflite::SerializePad(const mojom::Pad& pad)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializePool2d(const mojom::Pool2d& pool2d)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializePrelu(const mojom::Prelu& prelu)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeReciprocal(
const mojom::ElementWiseUnary& reciprocal)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeReduce(const mojom::Reduce& reduce)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeReduceSumSquare(const mojom::Reduce& reduce,
int32_t output_tensor_index)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeRelu(const mojom::Relu& relu)
-> OperatorOffset { … }
auto GraphBuilderTflite::SerializeResample2d(
const mojom::Resample2d& resample2d)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeReshape(uint64_t input_operand_id,
uint64_t output_operand_id)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeSigmoid(const mojom::Sigmoid& sigmoid)
-> OperatorOffset { … }
auto GraphBuilderTflite::SerializeSlice(const mojom::Slice& slice)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeSoftmax(const mojom::Softmax& softmax)
-> OperatorOffset { … }
auto GraphBuilderTflite::SerializeSoftplus(const mojom::Softplus& softplus)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeSoftsign(const mojom::Softsign& softsign)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeSplit(const mojom::Split& split)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeTan(const mojom::ElementWiseUnary& tan)
-> OperatorOffset { … }
auto GraphBuilderTflite::SerializeTanh(const mojom::Tanh& tanh)
-> OperatorOffset { … }
auto GraphBuilderTflite::SerializeTriangular(
const mojom::Triangular& triangular)
-> base::expected<OperatorOffset, std::string> { … }
auto GraphBuilderTflite::SerializeTranspose(const mojom::Transpose& transpose)
-> OperatorOffset { … }
auto GraphBuilderTflite::SerializeWhere(const mojom::Where& where)
-> OperatorOffset { … }
}