chromium/third_party/mediapipe/src/mediapipe/tasks/cc/genai/inference/utils/llm_utils/model_data.h

// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_LLM_UTILS_MODEL_DATA_H_
#define MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_LLM_UTILS_MODEL_DATA_H_

#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>

#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/genai/inference/proto/llm_params.pb.h"
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/memory_mapped_file.h"
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/scoped_file.h"
#include "tensorflow/lite/model_builder.h"

namespace mediapipe::tasks::genai::llm_utils {

// Provides access to data tied to an underlying resource. The resource may be
// released when this object is destroyed and spans previously returned from
// GetData() will no longer be valid.
template <typename T>
class DataHolder {
 public:
  virtual ~DataHolder() = default;

  virtual absl::Span<T> GetData() const = 0;

  // The underlying data may be at an offset into a buffer. This method gets the
  // underlying data with no offsets.
  virtual absl::Span<T> GetRawData() const { return GetData(); }
};

struct OffsetAndSize {
  uint64_t offset = 0;
  uint64_t size = 0;
};
// Gets an offset and size which will be valid to pass to MemoryMappedFile.
OffsetAndSize GetAlignedOffsetAndSize(uint64_t base_offset, uint64_t base_size);

// Creates a DataHolder by memory mapping `file`. `key` can be passed as an
// optimization when the same file is being mapped multiple times. It should be
// unique to `file`.
template <typename T>
absl::StatusOr<std::unique_ptr<DataHolder<T>>> CreateMemoryMappedDataHolder(
    ScopedFile::PlatformFile file, uint64_t offset = 0, uint64_t size = 0,
    absl::string_view key = "") {
  class MemoryMappedDataHolder : public DataHolder<T> {
   public:
    explicit MemoryMappedDataHolder(std::unique_ptr<MemoryMappedFile> region,
                                    uint64_t offset, uint64_t size)
        : region_(std::move(region)), offset_(offset), size_(size) {}
    ~MemoryMappedDataHolder() override = default;

    absl::Span<T> GetData() const override {
      return absl::MakeSpan(static_cast<T*>(region_->data()) + offset_, size_);
    }

    absl::Span<T> GetRawData() const override {
      return absl::MakeSpan(static_cast<T*>(region_->data()),
                            region_->length());
    }

   private:
    std::unique_ptr<MemoryMappedFile> region_;
    uint64_t offset_;
    uint64_t size_;
  };

  OffsetAndSize offset_and_size;
  if (offset != 0 || size != 0) {
    offset_and_size = GetAlignedOffsetAndSize(offset, size);
  }
  MP_ASSIGN_OR_RETURN(auto region,
                      MemoryMappedFile::Create(file, offset_and_size.offset,
                                               offset_and_size.size, key));
  if (size == 0) {
    size = region->length();
  }
  return std::make_unique<MemoryMappedDataHolder>(
      std::move(region), offset - offset_and_size.offset, size);
}

// This class is responsible for accessing the underlying model data and
// abstracting out any differences in file formats.
class ModelData {
 public:
  // Loads from a single tflite flatbuffer. The allocation should contain the
  // whole model including buffers.
  static absl::StatusOr<std::shared_ptr<ModelData>> Create(
      std::shared_ptr<tflite::FlatBufferModel> model);

  // Loads a tflite model from a file. This is more efficient than the above
  // method since the data can be read into memory as needed.
  static absl::StatusOr<std::shared_ptr<ModelData>> Create(ScopedFile file);

  // Similar to the above, but accept shared_ptr. The smart pointer is pointing
  // to a constant object, indicating there's only read access.
  static absl::StatusOr<std::shared_ptr<ModelData>> Create(
      std::shared_ptr<const ScopedFile> file);

  // Loads `ModelData` from the provided `weight_path`, which contains a tflite
  // file.
  static absl::StatusOr<std::shared_ptr<ModelData>> Create(
      absl::string_view weight_path);

  enum ReadMode {
    KEEP = 0,
    DISCARD = 1,
    DISCARD_ALL = 2,
  };
  using ReadDataFn =
      std::function<void*(uint64_t offset, uint64_t size, int mode)>;
  // Loads a tflite model using the passed `fn`, and reads buffers as needed.
  static absl::StatusOr<std::shared_ptr<ModelData>> Create(ReadDataFn fn);

  virtual ~ModelData() = default;

  // Get the type for the model. If a type is not specified by the model files,
  // std::nullopt will be returned.
  virtual std::optional<odml::infra::proto::LlmModelType> GetModelType() = 0;

  // Get the LoRA rank of the model, or std::nullopt if this is not a set of
  // LoRA weights.
  virtual std::optional<int> LoRARank() = 0;

  // Get the parameters to define the model.
  virtual const odml::infra::proto::LlmParameters& GetLlmParameters() = 0;

  // Read a metadata string about the model.
  virtual absl::StatusOr<std::string> ReadMetadata(absl::string_view name) = 0;

  // Returns the maximum tensor size for this model.
  virtual uint64_t GetMaxTensorSize() const = 0;

  // Gets the size of the tensor with `name` or 0 if it does not exist.
  virtual uint64_t GetTensorSize(absl::string_view name) const = 0;

  // Returns the tensor data of the tensor with `name`.
  virtual absl::StatusOr<std::unique_ptr<DataHolder<uint8_t>>> ReadTensor(
      absl::string_view name) = 0;

  // Frees the underlying data.
  virtual void Clear() = 0;
};

// Holds data referring to a set of LoRA weights.
struct LoRAData {
  // The ID used to refer to this LoRA.
  uint32_t id;

  // The weight data for this LoRA.
  std::shared_ptr<ModelData> model_data;
};

}  // namespace mediapipe::tasks::genai::llm_utils

#endif  // MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_LLM_UTILS_MODEL_DATA_H_