// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_XNN_UTILS_LLM_WEIGHTS_H_
#define MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_XNN_UTILS_LLM_WEIGHTS_H_
#include <cstddef>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <variant>
#include <vector>
#include "absl/base/attributes.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/tasks/cc/genai/inference/proto/llm_params.pb.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/pack_weights_cache.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.h"
namespace mediapipe::tasks::genai::xnn_utils {
struct LlmParams {
// Construct LlmParams from proto.
static LlmParams FromLLMParametersProto(
const odml::infra::proto::LlmParameters& llm_params);
size_t num_transformer_M = 0;
size_t batch_size_B = 0;
size_t seq_size_T = 0;
size_t model_dim_D = 0;
size_t hidden_dim_HD = 0;
size_t head_dim_H = 0;
size_t n_heads_N = 0;
size_t voc_size_V = 0;
size_t draft_size_G = 0;
// Number of kv heads. In case of Multi-Head-Attention (MHA), num_kv_heads is
// the same as n_heads_N, which is number of query heads; In case of
// Multi-Query-Attention (MQA), key and value have one head; otherwise, this
// specifies the number of heads for key and value, and
// Grouped-Query-Attention (GQA) will be used. See
// https://arxiv.org/pdf/2305.13245.pdf for details.
size_t num_kv_heads = 0;
// Meant to be a mapping of pax LanguageModelType. This will affect e.g.
// attention mask shape.
enum class ModelType {
UNSPECIFIED = 0,
// Attention mask for input are prefixed to be bidirectional.
PREFIX = 1,
// Attention mask are forward only.
CAUSAL = 2,
} model_type = ModelType::CAUSAL;
enum class Activation {
UNSPECIFIED = 0,
// Gaussian Error Linear Unit.
GELU = 1,
// Sigmoid-Weighted Linear Unit.
SILU = 2,
// Rectified Linear Unit.
RELU = 3,
};
enum class Norm {
UNSPECIFIED = 0,
NO_NORM = 1,
RMS_NORM = 2,
LAYER_NORM = 3,
};
enum class AttentionScaleType {
UNSPECIFIED = 0,
// Per dimension scale, query is scaled by log_2(1 + exp(w)) /
// sqrt(head_dim) where w is s static weight.
PER_DIM_SCALE = 1,
// Query is scaled by 1/sqrt(head_dim).
INV_SQRT_HEAD_DIM = 2,
};
// If false, add absolute positional embeddings.
bool skip_absolute_positional_embeddings = false;
struct SelfAttentionParams {
bool qkv_no_bias = false;
bool post_proj_no_bias = false;
Norm pre_norm = Norm::RMS_NORM;
Norm post_norm = Norm::RMS_NORM;
// If greater than 0, CapTanh will be applied. Otherwise, no cap will be
// applied.
float soft_cap_value = 0.0f;
// Attention scale type to be applied within the transformer.
AttentionScaleType attention_scale_type;
} sa_params;
struct FeedForwardParams {
// If `no_bias`, fully connect will degrade to matrix multiply.
bool no_bias = false;
Activation activation = Activation::GELU;
Norm pre_norm = Norm::RMS_NORM;
Norm post_norm = Norm::RMS_NORM;
} ff_params;
Norm final_norm = Norm::RMS_NORM;
struct FinalProjectParams {
// If `no_bias`, final fully connect will degrade to matrix multiply.
bool no_bias = false;
} final_proj_params;
/*
* Parameters below do NOT change the "correctness" of the model, they
* configure the acceleration of inference.
*/
bool enable_kv_cache = false;
// If true, inference engine will optimize tensor shape according to current
// sequence length to avoid computation waste.
bool enable_dynamic_shape ABSL_DEPRECATED(
"This is always enabled if enable_kv_cache is true.") = false;
// If provided, the runtime will prepare cache at the provided directory.
// Otherwise, cache will be prepared besides the original model.
std::string cache_dir;
};
struct RMSNormWeights {
std::shared_ptr<Tensor> norm_weight;
};
struct LayerNormWeights {
float epsilon = 1e-5;
std::shared_ptr<Tensor> gamma;
std::shared_ptr<Tensor> beta;
};
struct LlmWeights {
using NormWeights = std::variant<RMSNormWeights, LayerNormWeights>;
struct SelfAttentionWeights {
std::optional<NormWeights> pre_norm_weight;
std::shared_ptr<Tensor> k_weight;
std::shared_ptr<Tensor> k_bias;
std::shared_ptr<Tensor> q_weight;
std::shared_ptr<Tensor> q_bias;
std::shared_ptr<Tensor> v_weight;
std::shared_ptr<Tensor> v_bias;
std::shared_ptr<Tensor> per_dim_scale;
std::shared_ptr<Tensor> post_proj_weight;
std::shared_ptr<Tensor> post_proj_bias;
std::optional<NormWeights> post_norm_weight;
};
struct FeedForwardWeights {
std::optional<NormWeights> pre_norm_weight;
std::shared_ptr<Tensor> layer_1_weight;
std::shared_ptr<Tensor> layer_1_bias;
std::shared_ptr<Tensor> layer_1_gate_weight;
std::shared_ptr<Tensor> layer_1_gate_bias;
std::shared_ptr<Tensor> layer_2_weight;
std::shared_ptr<Tensor> layer_2_bias;
std::optional<NormWeights> post_norm_weight;
};
std::vector<FeedForwardWeights> ffs;
std::vector<SelfAttentionWeights> sas;
std::vector<SelfAttentionWeights> cas;
std::optional<NormWeights> final_norm_weight;
std::shared_ptr<Tensor> softmax_linear;
std::shared_ptr<Tensor> softmax_bias;
// Usually same as softmax_linear, but some models use different
// softmax_linear v.s. embedding table.
std::shared_ptr<Tensor> token_embedding;
// For models that inherit Llm that need more weights other than above defined
// ones, they can load custom weights through their custom weight loader, and
// store in this map. The builder can then access these custom weights.
absl::flat_hash_map<std::string, std::shared_ptr<Tensor>> custom_weights;
};
class LlmWeightsLoader {
public:
constexpr static absl::string_view kTokenEmbedding{
"params.lm.token_embedding.w"};
constexpr static absl::string_view kTransformerWeightPrefix{
"params.lm.transformer.x_layers_"};
constexpr static absl::string_view kLogitsFfnBiasFilename{
"params.lm.softmax.logits_ffn.bias.b"};
constexpr static absl::string_view kLogitsFfnWeightFilename{
"params.lm.softmax.logits_ffn.linear.w"};
LlmWeightsLoader(std::unique_ptr<WeightAccessor> weight_accessor,
const LlmParams& params)
: weight_accessor_(std::move(weight_accessor)), params_(params) {}
virtual ~LlmWeightsLoader() = default;
virtual absl::StatusOr<LlmWeights> LoadWeights();
LlmParams& llm_params() { return params_; }
const LlmParams& llm_params() const { return params_; }
// Returns the XnnWeightsCache that could work with weights loader, if any.
virtual std::shared_ptr<XnnWeightsCache> GetXnnWeightsCache() {
return nullptr;
}
protected:
absl::StatusOr<LlmWeights::SelfAttentionWeights> LoadSelfAttention(
int layer_id);
absl::StatusOr<LlmWeights::FeedForwardWeights> LoadFeedForward(int layer_id);
// is_query: indicating whether the weight is for query projection or not.
// Note that the key/value projection weights are handled differently between
// MHA vs. MQA.
absl::StatusOr<std::shared_ptr<Tensor>> TryCacheThenLoadSelfAttention(
absl::string_view filename_prefix, absl::string_view alt_filename_prefix,
bool is_query);
std::unique_ptr<WeightAccessor> weight_accessor_;
LlmParams params_;
};
class DefaultLlmWeightsLoader : public LlmWeightsLoader {
public:
DefaultLlmWeightsLoader(absl::string_view weight_path,
const LlmParams& params);
std::shared_ptr<XnnWeightsCache> GetXnnWeightsCache() override {
return xnn_weights_cache_;
}
private:
std::shared_ptr<PackWeightsCache> xnn_weights_cache_;
};
} // namespace mediapipe::tasks::genai::xnn_utils
#endif // MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_XNN_UTILS_LLM_WEIGHTS_H_