chromium/third_party/mediapipe/src/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/falcon.cc

// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/falcon.h"

#include <cmath>
#include <cstddef>
#include <limits>
#include <memory>
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/genai/inference/common/mdspan.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/llm.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/llm_weights.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.h"

namespace mediapipe::tasks::genai::xnn_utils {

absl::StatusOr<
    std::pair<std::shared_ptr<Tensor>, FalconRW1BBuilder::InputResource>>
FalconRW1BBuilder::PreProcess(std::shared_ptr<Tensor> token_embedding,
                              bool is_prefix) {
  InputResource resource;
  if (is_prefix) {
    // Fused attention mask includes AliBi
    MP_ASSIGN_OR_RETURN(resource.atten_mask,
                        NewInput({llm_params_.seq_size_T, llm_params_.n_heads_N,
                                  llm_params_.seq_size_T}));
  } else {
    MP_ASSIGN_OR_RETURN(
        resource.atten_mask,
        NewInput({1, llm_params_.n_heads_N, llm_params_.seq_size_T}));
  }
  return std::make_pair(token_embedding, resource);
};

absl::StatusOr<std::shared_ptr<Tensor>>
FalconRW1BBuilder::SelfAttentionExcludeNorm(
    std::shared_ptr<Tensor> input, InputResource resource,
    const LlmWeights::SelfAttentionWeights& sa_weights) {
  // [B, 1|T, N, H]
  MP_ASSIGN_OR_RETURN(auto k_proj, SelfAttentionProj(input, sa_weights.k_weight,
                                                     sa_weights.k_bias));
  MP_ASSIGN_OR_RETURN(auto q_proj, SelfAttentionProj(input, sa_weights.q_weight,
                                                     sa_weights.q_bias));
  MP_ASSIGN_OR_RETURN(auto v_proj, SelfAttentionProj(input, sa_weights.v_weight,
                                                     sa_weights.v_bias));

  MP_RETURN_IF_ERROR(BuildKVCache(k_proj, v_proj, resource));

  MP_ASSIGN_OR_RETURN(auto permuted_mask,
                      Permute(resource.atten_mask, {1, 0, 2}));
  // [B, 1|T, N, H]
  MP_ASSIGN_OR_RETURN(auto kqv_merged, DotAttention(q_proj, k_proj, v_proj,
                                                    permuted_mask, sa_weights));

  const size_t B = kqv_merged->dims[0];
  const size_t NH = kqv_merged->dims[2] * kqv_merged->dims[3];
  MP_ASSIGN_OR_RETURN(auto outcome_reshaped, Reshape(kqv_merged, {B, 0, NH}));
  return FullConn(outcome_reshaped, sa_weights.post_proj_weight,
                  sa_weights.post_proj_bias);
}

absl::StatusOr<std::shared_ptr<Tensor>>
FalconRW1BBuilder::FeedForwardExcludeNorm(
    std::shared_ptr<Tensor> input,
    const LlmWeights::FeedForwardWeights& ff_weights) {
  MP_ASSIGN_OR_RETURN(auto linear1, FullConn(input, ff_weights.layer_1_weight,
                                             ff_weights.layer_1_bias));
  MP_ASSIGN_OR_RETURN(auto gelu1, Gelu(linear1));
  return FullConn(gelu1, ff_weights.layer_2_weight, ff_weights.layer_2_bias);
}

absl::Status FalconRW1BBuilder::InitAttentionMask(size_t current_seq_len,
                                                  size_t process_seq_len,
                                                  Tensor& out_attn_mask) {
  if (!attention_mask_values_.data()) {
    MP_RETURN_IF_ERROR(InitAlibiAttentionMaskValues());
  }

  if (llm_params_.enable_dynamic_shape) {
    out_attn_mask.Resize(Tensor::DimsType{process_seq_len,
                                          llm_params_.n_heads_N,
                                          current_seq_len + process_seq_len});
    for (size_t r = 0; r < out_attn_mask.dims[0]; ++r) {
      for (size_t n = 0; n < llm_params_.n_heads_N; ++n) {
        auto slice = out_attn_mask.Slice(0, r)->Slice(1, n);
        MP_RETURN_IF_ERROR(slice->LoadFromBuffer(
            attention_mask_values_[r + current_seq_len][n].data()));
      }
    }
  } else {
    RET_CHECK_EQ(out_attn_mask.num_elements, llm_params_.seq_size_T *
                                                 llm_params_.n_heads_N *
                                                 llm_params_.seq_size_T);
    MP_RETURN_IF_ERROR(
        out_attn_mask.LoadFromBuffer(attention_mask_values_.data()));
  }

  return absl::OkStatus();
}

absl::Status FalconRW1BBuilder::InitAlibiAttentionMaskValues() {
  RET_CHECK_EQ(llm_params_.n_heads_N, 32)
      << "Hardcoded base only works with 32 Heads.";
  const float base = 1 / sqrt(sqrt(2));
  const float scale = 1.0f / sqrt(llm_params_.head_dim_H);

  {
    std::vector<float> values(
        llm_params_.seq_size_T * llm_params_.n_heads_N * llm_params_.seq_size_T,
        0.8 * std::numeric_limits<float>::lowest());
    float* values_ptr = values.data();
    attention_mask_values_ =
        MakeMdSpan(values_ptr, llm_params_.seq_size_T, llm_params_.n_heads_N,
                   llm_params_.seq_size_T, [values = std::move(values)]() {});
  }

  // mask: T,N,T
  // Note: Since the mask has different values across the heads, we use an
  // alternative mask shape to allow tensor slicing. The mask gets transposed
  // before being added to the attention scores.
  for (int i = 0; i < llm_params_.seq_size_T; ++i) {
    float alibi = 1.0f;
    for (int j = 0; j < llm_params_.n_heads_N; ++j) {
      alibi *= base;
      for (int k = 0; k < llm_params_.seq_size_T; ++k) {
        if (k > i) {
          break;
        }
        attention_mask_values_.at(i, j, k) = k * alibi * scale;
      }
    }
  }
  return absl::OkStatus();
}

}  // namespace mediapipe::tasks::genai::xnn_utils