// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/well_known_models.h"
#include "mediapipe/tasks/cc/genai/inference/proto/llm_params.pb.h"
#include "mediapipe/tasks/cc/genai/inference/proto/transformer_params.pb.h"
namespace mediapipe::tasks::genai::llm_utils {
namespace {
using LlmModelType = odml::infra::proto::LlmModelType;
using LlmParameters = odml::infra::proto::LlmParameters;
using TransformerParameters = odml::infra::proto::TransformerParameters;
constexpr int kBatchSize = 1;
} // namespace
LlmParameters GetGemma2BParams() {
LlmParameters llm_params;
llm_params.set_start_token_id(2);
llm_params.add_stop_tokens("<eos>");
llm_params.set_vocab_size(256000);
TransformerParameters& transformer_params =
*llm_params.mutable_transformer_parameters();
transformer_params.set_batch_size(kBatchSize);
transformer_params.set_embedding_dim(2048);
transformer_params.set_hidden_dimension(16384);
transformer_params.set_head_dimension(256);
transformer_params.set_num_heads(8);
transformer_params.set_num_stacks(18);
// MQA
transformer_params.set_num_kv_heads(1);
transformer_params.set_pre_norm(TransformerParameters::RMS_NORM);
transformer_params.set_post_norm(TransformerParameters::NO_NORM);
transformer_params.set_final_norm(TransformerParameters::RMS_NORM);
transformer_params.set_skip_absolute_positional_embeddings(true);
TransformerParameters::SelfAttentionParameters& sa_params =
*transformer_params.mutable_self_attention_parameters();
sa_params.set_attention_mask_type(TransformerParameters::CAUSAL);
sa_params.set_qkv_no_bias(true);
sa_params.set_post_proj_no_bias(true);
sa_params.set_attention_scale_type(
TransformerParameters::SCALE_TYPE_INV_SQRT_HEAD_DIM);
// Disable soft cap.
sa_params.set_soft_cap_value(0.0f);
TransformerParameters::FeedForwardParameters& ff_params =
*transformer_params.mutable_feed_forward_parameters();
ff_params.set_no_bias(true);
ff_params.set_activation(TransformerParameters::GELU);
ff_params.set_pre_norm(TransformerParameters::RMS_NORM);
ff_params.set_post_norm(TransformerParameters::NO_NORM);
TransformerParameters::FinalProjectParameters& fp_params =
*transformer_params.mutable_final_project_parameters();
fp_params.set_no_bias(true);
// Disable soft cap.
fp_params.set_soft_cap_value(0.0f);
return llm_params;
}
LlmParameters GetGemma7BParams() {
LlmParameters llm_params;
llm_params.set_start_token_id(2);
llm_params.add_stop_tokens("<eos>");
llm_params.set_vocab_size(256000);
TransformerParameters& transformer_params =
*llm_params.mutable_transformer_parameters();
transformer_params.set_batch_size(kBatchSize);
transformer_params.set_embedding_dim(3072);
transformer_params.set_hidden_dimension(8 * 3072);
transformer_params.set_head_dimension(256);
transformer_params.set_num_heads(16);
transformer_params.set_num_stacks(28);
// MHA
transformer_params.set_num_kv_heads(0);
transformer_params.set_pre_norm(TransformerParameters::RMS_NORM);
transformer_params.set_post_norm(TransformerParameters::NO_NORM);
transformer_params.set_final_norm(TransformerParameters::RMS_NORM);
transformer_params.set_skip_absolute_positional_embeddings(true);
TransformerParameters::SelfAttentionParameters& sa_params =
*transformer_params.mutable_self_attention_parameters();
sa_params.set_attention_mask_type(TransformerParameters::CAUSAL);
sa_params.set_qkv_no_bias(true);
sa_params.set_post_proj_no_bias(true);
sa_params.set_attention_scale_type(
TransformerParameters::SCALE_TYPE_INV_SQRT_HEAD_DIM);
// Disable soft cap.
sa_params.set_soft_cap_value(0.0f);
TransformerParameters::FeedForwardParameters& ff_params =
*transformer_params.mutable_feed_forward_parameters();
ff_params.set_no_bias(true);
ff_params.set_activation(TransformerParameters::GELU);
ff_params.set_pre_norm(TransformerParameters::RMS_NORM);
ff_params.set_post_norm(TransformerParameters::NO_NORM);
TransformerParameters::FinalProjectParameters& fp_params =
*transformer_params.mutable_final_project_parameters();
fp_params.set_no_bias(true);
// Disable soft cap.
fp_params.set_soft_cap_value(0.0f);
return llm_params;
}
LlmParameters GetFalconRW1BParams() {
LlmParameters llm_params;
llm_params.set_start_token_id(1);
llm_params.add_stop_tokens("<|endoftext|>");
llm_params.set_vocab_size(50304);
TransformerParameters& transformer_params =
*llm_params.mutable_transformer_parameters();
transformer_params.set_batch_size(kBatchSize);
transformer_params.set_embedding_dim(2048);
transformer_params.set_hidden_dimension(4 * 2048);
transformer_params.set_head_dimension(64);
transformer_params.set_num_heads(32);
// `num_kv_heads` is same as `num_heads` in MHA.
transformer_params.set_num_kv_heads(32);
transformer_params.set_num_stacks(24);
transformer_params.set_pre_norm(TransformerParameters::LAYER_NORM);
transformer_params.set_post_norm(TransformerParameters::NO_NORM);
transformer_params.set_final_norm(TransformerParameters::LAYER_NORM);
transformer_params.set_skip_absolute_positional_embeddings(true);
TransformerParameters::SelfAttentionParameters& sa_params =
*transformer_params.mutable_self_attention_parameters();
sa_params.set_attention_mask_type(TransformerParameters::CAUSAL);
sa_params.set_qkv_no_bias(false);
sa_params.set_post_proj_no_bias(false);
sa_params.set_attention_scale_type(
TransformerParameters::SCALE_TYPE_INV_SQRT_HEAD_DIM);
// Disable soft cap.
sa_params.set_soft_cap_value(0.0f);
TransformerParameters::FeedForwardParameters& ff_params =
*transformer_params.mutable_feed_forward_parameters();
ff_params.set_no_bias(false);
ff_params.set_activation(TransformerParameters::GELU);
ff_params.set_pre_norm(TransformerParameters::LAYER_NORM);
ff_params.set_post_norm(TransformerParameters::NO_NORM);
TransformerParameters::FinalProjectParameters& fp_params =
*transformer_params.mutable_final_project_parameters();
fp_params.set_no_bias(true);
// Disable soft cap.
fp_params.set_soft_cap_value(0.0f);
return llm_params;
}
LlmParameters GetStablelm4E1T3BParams() {
LlmParameters llm_params;
llm_params.set_start_token_id(0);
llm_params.add_stop_tokens("<|endoftext|>");
llm_params.set_vocab_size(50304);
TransformerParameters& transformer_params =
*llm_params.mutable_transformer_parameters();
transformer_params.set_batch_size(kBatchSize);
transformer_params.set_embedding_dim(2560);
transformer_params.set_hidden_dimension(6912);
transformer_params.set_head_dimension(80);
transformer_params.set_num_heads(32);
// MHA.
transformer_params.set_num_kv_heads(0);
transformer_params.set_num_stacks(32);
transformer_params.set_pre_norm(TransformerParameters::LAYER_NORM);
transformer_params.set_post_norm(TransformerParameters::NO_NORM);
transformer_params.set_final_norm(TransformerParameters::LAYER_NORM);
transformer_params.set_skip_absolute_positional_embeddings(true);
TransformerParameters::SelfAttentionParameters& sa_params =
*transformer_params.mutable_self_attention_parameters();
sa_params.set_attention_mask_type(TransformerParameters::CAUSAL);
sa_params.set_qkv_no_bias(true);
sa_params.set_post_proj_no_bias(true);
sa_params.set_attention_scale_type(
TransformerParameters::SCALE_TYPE_INV_SQRT_HEAD_DIM);
// Disable soft cap.
sa_params.set_soft_cap_value(0.0f);
TransformerParameters::FeedForwardParameters& ff_params =
*transformer_params.mutable_feed_forward_parameters();
ff_params.set_no_bias(true);
ff_params.set_activation(TransformerParameters::SILU);
ff_params.set_pre_norm(TransformerParameters::LAYER_NORM);
ff_params.set_post_norm(TransformerParameters::NO_NORM);
TransformerParameters::FinalProjectParameters& fp_params =
*transformer_params.mutable_final_project_parameters();
fp_params.set_no_bias(true);
// Disable soft cap.
fp_params.set_soft_cap_value(0.0f);
return llm_params;
}
LlmParameters GetPhi2Params() {
LlmParameters llm_params;
llm_params.set_start_token_id(50256);
llm_params.add_stop_tokens("<|endoftext|>");
llm_params.set_vocab_size(51200);
TransformerParameters& transformer_params =
*llm_params.mutable_transformer_parameters();
transformer_params.set_batch_size(kBatchSize);
transformer_params.set_embedding_dim(2560);
transformer_params.set_hidden_dimension(10240);
transformer_params.set_head_dimension(80);
transformer_params.set_num_heads(32);
// MHA.
transformer_params.set_num_kv_heads(0);
transformer_params.set_num_stacks(32);
transformer_params.set_pre_norm(TransformerParameters::LAYER_NORM);
transformer_params.set_post_norm(TransformerParameters::NO_NORM);
transformer_params.set_final_norm(TransformerParameters::LAYER_NORM);
transformer_params.set_skip_absolute_positional_embeddings(true);
TransformerParameters::SelfAttentionParameters& sa_params =
*transformer_params.mutable_self_attention_parameters();
sa_params.set_qkv_no_bias(false);
sa_params.set_post_proj_no_bias(false);
sa_params.set_attention_mask_type(TransformerParameters::CAUSAL);
sa_params.set_attention_scale_type(
TransformerParameters::SCALE_TYPE_INV_SQRT_HEAD_DIM);
// Disable soft cap.
sa_params.set_soft_cap_value(0.0f);
TransformerParameters::FeedForwardParameters& ff_params =
*transformer_params.mutable_feed_forward_parameters();
ff_params.set_no_bias(false);
ff_params.set_activation(TransformerParameters::GELU);
ff_params.set_pre_norm(TransformerParameters::NO_NORM);
ff_params.set_post_norm(TransformerParameters::NO_NORM);
TransformerParameters::FinalProjectParameters& fp_params =
*transformer_params.mutable_final_project_parameters();
fp_params.set_no_bias(false);
// Disable soft cap.
fp_params.set_soft_cap_value(0.0f);
return llm_params;
}
} // namespace mediapipe::tasks::genai::llm_utils