chromium/third_party/mediapipe/src/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/llm.h

// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_XNN_UTILS_LLM_H_
#define MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_XNN_UTILS_LLM_H_

#include <cstddef>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "absl/base/attributes.h"
#include "absl/base/nullability.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "mediapipe/tasks/cc/genai/inference/common/mdspan.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/llm_weights.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/sampling.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.h"

namespace mediapipe::tasks::genai {
namespace xnn_utils {

class LlmBuilder;

// The base class that hosts the XNNPACK graph for large language models. It is
// responsible for hosting the assets required to run the models, including
// pointers to the construct tensors, KV-cache, as well as constructing the
// whole models. Note that this class is designed to serve the models that share
// similar "structures" so please be mindful when you plan to inherit from it
// and perform customization. A general guideline is that if you are
// implementing a decode-only model with prefix/decode graphs, you shouldn't
// need to update the class but to perform the customization in the
// LlmGraphBuilder. For example:
//   1) to implement Llama, one should make the changes in the LlmGraphBuilder.
//   2) to implement an "encoder-only" model (i.e. only run the prefix graph
//      with no decode graphs and kv-cache), one should inherit from this and
//      update the logics. See llm_encoder_only.h for more details.
class Llm : protected xnn_utils::XnnGraph {
 public:
  explicit Llm(XnnGraph&& other) : XnnGraph(std::move(other)) {}
  Llm(Llm&&) = default;
  ~Llm() override = default;

  // Enable if enable_kv_cache
  struct KVCache {
    std::shared_ptr<Tensor> k_cache;
    std::shared_ptr<Tensor> v_cache;
    std::shared_ptr<Tensor> k_slice;
    std::shared_ptr<Tensor> v_slice;
  };

  // An aggregation of all the data that can represent the context of the
  // model.
  struct Context {
    // Previous ids, including prompt.
    std::vector<std::vector<int>> batch_prev_ids;
    std::vector<KVCache> kv_cache;
  };

  // Reduce the number of previous ids to effectively undo the last
  // `batch_num_tokens` tokens. Used for reverting incorrect draft tokens in
  // speculative decoding.
  static absl::Status ReduceContextPrevIds(std::shared_ptr<Context> context,
                                           std::vector<int> batch_num_tokens);

  // Create LLM graph using the `DefaultLlmWeightsLoader` to load model from
  // `weights_folder`.
  static absl::StatusOr<std::unique_ptr<Llm>> CreateLlm(
      absl::string_view weights_folder, const LlmParams& llm_params,
      std::unique_ptr<xnn_utils::RuntimeConfigs> runtime_configs = nullptr);

  // Create LLM graph using provided `weight_loader`, which provides LlmParams
  // through llm_params() and LlmWeights through LoadWeights(). This is
  // typically used when you would like to load weights from somewhere other
  // than filesystem (e.g. fake weights during benchmark):
  //
  // MP_ASSIGN_OR_RETURN(auto llm, CreateLlm(
  //     std::make_unique<BenchmarkLlmWeightsLoader>(llm_params)));
  static absl::StatusOr<std::unique_ptr<Llm>> CreateLlm(
      std::unique_ptr<LlmWeightsLoader> weight_loader,
      std::unique_ptr<xnn_utils::RuntimeConfigs> runtime_configs = nullptr);

  // Create LLM graph using provided `weight_loader` and `builder`.
  // `weight_loader` is used the same way as above version. This is typically
  // used when you would like to customize wiring logic of model construction
  // through `builder`:
  //
  // MP_ASSIGN_OR_RETURN(auto llm, CreateLlm(
  //     std::make_unique<LlmEncoderOnlyWeightsLoader>(llm_params),
  //     std::make_unique<LlmEncoderOnlyBuilder>(runtime_conigs)));
  static absl::StatusOr<std::unique_ptr<Llm>> CreateLlm(
      std::unique_ptr<LlmWeightsLoader> weight_loader,
      std::unique_ptr<LlmBuilder> builder);

