chromium/services/webnn/tflite/graph_builder_tflite.cc

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/webnn/tflite/graph_builder_tflite.h"

#include <cstdint>
#include <numeric>
#include <vector>

#include "base/containers/fixed_flat_set.h"
#include "base/containers/span.h"
#include "base/numerics/checked_math.h"
#include "base/numerics/safe_conversions.h"
#include "base/ranges/algorithm.h"
#include "base/strings/stringprintf.h"
#include "base/types/expected.h"
#include "base/types/expected_macros.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/graph_validation_utils.h"
#include "services/webnn/public/cpp/webnn_errors.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/webnn_utils.h"
#include "third_party/tflite/src/tensorflow/lite/schema/schema_generated.h"

namespace webnn::tflite {

namespace {

// The version number of the Schema. Ideally all changes will be backward
// compatible. If that ever changes, we must ensure that version is the first
// entry in the new tflite root so that we can see that version is not 1.
#define TFLITE_SCHEMA_VERSION

// Maps a DataType to a `::tflite::TensorType`. Other `TensorTypeMap` overloads
// may be declared below as needed.
//
// Example: TensorTypeMap<uint32_t>::value -> ::tflite::TensorType_UINT32
template <typename DataType>
  requires internal::IsSupportedTensorType<DataType>
struct TensorTypeMap;

template <>
struct TensorTypeMap<float> {};
template <>
struct TensorTypeMap<int32_t> {};
template <>
struct TensorTypeMap<uint32_t> {};
template <>
struct TensorTypeMap<int64_t> {};

// Useful for converting dimension arrays coming from mojo as uint32 to the
// int32 vectors used by TFLite.
base::expected<std::vector<int32_t>, std::string> ToSignedDimensions(
    base::span<const uint32_t> input_dimensions) {}

::tflite::TensorType OperandDataTypeToTFLite(OperandDataType data_type) {}

enum class ClampRange {};

base::expected<ClampRange, std::string> GetClampRange(
    const mojom::Clamp& clamp) {}

::tflite::BuiltinOperator GetRecurrentNetworkActivation(
    mojom::RecurrentNetworkActivation activation) {}

struct PaddingSizes {};

// Helper to calculate the explicit padding for tflite::Padding_SAME mode with
// https://www.tensorflow.org/versions/r2.14/api_docs/python/tf/nn#notes_on_padding_2.
std::optional<PaddingSizes> CalculateExplicitPaddingForSamePaddingMode(
    uint32_t input_size,
    uint32_t filter_size,
    uint32_t stride,
    uint32_t dilation,
    bool is_transposed_conv2d) {}

struct TfLitePadding {};

// Helper to get tflite padding mode for convolution 2d or pooling 2d.
base::expected<TfLitePadding, std::string> GetTfLitePaddingMode(
    const mojom::Padding2d& padding2d,
    const webnn::Size2d<uint32_t>& input,
    const webnn::Size2d<uint32_t>& filter,
    const mojom::Size2d& stride,
    const mojom::Size2d& dilation,
    bool is_transposed_conv2d) {}

// Sort the indexes of the elements in the axes array based on their values and
// return the sorted index array for adding a transpose operation if needed. For
// example input shape is [2, 1, 4, 3], the shape of the scale and bias is [3,
// 1, 4] if axes is [3, 1, 2], the sorted axes would be [1, 2, 3], then the
// permutation would be (sorted indices array) [1, 2, 0].
std::vector<uint32_t> GetIndexOfSortedValue(base::span<const uint32_t> axes) {}

// An element in row `i` and column `j` of a matrix is in the upper-triangular
// portion if `j >= i + diagonal`. It is in the lower-triangular portion if
// `j <= i + diagonal`.
//
// This function generates an upper-triangular or lower-triangular matrix
// with the given mask value. For example, the matrices below are the upper-
// and lower-triangular [3, 3] tensors with a mask value of 1:
// [ 1, 1, 1                    [ 1, 0, 0
//   0, 1, 1,                     1, 1, 0,
//   0, 0, 1]                     1, 1, 1]
template <typename DataType>
std::vector<DataType> FillMaskTriangular(base::span<const int32_t> dimensions,
                                         bool upper,
                                         int32_t diagonal,
                                         DataType mask) {}

}  // namespace

// static
base::expected<flatbuffers::DetachedBuffer, std::string>
GraphBuilderTflite::CreateAndBuild(ContextProperties context_properties,
                                   const mojom::GraphInfo& graph_info) {}

// static
ContextProperties GraphBuilderTflite::GetContextProperties() {}

GraphBuilderTflite::GraphBuilderTflite(ContextProperties context_properties,
                                       const mojom::GraphInfo& graph_info)
    :{}

GraphBuilderTflite::~GraphBuilderTflite() = default;

base::expected<void, std::string> GraphBuilderTflite::SerializeOperand(
    uint64_t operand_id,
    const mojom::Operand& operand) {}

base::expected<void, std::string> GraphBuilderTflite::SerializeOperation(
    const mojom::Operation& op) {}

flatbuffers::DetachedBuffer GraphBuilderTflite::FinishAndTakeFlatBuffer(
    base::span<const uint64_t> input_operands,
    base::span<const uint64_t> output_operands) {}

uint32_t GraphBuilderTflite::SerializeBuffer(
    const mojo_base::BigBuffer& constant) {}

template <typename DataType>
  requires internal::IsSupportedTensorType<DataType>
int32_t GraphBuilderTflite::SerializeTensorWithBuffer(
    base::span<const DataType> buffer,
    base::span<const int32_t> dimensions) {}

int32_t GraphBuilderTflite::SerializeTemporaryTensor(
    base::span<const int32_t> dimensions,
    ::tflite::TensorType tensor_type) {}

uint32_t GraphBuilderTflite::GetOperatorCodeIndex(
    ::tflite::BuiltinOperator code,
    int32_t version) {}

const mojom::Operand& GraphBuilderTflite::GetOperand(
    uint64_t operand_id) const {}

auto GraphBuilderTflite::SerializeUnaryOperation(
    ::tflite::BuiltinOperator code,
    int32_t input_tensor_index,
    int32_t output_tensor_index,
    ::tflite::BuiltinOptions builtin_options_type,
    flatbuffers::Offset<void> builtin_options) -> OperatorOffset {}

auto GraphBuilderTflite::SerializeCastOperation(
    int32_t input_tensor_index,
    ::tflite::TensorType input_tensor_type,
    int32_t output_tensor_index,
    ::tflite::TensorType output_tensor_type) -> OperatorOffset {}

auto GraphBuilderTflite::SerializeBinaryOperation(
    ::tflite::BuiltinOperator code,
    int32_t lhs_tensor_index,
    int32_t rhs_tensor_index,
    int32_t output_tensor_index) -> OperatorOffset {}

auto GraphBuilderTflite::SerializeConcatOperation(
    base::span<const int32_t> input_tensor_indices,
    int32_t output_tensor_index,
    uint32_t axis) -> OperatorOffset {}

auto GraphBuilderTflite::SerializeMatmulOperation(int32_t a_tensor_index,
                                                  int32_t b_tensor_index,
                                                  int32_t output_tensor_index)
    -> OperatorOffset {}

auto GraphBuilderTflite::SerializeLinearOperation(
    base::span<const int32_t> input_dimensions,
    ::tflite::TensorType input_tensor_type,
    int32_t input_tensor_index,
    int32_t output_tensor_index,
    float alpha,
    float beta) -> OperatorOffset {}

auto GraphBuilderTflite::SerializeNormalizationOperation(
    base::span<const int32_t> input_dimensions,
    ::tflite::TensorType input_tensor_type,
    int32_t input_tensor_index,
    int32_t output_tensor_index,
    int32_t mean_tensor_index,
    int32_t variance_tensor_index,
    float epsilon,
    std::optional<int32_t> scale_tensor_index,
    std::optional<int32_t> bias_tensor_index) -> OperatorOffset {}

auto GraphBuilderTflite::SerializeReduceOperation(
    ::tflite::BuiltinOperator operator_code,
    int32_t input_tensor_index,
    int32_t output_tensor_index,
    base::span<const int32_t> axes,
    bool keep_dimensions) -> OperatorOffset {}

auto GraphBuilderTflite::SerializeReshapeOperation(
    int32_t input_tensor_index,
    int32_t output_tensor_index,
    base::span<const int32_t> new_shape) -> OperatorOffset {}

auto GraphBuilderTflite::SerializeSliceOperation(
    int32_t input_tensor_index,
    int32_t output_tensor_index,
    base::span<const int32_t> slice_starts,
    base::span<const int32_t> slice_sizes)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeTransposeOperation(
    int32_t input_tensor_index,
    int32_t output_tensor_index,
    base::span<const uint32_t> permutation) -> OperatorOffset {}

auto GraphBuilderTflite::InsertPadOperation(const mojom::Operand& input_operand,
                                            int32_t input_tensor_index,
                                            base::span<const uint32_t> paddings)
    -> base::expected<int32_t, std::string> {}

int32_t GraphBuilderTflite::InsertTransposeOperation(
    base::span<const int32_t> input_dimensions,
    ::tflite::TensorType input_tensor_type,
    int32_t input_tensor_index,
    base::span<const uint32_t> permutation) {}

int32_t GraphBuilderTflite::SerializeSubGraphPowMul(
    base::span<const int32_t> input_dimensions,
    ::tflite::TensorType input_tensor_type,
    int32_t input_tensor_index,
    float pow_exponent,
    float mul_alpha) {}

auto GraphBuilderTflite::SerializeArgMinMax(const mojom::ArgMinMax& arg_min_max)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeBatchNormalization(
    const mojom::BatchNormalization& batch_normalization)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeClamp(const mojom::Clamp& clamp)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeConcat(const mojom::Concat& concat)
    -> OperatorOffset {}

auto GraphBuilderTflite::SerializeConv2d(const mojom::Conv2d& conv2d)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeElementWiseBinary(
    const mojom::ElementWiseBinary& op) -> OperatorOffset {}

auto GraphBuilderTflite::SerializeElementWiseUnary(
    const mojom::ElementWiseUnary& op)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeElu(const mojom::Elu& elu)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeErf(const mojom::ElementWiseUnary& erf)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeExpand(const mojom::Expand& expand)
    -> OperatorOffset {}

auto GraphBuilderTflite::SerializeGather(const mojom::Gather& gather)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeGelu(const mojom::Gelu& gelu)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeGemm(const mojom::Gemm& gemm)
    -> base::expected<OperatorOffset, std::string> {}

// Serialize a sub graph (input * weight + bias) for gru cell.
//
//     [input]   [weight]
//         \        /
//           Matmul   [bias]
//             \        /
//                 add
//                  |
//              [output]
int32_t GraphBuilderTflite::SerializeSubGraphMatmulAdd(
    base::span<const int32_t> input_dimensions,
    ::tflite::TensorType input_tensor_type,
    int32_t input_tensor_index,
    int32_t weight_tensor_index,
    std::optional<int32_t> bias_tensor_index) {}

// Serialize a sub graph (slice appending transpose operation) for gru cell.
//
//     [input]
//        |
//      slice
//        |
//     transpose
//        |
//     [output]
auto GraphBuilderTflite::SerializeSubGraphSliceTranspose(
    ::tflite::TensorType input_tensor_type,
    int32_t input_tensor_index,
    base::span<const int32_t> slice_starts,
    base::span<const int32_t> slice_sizes)
    -> base::expected<int32_t, std::string> {}

auto GraphBuilderTflite::SerializeGruGate(
    const GruCellOperation& gru_cell,
    GruGateType type,
    std::optional<int32_t> reset_gate_tensor_index)
    -> base::expected<int32_t, std::string> {}

GraphBuilderTflite::RecurrentNetworkBase::RecurrentNetworkBase(
    base::span<const int32_t> input_dimensions,
    ::tflite::TensorType input_tensor_type,
    int32_t input_tensor_index,
    int32_t weight_tensor_index,
    int32_t recurrent_weight_tensor_index,
    std::optional<int32_t> bias_tensor_index,
    std::optional<int32_t> recurrent_bias_tensor_index,
    int32_t hidden_state_tensor_index,
    int32_t hidden_size,
    base::span<const mojom::RecurrentNetworkActivation> activations)
    :{}

GraphBuilderTflite::RecurrentNetworkBase::~RecurrentNetworkBase() = default;

GraphBuilderTflite::GruCellOperation::GruCellOperation(
    base::span<const int32_t> input_dimensions,
    ::tflite::TensorType input_tensor_type,
    int32_t input_tensor_index,
    int32_t output_tensor_index,
    int32_t weight_tensor_index,
    int32_t recurrent_weight_tensor_index,
    std::optional<int32_t> bias_tensor_index,
    std::optional<int32_t> recurrent_bias_tensor_index,
    int32_t hidden_state_tensor_index,
    int32_t hidden_size,
    bool reset_after,
    mojom::GruWeightLayout layout,
    base::span<const mojom::RecurrentNetworkActivation> activations)
    :{}

GraphBuilderTflite::GruCellOperation::~GruCellOperation() = default;

auto GraphBuilderTflite::SerializeGruCell(const mojom::GruCell& gru_cell)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeGruCellOperation(
    const GruCellOperation& gru_cell)
    -> base::expected<OperatorOffset, std::string> {}

GraphBuilderTflite::LstmCellOperation::LstmCellOperation(
    base::span<const int32_t> input_dimensions,
    ::tflite::TensorType input_tensor_type,
    int32_t input_tensor_index,
    base::span<const int32_t> output_tensor_indices,
    int32_t weight_tensor_index,
    int32_t recurrent_weight_tensor_index,
    std::optional<int32_t> bias_tensor_index,
    std::optional<int32_t> recurrent_bias_tensor_index,
    int32_t hidden_state_tensor_index,
    int32_t hidden_size,
    int32_t cell_state_tensor_index,
    std::optional<int32_t> peephole_weight_tensor_index,
    mojom::LstmWeightLayout layout,
    base::span<const mojom::RecurrentNetworkActivation> activations)
    :{}

GraphBuilderTflite::LstmCellOperation::~LstmCellOperation() = default;

base::expected<int32_t, std::string> GraphBuilderTflite::SerializeLstmGate(
    const LstmCellOperation& lstm_cell,
    LstmGateType type) {}

auto GraphBuilderTflite::SerializeLstmCellOperation(
    const LstmCellOperation& lstm_cell)
    -> base::expected<OperatorOffset, std::string> {}

// Serialize a sub graph (slice appending squeeze operation) for gru.
//
//     [input]
//        |
//      slice
//        |
//     squeeze
//        |
//     [output]
auto GraphBuilderTflite::SerializeSubGraphSliceSqueeze(
    ::tflite::TensorType input_tensor_type,
    int32_t input_tensor_index,
    base::span<const int32_t> slice_starts,
    base::span<const int32_t> slice_sizes,
    int32_t squeeze_axis) -> base::expected<int32_t, std::string> {}

// `RecurrentNetworkType` must be `mojom::Gru` or `mojom::Lstm`.
template <typename RecurrentNetworkType>
auto GraphBuilderTflite::SerializeRecurrentNetwork(
    const RecurrentNetworkType& recurrent_network)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeHardSigmoid(
    const mojom::HardSigmoid& hard_sigmoid) -> OperatorOffset {}

auto GraphBuilderTflite::SerializeHardSwish(const mojom::HardSwish& hard_swish)
    -> OperatorOffset {}

std::tuple<int32_t, int32_t>
GraphBuilderTflite::ComputeMeanAndVarianceForNormalization(
    base::span<const int32_t> input_dimensions,
    ::tflite::TensorType input_tensor_type,
    int32_t input_tensor_index,
    base::span<const int32_t> spatial_dimensions) {}

int32_t GraphBuilderTflite::TransposeAndReshapeLayerNormalizationScaleBias(
    base::span<const int32_t> input_dimensions,
    uint64_t scale_or_bias_operand_id,
    base::span<const uint32_t> axes) {}

auto GraphBuilderTflite::SerializeInstanceNormalization(
    const mojom::InstanceNormalization& instance_normalization)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeLayerNormalization(
    const mojom::LayerNormalization& layer_normalization)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeLeakyRelu(const mojom::LeakyRelu& leaky_relu)
    -> OperatorOffset {}

auto GraphBuilderTflite::SerializeLinear(const mojom::Linear& linear)
    -> OperatorOffset {}

auto GraphBuilderTflite::SerializeLogicalNot(
    const mojom::ElementWiseUnary& logical_not) -> OperatorOffset {}

auto GraphBuilderTflite::SerializeLstmCell(const mojom::LstmCell& lstm_cell)
    -> base::expected<OperatorOffset, std::string> {}

int32_t GraphBuilderTflite::GetInitialHiddenAndCellState(
    std::optional<uint64_t> state_operand_id,
    base::span<const int32_t> state_dimensions) {}

int32_t GraphBuilderTflite::ReshapeHiddenAndCellState(
    ::tflite::TensorType input_tensor_type,
    int32_t input_tensor_index,
    base::span<const int32_t> new_shape,
    std::optional<int32_t> concat_input_tensor_index,
    base::span<const int32_t> concat_output_shape) {}

auto GraphBuilderTflite::SerializeMatmul(const mojom::Matmul& matmul)
    -> OperatorOffset {}

auto GraphBuilderTflite::SerializePad(const mojom::Pad& pad)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializePool2d(const mojom::Pool2d& pool2d)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializePrelu(const mojom::Prelu& prelu)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeReciprocal(
    const mojom::ElementWiseUnary& reciprocal)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeReduce(const mojom::Reduce& reduce)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeReduceSumSquare(const mojom::Reduce& reduce,
                                                  int32_t output_tensor_index)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeRelu(const mojom::Relu& relu)
    -> OperatorOffset {}

auto GraphBuilderTflite::SerializeResample2d(
    const mojom::Resample2d& resample2d)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeReshape(uint64_t input_operand_id,
                                          uint64_t output_operand_id)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeSigmoid(const mojom::Sigmoid& sigmoid)
    -> OperatorOffset {}

auto GraphBuilderTflite::SerializeSlice(const mojom::Slice& slice)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeSoftmax(const mojom::Softmax& softmax)
    -> OperatorOffset {}

auto GraphBuilderTflite::SerializeSoftplus(const mojom::Softplus& softplus)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeSoftsign(const mojom::Softsign& softsign)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeSplit(const mojom::Split& split)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeTan(const mojom::ElementWiseUnary& tan)
    -> OperatorOffset {}

auto GraphBuilderTflite::SerializeTanh(const mojom::Tanh& tanh)
    -> OperatorOffset {}

auto GraphBuilderTflite::SerializeTriangular(
    const mojom::Triangular& triangular)
    -> base::expected<OperatorOffset, std::string> {}

auto GraphBuilderTflite::SerializeTranspose(const mojom::Transpose& transpose)
    -> OperatorOffset {}

auto GraphBuilderTflite::SerializeWhere(const mojom::Where& where)
    -> OperatorOffset {}

}  // namespace webnn::tflite