chromium/third_party/mediapipe/src/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/llm.cc

// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/llm.h"

#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstring>
#include <limits>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "absl/base/nullability.h"
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/genai/inference/common/mdspan.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/llm_weights.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/sampling.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/utils.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.h"
#include "xnnpack.h"  // from @XNNPACK

namespace mediapipe::tasks::genai {
namespace xnn_utils {
namespace {

using FeedForwardWeights = LlmWeights::FeedForwardWeights;
using SelfAttentionWeights = LlmWeights::SelfAttentionWeights;

}  // namespace

absl::StatusOr<std::unique_ptr<Llm>> Llm::CreateLlm(
    absl::string_view weights_folder, const LlmParams& llm_params,
    std::unique_ptr<xnn_utils::RuntimeConfigs> runtime_configs) {
  auto weight_loader =
      std::make_unique<DefaultLlmWeightsLoader>(weights_folder, llm_params);
  return CreateLlm(std::move(weight_loader), std::move(runtime_configs));
}

absl::StatusOr<std::unique_ptr<Llm>> Llm::CreateLlm(
    std::unique_ptr<LlmWeightsLoader> weight_loader,
    std::unique_ptr<xnn_utils::RuntimeConfigs> runtime_configs) {
  const auto& llm_params = weight_loader->llm_params();
  return CreateLlm(
      std::move(weight_loader),
      std::make_unique<LlmBuilder>(llm_params, std::move(runtime_configs)));
}

absl::StatusOr<std::unique_ptr<Llm>> Llm::CreateLlm(
    std::unique_ptr<LlmWeightsLoader> weight_loader,
    std::unique_ptr<LlmBuilder> builder) {
  const auto& llm_params = weight_loader->llm_params();
  RET_CHECK_EQ(llm_params.enable_kv_cache, llm_params.enable_dynamic_shape)
          .SetCode(absl::StatusCode::kInvalidArgument)
      << "Dynamic shape should be enabled together with KV cache.";
  MP_ASSIGN_OR_RETURN(auto weights, weight_loader->LoadWeights());
  return CreatePrefixDecodeLlm(std::move(weights), std::move(builder));
}

absl::StatusOr<std::unique_ptr<Llm>> Llm::CreatePrefixDecodeLlm(
    LlmWeights weights, std::shared_ptr<LlmBuilder> builder) {
  RET_CHECK(builder);
  const LlmParams& llm_params = builder->llm_params_;
  RET_CHECK_NE(llm_params.batch_size_B, 0);

  MP_ASSIGN_OR_RETURN(auto input, builder->NewInput({llm_params.batch_size_B,
                                                     llm_params.seq_size_T,
                                                     llm_params.model_dim_D},
                                                    "prefix_input"));

  MP_ASSIGN_OR_RETURN(auto preprocess_out,
                      builder->PreProcess(input, /*is_prefix=*/true));

  auto& inter_layer = preprocess_out.first;
  auto& resource = preprocess_out.second;

  std::vector<KVCache> kv_cache;
  std::shared_ptr<Tensor> logits_output;

  for (int i = 0; i < llm_params.num_transformer_M; ++i) {
    KVCache* cache = nullptr;
    if (llm_params.enable_kv_cache) {
      kv_cache.push_back(KVCache{});
      cache = &kv_cache.back();
    }
    resource.cache = cache;
    const auto& sa = weights.sas[i];
    const auto& ff = weights.ffs[i];
    MP_ASSIGN_OR_RETURN(inter_layer, builder->OneStackTransformer(
                                         i, inter_layer, resource, sa, ff,
                                         /*is_prefix=*/true));
  }

  if (builder->internal_llm_params_.stop_at_last_kv_cache) {
    logits_output = inter_layer;
  } else {
    MP_ASSIGN_OR_RETURN(logits_output,
                        builder->PostProcess(inter_layer, weights));
  }
  logits_output->MarkOutput();

  MP_ASSIGN_OR_RETURN(auto graph, builder->Build());
  auto llm = std::make_unique<Llm>(std::move(*graph));
  llm->transformer_input_ = input;
  llm->logits_output_ = logits_output;
  llm->context_ = std::make_shared<Context>(Context{
      .kv_cache = std::move(kv_cache),
  });
  llm->batch_prev_ids().resize(llm_params.batch_size_B);

  llm->pos_embedding_ = resource.pos_embedding;
  llm->segment_pos_ = resource.segment_pos;
  llm->atten_masks_ = resource.atten_mask;

  llm->weights_ = std::move(weights);
  llm->llm_params_ = llm_params;
  llm->builder_ = builder;

  return llm;
}

size_t Llm::TotalTokenSize() const {
  ABSL_CHECK(!batch_prev_ids().empty());
  // batch_prev_ids() is of length llm_params.batch_size_B, and we assume each
  // batch decode simultaneously, thus prev_ids[i] have the same size, which is
  // total token size.
  return batch_prev_ids()[0].size();
}

absl::Status Llm::ReshapeInputResource() {
  if (llm_params_.enable_dynamic_shape) {
    RET_CHECK_EQ(
        xnn_status_success,
        xnn_reshape_external_value(
            runtime_.get(), atten_masks_->tensor_id(owned_subgraph_.get()),
            atten_masks_->dims.size(), atten_masks_->dims.data()));
    if (!llm_params_.skip_absolute_positional_embeddings) {
      RET_CHECK_EQ(
          xnn_status_success,
          xnn_reshape_external_value(
              runtime_.get(), pos_embedding_->tensor_id(owned_subgraph_.get()),
              pos_embedding_->dims.size(), pos_embedding_->dims.data()));
    }
    if (segment_pos_) {
      RET_CHECK_EQ(
          xnn_status_success,
          xnn_reshape_external_value(
              runtime_.get(), segment_pos_->tensor_id(owned_subgraph_.get()),
              segment_pos_->dims.size(), segment_pos_->dims.data()));
    }
  }
  return absl::OkStatus();
}

std::shared_ptr<Tensor>& Llm::transformer_input() { return transformer_input_; }

const std::shared_ptr<Tensor>& Llm::transformer_input() const {
  return transformer_input_;
}

std::shared_ptr<Tensor>& Llm::logits_output() { return logits_output_; }

const std::shared_ptr<Tensor>& Llm::logits_output() const {
  return logits_output_;
}

std::vector<std::vector<int>>& Llm::batch_prev_ids() {
  ABSL_DCHECK(context_);
  return context_->batch_prev_ids;
}

const std::vector<std::vector<int>>& Llm::batch_prev_ids() const {
  ABSL_DCHECK(context_);
  return context_->batch_prev_ids;
}

std::vector<Llm::KVCache>& Llm::kv_cache() {
  ABSL_DCHECK(context_);
  return context_->kv_cache;
}

const std::vector<Llm::KVCache>& Llm::kv_cache() const {
  ABSL_DCHECK(context_);
  return context_->kv_cache;
}

absl::StatusOr<Llm::Context> Llm::NewContext() const {
  RET_CHECK(runtime_configs_);
  std::shared_ptr<Tensor> new_pivot;
  return Llm::Context{
      .batch_prev_ids = std::vector<std::vector<int>>(batch_prev_ids().size()),
      .kv_cache =
          [this]() {
            std::vector<KVCache> kvs;
            if (!llm_params_.enable_kv_cache) return kvs;
            kvs.resize(kv_cache().size());
            for (size_t i = 0; i < kvs.size(); ++i) {
              auto& kv = kvs[i];
              const auto& current_kv = kv_cache()[i];
              kv.k_cache = std::make_shared<Tensor>(
                  current_kv.k_cache->dims, current_kv.k_cache->datatype);
              kv.k_cache->LoadFromVec({}).IgnoreError();
              kv.v_cache = std::make_shared<Tensor>(
                  current_kv.v_cache->dims, current_kv.v_cache->datatype);
              kv.v_cache->LoadFromVec({}).IgnoreError();
              kv.k_slice = std::make_shared<Tensor>(
                  current_kv.k_slice->dims, current_kv.k_slice->datatype);
              kv.k_slice->Borrow(kv.k_cache->Slice(0, 0));
              kv.v_slice = std::make_shared<Tensor>(
                  current_kv.v_slice->dims, current_kv.v_slice->datatype);
              kv.v_slice->Borrow(kv.v_cache->Slice(0, 0));
            }
            return kvs;
          }(),
  };
}

absl::Status Llm::LoadContext(
    absl::Nullable<std::shared_ptr<Context>> context) {
  if (!context || (context_ == context)) return absl::OkStatus();
  // There are some metadata we'd like to keep with existing context, also we'd
  // like to use pointer address to distinguish context. So the following logic
  // is: 1) let existing context point to the buffer from new context; 2) move
  // tensors from existing context to new context; 3) store new context.
  {
    for (size_t i = 0; i < kv_cache().size(); ++i) {
      kv_cache()[i].k_cache->Borrow(context->kv_cache[i].k_cache);
      kv_cache()[i].v_cache->Borrow(context->kv_cache[i].v_cache);
      kv_cache()[i].k_slice->Borrow(context->kv_cache[i].k_slice);
      kv_cache()[i].v_slice->Borrow(context->kv_cache[i].v_slice);
    }
    context->kv_cache = std::move(kv_cache());
  }
  context_ = std::move(context);
  return absl::OkStatus();
}

absl::Status Llm::ReduceContextPrevIds(std::shared_ptr<Context> context,
                                       std::vector<int> batch_num_tokens) {
  ABSL_CHECK_EQ(batch_num_tokens.size(), context->batch_prev_ids.size());
  for (size_t batch_size = 0; batch_size < context->batch_prev_ids.size();
       ++batch_size) {
    auto& prev_ids = context->batch_prev_ids[batch_size];
    const auto& num_tokens = batch_num_tokens[batch_size];
    if (num_tokens == 0) continue;
    prev_ids.erase(prev_ids.end() - num_tokens, prev_ids.end());
  }
  return absl::OkStatus();
}

absl::Status Llm::AddInputTokens(
    absl::Span<const std::vector<int>> batch_input_ids) {
  RET_CHECK_EQ(batch_input_ids.size(), batch_prev_ids().size());
  const size_t input_seq_len = batch_input_ids.at(0).size();
  if (input_seq_len == 0) {
    // In one of the CLs related to below bug, we added an empty prompt to flush
    // previous prompts, in LlmEngine::AddQueryChunk().
    // TODO: b/343765969: Remove the empty prompt.
    return absl::OkStatus();
  }
  for (auto it = batch_input_ids.begin() + 1; it != batch_input_ids.end();
       ++it) {
    RET_CHECK_EQ(it->size(), input_seq_len);
  }

  RET_CHECK(!batch_prev_ids().empty());
  const size_t current_seq_len = TotalTokenSize();

  // Let builder re-populate the values of these tensors.
  MP_RETURN_IF_ERROR(builder_->InitAttentionMask(current_seq_len, input_seq_len,
                                                 *atten_masks_));
  if (!llm_params_.skip_absolute_positional_embeddings) {
    // Initialize the positional embedding data.
    MP_RETURN_IF_ERROR(builder_->InitPosEmbedding(
        current_seq_len, input_seq_len, *pos_embedding_));
  }
  if (segment_pos_) {
    // Initialize the segment pos.
    MP_RETURN_IF_ERROR(builder_->InitSegmentPos(current_seq_len, input_seq_len,
                                                *segment_pos_));
  }

  if (llm_params_.enable_dynamic_shape) {
    MP_RETURN_IF_ERROR(ReshapeInputResource());

    transformer_input()->Resize(Tensor::DimsType{
        batch_input_ids.size(), input_seq_len, llm_params_.model_dim_D});
    RET_CHECK_EQ(xnn_status_success,
                 xnn_reshape_external_value(
                     runtime_.get(),
                     transformer_input()->tensor_id(owned_subgraph_.get()),
                     transformer_input()->dims.size(),
                     transformer_input()->dims.data()));
    logits_output()->Resize(Tensor::DimsType{
        batch_input_ids.size(), input_seq_len, llm_params_.voc_size_V});
    RET_CHECK_EQ(
        xnn_status_success,
        xnn_reshape_external_value(
            runtime_.get(), logits_output()->tensor_id(owned_subgraph_.get()),
            logits_output()->dims.size(), logits_output()->dims.data()));
    for (auto& kv_cache : kv_cache()) {
      auto key = kv_cache.k_cache;
      auto value = kv_cache.v_cache;
      key->Resize({current_seq_len + input_seq_len, llm_params_.batch_size_B,
                   llm_params_.num_kv_heads, llm_params_.head_dim_H});
      value->Resize({current_seq_len + input_seq_len, llm_params_.batch_size_B,
                     llm_params_.num_kv_heads, llm_params_.head_dim_H});
      RET_CHECK_EQ(xnn_status_success,
                   xnn_reshape_external_value(
                       runtime_.get(), key->tensor_id(owned_subgraph_.get()),
                       key->dims.size(), key->dims.data()));
      RET_CHECK_EQ(xnn_status_success,
                   xnn_reshape_external_value(
                       runtime_.get(), value->tensor_id(owned_subgraph_.get()),
                       value->dims.size(), value->dims.data()));
    }
    RET_CHECK_EQ(xnn_status_success, xnn_reshape_runtime(runtime_.get()));
  }

  for (auto& kv_cache : kv_cache()) {
    ABSL_DCHECK(kv_cache.k_slice);
    ABSL_DCHECK(kv_cache.v_slice);
    kv_cache.k_slice->Borrow(kv_cache.k_cache->Slice(
        0, /*start=*/current_seq_len, /*end=*/current_seq_len + input_seq_len));
    kv_cache.v_slice->Borrow(kv_cache.v_cache->Slice(
        0, /*start=*/current_seq_len, /*end=*/current_seq_len + input_seq_len));
  }

  for (size_t batch = 0; batch < llm_params_.batch_size_B; ++batch) {
    auto slice = transformer_input()->Slice(0, batch);
    MP_RETURN_IF_ERROR(
        GetTokenEmbedding(batch_input_ids[batch], slice->DataAs<float>()));
  }

  for (size_t batch = 0; batch < llm_params_.batch_size_B; ++batch) {
    auto& prev_ids = batch_prev_ids()[batch];
    const auto& input_ids = batch_input_ids[batch];
    prev_ids.insert(prev_ids.end(), input_ids.begin(), input_ids.end());
  }
  MP_RETURN_IF_ERROR(SetupRuntime());
  return Run();
}

absl::Status Llm::SeekTimeStep(size_t time_step) {
  for (auto& prev_ids : batch_prev_ids()) {
    prev_ids.resize(time_step);
  }
  return absl::OkStatus();
}

absl::Status Llm::GetNextToken(std::vector<int>* output_ids) {
  MP_ASSIGN_OR_RETURN(auto logits, ComputeLogits());

  MP_ASSIGN_OR_RETURN(std::vector<std::vector<int>> tokens,
                      builder_->Sample(*logits));
  // Return only the first token for each draft.
  std::vector<int> output;
  output.reserve(tokens.size());
  for (int i = 0; i < tokens.size(); ++i) {
    output.push_back(tokens[i][0]);
  }

  *output_ids = output;
  RET_CHECK_EQ(output_ids->size(), llm_params_.batch_size_B);

  std::vector<std::vector<int>> next_token_ids(output_ids->size());
  for (size_t batch = 0; batch < llm_params_.batch_size_B; ++batch) {
    next_token_ids[batch].push_back(output_ids->at(batch));
  }

  return AddInputTokens(next_token_ids);
}

absl::StatusOr<std::shared_ptr<Tensor>> Llm::ComputeLogits(
    size_t expected_seq_len) {
  const size_t decode_step = TotalTokenSize();
  VLOG(2) << "Decode step " << decode_step;

  if (decode_step + llm_params_.draft_size_G >= llm_params_.seq_size_T) {
    return absl::OutOfRangeError(
        absl::StrCat("Hit max sequence length ", llm_params_.seq_size_T));
  }

  RET_CHECK(logits_output());
  const size_t logits_total_seq_len = logits_output()->dims[1];
  RET_CHECK_GE(logits_total_seq_len, expected_seq_len);
  if (logits_total_seq_len == expected_seq_len) {
    return logits_output();
  } else {
    if (logits_output()->dims[0] == 1) {
      return logits_output()->Slice(
          /*index=*/1, /*start=*/logits_total_seq_len - expected_seq_len,
          /*end=*/logits_total_seq_len);
    } else {
      Tensor::DimsType new_dims = logits_output()->dims;
      new_dims[1] = 1;
      std::shared_ptr<Tensor> last_slice(new Tensor(new_dims));
      MP_RETURN_IF_ERROR(last_slice->LoadFromVec({}));
      for (int batch = 0; batch < logits_output()->dims[0]; ++batch) {
        MP_RETURN_IF_ERROR(last_slice->Slice(0, batch)->LoadFromBuffer(
            logits_output()
                ->Slice(0, batch)
                ->Slice(1, logits_total_seq_len - expected_seq_len)
                ->Data()));
      }
      return last_slice;
    }
  }
}

absl::Status Llm::GetTokenEmbedding(const std::vector<int>& ids,
                                    float* embedding) {
  RET_CHECK_LE(ids.size(), llm_params_.seq_size_T);
  auto token_embedding = weights_.token_embedding ? weights_.token_embedding
                                                  : weights_.softmax_linear;
  RET_CHECK(token_embedding);
  RET_CHECK(token_embedding->dims[0] == llm_params_.voc_size_V)
      << "shape must be [vocab_size, _], such that following Slice() makes "
         "sense.";
  for (int id : ids) {
    MP_ASSIGN_OR_RETURN(auto embedding_slice,
                        token_embedding->Slice(0, id)->ConvertToF32());
    memcpy(embedding, embedding_slice->Data(),
           llm_params_.model_dim_D * sizeof(float));
    embedding += llm_params_.model_dim_D;
  }
  return absl::OkStatus();
}

absl::StatusOr<std::pair<std::shared_ptr<Tensor>, LlmBuilder::InputResource>>
LlmBuilder::PreProcess(std::shared_ptr<Tensor> token_embedding,
                       bool is_prefix) {
  InputResource resource;
  constexpr absl::string_view kAttnMaskSource = "atten_mask";
  constexpr absl::string_view kPosEmbeddingSource = "pos_embedding";
  constexpr absl::string_view kSegmentPosSource = "segment_pos";
  if (is_prefix) {
    MP_ASSIGN_OR_RETURN(resource.atten_mask, NewInput({llm_params_.seq_size_T,
                                                       llm_params_.seq_size_T},
                                                      kAttnMaskSource));
    MP_ASSIGN_OR_RETURN(resource.segment_pos, NewInput({llm_params_.seq_size_T,
                                                        llm_params_.head_dim_H},
                                                       kSegmentPosSource));
    MP_RETURN_IF_ERROR(
        InitSegmentPos(0, llm_params_.seq_size_T, *resource.segment_pos));
    MP_ASSIGN_OR_RETURN(
        resource.pos_embedding,
        NewInput({llm_params_.seq_size_T, llm_params_.model_dim_D},
                 kPosEmbeddingSource));
  } else {
    MP_ASSIGN_OR_RETURN(
        resource.pos_embedding,
        NewInput({llm_params_.draft_size_G + 1, llm_params_.model_dim_D},
                 kPosEmbeddingSource));
    MP_ASSIGN_OR_RETURN(
        resource.atten_mask,
        NewInput({llm_params_.draft_size_G + 1, llm_params_.seq_size_T},
                 kAttnMaskSource));
    MP_ASSIGN_OR_RETURN(
        resource.segment_pos,
        NewInput({llm_params_.draft_size_G + 1, llm_params_.head_dim_H},
                 kSegmentPosSource));
    MP_RETURN_IF_ERROR(
        InitSegmentPos(0, llm_params_.draft_size_G + 1, *resource.segment_pos));
  }
  const float dim_scale = std::sqrt(llm_params_.model_dim_D);
  MP_ASSIGN_OR_RETURN(auto scaled_embedding,
                      ElementMul(token_embedding, dim_scale));
  return std::make_pair(scaled_embedding, resource);
}

absl::StatusOr<std::shared_ptr<Tensor>> LlmBuilder::OneStackTransformer(
    int layer_index, std::shared_ptr<Tensor> input,
    LlmBuilder::InputResource resource,
    const LlmWeights::SelfAttentionWeights& sa_weights,
    const LlmWeights::FeedForwardWeights& ff_weights, bool is_prefix) {
  std::shared_ptr<Tensor> output;
  if (is_prefix) {
    MP_ASSIGN_OR_RETURN(
        output, SelfAttentionIncludeResidual(input, resource, sa_weights));
    if (internal_llm_params_.stop_at_last_kv_cache &&
        (layer_index == llm_params_.num_transformer_M - 1)) {
      return output;
    }
    MP_ASSIGN_OR_RETURN(output, FeedForwardIncludeResidual(output, ff_weights));
  } else {
    MP_ASSIGN_OR_RETURN(
        output, SelfAttentionIncludeResidual(input, resource, sa_weights));
    MP_ASSIGN_OR_RETURN(output, FeedForwardIncludeResidual(output, ff_weights));
  }
  return output;
}

absl::StatusOr<std::shared_ptr<Tensor>> LlmBuilder::SelfAttentionExcludeNorm(
    std::shared_ptr<Tensor> input, InputResource resource,
    const SelfAttentionWeights& sa_weights) {
  // [B, 1|T, N, H]
  MP_ASSIGN_OR_RETURN(auto k_proj,
                      SelfAttentionProj(input, sa_weights.k_weight));
  MP_ASSIGN_OR_RETURN(auto q_proj,
                      SelfAttentionProj(input, sa_weights.q_weight));
  MP_ASSIGN_OR_RETURN(auto v_proj,
                      SelfAttentionProj(input, sa_weights.v_weight));

  MP_ASSIGN_OR_RETURN(auto query_proj_after_rope,
                      Rope(q_proj, resource.segment_pos));
  MP_ASSIGN_OR_RETURN(auto key_proj_after_rope,
                      Rope(k_proj, resource.segment_pos));

  MP_RETURN_IF_ERROR(BuildKVCache(key_proj_after_rope, v_proj, resource));

  // encoded, [B, 1|T, N, H]
  MP_ASSIGN_OR_RETURN(auto kqv_merged,
                      DotAttention(query_proj_after_rope, key_proj_after_rope,
                                   v_proj, resource.atten_mask, sa_weights));

  const size_t B = kqv_merged->dims[0];
  const size_t NH = kqv_merged->dims[2] * kqv_merged->dims[3];
  MP_ASSIGN_OR_RETURN(auto outcome_reshaped, Reshape(kqv_merged, {B, 0, NH}));

  return MatMul(outcome_reshaped, sa_weights.post_proj_weight);
}

absl::StatusOr<std::shared_ptr<Tensor>>
LlmBuilder::SelfAttentionIncludeResidual(
    std::shared_ptr<Tensor> input, InputResource resource,
    const SelfAttentionWeights& sa_weights) {
  MP_ASSIGN_OR_RETURN(auto pre_attention,
                      ApplyNorm(input, sa_weights.pre_norm_weight,
                                llm_params_.sa_params.pre_norm));

  MP_ASSIGN_OR_RETURN(
      auto post_attention,
      SelfAttentionExcludeNorm(pre_attention, std::move(resource), sa_weights));

  MP_ASSIGN_OR_RETURN(auto post_norm,
                      ApplyNorm(post_attention, sa_weights.post_norm_weight,
                                llm_params_.sa_params.post_norm));

  return ElementAdd(input, post_norm);
}

absl::StatusOr<std::shared_ptr<Tensor>> LlmBuilder::FeedForwardExcludeNorm(
    std::shared_ptr<Tensor> input, const FeedForwardWeights& ff_weights) {
  MP_ASSIGN_OR_RETURN(auto layer_1, FullConn(input, ff_weights.layer_1_weight,
                                             ff_weights.layer_1_bias));

  MP_ASSIGN_OR_RETURN(auto layer_1_gate_before_activation,
                      FullConn(input, ff_weights.layer_1_gate_weight,
                               ff_weights.layer_1_gate_bias));
  std::shared_ptr<Tensor> layer_1_gate;
  switch (llm_params_.ff_params.activation) {
    case LlmParams::Activation::UNSPECIFIED:
      layer_1_gate = layer_1_gate_before_activation;
      break;
    case LlmParams::Activation::GELU: {
      MP_ASSIGN_OR_RETURN(layer_1_gate, Gelu(layer_1_gate_before_activation));
      break;
    }
    case LlmParams::Activation::SILU: {
      MP_ASSIGN_OR_RETURN(layer_1_gate, Silu(layer_1_gate_before_activation));
      break;
    }
    case LlmParams::Activation::RELU: {
      MP_ASSIGN_OR_RETURN(layer_1_gate, Relu(layer_1_gate_before_activation));
      break;
    }
    default: {
      break;
    }
  }

  MP_ASSIGN_OR_RETURN(auto layer_1_and_gate, ElementMul(layer_1, layer_1_gate));
  MP_ASSIGN_OR_RETURN(auto layer_2,
                      FullConn(layer_1_and_gate, ff_weights.layer_2_weight,
                               ff_weights.layer_2_bias));

  return layer_2;
}

absl::StatusOr<std::shared_ptr<Tensor>> LlmBuilder::FeedForwardIncludeResidual(
    std::shared_ptr<Tensor> input, const FeedForwardWeights& ff_weights) {
  MP_ASSIGN_OR_RETURN(auto pre_ff, ApplyNorm(input, ff_weights.pre_norm_weight,
                                             llm_params_.ff_params.pre_norm));

  MP_ASSIGN_OR_RETURN(auto pre_norm,
                      FeedForwardExcludeNorm(pre_ff, ff_weights));

  MP_ASSIGN_OR_RETURN(auto post_norm,
                      ApplyNorm(pre_norm, ff_weights.post_norm_weight,
                                llm_params_.ff_params.post_norm));

  return ElementAdd(post_norm, input);
}

absl::StatusOr<std::shared_ptr<Tensor>> LlmBuilder::PostProcess(
    std::shared_ptr<Tensor> transformer_out, const LlmWeights& weights) {
  MP_ASSIGN_OR_RETURN(transformer_out,
                      ApplyNorm(transformer_out, weights.final_norm_weight,
                                llm_params_.final_norm));
  RET_CHECK(weights.softmax_linear);
  MP_ASSIGN_OR_RETURN(
      auto logits_output,
      FullConn(transformer_out, weights.softmax_linear, weights.softmax_bias));
  return logits_output;
}

absl::Status LlmBuilder::InitAttentionMask(size_t current_seq_len,
                                           size_t process_seq_len,
                                           Tensor& out_attn_mask) {
  if (!attention_mask_values_.data()) {
    MP_RETURN_IF_ERROR(InitAttentionMaskValues(process_seq_len));
  }

  if (llm_params_.enable_dynamic_shape) {
    out_attn_mask.Resize(
        Tensor::DimsType{process_seq_len, current_seq_len + process_seq_len});
    for (size_t r = 0; r < out_attn_mask.dims[0]; ++r) {
      auto slice = out_attn_mask.Slice(0, r);
      MP_RETURN_IF_ERROR(slice->LoadFromBuffer(
          attention_mask_values_[r + current_seq_len].data()));
    }
  } else {
    RET_CHECK_EQ(out_attn_mask.num_elements,
                 llm_params_.seq_size_T * llm_params_.seq_size_T);
    MP_RETURN_IF_ERROR(
        out_attn_mask.LoadFromBuffer(attention_mask_values_.data()));
  }

  return absl::OkStatus();
}

absl::Status LlmBuilder::InitAttentionMaskValues(size_t process_seq_len) {
  const size_t seq_size = llm_params_.seq_size_T;
  constexpr float neg_value = 0.8 * std::numeric_limits<float>::lowest();
  {
    std::vector<float> values(seq_size * seq_size, neg_value);
    float* values_ptr = values.data();
    attention_mask_values_ = MakeMdSpan(values_ptr, seq_size, seq_size,
                                        [values = std::move(values)]() {});
  }
  switch (llm_params_.model_type) {
    case LlmParams::ModelType::PREFIX: {
      RET_CHECK_LE(process_seq_len, seq_size);
      // Prefix full attention for all tokens within input ids size(input),
      // and causal attention mask for all following tokens.
      for (int i = 0; i < seq_size; ++i) {
        for (int j = 0; j < seq_size; ++j) {
          if (j <= i || std::max(j, i) < process_seq_len) {
            attention_mask_values_.at(i, j) = 0;
          } else {
            break;
          }
        }
      }
      break;
    }
    case LlmParams::ModelType::CAUSAL: {
      for (int i = 0; i < seq_size; ++i) {
        for (int j = 0; j < seq_size; ++j) {
          if (j <= i) {
            attention_mask_values_.at(i, j) = 0;
          } else {
            break;
          }
        }
      }
      break;
    }
    default: {
      return absl::InvalidArgumentError(
          absl::StrCat("Unsupported model type: ", llm_params_.model_type));
    }
  }
  return absl::OkStatus();
}

absl::Status LlmBuilder::InitPosEmbeddingValues(size_t process_seq_len) {
  return absl::OkStatus();
}

absl::Status LlmBuilder::InitPosEmbedding(size_t current_seq_len,
                                          size_t process_seq_len,
                                          Tensor& out_pos_embedding) {
  if (!position_embedding_values_) {
    MP_RETURN_IF_ERROR(InitPosEmbeddingValues(process_seq_len));
  }

  RET_CHECK_EQ(out_pos_embedding.dims.size(), 2);
  if (out_pos_embedding.dims[0] == 1) {
    RET_CHECK_EQ(out_pos_embedding.num_elements, llm_params_.model_dim_D);
    MP_RETURN_IF_ERROR(out_pos_embedding.LoadFromBuffer(
        position_embedding_values_->data() +
        llm_params_.model_dim_D * current_seq_len));
  } else {
    out_pos_embedding.Resize(
        Tensor::DimsType{process_seq_len, llm_params_.model_dim_D});
    MP_RETURN_IF_ERROR(out_pos_embedding.LoadFromBuffer(
        position_embedding_values_->data() +
        llm_params_.model_dim_D * current_seq_len));
  }
  return absl::OkStatus();
}

absl::Status LlmBuilder::InitSegmentPosValues(size_t rope_size) {
  std::vector<float> values =
      FillXnnRoPEWeights(llm_params_.seq_size_T, rope_size);
  float* values_ptr = values.data();
  segment_pos_values_ =
      MakeMdSpan(values_ptr, llm_params_.seq_size_T, rope_size,
                 [values = std::move(values)]() {});
  return absl::OkStatus();
}

absl::Status LlmBuilder::InitSegmentPos(size_t current_seq_len,
                                        size_t process_seq_len,
                                        Tensor& out_segment_pos) {
  RET_CHECK_EQ(out_segment_pos.dims.size(), 2);
  const size_t rope_size = out_segment_pos.dims[1];
  if (!segment_pos_values_.data()) {
    MP_RETURN_IF_ERROR(InitSegmentPosValues(rope_size));
  }

  out_segment_pos.Resize(Tensor::DimsType{process_seq_len, rope_size});
  MP_RETURN_IF_ERROR(out_segment_pos.LoadFromBuffer(
      segment_pos_values_[current_seq_len].data()));
  return absl::OkStatus();
}

absl::StatusOr<std::vector<std::vector<int>>> LlmBuilder::Sample(
    const Tensor& logits) {
  if (sampler_ == nullptr) {
    MP_ASSIGN_OR_RETURN(
        sampler_,
        Sampler::Create(Sampler::Type::kGreedy, /*top_k=*/0, /*top_p=*/0.0,
                        /*top_temperature=*/0.0, /*seed=*/0));
  }
  return sampler_->Sample(logits);
}

absl::StatusOr<std::shared_ptr<Tensor>> LlmBuilder::DotAttention(
    std::shared_ptr<Tensor> query_proj, std::shared_ptr<Tensor> key_proj,
    std::shared_ptr<Tensor> value_proj, std::shared_ptr<Tensor> atten_mask,
    const SelfAttentionWeights& sa_weights) {
  // BTNH
  std::shared_ptr<Tensor> query_after_scale;
  switch (llm_params_.sa_params.attention_scale_type) {
    case LlmParams::AttentionScaleType::PER_DIM_SCALE: {
      MP_ASSIGN_OR_RETURN(query_after_scale,
                          PerDimScale(query_proj, sa_weights.per_dim_scale));
      break;
    }
    case LlmParams::AttentionScaleType::INV_SQRT_HEAD_DIM: {
      // Scale the query values by multiplying 1 / sqrt(dim_per_head).
      float scale = 1.0f / sqrt(llm_params_.head_dim_H);
      MP_ASSIGN_OR_RETURN(query_after_scale, ElementMul(query_proj, scale));
      break;
    }
    default:
      return absl::InvalidArgumentError(
          absl::StrCat("Unsupported attention scale type: ",
                       llm_params_.sa_params.attention_scale_type));
  }

  // Dot similarity
  // BTNH -> BNTH
  MP_ASSIGN_OR_RETURN(auto query_permuted,
                      Permute(query_after_scale, {0, 2, 1, 3}));
  // BSN'H -> BN'SH
  MP_ASSIGN_OR_RETURN(auto key_permuted, Permute(key_proj, {0, 2, 1, 3}));
  // einsum(BNTH.BN'SH -> BNTS)
  MP_ASSIGN_OR_RETURN(auto logits, QKVAttention(query_permuted, key_permuted,
                                                {0, llm_params_.head_dim_H}));

  // Cap, mask
  if (llm_params_.sa_params.soft_cap_value > 0.0f) {
    MP_ASSIGN_OR_RETURN(logits,
                        CapTanh(logits, llm_params_.sa_params.soft_cap_value));
  }
  MP_ASSIGN_OR_RETURN(auto padded_logits, ElementAdd(atten_mask, logits));
  MP_ASSIGN_OR_RETURN(auto probs, Softmax(padded_logits));
  MP_ASSIGN_OR_RETURN(auto value_permuted, Permute(value_proj, {0, 2, 3, 1}));

  // Outcome
  // einsum(BNTS.BN'HS) -> BNTH
  MP_ASSIGN_OR_RETURN(
      auto outcome_before_permute,
      QKVAttention(probs, value_permuted, {llm_params_.head_dim_H, 0}));
  // [B, N, T, H] -> BTNH
  return Permute(outcome_before_permute, {0, 2, 1, 3});
}

absl::StatusOr<std::shared_ptr<Tensor>> LlmBuilder::ApplyNorm(
    std::shared_ptr<Tensor> input,
    std::optional<LlmWeights::NormWeights> weights, LlmParams::Norm norm_type) {
  std::shared_ptr<Tensor> output = input;
  switch (norm_type) {
    case LlmParams::Norm::NO_NORM:
      break;
    case LlmParams::Norm::RMS_NORM: {
      MP_ASSIGN_OR_RETURN(
          output,
          RmsNorm(input,
                  std::get<RMSNormWeights>(weights.value()).norm_weight));
      break;
    }
    case LlmParams::Norm::LAYER_NORM: {
      const auto& layer_norm_weights =
          std::get<LayerNormWeights>(weights.value());
      MP_ASSIGN_OR_RETURN(
          output, LayerNorm(input, layer_norm_weights.epsilon,
                            layer_norm_weights.gamma, layer_norm_weights.beta));
      break;
    }
    default:
      return absl::NotFoundError("No norm specified.");
  }
  return output;
}

absl::Status LlmBuilder::BuildKVCache(std::shared_ptr<Tensor>& key,
                                      std::shared_ptr<Tensor>& value,
                                      InputResource& resource) {
  if (resource.cache) {
    RET_CHECK_EQ(key->dims.size(), 4);
    RET_CHECK_EQ(key->dims[0], llm_params_.batch_size_B);
    RET_CHECK_EQ(value->dims.size(), 4);
    RET_CHECK_EQ(value->dims[0], llm_params_.batch_size_B);
    // Permute has memory copy, in some cases we can use reshape to mimic
    // permute, to avoid memory copy.
    const bool quick_reshape = (key->dims[0] == 1 || key->dims[1] == 1);
    // BSNH -> SBNH
    if (quick_reshape) {
      MP_ASSIGN_OR_RETURN(
          resource.cache->k_slice,
          Reshape(key, {key->dims[1], llm_params_.batch_size_B,
                        llm_params_.num_kv_heads, llm_params_.head_dim_H}));
      MP_ASSIGN_OR_RETURN(
          resource.cache->v_slice,
          Reshape(value, {value->dims[1], llm_params_.batch_size_B,
                          llm_params_.num_kv_heads, llm_params_.head_dim_H}));
    } else {
      MP_ASSIGN_OR_RETURN(resource.cache->k_slice, Permute(key, {1, 0, 2, 3}));
      MP_ASSIGN_OR_RETURN(resource.cache->v_slice,
                          Permute(value, {1, 0, 2, 3}));
    }

    MP_ASSIGN_OR_RETURN(
        resource.cache->k_cache,
        NewInput(resource.cache->k_slice->dims, "prefix_k_cache"));
    MP_ASSIGN_OR_RETURN(
        resource.cache->v_cache,
        NewInput(resource.cache->v_slice->dims, "prefix_v_cache"));
    (resource.cache->k_slice = key)->MarkOutput().tag = "prefix_k_slice";
    (resource.cache->v_slice = value)->MarkOutput().tag = "prefix_v_slice";

    // TBNH -> BTNH
    if (quick_reshape) {
      MP_ASSIGN_OR_RETURN(
          key, Reshape(resource.cache->k_cache,
                       {llm_params_.batch_size_B, 0, llm_params_.num_kv_heads,
                        llm_params_.head_dim_H}));
      MP_ASSIGN_OR_RETURN(
          value, Reshape(resource.cache->v_cache,
                         {llm_params_.batch_size_B, 0, llm_params_.num_kv_heads,
                          llm_params_.head_dim_H}));
    } else {
      // TODO - b/329445989: Consolidate this permute with DotAttention.
      MP_ASSIGN_OR_RETURN(key, Permute(resource.cache->k_cache, {1, 0, 2, 3}));
      MP_ASSIGN_OR_RETURN(value,
                          Permute(resource.cache->v_cache, {1, 0, 2, 3}));
    }
  }

  return absl::OkStatus();
}

}  // namespace xnn_utils
}  // namespace mediapipe::tasks::genai