  // Add input token ids at the end of all previously added tokens.
  virtual absl::Status AddInputTokens(
      absl::Span<const std::vector<int>> batch_input_ids);

  // Seeks to the given time step. This is typically used to go back to certain
  // status for speculative decoding. SeekTimeStep(0) is effectively resetting
  // the internal state.
  absl::Status SeekTimeStep(size_t time_step);

  // Samples the logits from ComputeLogits() and returns the sampled ids. This
  // also AddInputTokens() with the sampled ids.
  ABSL_DEPRECATED("Use ComputeLogits() and do your own sampling.")
  virtual absl::Status GetNextToken(std::vector<int>* output_ids);

  // Computes logits with all previously added tokens. Output is in shape of
  // [batch_B, expected_seq_len, vacab_size_V] representing the last
  // `expected_seq_len` along the sequence dimension.
  virtual absl::StatusOr<std::shared_ptr<Tensor>> ComputeLogits(
      size_t expected_seq_len);
  absl::StatusOr<std::shared_ptr<Tensor>> ComputeLogits() {
    return this->ComputeLogits(1);
  }

  // The size of all tokens, including prompt and generated tokens.
  virtual size_t TotalTokenSize() const;

  const LlmParams& GetLlmParams() { return llm_params_; }

  // Create a new context with internal model parameters. The variables in the
  // context will have proper batch size, sequence length, etc.
  virtual absl::StatusOr<Context> NewContext() const;

  // If `context` is non-null, and different from existing context_, load the
  // context into the model.
  virtual absl::Status LoadContext(
      absl::Nullable<std::shared_ptr<Context>> context);

 protected:
  friend class PrefixDecodeLlm;
  friend class LlmTest;
  friend class LlmBuilder;

  Llm() : XnnGraph(XnnSubgraphPtr{nullptr, nullptr}, nullptr) {}

  // Internal parameters to control prefix model.
  struct InternalLlmParams {
    // Stops at last KV cache, so we don't waste computation.
    bool stop_at_last_kv_cache = false;
  };

  // Creates a `Llm` instance with prefix-decoder architecture.
  static absl::StatusOr<std::unique_ptr<Llm>> CreatePrefixDecodeLlm(
      LlmWeights, std::shared_ptr<LlmBuilder> builder);

  std::shared_ptr<Tensor>& transformer_input();
  const std::shared_ptr<Tensor>& transformer_input() const;
  std::shared_ptr<Tensor>& logits_output();
  const std::shared_ptr<Tensor>& logits_output() const;
  // Previous ids, including prompt.
  std::vector<std::vector<int>>& batch_prev_ids();
  const std::vector<std::vector<int>>& batch_prev_ids() const;
  std::vector<KVCache>& kv_cache();
  const std::vector<KVCache>& kv_cache() const;

  // Fill `embedding` according to given `ids`, by table lookup the token
  // embedding provided through weights. The first ids.size() * model_dim_D
  // elements pointed by `embedding` will be filled.
  absl::Status GetTokenEmbedding(const std::vector<int>& ids, float* embedding);

  absl::Status ReshapeInputResource();

  LlmWeights weights_;
  LlmParams llm_params_;

  std::shared_ptr<Tensor> pos_embedding_;
  std::shared_ptr<Tensor> atten_masks_;
  std::shared_ptr<Tensor> segment_pos_;

  // Embedding input to the model.
  std::shared_ptr<Tensor> transformer_input_;
  // Logits output from the model.
  std::shared_ptr<Tensor> logits_output_;
  std::shared_ptr<Context> context_;

