chromium/third_party/mediapipe/src/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.cc

// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.h"

#include <fcntl.h>

#include <cmath>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <functional>
#include <memory>
#include <numeric>
#include <optional>
#include <ostream>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/genai/inference/common/mdspan.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/utils.h"
#include "xnnpack.h"  // from @XNNPACK

namespace mediapipe::tasks::genai {
namespace xnn_utils {
namespace {

// Same as numpy isclose()
bool IsClose(float actual, float expected, float atol, float rtol) {
  float tolerance = std::abs(expected * rtol) + std::abs(atol);
  float diff = std::abs(actual - expected);
  return diff <= tolerance;
}

}  // namespace

std::ostream& operator<<(std::ostream& os,
                         const absl::flat_hash_map<std::string, int> map) {
  os << "{";
  int cnt = 0;
  for (const auto& [key, value] : map) {
    os << key << ":" << value;
    if (cnt++ == 0 && map.size() != 1) {
      os << ", ";
    }
  }
  os << "}";
  return os;
}

std::ostream& operator<<(std::ostream& os, const Tensor& tensor) {
  os << "Tensor{dims=[" << tensor.dims << "], datatype=" << tensor.datatype
     << ", num_elements=" << tensor.num_elements
     << ", metadata=" << tensor.metadata;
  if (!tensor.tag.empty()) {
    os << ", tag=" << tensor.tag;
  }
  os << "}";
  return os;
}

std::ostream& operator<<(std::ostream& os, const QCTensor& tensor) {
  os << "QCTensor{dims=[" << tensor.dims << "], dim_scale=" << tensor.dim_scale
     << " datatype=" << tensor.datatype
     << ", num_elements=" << tensor.num_elements
     << ", metadata=" << tensor.metadata << "}";
  return os;
}

bool Tensor::operator==(const Tensor& other) const {
  if (dims.size() != other.dims.size()) {
    return false;
  } else if (datatype != other.datatype) {
    return false;
  } else {
    for (size_t i = 0; i < dims.size(); ++i) {
      if (dims[i] != other.dims[i]) {
        return false;
      }
    }
  }
  return 0 == memcmp(Data(), other.Data(), ElementSize(num_elements));
}

std::optional<int> Tensor::GetMetadata(absl::string_view key) const {
  if (metadata.contains(key)) {
    return metadata.at(key);
  }
  return std::nullopt;
}

int Tensor::GetMetadata(absl::string_view key, int default_value) const {
  if (metadata.contains(key)) {
    return metadata.at(key);
  }
  return default_value;
}

void Tensor::AllocateBufferIfNeeded() {
  if (!flat_data) {
    auto real_buffer = std::make_shared<std::vector<char>>(
        ElementSize(num_elements) + XNN_EXTRA_BYTES, 0x00);
    elements_capacity = num_elements;
    flat_data = std::shared_ptr<char>(real_buffer, real_buffer->data());
  }
}

void* Tensor::Data() {
  ABSL_DCHECK(flat_data)
      << "If this is weight, you may need to call one of the LoadFrom*()";
  return flat_data.get();
}

std::shared_ptr<Tensor> Tensor::Slice(DimsType offset) {
  ABSL_DCHECK(flat_data);
  ABSL_CHECK_EQ(offset.size(), dims.size()) << offset << " vs. " << dims;
  // offset: [0, k, 0, 0], dims: [1, K, _, _]. dims before k must be 1.
  bool found_non_zero_offset = false;
  int index_k = -1;
  for (int i = 0; i < dims.size(); ++i) {
    if (found_non_zero_offset) {
      ABSL_DCHECK_EQ(offset[i], 0);
    } else if (offset[i] != 0) {
      found_non_zero_offset = true;
      index_k = i;
    }
  }
  ABSL_DCHECK(found_non_zero_offset) << offset;

  return Slice(index_k, offset[index_k]);
}

void Tensor::PrintSpan() {
  if (dims.size() == 1) {
    ABSL_LOG(INFO) << MakeMdSpan(DataAs<float>(), dims[0]);
  } else if (dims.size() == 2) {
    ABSL_LOG(INFO) << MakeMdSpan(DataAs<float>(), dims[0], dims[1]);
  } else if (dims.size() == 3) {
    ABSL_LOG(INFO) << MakeMdSpan(DataAs<float>(), dims[0], dims[1], dims[2]);
  } else if (dims.size() == 4) {
    ABSL_LOG(INFO) << MakeMdSpan(DataAs<float>(), dims[0], dims[1], dims[2],
                                 dims[3]);
  } else {
    ABSL_LOG(FATAL) << "Unsupported dims size: " << dims.size();
  }
}

std::shared_ptr<Tensor> Tensor::Slice(size_t index, size_t offset) {
  size_t num_elements_offset = 1;
  DimsType new_dim = dims;
  for (int i = 0; i < dims.size(); ++i) {
    if (i < index) {
      ABSL_DCHECK_EQ(dims[i], 1);
    } else if (i == index) {
      ABSL_DCHECK_LT(offset, dims[i]) << "i = " << i;
      num_elements_offset *= offset;
      new_dim[i] = 1;
    } else {
      num_elements_offset *= dims[i];
    }
  }

  auto result =
      std::make_shared<Tensor>(std::move(new_dim), datatype, is_sparse());
  result->flat_data = std::shared_ptr<char>(
      flat_data, flat_data.get() + ElementSize(num_elements_offset));
  result->elements_capacity = result->num_elements;
  return result;
}

std::shared_ptr<Tensor> Tensor::Slice(size_t index, size_t start, size_t end) {
  size_t num_elements_offset = 1;
  DimsType new_dim = dims;
  for (int i = 0; i < dims.size(); ++i) {
    if (i < index) {
      ABSL_DCHECK_EQ(dims[i], 1);
    } else if (i == index) {
      ABSL_DCHECK_LT(start, end);
      ABSL_DCHECK_LE(end, dims[i]);
      num_elements_offset *= start;
      new_dim[i] = end - start;
    } else {
      num_elements_offset *= dims[i];
    }
  }

  auto result =
      std::make_shared<Tensor>(std::move(new_dim), datatype, is_sparse());
  result->flat_data = std::shared_ptr<char>(
      flat_data, flat_data.get() + ElementSize(num_elements_offset));
  result->elements_capacity = result->num_elements;
  return result;
}

Tensor& Tensor::Borrow(std::shared_ptr<Tensor> other, size_t element_offset) {
  ABSL_DCHECK_EQ(datatype, other->datatype);
  ABSL_DCHECK_EQ(dims.size(), other->dims.size());
  flat_data = std::shared_ptr<char>(
      other->flat_data, other->flat_data.get() + ElementSize(element_offset));
  elements_capacity = other->elements_capacity - element_offset;
  return *this;
}

Tensor& Tensor::Resize(DimsType new_dims) {
  ABSL_DCHECK(!new_dims.empty());
  const size_t old_num_elements = num_elements;
  internal_dims = std::move(new_dims);
  internal_num_elements = std::accumulate(dims.begin(), dims.end(), size_t(1),
                                          std::multiplies<size_t>());
  ABSL_DCHECK_NE(internal_num_elements, 0);
  if (num_elements > elements_capacity) {
    auto old_flat_data = std::move(flat_data);
    AllocateBufferIfNeeded();
    memcpy(Data(), old_flat_data.get(), ElementSize(old_num_elements));
  }
  return *this;
}

const void* Tensor::Data() const { return const_cast<Tensor*>(this)->Data(); }

absl::Status Tensor::DefineInSubgraph(xnn_subgraph& subgraph, uint32_t flags) {
  uint32_t id;
  switch (datatype) {
    case xnn_datatype_fp32: {
      RET_CHECK_EQ(xnn_status_success,
                   xnn_define_tensor_value(
                       &subgraph, datatype, dims.size(), dims.data(),
                       /*data=*/nullptr,
                       /*external_id=*/tensor_id(&subgraph), flags, &id));
      break;
    }
    case xnn_datatype_qdint8: {
      // Set num_non_batch_dims=1, the last dim is # of channels, the other dims
      // are flattened and treated as batch size.
      RET_CHECK_EQ(xnn_status_success,
                   xnn_define_dynamically_quantized_tensor_value(
                       &subgraph, datatype, dims.size(),
                       /*num_non_batch_dims=*/1, dims.data(),
                       /*external_id=*/tensor_id(&subgraph), flags, &id))
          << dims;
      break;
    }
    default:
      return absl::InvalidArgumentError(
          absl::StrCat("Unsupported datatype: ", datatype));
  }
  if (tensor_id(&subgraph) == XNN_INVALID_VALUE_ID) {
    RET_CHECK_NE(id, XNN_INVALID_VALUE_ID);
    map_subgraph_to_tensor_id[&subgraph] = id;
  } else {
    RET_CHECK_EQ(id, tensor_id(&subgraph));
  }
  return absl::OkStatus();
}

absl::Status Tensor::DefineAsInput(xnn_subgraph& subgraph) {
  return DefineInSubgraph(subgraph, XNN_VALUE_FLAG_EXTERNAL_INPUT);
}

absl::Status Tensor::DefineAsOutput(xnn_subgraph& subgraph) {
  return DefineInSubgraph(subgraph, XNN_VALUE_FLAG_EXTERNAL_OUTPUT);
}

absl::Status Tensor::DefineAsIntermediateTensor(xnn_subgraph& subgraph) {
  RET_CHECK_EQ(tensor_id(&subgraph), XNN_INVALID_VALUE_ID);
  return DefineInSubgraph(subgraph, 0);
}

absl::Status Tensor::DefineWeight(xnn_subgraph& subgraph, uint32_t flags) {
  uint32_t assigned_tensor_id;
  RET_CHECK_EQ(xnn_status_success,
               xnn_define_tensor_value(
                   &subgraph, datatype, dims.size(), dims.data(), Data(),
                   tensor_id(&subgraph), flags, &assigned_tensor_id));
  RET_CHECK_NE(assigned_tensor_id, XNN_INVALID_VALUE_ID);
  map_subgraph_to_tensor_id[&subgraph] = assigned_tensor_id;
  return absl::OkStatus();
}

absl::Status Tensor::DefineWeight(xnn_subgraph& subgraph) {
  RET_CHECK_EQ(tensor_id(&subgraph), XNN_INVALID_VALUE_ID);
  return DefineWeight(subgraph, 0);
}

uint32_t Tensor::tensor_id(xnn_subgraph_t subgraph) {
  if (map_subgraph_to_tensor_id.contains(subgraph)) {
    return map_subgraph_to_tensor_id.at(subgraph);
  }
  return XNN_INVALID_VALUE_ID;
}

void Tensor::set_tensor_id(xnn_subgraph_t subgraph, uint32_t id) {
  map_subgraph_to_tensor_id[subgraph] = id;
}

absl::Status Tensor::LoadFromBuffer(const void* buffer) {
  AllocateBufferIfNeeded();
  memcpy(Data(), buffer, ElementSize(num_elements));
  return absl::OkStatus();
}

absl::Status Tensor::LoadFromVec(const std::vector<float>& data,
                                 bool exact_match) {
  AllocateBufferIfNeeded();
  if (exact_match) {
    RET_CHECK_EQ(ElementSize(num_elements), data.size() * sizeof(float));
  }

  memcpy(Data(), data.data(), data.size() * sizeof(float));

  return absl::OkStatus();
}

absl::Status Tensor::DumpToBuffer(void* buffer) {
  memcpy(buffer, Data(), ElementSize(num_elements));
  return absl::OkStatus();
}

absl::Status Tensor::DumpToVec(std::vector<float>& out_data, bool exact_match) {
  if (exact_match) {
    RET_CHECK_EQ(ElementSize(num_elements), out_data.size() * sizeof(float));
  } else {
    out_data.resize(num_elements);
  }
  memcpy(out_data.data(), Data(), ElementSize(num_elements));
  return absl::OkStatus();
}

absl::Status Tensor::DumpToFile(absl::string_view file_path) {
  return mediapipe::file::SetContents(
      file_path, absl::string_view(flat_data.get(), ElementSize(num_elements)));
}

absl::Status Tensor::LoadFromFile(absl::string_view file_path, bool use_mmap,
                                  bool exact_match) {
  const size_t expected_size_in_bytes =
      exact_match ? ElementSize(num_elements) : 0;

  size_t buffer_size;
  MP_ASSIGN_OR_RETURN(auto tmp_flat_data,
                      LoadBufferFromFile(file_path, &buffer_size, use_mmap,
                                         expected_size_in_bytes));
  if (!flat_data) {
    flat_data = tmp_flat_data;
    elements_capacity = num_elements;
  } else {
    memcpy(flat_data.get(), tmp_flat_data.get(), buffer_size);
  }
  tag = mediapipe::file::Basename(file_path);

  return absl::OkStatus();
}

std::shared_ptr<Tensor> Tensor::Transpose() {
  ABSL_DCHECK_EQ(dims.size(), 2);
  DimsType out_dims{dims.rbegin(), dims.rend()};
  auto result =
      std::make_shared<Tensor>(std::move(out_dims), datatype, is_sparse());
  result->AllocateBufferIfNeeded();
  xnn_status s;
  const DimsType perm{1, 0};
  if (datatype == xnn_datatype_fp32) {
    s = xnn_run_transpose_nd_x32(Data(), result->Data(), dims.size(),
                                 dims.data(), perm.data(),
                                 /*flags=*/0, /*threadpool=*/nullptr);
  } else {
    ABSL_LOG(FATAL) << "Need update to support new type";
  }
  ABSL_DCHECK_EQ(s, xnn_status_success);
  return (s == xnn_status_success) ? result : nullptr;
}

absl::StatusOr<std::shared_ptr<Tensor>> Tensor::ConvertToF32() {
  auto result = std::make_shared<Tensor>(dims, xnn_datatype_fp32, is_sparse());
  MP_RETURN_IF_ERROR(result->LoadFromBuffer(Data()));
  return result;
}

absl::StatusOr<::mediapipe::Tensor> Tensor::ConvertToMediapipeTensor() {
  RET_CHECK_EQ(datatype, xnn_datatype_fp32) << "Try ConvertToF32 then convert";
  ::mediapipe::Tensor mp_tensor(
      ::mediapipe::Tensor::ElementType::kFloat32,
      ::mediapipe::Tensor::Shape(std::vector<int>(dims.begin(), dims.end())));
  void* mp_tensor_buffer = mp_tensor.GetCpuWriteView().buffer<float>();
  std::memcpy(mp_tensor_buffer, Data(), ElementSize(num_elements));
  return mp_tensor;
}

absl::Status Tensor::IsCloseTo(const Tensor& expected_tensor, float atol,
                               float rtol) {
  RET_CHECK_EQ(datatype, xnn_datatype_fp32) << "Try ConvertToF32";
  RET_CHECK_EQ(dims.size(), expected_tensor.dims.size());
  for (int i = 0; i < dims.size(); ++i) {
    RET_CHECK_EQ(dims[i], expected_tensor.dims[i])
        << dims << " v.s. " << expected_tensor.dims;
  }

  const auto* actual = static_cast<const float*>(Data());
  const auto* expected = static_cast<const float*>(expected_tensor.Data());
  std::stringstream ss;

  size_t total_print = 0;
#define LOG_AND_COUNT() \
  ++total_print;        \
  ss << "\n" << i << ", expect: " << expected[i] << ", actual: " << actual[i];

  for (size_t i = 0; i < expected_tensor.num_elements; ++i) {
    if (std::isnan(actual[i]) || std::isnan(expected[i])) {
      LOG_AND_COUNT()
    } else if (!IsClose(actual[i], expected[i], atol, rtol)) {
      LOG_AND_COUNT()
    }
    if (total_print > 100) {
      ss << "\nand more...";
      return absl::InternalError(ss.str());
    }
  }
#undef LOG_AND_COUNT

  return absl::OkStatus();
}

absl::Status QCTensor::LoadFromFile(absl::string_view quantized_weight_filename,
                                    absl::string_view scale_filename,
                                    bool use_mmap, bool exact_match) {
  size_t scale_element_size = dims[dim_scale];

  size_t buffer_size, scale_buffer_size;
  MP_ASSIGN_OR_RETURN(
      auto tmp_flat_data,
      LoadBufferFromFile(quantized_weight_filename, &buffer_size, use_mmap,
                         exact_match ? ElementSize(num_elements) : 0));
  MP_ASSIGN_OR_RETURN(
      auto tmp_scale_data,
      LoadBufferFromFile<float>(
          scale_filename, &scale_buffer_size, use_mmap,
          exact_match ? scale_element_size * sizeof(float) : 0));
  if (!flat_data) {
    flat_data = tmp_flat_data;
    scale_data = tmp_scale_data;
    elements_capacity = num_elements;
  } else {
    memcpy(flat_data.get(), tmp_flat_data.get(), buffer_size);
    memcpy(scale_data.get(), tmp_scale_data.get(), scale_buffer_size);
  }
  tag = mediapipe::file::Basename(quantized_weight_filename);

  return absl::OkStatus();
}

absl::Status QCTensor::DumpToFile(absl::string_view file_path) {
  MP_RETURN_IF_ERROR(mediapipe::file::SetContents(
      file_path,
      absl::string_view(flat_data.get(), ElementSize(num_elements))));
  return mediapipe::file::SetContents(
      absl::StrCat(file_path, kQuantizedScaleSuffix),
      absl::string_view(reinterpret_cast<char*>(scale_data.get()),
                        dims[dim_scale] * sizeof(float)));
}

absl::Status QCTensor::DefineWeight(xnn_subgraph& subgraph, uint32_t flags) {
  uint32_t assigned_tensor_id;
  RET_CHECK_EQ(xnn_status_success,
               xnn_define_channelwise_quantized_tensor_value_v2(
                   &subgraph, datatype, zero_point, scale_data.get(),
                   dims.size(), dim_scale, dims.data(), Data(),
                   XNN_INVALID_VALUE_ID, flags, &assigned_tensor_id))
      << *this;
  RET_CHECK_NE(assigned_tensor_id, XNN_INVALID_VALUE_ID);
  map_subgraph_to_tensor_id[&subgraph] = assigned_tensor_id;
  return absl::OkStatus();
}

void QCTensor::AllocateBufferIfNeeded() {
  Tensor::AllocateBufferIfNeeded();
  if (!scale_data) {
    auto real_buffer = std::make_shared<std::vector<float>>();
    real_buffer->resize(dims[dim_scale]);
    scale_data = std::shared_ptr<float>(real_buffer, real_buffer->data());
  }
}

std::shared_ptr<Tensor> QCTensor::Transpose() {
  ABSL_DCHECK_EQ(dims.size(), 2);
  size_t channel_size = dims[dim_scale];
  DimsType out_dims{dims.rbegin(), dims.rend()};
  auto result = std::make_shared<QCTensor>(std::move(out_dims), 1 - dim_scale,
                                           datatype, is_sparse());
  result->zero_point = zero_point;
  result->AllocateBufferIfNeeded();
  memcpy(result->scale_data.get(), scale_data.get(),
         channel_size * sizeof(float));
  xnn_status s;
  const DimsType perm{1, 0};
  switch (datatype) {
    case xnn_datatype_qcint8:
      s = xnn_run_transpose_nd_x8(Data(), result->Data(), dims.size(),
                                  dims.data(), perm.data(),
                                  /*flags=*/0, /*threadpool=*/nullptr);
      break;
    case xnn_datatype_qcint4: {
      std::vector<uint8_t> unpacked =
          xnn_utils::UnpackInt8ToInt4(
              absl::Span<uint8_t>(static_cast<uint8_t*>(Data()),
                                  ElementSize(num_elements)))
              .value();
      std::vector<uint8_t> transposed_unpacked(unpacked.size());
      s = xnn_run_transpose_nd_x8(unpacked.data(), transposed_unpacked.data(),
                                  dims.size(), dims.data(), perm.data(),
                                  /*flags=*/0, /*threadpool=*/nullptr);
      std::vector<uint8_t> packed =
          xnn_utils::PackInt4ToInt8(
              absl::Span<uint8_t>(transposed_unpacked.data(),
                                  transposed_unpacked.size()))
              .value();
      ABSL_CHECK_OK(result->LoadFromBuffer(packed.data()));
      break;
    }
    default:
      ABSL_LOG(FATAL) << "Need update to support new type";
  }
  ABSL_DCHECK_EQ(s, xnn_status_success);
  return (s == xnn_status_success) ? result : nullptr;
}

absl::StatusOr<std::shared_ptr<Tensor>> QCTensor::ConvertToF32() {
  RET_CHECK_EQ(dims.size(), 2)
      << "QCTensor is usually weight for FullConn" << dims;
  auto result = std::make_shared<Tensor>(dims, xnn_datatype_fp32, is_sparse());
  MP_RETURN_IF_ERROR(result->LoadFromVec({}, /*exact_match=*/false));
  float* scaled_data = result->DataAs<float>();

  switch (datatype) {
    case xnn_datatype_qcint8: {
      auto* quantized_data = DataAs<int8_t>();
      for (size_t i = 0; i < dims[0]; ++i) {
        for (size_t j = 0; j < dims[1]; ++j) {
          float scale = dim_scale ? scale_data.get()[j] : scale_data.get()[i];
          *scaled_data = *quantized_data * scale;
          ++scaled_data;
          ++quantized_data;
        }
      }
      break;
    }
    case xnn_datatype_qcint4: {
      uint8_t* quantized_data = static_cast<uint8_t*>(Data());
      RET_CHECK_EQ(dims[1] % 2, 0);
      for (size_t i = 0; i < dims[0]; ++i) {
        for (size_t j = 0; j < dims[1] / 2; ++j) {
          {
            // first element
            float scale =
                dim_scale ? scale_data.get()[j * 2] : scale_data.get()[i];
            *scaled_data =
                (static_cast<int32_t>(*quantized_data & 0x0f) - zero_point) *
                scale;
            ++scaled_data;
          }
          {
            // second element
            float scale =
                dim_scale ? scale_data.get()[j * 2 + 1] : scale_data.get()[i];
            *scaled_data =
                (static_cast<int32_t>(*quantized_data >> 4) - zero_point) *
                scale;
            ++scaled_data;
          }
          ++quantized_data;
        }
      }
      break;
    }
    default: {
      return absl::InvalidArgumentError("Need update to support new type");
    }
  }

  return result;
}

std::shared_ptr<Tensor> QCTensor::Slice(size_t index, size_t offset) {
  ABSL_CHECK_LE(index, 1);
  ABSL_CHECK_EQ(index, dim_scale);

  std::shared_ptr<QCTensor> result;
  if (index == 0) {
    result = std::make_shared<QCTensor>(DimsType{1, dims[1]}, dim_scale,
                                        datatype, is_sparse());
    result->flat_data = std::shared_ptr<char>(
        flat_data, flat_data.get() + ElementSize(dims[1] * offset));
    result->scale_data = std::make_shared<float>(*(scale_data.get() + offset));
  } else {
    result = std::make_shared<QCTensor>(DimsType{dims[0], 1}, dim_scale,
                                        datatype, is_sparse());
    result->flat_data = std::shared_ptr<char>(
        flat_data, flat_data.get() + ElementSize(dims[1] * offset));
    result->scale_data = std::make_shared<float>(*(scale_data.get() + offset));
  }
  result->elements_capacity = result->num_elements;
  return result;
}

}  // namespace xnn_utils
}  // namespace mediapipe::tasks::genai