chromium/third_party/mediapipe/src/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/tflite_weight_accessor.cc

// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/tflite_weight_accessor.h"

#include <cstddef>
#include <memory>
#include <string>
#include <utility>

#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "flatbuffers/buffer.h"
#include "flatbuffers/vector.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/memory_mapped_file.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "xnnpack.h"  // from @XNNPACK

namespace mediapipe::tasks::genai::xnn_utils {

TfLiteWeightAccessor::TfLiteWeightAccessor(
    std::shared_ptr<const tflite::Model> tflite_model, char* data)
    : tflite_model_(tflite_model) {
  BuildWeightsMapFromTfliteModel(data);
}

TfLiteWeightAccessor::TfLiteWeightAccessor(absl::string_view filename) {
  std::shared_ptr<llm_utils::MemoryMappedFile> mmap_file =
      llm_utils::MemoryMappedFile::Create(filename).value_or(nullptr);
  if (mmap_file) {
    tflite_model_ = std::shared_ptr<const ::tflite::Model>(
        mmap_file, ::tflite::GetModel(mmap_file->data()));
    BuildWeightsMapFromTfliteModel(static_cast<char*>(mmap_file->data()));
  }
}

void TfLiteWeightAccessor::BuildWeightsMapFromTfliteModel(char* data) {
  if (!tflite_model_) return;
  const flatbuffers::Vector<flatbuffers::Offset<::tflite::Buffer>>& buffers =
      *tflite_model_->buffers();
  for (const tflite::SubGraph* subgraph : *tflite_model_->subgraphs()) {
    for (const tflite::Tensor* tfl_tensor : *subgraph->tensors()) {
      absl::string_view tensor_name = absl::string_view(
          tfl_tensor->name()->data(), tfl_tensor->name()->size());
      Tensor::DimsType dims(tfl_tensor->shape()->begin(),
                            tfl_tensor->shape()->end());
      std::shared_ptr<Tensor> tensor;
      ABSL_DCHECK_LT(tfl_tensor->buffer(), buffers.size());
      const tflite::Buffer& tfl_buffer = *buffers.Get(tfl_tensor->buffer());
      switch (tfl_tensor->type()) {
        case tflite::TensorType::TensorType_FLOAT32:
          tensor = std::make_shared<Tensor>(std::move(dims), xnn_datatype_fp32);
          ABSL_DCHECK_EQ(tfl_buffer.size(),
                         tensor->num_elements * sizeof(float));
          tensor->flat_data =
              std::shared_ptr<char>(tflite_model_, data + tfl_buffer.offset());
          break;
        case tflite::TensorType::TensorType_INT8:
          tensor = std::make_shared<QCTensor>(std::move(dims), /*dim_scale=*/0,
                                              xnn_datatype_qcint8);
          ABSL_DCHECK_EQ(tfl_buffer.size(), tensor->num_elements);
          tensor->flat_data =
              std::shared_ptr<char>(tflite_model_, data + tfl_buffer.offset());
          break;
        case tflite::TensorType::TensorType_INT4:
          tensor = std::make_shared<QCTensor>(std::move(dims), /*dim_scale=*/0,
                                              xnn_datatype_qcint4);
          ABSL_DCHECK_EQ(tfl_buffer.size(), tensor->num_elements / 2);
          tensor->flat_data =
              std::shared_ptr<char>(tflite_model_, data + tfl_buffer.offset());
          break;
        default:
          ABSL_LOG(DFATAL) << "Unsupported tensor type: " << tfl_tensor->type();
          break;
      }
      weights_[tensor_name] = tensor;
    }
  }
}

absl::StatusOr<std::shared_ptr<Tensor>> TfLiteWeightAccessor::LoadWeight(
    absl::string_view tensor_name, Tensor::DimsType expected_dims,
    size_t dim_scale_if_any) const {
  RET_CHECK(tflite_model_);
  if (!weights_.contains(tensor_name)) {
    ABSL_DLOG(WARNING) << "Tensor not found: " << tensor_name;
    return nullptr;
  }
  // qtensor means potentially quantized tensor.
  std::shared_ptr<Tensor> qtensor = weights_.at(tensor_name);
  // Check dimensions.
  {
    bool correct_dimension = true;
    correct_dimension &= (qtensor->dims.size() == expected_dims.size());
    for (int i = 0; i < qtensor->dims.size(); ++i) {
      correct_dimension &= (qtensor->dims[i] == expected_dims[i]);
    }
    if (!correct_dimension) {
      return absl::InvalidArgumentError(absl::StrCat(
          "Dimension mismatch at ", tensor_name, " with expected dimensions [",
          absl::StrJoin(expected_dims, ", "), "] actual dimensions [",
          absl::StrJoin(qtensor->dims, ", "), "]"));
    }
  }
  if (qtensor->datatype == xnn_datatype_fp32) {
    return qtensor;
  }

  // Following are logic for quantized weights.
  std::shared_ptr<QCTensor> result;
  std::string scale_tensor_name =
      absl::StrCat(tensor_name, kQuantizedScaleSuffix);
  if (!weights_.contains(scale_tensor_name)) {
    return absl::NotFoundError(
        absl::StrCat("Scale tensor not found: ", scale_tensor_name));
  }
  std::shared_ptr<Tensor> scale_tensor = weights_.at(scale_tensor_name);
  RET_CHECK_EQ(expected_dims[dim_scale_if_any], scale_tensor->num_elements);
  switch (qtensor->datatype) {
    case xnn_datatype_qcint8:
      result = std::make_shared<QCTensor>(
          std::move(expected_dims), dim_scale_if_any, xnn_datatype_qcint8);
      result->flat_data = qtensor->flat_data;
      result->scale_data = std::shared_ptr<float>(
          scale_tensor->flat_data,
          reinterpret_cast<float*>(scale_tensor->flat_data.get()));
      break;
    case xnn_datatype_qcint4:
      result = std::make_shared<QCTensor>(
          std::move(expected_dims), dim_scale_if_any, xnn_datatype_qcint4);
      result->flat_data = qtensor->flat_data;
      result->scale_data = std::shared_ptr<float>(
          scale_tensor->flat_data,
          reinterpret_cast<float*>(scale_tensor->flat_data.get()));
      break;
    default:
      return absl::InvalidArgumentError("Unsupported tensor type");
      break;
  }
  return result;
}

absl::StatusOr<std::shared_ptr<Tensor>>
TfLiteWeightAccessor::LoadTransposedWeight(absl::string_view tensor_name,
                                           Tensor::DimsType expected_dims,
                                           size_t dim_scale_if_any) const {
  RET_CHECK(tflite_model_);
  RET_CHECK_EQ(expected_dims.size(), 2);
  return LoadWeight(
      tensor_name,
      Tensor::DimsType(expected_dims.rbegin(), expected_dims.rend()),
      1 - dim_scale_if_any);
}

}  // namespace mediapipe::tasks::genai::xnn_utils