  // Hold a shared_ptr to the LlmBuilder for initializing the input resources
  // as well as performing necessary wiring customizations at decoding time.
  std::shared_ptr<LlmBuilder> builder_;
};

// Responsible for creating the high-level components that are required by large
// language models. The high-level components are:
//   1) PreProcess: including embedding lookup/attention mask/positional
//      embedding preparations..etc.
//   2) SelfAttentionIncludeResidual: the self-attention module along with
//      residual connections and some normalizations.
//   3) FeedForwardIncludeResidual: The feedforward layers that follows the
//      attention outputs, including residual connections and normalizations.
//   4) PostProcess: the final projection layer after the stacked transformers.
// The LlmBuilder allows developers to overwrite the logics of those components
// whenever needed (i.e. the existing Llm/LlmBuilder's configuration/settings
// don't capture the required changes).
class LlmBuilder : protected XnnGraphBuilder {
 public:
  // The following struct define the "resources" that are required by each
  // high-level modules. For clarification, even though most of the input/output
  // of those high-level modules are actually all "xnn_utils::Tensor", their
  // definitions are as the following:
  //   1) Weight: refers to the model weights which are static during
  //      initialization and runtime. For example:
  //      LlmWeights::FeedForwardWeights.
  //   2) Resource: the tensors that host the values which can be "precomputed"
  //      and remain reusable/fixed during inference (i.e. independent of the
  //      input values). For example: pos_embedding, atten_mask.
  //   3) Tensor: The data values that depends on the input data at the runtime.
  //      For example: the return value of PreProcess.
  struct InputResource {
    std::shared_ptr<Tensor> pos_embedding;
    std::shared_ptr<Tensor> atten_mask;
    std::shared_ptr<Tensor> segment_pos;

    // The type of this field will be updated in the future. Please contact
    // odml-llm-support if you'd like to use this field.
    Llm::KVCache* cache = nullptr;
  };

  explicit LlmBuilder(LlmParams llm_params,
                      std::unique_ptr<RuntimeConfigs> runtime_configs = nullptr,
                      xnn_datatype datatype = xnn_datatype_fp32)
      : LlmBuilder(llm_params, nullptr, std::move(runtime_configs), datatype) {}
  LlmBuilder(LlmParams llm_params, std::unique_ptr<Sampler> sampler,
             std::unique_ptr<RuntimeConfigs> runtime_configs = nullptr,
             xnn_datatype datatype = xnn_datatype_fp32)
      : XnnGraphBuilder(std::move(runtime_configs), datatype),
        llm_params_(llm_params),
        sampler_(std::move(sampler)) {}

  using XnnGraphBuilder::Build;
  using XnnGraphBuilder::NewInput;

  // Apply pre-processing to the input before feeding to stacked transformers as
  // well as preparing the InputResource that will be used by other modules,
  // e.g. positional embedding.
  // `token_embedding` represents the token embedding ([batch_B, S,
  // model_dim_D], where S varies from 1 to seq_size_T).
  // `is_prefix` indicates whether this function is called by the prefix graph
  // as some resource preparation might be different between prefix vs. decode.
  virtual absl::StatusOr<std::pair<std::shared_ptr<Tensor>, InputResource>>
  PreProcess(std::shared_ptr<Tensor> token_embedding, bool is_prefix);

  // One transformer block consisting of self-attention and feedforward modules.
  // The default version builds a sequential SA and FF block. This can be
  // overwritten for fine-grained control over each OneStackTransformer.
  virtual absl::StatusOr<std::shared_ptr<Tensor>> OneStackTransformer(
      int layer_index, std::shared_ptr<Tensor> input, InputResource resource,
      const LlmWeights::SelfAttentionWeights& sa_weights,
      const LlmWeights::FeedForwardWeights& ff_weights, bool is_prefix);

  // Building blocks used within `DecoderLayer`.
  virtual absl::StatusOr<std::shared_ptr<Tensor>> SelfAttentionIncludeResidual(
      std::shared_ptr<Tensor> input, InputResource resource,
      const LlmWeights::SelfAttentionWeights& sa_weights);
  virtual absl::StatusOr<std::shared_ptr<Tensor>> FeedForwardIncludeResidual(
      std::shared_ptr<Tensor> input,
      const LlmWeights::FeedForwardWeights& ff_weights);

  // Apply post-processing to the output of stacked transformers, e.g. final
  // norm, final projection, etc.
  virtual absl::StatusOr<std::shared_ptr<Tensor>> PostProcess(
      std::shared_ptr<Tensor> transformer_out, const LlmWeights& weights);

