// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/llm_weights.h"
#include <sys/stat.h>
#include <cstddef>
#include <memory>
#include <optional>
#include <utility>
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/genai/inference/proto/transformer_params.pb.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/pack_weights_cache.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/tflite_weight_accessor.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/utils.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.h"
namespace mediapipe::tasks::genai::xnn_utils {
namespace {
using FeedForwardWeights = LlmWeights::FeedForwardWeights;
using SelfAttentionWeights = LlmWeights::SelfAttentionWeights;
using TransformerParameters = odml::infra::proto::TransformerParameters;
LlmParams::Norm TransformerParametersProtoNormTypeToLlmParamsNormType(
TransformerParameters::Norm norm_type) {
switch (norm_type) {
case TransformerParameters::NORM_UNSPECIFIED:
ABSL_LOG(DFATAL) << "Unspecified norm type.";
return LlmParams::Norm::UNSPECIFIED;
case TransformerParameters::NO_NORM:
return LlmParams::Norm::NO_NORM;
case TransformerParameters::RMS_NORM:
return LlmParams::Norm::RMS_NORM;
case TransformerParameters::LAYER_NORM:
return LlmParams::Norm::LAYER_NORM;
default:
ABSL_LOG(DFATAL) << "Unknown norm type: " << norm_type;
}
return LlmParams::Norm::UNSPECIFIED;
}
// According to norm_type, load necessary weights with given basename.
absl::StatusOr<std::optional<LlmWeights::NormWeights>> LoadNormWeights(
LlmParams::Norm norm_type, const LlmParams& params,
absl::string_view basename, WeightAccessor& weight_accessor) {
switch (norm_type) {
case LlmParams::Norm::UNSPECIFIED:
break;
case LlmParams::Norm::NO_NORM:
break;
case LlmParams::Norm::RMS_NORM: {
auto rms_norm_weights = RMSNormWeights();
MP_ASSIGN_OR_RETURN(
rms_norm_weights.norm_weight,
weight_accessor.LoadWeight(absl::StrCat(basename, ".scale"),
{params.model_dim_D}));
return rms_norm_weights;
}
case LlmParams::Norm::LAYER_NORM: {
auto layer_norm_weights = LayerNormWeights();
MP_ASSIGN_OR_RETURN(
layer_norm_weights.beta,
weight_accessor.LoadWeight(absl::StrCat(basename, ".bias"),
{1, 1, params.model_dim_D}));
MP_ASSIGN_OR_RETURN(
layer_norm_weights.gamma,
weight_accessor.LoadWeight(absl::StrCat(basename, ".scale"),
{1, 1, params.model_dim_D}));
return layer_norm_weights;
}
default:
break;
}
return std::nullopt;
}
} // namespace
LlmParams LlmParams::FromLLMParametersProto(
const odml::infra::proto::LlmParameters& llm_params) {
const auto& transformer_params = llm_params.transformer_parameters();
LlmParams params = {
.num_transformer_M = static_cast<size_t>(transformer_params.num_stacks()),
.batch_size_B = static_cast<size_t>(transformer_params.batch_size()),
.seq_size_T = static_cast<size_t>(transformer_params.max_seq_length()),
.model_dim_D = static_cast<size_t>(transformer_params.embedding_dim()),
.hidden_dim_HD =
static_cast<size_t>(transformer_params.hidden_dimension()),
.head_dim_H = static_cast<size_t>(transformer_params.head_dimension()),
.n_heads_N = static_cast<size_t>(transformer_params.num_heads()),
.voc_size_V = static_cast<size_t>(llm_params.vocab_size()),
.num_kv_heads =
static_cast<size_t>(transformer_params.num_kv_heads() == 0
? transformer_params.num_heads()
: transformer_params.num_kv_heads()),
.enable_kv_cache = true,
.enable_dynamic_shape = true};
if (llm_params.has_num_draft_tokens()) {
params.draft_size_G = llm_params.num_draft_tokens();
}
switch (
transformer_params.self_attention_parameters().attention_mask_type()) {
case TransformerParameters::UNSPECIFIED:
ABSL_LOG(DFATAL) << "Unspecified attention_mask_type, assuming causal";
params.model_type = LlmParams::ModelType::UNSPECIFIED;
break;
case TransformerParameters::CAUSAL:
params.model_type = LlmParams::ModelType::CAUSAL;
break;
case TransformerParameters::PREFIX:
params.model_type = LlmParams::ModelType::PREFIX;
break;
default:
ABSL_LOG(DFATAL) << "Unknown attention_mask_type: "
<< transformer_params.self_attention_parameters()
.attention_mask_type();
}
params.ff_params = LlmParams::FeedForwardParams{
.no_bias = transformer_params.feed_forward_parameters().no_bias(),
};
params.final_proj_params = LlmParams::FinalProjectParams{
.no_bias = transformer_params.final_project_parameters().no_bias(),
};
switch (transformer_params.feed_forward_parameters().activation()) {
case TransformerParameters::ACTIVATION_UNSPECIFIED:
ABSL_LOG(DFATAL) << "Unspecified feed_forward_parameters.activation.";
params.ff_params.activation = LlmParams::Activation::UNSPECIFIED;
break;
case TransformerParameters::GELU:
params.ff_params.activation = LlmParams::Activation::GELU;
break;
case TransformerParameters::SILU:
params.ff_params.activation = LlmParams::Activation::SILU;
break;
case TransformerParameters::RELU:
params.ff_params.activation = LlmParams::Activation::RELU;
break;
default:
ABSL_LOG(DFATAL)
<< "Unknown feed_forward_parameters.activation: "
<< transformer_params.feed_forward_parameters().activation();
}
params.sa_params.qkv_no_bias =
transformer_params.self_attention_parameters().qkv_no_bias();
params.sa_params.post_proj_no_bias =
transformer_params.self_attention_parameters().post_proj_no_bias();
params.sa_params.pre_norm =
TransformerParametersProtoNormTypeToLlmParamsNormType(
transformer_params.pre_norm());
params.sa_params.post_norm =
TransformerParametersProtoNormTypeToLlmParamsNormType(
transformer_params.post_norm());
params.sa_params.soft_cap_value =
transformer_params.self_attention_parameters().soft_cap_value();
params.ff_params.pre_norm =
TransformerParametersProtoNormTypeToLlmParamsNormType(
transformer_params.feed_forward_parameters().pre_norm());
params.ff_params.post_norm =
TransformerParametersProtoNormTypeToLlmParamsNormType(
transformer_params.feed_forward_parameters().post_norm());
params.final_norm = TransformerParametersProtoNormTypeToLlmParamsNormType(
transformer_params.final_norm());
params.skip_absolute_positional_embeddings =
transformer_params.skip_absolute_positional_embeddings();
if (transformer_params.self_attention_parameters()
.has_attention_scale_type()) {
switch (
transformer_params.self_attention_parameters().attention_scale_type()) {
case TransformerParameters::SCALE_TYPE_UNSPECIFIED:
ABSL_LOG(DFATAL) << "Unspecified attention_scale_type.";
params.sa_params.attention_scale_type =
LlmParams::AttentionScaleType::UNSPECIFIED;
break;
case TransformerParameters::SCALE_TYPE_PER_DIM_SCALE:
params.sa_params.attention_scale_type =
LlmParams::AttentionScaleType::PER_DIM_SCALE;
break;
case TransformerParameters::SCALE_TYPE_INV_SQRT_HEAD_DIM:
params.sa_params.attention_scale_type =
LlmParams::AttentionScaleType::INV_SQRT_HEAD_DIM;
break;
default:
ABSL_LOG(DFATAL) << "Unknown attention_scale_type: "
<< transformer_params.self_attention_parameters()
.attention_scale_type();
}
} else {
if (transformer_params.num_kv_heads() == 0 ||
transformer_params.num_heads() == transformer_params.num_kv_heads()) {
// If MHA, PER_DIM_SCALE is used.
params.sa_params.attention_scale_type =
LlmParams::AttentionScaleType::PER_DIM_SCALE;
} else {
// If MQA or GQA, INV_SQRT_HEAD_DIM is used.
params.sa_params.attention_scale_type =
LlmParams::AttentionScaleType::INV_SQRT_HEAD_DIM;
}
}
return params;
}
absl::StatusOr<std::shared_ptr<Tensor>>
LlmWeightsLoader::TryCacheThenLoadSelfAttention(
absl::string_view filename_prefix, absl::string_view alt_filename_prefix,
bool is_query) {
std::shared_ptr<Tensor> r;
if (!is_query) {
MP_ASSIGN_OR_RETURN(
r, weight_accessor_->LoadTransposedWeight(
filename_prefix,
{params_.model_dim_D, params_.num_kv_heads * params_.head_dim_H},
1));
if (!r) {
MP_ASSIGN_OR_RETURN(r, weight_accessor_->LoadTransposedWeight(
alt_filename_prefix,
{params_.model_dim_D,
params_.num_kv_heads * params_.head_dim_H},
1));
}
RET_CHECK(r) << "Could not load " << filename_prefix << " (or "
<< alt_filename_prefix << ")";
r->SetMetadata(xnn_utils::kKeySelfAttentionReshapedWeight,
params_.num_kv_heads);
} else {
MP_ASSIGN_OR_RETURN(
r,
weight_accessor_->LoadTransposedWeight(
filename_prefix,
{params_.model_dim_D, params_.n_heads_N * params_.head_dim_H}, 1));
if (!r) {
MP_ASSIGN_OR_RETURN(
r, weight_accessor_->LoadTransposedWeight(
alt_filename_prefix,
{params_.model_dim_D, params_.n_heads_N * params_.head_dim_H},
1));
}
RET_CHECK(r) << "Could not load " << filename_prefix << " (or "
<< alt_filename_prefix << ")";
r->SetMetadata(xnn_utils::kKeySelfAttentionReshapedWeight,
params_.n_heads_N);
}
r->SetMetadata(kKeyInDimLastInWeight, 1);
return r;
}
absl::StatusOr<FeedForwardWeights> LlmWeightsLoader::LoadFeedForward(
int layer_id) {
const auto& params = params_;
auto ff_file_prefix =
absl::StrCat(kTransformerWeightPrefix, layer_id, ".ff_layer.");
FeedForwardWeights feed_forward;
MP_ASSIGN_OR_RETURN(
feed_forward.pre_norm_weight,
LoadNormWeights(params.ff_params.pre_norm, params,
absl::StrCat(ff_file_prefix, "pre_layer_norm"),
*weight_accessor_));
MP_ASSIGN_OR_RETURN(
feed_forward.post_norm_weight,
LoadNormWeights(params.ff_params.post_norm, params,
absl::StrCat(ff_file_prefix, "post_layer_norm"),
*weight_accessor_));
MP_ASSIGN_OR_RETURN(feed_forward.layer_1_weight,
weight_accessor_->LoadTransposedWeight(
absl::StrCat(ff_file_prefix, "ffn_layer1.w"),
{params.model_dim_D, params.hidden_dim_HD},
/*original_dim_scale=*/1));
if (!feed_forward.layer_1_weight) {
MP_ASSIGN_OR_RETURN(feed_forward.layer_1_weight,
weight_accessor_->LoadTransposedWeight(
absl::StrCat(ff_file_prefix, "ffn_layer1.linear.w"),
{params.model_dim_D, params.hidden_dim_HD},
/*original_dim_scale=*/1));
}
MP_ASSIGN_OR_RETURN(feed_forward.layer_1_gate_weight,
weight_accessor_->LoadTransposedWeight(
absl::StrCat(ff_file_prefix, "ffn_layer1_gate.w"),
{params.model_dim_D, params.hidden_dim_HD},
/*original_dim_scale=*/1));
if (!feed_forward.layer_1_gate_weight) {
MP_ASSIGN_OR_RETURN(
feed_forward.layer_1_gate_weight,
weight_accessor_->LoadTransposedWeight(
absl::StrCat(ff_file_prefix, "ffn_layer1_gate.linear.w"),
{params.model_dim_D, params.hidden_dim_HD},
/*original_dim_scale=*/1));
}
MP_ASSIGN_OR_RETURN(
feed_forward.layer_2_weight,
weight_accessor_->LoadTransposedWeight(
absl::StrCat(ff_file_prefix, "ffn_layer2.w"),
Tensor::DimsType{params.hidden_dim_HD, params.model_dim_D},
/*original_dim_scale=*/1));
if (!feed_forward.layer_2_weight) {
MP_ASSIGN_OR_RETURN(
feed_forward.layer_2_weight,
weight_accessor_->LoadTransposedWeight(
absl::StrCat(ff_file_prefix, "ffn_layer2.linear.w"),
Tensor::DimsType{params.hidden_dim_HD, params.model_dim_D},
/*original_dim_scale=*/1));
}
if (!params.ff_params.no_bias) {
MP_ASSIGN_OR_RETURN(feed_forward.layer_1_bias,
weight_accessor_->LoadWeight(
absl::StrCat(ff_file_prefix, "ffn_layer1.bias.b"),
{params.hidden_dim_HD}));
MP_ASSIGN_OR_RETURN(
feed_forward.layer_1_gate_bias,
weight_accessor_->LoadWeight(
absl::StrCat(ff_file_prefix, "ffn_layer1_gate.bias.b"),
{params.hidden_dim_HD}));
MP_ASSIGN_OR_RETURN(feed_forward.layer_2_bias,
weight_accessor_->LoadWeight(
absl::StrCat(ff_file_prefix, "ffn_layer2.bias.b"),
{params.model_dim_D}));
}
return feed_forward;
}
absl::StatusOr<SelfAttentionWeights> LlmWeightsLoader::LoadSelfAttention(
int layer_id) {
const auto& params = params_;
SelfAttentionWeights self_attention;
auto sa_file_prefix = absl::StrCat(kTransformerWeightPrefix, layer_id);
MP_ASSIGN_OR_RETURN(
self_attention.pre_norm_weight,
LoadNormWeights(params.sa_params.pre_norm, params,
absl::StrCat(sa_file_prefix, ".pre_layer_norm"),
*weight_accessor_));
MP_ASSIGN_OR_RETURN(
self_attention.post_norm_weight,
LoadNormWeights(params.sa_params.post_norm, params,
absl::StrCat(sa_file_prefix, ".post_layer_norm"),
*weight_accessor_));
absl::StrAppend(&sa_file_prefix, ".self_attention.");
MP_ASSIGN_OR_RETURN(
self_attention.k_weight,
TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "k.w"),
absl::StrCat(sa_file_prefix, "k.linear.w"),
/*is_query=*/false));
MP_ASSIGN_OR_RETURN(
self_attention.q_weight,
TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "q.w"),
absl::StrCat(sa_file_prefix, "q.linear.w"),
/*is_query=*/true));
MP_ASSIGN_OR_RETURN(
self_attention.v_weight,
TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "v.w"),
absl::StrCat(sa_file_prefix, "v.linear.w"),
/*is_query=*/false));
if (!params.sa_params.qkv_no_bias) {
MP_ASSIGN_OR_RETURN(
self_attention.q_bias,
weight_accessor_->LoadWeight(absl::StrCat(sa_file_prefix, "q.bias.b"),
{params.n_heads_N * params.head_dim_H}));
MP_ASSIGN_OR_RETURN(
self_attention.k_bias,
weight_accessor_->LoadWeight(absl::StrCat(sa_file_prefix, "k.bias.b"),
{params.n_heads_N * params.head_dim_H}));
MP_ASSIGN_OR_RETURN(
self_attention.v_bias,
weight_accessor_->LoadWeight(absl::StrCat(sa_file_prefix, "v.bias.b"),
{params.n_heads_N * params.head_dim_H}));
}
if (params.sa_params.attention_scale_type ==
LlmParams::AttentionScaleType::PER_DIM_SCALE) {
MP_ASSIGN_OR_RETURN(
self_attention.per_dim_scale,
weight_accessor_->LoadWeight(
absl::StrCat(sa_file_prefix, "per_dim_scale.per_dim_scale"),
{params.head_dim_H}));
}
MP_ASSIGN_OR_RETURN(
self_attention.post_proj_weight,
weight_accessor_->LoadWeight(
absl::StrCat(sa_file_prefix, "post.w"),
{params.model_dim_D, params.n_heads_N * params.head_dim_H},
/*dim_scale_if_any=*/0));
if (!self_attention.post_proj_weight) {
MP_ASSIGN_OR_RETURN(
self_attention.post_proj_weight,
weight_accessor_->LoadWeight(
absl::StrCat(sa_file_prefix, "post.linear.w"),
{params.model_dim_D, params.n_heads_N * params.head_dim_H},
/*dim_scale_if_any=*/0));
}
if (!params.sa_params.post_proj_no_bias) {
MP_ASSIGN_OR_RETURN(
self_attention.post_proj_bias,
weight_accessor_->LoadWeight(
absl::StrCat(sa_file_prefix, "post.bias.b"), {params.model_dim_D}));
}
return self_attention;
}
absl::StatusOr<LlmWeights> LlmWeightsLoader::LoadWeights() {
RET_CHECK(weight_accessor_);
LlmWeights result;
for (int layer_id = 0; layer_id < params_.num_transformer_M; ++layer_id) {
MP_ASSIGN_OR_RETURN(auto ff, LoadFeedForward(layer_id));
result.ffs.push_back(std::move(ff));
MP_ASSIGN_OR_RETURN(auto sa, LoadSelfAttention(layer_id));
result.sas.push_back(std::move(sa));
}
MP_ASSIGN_OR_RETURN(result.final_norm_weight,
LoadNormWeights(params_.final_norm, params_,
"params.lm.final_ln", *weight_accessor_));
MP_ASSIGN_OR_RETURN(
result.softmax_linear,
weight_accessor_->LoadTransposedWeight(
absl::StrReplaceAll(kLogitsFfnWeightFilename, {{".linear.", "."}}),
{params_.model_dim_D, params_.voc_size_V}, 1));
if (!result.softmax_linear) {
MP_ASSIGN_OR_RETURN(result.softmax_linear,
weight_accessor_->LoadTransposedWeight(
kLogitsFfnWeightFilename,
{params_.model_dim_D, params_.voc_size_V}, 1));
}
if (!params_.final_proj_params.no_bias) {
MP_ASSIGN_OR_RETURN(result.softmax_bias,
weight_accessor_->LoadWeight(kLogitsFfnBiasFilename,
{params_.voc_size_V}));
}
RET_CHECK(result.softmax_linear) << kLogitsFfnWeightFilename;
MP_ASSIGN_OR_RETURN(
result.token_embedding,
weight_accessor_->LoadWeight(kTokenEmbedding,
{params_.voc_size_V, params_.model_dim_D},
/*dim_scale_if_any=*/0));
return result;
}
DefaultLlmWeightsLoader::DefaultLlmWeightsLoader(absl::string_view weight_path,
const LlmParams& params)
: LlmWeightsLoader(nullptr, params) {
xnn_weights_cache_ = std::make_shared<PackWeightsCache>(
params.cache_dir.empty()
? absl::StrCat(weight_path, ".cache")
: mediapipe::file::JoinPath(
params.cache_dir,
absl::StrCat(mediapipe::file::Basename(weight_path),
".cache")));
ABSL_CHECK_OK(xnn_weights_cache_->Initialize());
weight_accessor_ = std::make_unique<WeightAccessorCompositeWithCache>(
std::make_shared<TfLiteWeightAccessor>(weight_path),
xnn_weights_cache_.get());
}
} // namespace mediapipe::tasks::genai::xnn_utils