chromium/third_party/mediapipe/src/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/utils.h

// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_XNN_UTILS_UTILS_H_
#define MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_XNN_UTILS_UTILS_H_

#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/memory_mapped_file.h"

namespace mediapipe::tasks::genai {
namespace xnn_utils {

static constexpr absl::string_view kKeySelfAttentionReshapedWeight{
    "self_attention_reshaped_weight_N"};
// Usually fully connect is [K,M] dot [M,N] => [K,N]. Some code base by default
// expects [K,M] dot [N,M] => [K,N], in which case this metadata should be set.
static constexpr absl::string_view kKeyInDimLastInWeight{
    "in_dim_last_in_weight"};

std::vector<float> FillXnnRoPEWeights(size_t max_seq_len, size_t num_channels);

// expect_size_bytes == 0 means don't check size.
template <typename element_type = char>
static absl::StatusOr<std::shared_ptr<element_type>> LoadBufferFromFile(
    absl::string_view file_path, size_t* buffer_size, bool use_mmap = true,
    size_t expect_size_bytes = 0) {
  RET_CHECK(buffer_size);
  if (use_mmap) {
    MP_ASSIGN_OR_RETURN(auto mapped_file,
                        llm_utils::MemoryMappedFile::Create(file_path));
    if (expect_size_bytes) {
      RET_CHECK_EQ(expect_size_bytes, mapped_file->length())
          << "File size " << mapped_file->length() << ", expected "
          << expect_size_bytes << ", file path " << file_path;
    }

    *buffer_size = mapped_file->length();

    // shared_ptr deleter must be copy constructible, so wrap the unique_ptr in
    // another shared_ptr.
    void* data = mapped_file->data();
    auto file_deleter =
        std::make_shared<std::unique_ptr<llm_utils::MemoryMappedFile>>(
            std::move(mapped_file));
    return std::shared_ptr<element_type>(
        static_cast<element_type*>(data),
        [file_deleter](auto* p) { file_deleter->reset(); });
  } else {
    auto read_buffer = std::make_shared<std::string>();
    MP_RETURN_IF_ERROR(
        mediapipe::file::GetContents(file_path, read_buffer.get()));

    if (expect_size_bytes) {
      RET_CHECK_EQ(expect_size_bytes, read_buffer->size())
          << "File size " << read_buffer->size() << ", expected "
          << expect_size_bytes << ", file path " << file_path;
    }

    *buffer_size = read_buffer->size();
    return std::shared_ptr<element_type>(
        read_buffer, reinterpret_cast<element_type*>(read_buffer->data()));
  }
}

template <typename element_type = char>
static absl::StatusOr<std::shared_ptr<element_type>> LoadBufferFromFile(
    absl::string_view file_path, bool use_mmap = true,
    size_t expect_size_bytes = 0) {
  size_t buffer_size;
  return LoadBufferFromFile<element_type>(file_path, &buffer_size, use_mmap,
                                          expect_size_bytes);
}

// Assumes each element in `vec` is less than 16 (4-bit), and packs into an
// array with half the original length. The 4-bit element is stored in LSB
// first, then MSB.
absl::StatusOr<std::vector<uint8_t>> PackInt4ToInt8(absl::Span<uint8_t> vec);

// Unpack the compact 8-bit elements to an array (twice the length of given
// array) of 4-bit elements. The lower 4-bits are unpacked, then followed by
// higher 4-bits.
absl::StatusOr<std::vector<uint8_t>> UnpackInt8ToInt4(
    absl::Span<uint8_t> packed_vec);

absl::StatusOr<std::vector<float>> PositionEmbedding(
    int seq_length, int embedding_dim, float min_timescale = 1.0f,
    float max_timescale = 10000.0f);

// Like PositionEmbedding() but outputs a fully padded embedding for a fixed
// sequence length of `seq_length`, with an input of `input_length` tokens where
// `input_length` is no larger than `seq_length`.
absl::StatusOr<std::vector<float>> FullPositionEmbedding(
    int input_length, int seq_length, int embedding_dim,
    float min_timescale = 1.0f, float max_timescale = 10000.0f);

absl::Status SelectTopK(int top_k,
                        std::vector<std::pair<float, int>>* logits_ids);

}  // namespace xnn_utils
}  // namespace mediapipe::tasks::genai

#endif  // MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_XNN_UTILS_UTILS_H_