chromium/third_party/mediapipe/src/mediapipe/tasks/cc/genai/inference/proto/transformer_params.proto

// Copyright 2023 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

syntax = "proto3";

package odml.infra.proto;

option java_package = "com.google.odml.infra.proto";
option java_outer_classname = "TransformerParametersProto";

// The parameters of transformer (https://arxiv.org/pdf/1706.03762.pdf)
message TransformerParameters {
  // Batch size of tensors.
  int32 batch_size = 1;

  // TODO: b/319312256 - Deprecate parameter.
  // Maximum sequence length of the input/output tensor.
  int32 max_seq_length = 2;

  // Embedding dimension (or model dimension), `d_model` in the paper.
  // `d_k` == `d_v` == `d_model`/`h`.
  int32 embedding_dim = 3;

  // Hidden dimension used in the feedforward layer, `d_ff` in the paper.
  int32 hidden_dimension = 4;

  // Head dimension, `d_k` or `d_v` in the paper.
  int32 head_dimension = 5;

  // Number of heads, `h` in the paper.
  int32 num_heads = 6;

  // Number of stacked transformers, `N` in the paper.
  int32 num_stacks = 7;

  // Deprecated: bool use_mqa. Use num_kv_heads below.
  reserved 8;

  // Number of kv heads. 0 means Multi-Head-Attention (MHA), key and value have
  // same number of heads as query; 1 means Multi-Query-Attention (MQA), key and
  // value have one head; otherwise, this specifies the number of heads for key
  // and value, and Grouped-Query-Attention (GQA) will be used. See
  // https://arxiv.org/pdf/2305.13245.pdf for details.
  int32 num_kv_heads = 9;

  // Different types of attention mask type.
  enum AttentionMaskType {
    UNSPECIFIED = 0;
    CAUSAL = 1;
    PREFIX = 2;
    BIDIRECTIONAL = 3;
  }
  // Deprecated, use SelfAttentionParameters.
  reserved 10;

  enum Activation {
    ACTIVATION_UNSPECIFIED = 0;
    // GELU stands for Gaussian Error Linear Unit, see
    // https://arxiv.org/pdf/1606.08415.pdf for details.
    GELU = 1;
    // SILU stands for Sigmoid-Weighted Linear Unit, see
    // https://arxiv.org/pdf/1702.03118v3.pdf for details.
    SILU = 2;
    // RELU stands for Rectified Linear Unit, see
    // https://dl.acm.org/doi/10.5555/3104322.3104425 for details.
    RELU = 3;
  }

  enum Norm {
    NORM_UNSPECIFIED = 0;
    // No normalization operation will be perform.
    NO_NORM = 1;
    // RMSNORM stands for Root Mean Square Layer Normalization, see
    // https://arxiv.org/pdf/1910.07467.pdf for details.
    RMS_NORM = 2;
    // LAYERNORM stands for Layer Normalization, see
    // https://arxiv.org/pdf/1607.06450v1.pdf for details.
    LAYER_NORM = 3;
  }

  message FeedForwardParameters {
    // If `no_bias`, fully connect will degrade to matrix multiply.
    bool no_bias = 1;
    Activation activation = 2;
    // Normalization before the dense layer.
    Norm pre_norm = 3;
    // Normalization after the dense layer.
    Norm post_norm = 4;
  }

  FeedForwardParameters feed_forward_parameters = 11;

  message FinalProjectParameters {
    // If `no_bias`, fully connect will degrade to matrix multiply.
    bool no_bias = 1;

    // The value to set the soft cap (Tanh) before calling the final project
    // layer. Setting the value to be <=0 indicates there is no cap.
    float soft_cap_value = 2;
  }

  FinalProjectParameters final_project_parameters = 12;

  // Normalization before the transformer block.
  Norm pre_norm = 13;
  // Normalization after the transformer block.
  Norm post_norm = 14;
  Norm final_norm = 15;

  enum AttentionScaleType {
    SCALE_TYPE_UNSPECIFIED = 0;

    // Per dimension scale, query is scaled by log_2(1 + exp(w)) /
    // sqrt(head_dim) where w is s static weight.
    SCALE_TYPE_PER_DIM_SCALE = 1;

    // Query is scaled by 1/sqrt(head_dim).
    SCALE_TYPE_INV_SQRT_HEAD_DIM = 2;

    // Query is scaled by 1/sqrt(model_dim/num_heads)
    // model_dim/num_heads is not always equal to head_dim
    SCALE_TYPE_INV_SQRT_D_MODEL_DIV_NUM_HEADS = 3;
  }

  message SelfAttentionParameters {
    // Whether bias term is used in Q, K, and V projections.
    bool qkv_no_bias = 1;
    // Whether bias term is used in post-projection.
    bool post_proj_no_bias = 2;

    AttentionMaskType attention_mask_type = 3;

    // The value to set the soft cap (Tanh) before calling the attention
    // softmax. Setting the value to be <=0 indicates there is no cap.
    float soft_cap_value = 4;

    // If specified, inference pipeline will use the specified scale type.
    // Otherwise SCALE_TYPE_PER_DIM_SCALE is used for Multi-Query-Attention by
    // default, and SCALE_TYPE_INV_SQRT_HEAD_DIM is used for
    // Multi-Head-Attention by default.
    optional AttentionScaleType attention_scale_type = 5;

    reserved 6;
  }

  SelfAttentionParameters self_attention_parameters = 16;

  reserved 17;
  // Whether to skip absolute positional embeddings. If the value is false, then
  // the absolute positional embeddings will be applied to the token embeddings
  // before the attention.
  bool skip_absolute_positional_embeddings = 18;

  reserved 19, 20, 21, 23, 24;

  // Audio parameters
  // Describes where to interleave residual adapters with transformer layers.
  enum WhereToInterleave {
    // Add a residual adapter after every transformer layer.
    INTERLEAVE_UNSPECIFIED = 0;
    ALL = 1;
    // Add a residual adapter after every 4th transformer layer, starting with
    // index 0 (i.e. after 0, 4, 8, ...).
    EVERY_OTHER_4 = 2;
  }

  // Holds information on audio residual adapters
  message ResidualAdapterParameters {
    WhereToInterleave where_to_interleave = 1;
    // bottleneck_dim is akin to hidden_dim in the FF layer
    int32 bottleneck_dimension = 2;
  }

  optional ResidualAdapterParameters residual_adapter_parameters = 22;
}