// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/model_data.h"
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <limits>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "flatbuffers/buffer.h"
#include "flatbuffers/vector.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/genai/inference/proto/llm_file_metadata.pb.h"
#include "mediapipe/tasks/cc/genai/inference/proto/llm_params.pb.h"
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/memory_mapped_file.h"
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/metadata_utils.h"
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/scoped_file.h"
#include "tensorflow/lite/model_builder.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mediapipe::tasks::genai::llm_utils {
namespace {
// The maximum size of the tflite::Model (excluding buffers).
constexpr uint64_t kTfliteBaseSize = 1024 * 1024;
class SpanHolder : public DataHolder<uint8_t> {
public:
explicit SpanHolder(absl::Span<uint8_t> data) : data_(data) {}
absl::Span<uint8_t> GetData() const override { return data_; }
private:
absl::Span<uint8_t> data_;
};
class FreeingSpanHolder : public DataHolder<uint8_t> {
public:
explicit FreeingSpanHolder(absl::Span<uint8_t> data) : data_(data) {}
~FreeingSpanHolder() override { free(data_.data()); }
absl::Span<uint8_t> GetData() const override { return data_; }
private:
absl::Span<uint8_t> data_;
};
// Base class for loading models from a tflite file.
class TfliteModelData : public ModelData {
public:
explicit TfliteModelData(std::shared_ptr<tflite::FlatBufferModel> model)
: model_(std::move(model)) {}
~TfliteModelData() override = default;
std::optional<odml::infra::proto::LlmModelType> GetModelType() override {
const tflite::Metadata* metadata = GetMetadata(kLlmModelTypeName);
if (metadata == nullptr) {
return std::nullopt;
}
return static_cast<odml::infra::proto::LlmModelType>(metadata->buffer());
}
std::optional<int> LoRARank() override {
const tflite::Metadata* metadata = GetMetadata(kLoRARank);
if (metadata == nullptr) return std::nullopt;
return static_cast<int>(metadata->buffer());
}
const odml::infra::proto::LlmParameters& GetLlmParameters() override {
return llm_parameters_;
}
absl::StatusOr<std::string> ReadMetadata(absl::string_view name) override {
const tflite::Metadata* metadata = GetMetadata(name);
if (metadata == nullptr) {
return absl::NotFoundError(
absl::StrCat("Failed to get metadata: ", name));
}
const tflite::Buffer* backend_buffer =
model_->GetModel()->buffers()->Get(metadata->buffer());
MP_ASSIGN_OR_RETURN(
auto data, ReadData(backend_buffer->offset(), backend_buffer->size()));
return std::string(reinterpret_cast<const char*>(data->GetData().data()),
data->GetData().size());
}
uint64_t GetMaxTensorSize() const override {
uint64_t max_size = 0;
const tflite::Model* tflite_model = model_->GetModel();
const flatbuffers::Vector<flatbuffers::Offset<tflite::Buffer>>& buffers =
*tflite_model->buffers();
for (const tflite::SubGraph* subgraph : *tflite_model->subgraphs()) {
for (const tflite::Tensor* tfl_tensor : *subgraph->tensors()) {
if (tfl_tensor->buffer() >= buffers.size()) {
continue;
}
max_size =
std::max(max_size, buffers.Get(tfl_tensor->buffer())->size());
}
}
return max_size;
}
uint64_t GetTensorSize(absl::string_view name) const override {
const tflite::Buffer* buffer = GetBuffer(name);
if (buffer) {
return buffer->size();
}
return 0;
}
absl::StatusOr<std::unique_ptr<DataHolder<uint8_t>>> ReadTensor(
absl::string_view name) override {
const tflite::Buffer* buffer = GetBuffer(name);
if (buffer) {
return ReadData(buffer->offset(), buffer->size());
}
return nullptr;
}
absl::Status InitLlmParameters() {
MP_ASSIGN_OR_RETURN(std::string proto_str,
ReadMetadata(llm_parameters_.GetTypeName()));
RET_CHECK(llm_parameters_.ParseFromString(proto_str));
return absl::OkStatus();
}
protected:
virtual absl::StatusOr<std::unique_ptr<DataHolder<uint8_t>>> ReadData(
uint64_t offset, uint64_t size) = 0;
std::shared_ptr<tflite::FlatBufferModel> model_;
odml::infra::proto::LlmParameters llm_parameters_;
private:
const tflite::Buffer* GetBuffer(absl::string_view name) const {
const tflite::Model* tflite_model = model_->GetModel();
const flatbuffers::Vector<flatbuffers::Offset<tflite::Buffer>>& buffers =
*tflite_model->buffers();
for (const tflite::SubGraph* subgraph : *tflite_model->subgraphs()) {
for (const tflite::Tensor* tfl_tensor : *subgraph->tensors()) {
if (name != tfl_tensor->name()->c_str()) {
continue;
}
if (tfl_tensor->buffer() >= buffers.size()) {
continue;
}
return buffers.Get(tfl_tensor->buffer());
}
}
return nullptr;
}
const tflite::Metadata* GetMetadata(absl::string_view name) {
const tflite::Model* tflite_model = model_->GetModel();
if (tflite_model->metadata() == nullptr) {
return nullptr;
}
for (const tflite::Metadata* metadata : *tflite_model->metadata()) {
if (name == metadata->name()->c_str()) {
return metadata;
}
}
return nullptr;
}
};
// Loads from a tflite model which includes all buffers in the allocation.
class InMemoryTfliteModelData : public TfliteModelData {
public:
explicit InMemoryTfliteModelData(
std::shared_ptr<tflite::FlatBufferModel> model)
: TfliteModelData(std::move(model)) {}
~InMemoryTfliteModelData() override = default;
void Clear() override {}
protected:
absl::StatusOr<std::unique_ptr<DataHolder<uint8_t>>> ReadData(
uint64_t offset, uint64_t size) override {
return std::make_unique<SpanHolder>(absl::MakeSpan(
const_cast<uint8_t*>(
static_cast<const uint8_t*>(model_->allocation()->base())) +
offset,
size));
}
};
// Loads tflite data from a file as needed.
class FileTfliteModelData : public TfliteModelData {
public:
FileTfliteModelData(std::shared_ptr<tflite::FlatBufferModel> model,
std::unique_ptr<DataHolder<const uint8_t>> model_data,
std::shared_ptr<const ScopedFile> file)
: TfliteModelData(std::move(model)),
file_(std::move(file)),
model_data_(std::move(model_data)) {}
~FileTfliteModelData() override = default;
void Clear() override {
file_.reset();
model_data_.reset();
}
protected:
absl::StatusOr<std::unique_ptr<DataHolder<uint8_t>>> ReadData(
uint64_t offset, uint64_t size) override {
RET_CHECK(file_);
return CreateMemoryMappedDataHolder<uint8_t>(file_->file(), offset, size,
key_);
}
private:
static uint32_t next_key_;
const std::string key_{absl::StrCat("FileTfliteModelData_", next_key_++)};
std::shared_ptr<const ScopedFile> file_;
std::unique_ptr<DataHolder<const uint8_t>> model_data_;
};
uint32_t FileTfliteModelData::next_key_ = 0;
// Loads tflite data from the provided function. This owns any data returned
// from the read data function.
class FunctionTfliteModelData : public TfliteModelData {
public:
FunctionTfliteModelData(std::shared_ptr<tflite::FlatBufferModel> model,
ModelData::ReadDataFn fn)
: TfliteModelData(std::move(model)), fn_(std::move(fn)) {}
~FunctionTfliteModelData() override { Clear(); }
void Clear() override {
free(const_cast<void*>(model_->allocation()->base()));
fn_(0, 0, ReadMode::DISCARD_ALL);
}
protected:
absl::StatusOr<std::unique_ptr<DataHolder<uint8_t>>> ReadData(
uint64_t offset, uint64_t size) override {
void* data = fn_(offset, size, ReadMode::DISCARD);
RET_CHECK(data) << "Error fetching data.";
return std::make_unique<FreeingSpanHolder>(
absl::MakeSpan(static_cast<uint8_t*>(data), size));
}
private:
ModelData::ReadDataFn fn_;
};
uint64_t AlignByN(uint64_t number, uint64_t n) {
const uint64_t q = number / n;
return (number % n == 0 ? q : q + 1) * n;
}
} // namespace
OffsetAndSize GetAlignedOffsetAndSize(uint64_t base_offset,
uint64_t base_size) {
const size_t kAlignment = MemoryMappedFile::GetOffsetAlignment();
uint64_t offset = (base_offset / kAlignment) * kAlignment;
uint64_t size = AlignByN(base_offset - offset + base_size, kAlignment);
return {.offset = offset, .size = size};
}
// static
absl::StatusOr<std::shared_ptr<ModelData>> ModelData::Create(
std::shared_ptr<tflite::FlatBufferModel> model) {
auto model_data = std::make_shared<InMemoryTfliteModelData>(std::move(model));
MP_RETURN_IF_ERROR(model_data->InitLlmParameters());
return model_data;
}
// static
absl::StatusOr<std::shared_ptr<ModelData>> ModelData::Create(ScopedFile file) {
return Create(std::make_shared<ScopedFile>(std::move(file)));
}
absl::StatusOr<std::shared_ptr<ModelData>> ModelData::Create(
std::shared_ptr<const ScopedFile> file) {
// Load the first chunk of the file as a tflite model, and load the rest
// on-demand when needed.
MP_ASSIGN_OR_RETURN(
auto data, CreateMemoryMappedDataHolder<const uint8_t>(
file->file(), /*offset=*/0, /*size=*/kTfliteBaseSize));
auto model = tflite::FlatBufferModel::BuildFromBuffer(
reinterpret_cast<const char*>(data->GetData().data()),
data->GetData().size());
RET_CHECK(model) << "Error building tflite model.";
auto model_data = std::make_shared<FileTfliteModelData>(
std::move(model), std::move(data), std::move(file));
MP_RETURN_IF_ERROR(model_data->InitLlmParameters());
return model_data;
}
// static
absl::StatusOr<std::shared_ptr<ModelData>> ModelData::Create(ReadDataFn fn) {
// Load the first chunk of the file as a tflite model, and load the rest
// on-demand when needed.
void* data = fn(0, kTfliteBaseSize, ReadMode::KEEP);
RET_CHECK(data) << "Error fetching data.";
auto model = tflite::FlatBufferModel::BuildFromBuffer(
reinterpret_cast<const char*>(data), kTfliteBaseSize);
RET_CHECK(model) << "Error building tflite model.";
auto model_data = std::make_shared<FunctionTfliteModelData>(std::move(model),
std::move(fn));
MP_RETURN_IF_ERROR(model_data->InitLlmParameters());
return model_data;
}
// static
absl::StatusOr<std::shared_ptr<ModelData>> ModelData::Create(
absl::string_view weight_path) {
MP_ASSIGN_OR_RETURN(auto tflite_file, ScopedFile::Open(weight_path));
return ModelData::Create(std::move(tflite_file));
}
} // namespace mediapipe::tasks::genai::llm_utils