  // The following functions are related to the InputResource preparation and
  // handling.

  // Set the value of `out_attn_mask` given the condition that `current_seq_len`
  // number of tokens has been processed, and it's about to process
  // `process_seq_len` number of tokens.
  virtual absl::Status InitAttentionMask(size_t current_seq_len,
                                         size_t process_seq_len,
                                         Tensor& out_attn_mask);

  // Initialize the `out_pos_embedding` values given the condition that
  // `current_seq_len` number of tokens has been processed, and it's about to
  // process `process_seq_len` number of tokens.
  virtual absl::Status InitPosEmbedding(size_t current_seq_len,
                                        size_t process_seq_len,
                                        Tensor& out_pos_embedding);

  // Initialize the `out_segment_pos` values given the condition that
  // `current_seq_len` number of tokens has been processed, and it's about to
  // process `process_seq_len` number of tokens. E.g. in decoding mode, assume
  // 17 tokens have been processed, this function will be called with
  // `current_seq_len` to be 17, and `process_seq_len` to be 1 (decoding one
  // token). `out_segment_pos` will be reshaped to [process_seq_len, rope_size].
  virtual absl::Status InitSegmentPos(size_t current_seq_len,
                                      size_t process_seq_len,
                                      Tensor& out_segment_pos);

  // Run sampling on model's output logits.
  absl::StatusOr<std::vector<std::vector<int>>> Sample(const Tensor& logits);

 protected:
  friend class Llm;
  friend class LlmBuilderTest;
  friend absl::StatusOr<std::unique_ptr<Llm>> Llm::CreatePrefixDecodeLlm(
      LlmWeights, std::shared_ptr<LlmBuilder>);

  absl::Status InitAttentionMaskValues(size_t process_seq_len);
  absl::Status InitPosEmbeddingValues(size_t process_seq_len);
  absl::Status InitSegmentPosValues(size_t rope_size);

  absl::StatusOr<std::shared_ptr<Tensor>> DotAttention(
      std::shared_ptr<Tensor> query_proj, std::shared_ptr<Tensor> key_proj,
      std::shared_ptr<Tensor> value_proj, std::shared_ptr<Tensor> atten_mask,
      const LlmWeights::SelfAttentionWeights& sa_weights);

  // Apply normalization according to `norm_type`, generally the output tensor
  // should have the same shape as `input`.
  absl::StatusOr<std::shared_ptr<Tensor>> ApplyNorm(
      std::shared_ptr<Tensor> input,
      std::optional<LlmWeights::NormWeights> weights,
      LlmParams::Norm norm_type);

  virtual absl::StatusOr<std::shared_ptr<Tensor>> SelfAttentionExcludeNorm(
      std::shared_ptr<Tensor> input, InputResource resource,
      const LlmWeights::SelfAttentionWeights& sa_weights);

  virtual absl::StatusOr<std::shared_ptr<Tensor>> FeedForwardExcludeNorm(
      std::shared_ptr<Tensor> input,
      const LlmWeights::FeedForwardWeights& ff_weights);

  absl::Status BuildKVCache(std::shared_ptr<Tensor>& key,
                            std::shared_ptr<Tensor>& value,
                            InputResource& resource);

  LlmParams llm_params_;
  Llm::InternalLlmParams internal_llm_params_;

  // Storing values of attention mask with shape [max_seq_len, max_seq_len]
  MdSpan<float, 2> attention_mask_values_;
  // Storing values of positional embedding with shape [max_seq_len,
  // model_dimension]
  std::shared_ptr<std::vector<float>> position_embedding_values_;
  // Storing values of segment pos with shape [max_seq_len, head_dimension]
  MdSpan<float, 2> segment_pos_values_;

  std::unique_ptr<Sampler> sampler_;
};

}  // namespace xnn_utils
}  // namespace mediapipe::tasks::genai

#endif  // MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_XNN_UTILS_LLM_H_