// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/pack_weights_cache.h"
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
#include "flatbuffers/buffer.h"
#include "flatbuffers/flatbuffer_builder.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/memory_mapped_file.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/named_buffer_generated.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.h"
#include "xnnpack.h" // from @XNNPACK
namespace mediapipe::tasks::genai::xnn_utils {
namespace {
bool operator==(const xnn_weights_cache_look_up_key& lhs,
const xnn_weights_cache_look_up_key& rhs) {
return lhs.kernel == rhs.kernel && lhs.bias == rhs.bias &&
lhs.seed == rhs.seed;
}
} // namespace
PackWeightsCache::PackWeightsCache(absl::string_view cache_path)
: cache_path_(cache_path) {
xnn_weights_cache = &cache_provider_;
}
PackWeightsCache::~PackWeightsCache() { xnn_weights_cache = nullptr; }
absl::Status PackWeightsCache::Initialize() {
mmap_file_ = GetMmapFile(cache_path_);
if (mmap_file_) {
MP_RETURN_IF_ERROR(InitializeFromCache(mmap_file_));
} else {
builder_ = std::make_unique<flatbuffers::FlatBufferBuilder>();
}
cache_provider_.context = this;
cache_provider_.look_up = (size_t (*)(
void*, const xnn_weights_cache_look_up_key*))PackWeightsCache::look_up;
cache_provider_.reserve_space =
(void* (*)(void*, size_t))PackWeightsCache::reserve_space;
cache_provider_.look_up_or_insert =
(size_t (*)(void*, const xnn_weights_cache_look_up_key*, void*,
size_t))PackWeightsCache::look_up_or_insert;
cache_provider_.is_finalized =
(bool (*)(void*))PackWeightsCache::is_finalized;
cache_provider_.offset_to_addr =
(void* (*)(void*, size_t))PackWeightsCache::offset_to_addr;
cache_provider_.delete_cache =
(enum xnn_status(*)(void*))PackWeightsCache::delete_cache;
return absl::OkStatus();
}
absl::Status PackWeightsCache::AddUnpackedWeight(
absl::string_view name, std::shared_ptr<Tensor> weight) {
RET_CHECK(!name.empty());
RET_CHECK(weight->Data());
RET_CHECK(!kernel_to_name_.contains(weight->Data()));
kernel_to_name_[weight->Data()] = name;
return absl::OkStatus();
}
absl::Status PackWeightsCache::Finalize() {
MP_RETURN_IF_ERROR(error_status_);
is_finalized_ = true;
if (!builder_) {
return absl::OkStatus();
}
std::vector<flatbuffers::Offset<Buffer>> buffers;
for (const auto& [name, offset_size] : name_to_offset_size_) {
auto new_buffer =
CreateBuffer(*builder_, builder_->CreateString(std::string(name)),
offset_size.first, offset_size.second);
buffers.push_back(new_buffer);
}
auto named_buffers = CreateNamedBuffers(
*builder_, builder_->CreateVector(buffers), /*flatbuffer_size=*/1);
FinishNamedBuffersBuffer(*builder_, named_buffers);
auto serialized =
std::string(reinterpret_cast<const char*>(builder_->GetBufferPointer()),
builder_->GetSize());
// Ensure 64 aligned.
while (serialized.size() % 64 != 0) serialized += '\0';
const size_t flatbuffer_size = serialized.size();
{
NamedBuffers* mutable_named_buffer =
GetMutableNamedBuffers(serialized.data());
RET_CHECK(mutable_named_buffer->mutate_flatbuffer_size(flatbuffer_size));
}
MP_RETURN_IF_ERROR(Prepend(serialized));
builder_.reset();
mmap_file_ = GetMmapFile(cache_path_);
RET_CHECK(mmap_file_);
return InitializeFromCache(mmap_file_);
}
bool PackWeightsCache::ShouldDoubleCheckCompatibility(
const xnn_weights_cache_look_up_key* cache_key) {
if (builder_) return false;
if (key_sent_for_double_check_.has_value()) return false;
if (auto entry = kernel_to_name_.find(cache_key->kernel);
entry != kernel_to_name_.end()) {
absl::string_view name = entry->second;
// Usually only the fully_connect op in LLM needs packing, so here we just
// double check the first FullConn kernel, and assume others are good.
if (absl::StrContains(name, ".w")) {
key_sent_for_double_check_ = *cache_key;
return true;
}
}
return false;
}
std::shared_ptr<llm_utils::MemoryMappedFile> PackWeightsCache::GetMmapFile(
absl::string_view filename) {
return mediapipe::file::Exists(filename).ok()
? llm_utils::MemoryMappedFile::CreateMutable(filename).value_or(
nullptr)
: nullptr;
}
absl::Status PackWeightsCache::InitializeFromCache(
std::shared_ptr<llm_utils::MemoryMappedFile> mmap_cache) {
name_to_offset_size_.clear();
named_buffers_ = std::shared_ptr<const NamedBuffers>(
mmap_cache, GetNamedBuffers(mmap_cache->data()));
for (const Buffer* buffer : *named_buffers_->buffers()) {
absl::string_view name =
absl::string_view(buffer->name()->c_str(), buffer->name()->size());
name_to_offset_size_[name] =
std::make_pair(buffer->offset(), buffer->size());
}
is_finalized_ = true;
return absl::OkStatus();
}
absl::Status PackWeightsCache::Append(absl::string_view filename,
absl::string_view data) {
return mediapipe::file::AppendStringToFile(filename, data);
}
absl::Status PackWeightsCache::Prepend(absl::string_view filename,
absl::string_view data) {
// Append `data` to the end of the file to ensure the file is large enough.
// Then move chunk_size of bytes towards the end of the file each time.
// Finally copy `data` to position 0 of the file.
MP_RETURN_IF_ERROR(Append(filename, data));
auto mmap_file = GetMmapFile(filename);
RET_CHECK(mmap_file);
size_t src_offset = mmap_file->length() - data.size();
do {
size_t chunk_size = std::min(src_offset, data.size());
src_offset -= chunk_size;
memcpy(static_cast<char*>(mmap_file->data()) + src_offset + data.size(),
static_cast<char*>(mmap_file->data()) + src_offset, chunk_size);
} while (src_offset > 0);
memcpy(mmap_file->data(), data.data(), data.size());
return absl::OkStatus();
}
absl::Status PackWeightsCache::Append(absl::string_view data) {
return Append(cache_path_, data);
}
absl::Status PackWeightsCache::Prepend(absl::string_view data) {
return Prepend(cache_path_, data);
}
size_t PackWeightsCache::look_up(
PackWeightsCache* context, const xnn_weights_cache_look_up_key* cache_key) {
ABSL_CHECK(cache_key);
// TODO: b/319561597 - take seed and bias into consideration.
if (auto entry = context->kernel_to_name_.find(cache_key->kernel);
entry != context->kernel_to_name_.end()) {
absl::string_view name = entry->second;
if (auto entry = context->name_to_offset_size_.find(name);
entry != context->name_to_offset_size_.end() &&
!context->ShouldDoubleCheckCompatibility(cache_key)) {
return entry->second.first;
}
}
return SIZE_MAX;
}
void* PackWeightsCache::reserve_space(PackWeightsCache* context, size_t n) {
context->tmp_buffer_to_pack_weight_.resize(n);
return context->tmp_buffer_to_pack_weight_.data();
}
size_t PackWeightsCache::look_up_or_insert(
PackWeightsCache* context, const xnn_weights_cache_look_up_key* cache_key,
void* ptr, size_t size) {
ABSL_CHECK(cache_key);
if (context->key_sent_for_double_check_.has_value() &&
*cache_key == *context->key_sent_for_double_check_) {
size_t ref_offset = look_up(context, cache_key);
void* ref_ptr = offset_to_addr(context, ref_offset);
if (0 == memcmp(ptr, ref_ptr, size)) {
return ref_offset;
}
const absl::string_view error_message =
"Packed weights is different from cache, it's likely the cache is out "
"dated.";
ABSL_LOG(DFATAL) << error_message;
context->error_status_ = absl::FailedPreconditionError(error_message);
return SIZE_MAX;
}
size_t offset = look_up(context, cache_key);
if (offset != SIZE_MAX) {
return offset;
}
if (!context->builder_) {
const absl::string_view error_message =
"insersion is not supported for an existing cache, consider clear and "
"rebuild the cache.";
ABSL_LOG(DFATAL) << error_message;
context->error_status_ = absl::FailedPreconditionError(error_message);
return SIZE_MAX;
}
if (auto entry = context->kernel_to_name_.find(cache_key->kernel);
entry != context->kernel_to_name_.end()) {
absl::string_view name = entry->second;
size_t offset = context->blob_size_;
context->name_to_offset_size_[name] = std::make_pair(offset, size);
if (auto s =
context->Append(absl::string_view(static_cast<char*>(ptr), size));
!s.ok()) {
return SIZE_MAX;
}
context->blob_size_ += size;
return offset;
}
return SIZE_MAX;
}
bool PackWeightsCache::is_finalized(PackWeightsCache* context) {
return context->is_finalized_;
}
void* PackWeightsCache::offset_to_addr(PackWeightsCache* context,
size_t offset) {
ABSL_DCHECK(is_finalized(context));
ABSL_DCHECK(!context->builder_);
uint32_t fb_size = context->named_buffers_->flatbuffer_size();
void* r = static_cast<char*>(context->mmap_file_->data()) + fb_size + offset;
return r;
}
absl::StatusOr<std::shared_ptr<Tensor>>
WeightAccessorCompositeWithCache::LoadWeight(absl::string_view tensor_name,
Tensor::DimsType expected_dims,
size_t dim_scale_if_any) const {
MP_ASSIGN_OR_RETURN(
auto r, accessor_->LoadWeight(tensor_name, std::move(expected_dims),
dim_scale_if_any));
// Some weights are not defined in some models and should be left empty.
if (r == nullptr) {
return r;
}
MP_RETURN_IF_ERROR(weights_cache_->AddUnpackedWeight(tensor_name, r));
return r;
}
absl::StatusOr<std::shared_ptr<Tensor>>
WeightAccessorCompositeWithCache::LoadTransposedWeight(
absl::string_view tensor_name, Tensor::DimsType expected_dims,
size_t dim_scale_if_any) const {
MP_ASSIGN_OR_RETURN(
auto r, accessor_->LoadTransposedWeight(
tensor_name, std::move(expected_dims), dim_scale_if_any));
// Some weights are not defined in some models and should be left empty.
if (r == nullptr) {
return r;
}
MP_RETURN_IF_ERROR(weights_cache_->AddUnpackedWeight(tensor_name, r));
return r;
}
} // namespace mediapipe::tasks::genai::xnn_utils