chromium/third_party/mediapipe/src/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.h

// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_XNN_UTILS_XNN_TENSOR_H_
#define MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_XNN_UTILS_XNN_TENSOR_H_

#include <fcntl.h>

#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <numeric>
#include <optional>
#include <ostream>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/logging.h"
#include "xnnpack.h"  // from @XNNPACK

namespace mediapipe::tasks::genai {
namespace xnn_utils {

static constexpr absl::string_view kQuantizedScaleSuffix{"_quantized_scale"};
static constexpr absl::string_view kSparsityParamsSuffix{"_sparsity_params"};

struct Tensor {
  using DimsType = std::vector<size_t>;

  explicit Tensor(DimsType in_dims, xnn_datatype datatype_ = xnn_datatype_fp32,
                  bool is_sparse_ = false)
      : datatype(datatype_),
        internal_dims(std::move(in_dims)),
        internal_num_elements(dims.empty()
                                  ? 0
                                  : std::accumulate(std::begin(dims),
                                                    std::end(dims), size_t(1),
                                                    std::multiplies<size_t>())),
        is_sparse_tensor(is_sparse_) {
    elements_capacity = internal_num_elements;
  }
  Tensor(Tensor&& other) = default;

  Tensor& operator=(const Tensor& other) = delete;
  Tensor& operator=(Tensor&& other) = default;

  virtual ~Tensor() = default;

  bool operator==(const Tensor& other) const;

  void SetMetadata(absl::string_view key, int value) { metadata[key] = value; }

  std::optional<int> GetMetadata(absl::string_view key) const;
  int GetMetadata(absl::string_view key, int default_value) const;

  // Add the tensor into subgraph.
  absl::Status DefineAsInput(xnn_subgraph& subgraph);
  absl::Status DefineAsOutput(xnn_subgraph& subgraph);
  absl::Status DefineAsIntermediateTensor(xnn_subgraph& subgraph);
  virtual absl::Status DefineWeight(xnn_subgraph& subgraph, uint32_t flags);
  absl::Status DefineWeight(xnn_subgraph& subgraph);

  // Load the tensor from buffer, assuming the buffer is long enough.
  absl::Status LoadFromBuffer(const void* buffer);
  // Load the tensor from vector of data. If not exact_match, data can hold less
  // than num_elements.
  absl::Status LoadFromVec(const std::vector<float>& data,
                           bool exact_match = false);
  // Load the tensor from file.
  //   file_path: a string representing the path to the file to load from.
  //   use_mmap: whether or not to use mmap to access the file.
  //   exact_match: if true, the file should contain exactly num_elements of
  //       data.
  virtual absl::Status LoadFromFile(absl::string_view file_path, bool use_mmap,
                                    bool exact_match);
  absl::Status LoadFromFile(absl::string_view file_path) {
    return LoadFromFile(file_path, true, true);
  }

  // Dump the tensor to buffer, assuming the buffer is long enough.
  absl::Status DumpToBuffer(void* buffer);
  // Dump the tensor to vector. If exact_match is set to false, out_data may be
  // resized.
  absl::Status DumpToVec(std::vector<float>& out_data, bool exact_match = true);
  // Dump the tensor to a file specified by file_path.
  virtual absl::Status DumpToFile(absl::string_view file_path);

  // If i'th offset is 0, view's ith dim equals to original i'th dim,
  // otherwise 1. e.g. Tensor[A,B,C,D].Slice([0,b,0,0]) returns a tensor of
  // shape [A,1,C,D].
  std::shared_ptr<Tensor> Slice(DimsType offset);
  // Slice along the `index`th dimension, offset at this dimension.
  virtual std::shared_ptr<Tensor> Slice(size_t index, size_t offset);
  // Slice along the `index`th dimension from index start to index end. e.g.
  // Tensor[A,B,C,D].Slice(1, 0, 5) returns a tensor of shape [A,5,C,D].
  virtual std::shared_ptr<Tensor> Slice(size_t index, size_t start, size_t end);

  // Point the underline data to the borrowed tensor's data.
  Tensor& Borrow(std::shared_ptr<Tensor>, size_t element_offset = 0);

  Tensor& Resize(DimsType new_dims);

  // Hint that this is an output of the graph.
  Tensor& MarkOutput() {
    AllocateBufferIfNeeded();
    is_output_tensor = true;
    return *this;
  }

  // Access the tensor data.
  virtual void* Data();
  const void* Data() const;

  // Access the tensor data as certain type.
  template <typename T>
  T* DataAs() {
    ABSL_DCHECK_EQ(ElementSize(1), sizeof(T));
    return static_cast<T*>(Data());
  }
  template <typename T>
  const T* DataAs() const {
    return static_cast<const T*>(Data());
  }

  // Transpose the tensor.
  virtual std::shared_ptr<Tensor> Transpose();

  // Print the tensor values. Tensors with dims > 4 unsupported.
  void PrintSpan();

  // Convert the tensor to f32 format.
  virtual absl::StatusOr<std::shared_ptr<Tensor>> ConvertToF32();

  // Convert the tensor to ::mediapipe::Tensor.
  virtual absl::StatusOr<::mediapipe::Tensor> ConvertToMediapipeTensor();

  // Indicates whether the tensor data is sparse i.e. contains a lot of zeros.
  bool is_sparse() const { return is_sparse_tensor; }

  // Check if the tensor is close to the expected tensor, only used in test.
  absl::Status IsCloseTo(const Tensor& expected_tensor, float atol = 0,
                         float rtol = 2e-3);

  const DimsType& dims = internal_dims;
  const size_t& num_elements = internal_num_elements;
  const xnn_datatype datatype = xnn_datatype_invalid;

  // Get and set id to a given subgraph.
  uint32_t tensor_id(xnn_subgraph_t);
  void set_tensor_id(xnn_subgraph_t, uint32_t id);

  // shared_ptr to make TensorMetadata copyable.
  std::shared_ptr<char> flat_data;
  size_t elements_capacity = 0;

  // Optional, annotates where the tensor comes from. E.g. the filename where
  // the tensor is loaded from.
  std::string tag;

 protected:
  friend class XnnGraphBuilder;
  friend class XnnGraph;
  friend std::ostream& operator<<(std::ostream& os, const Tensor& tensor);

  // Invoke xnn_define_*tensor_value to add this tensor to the `subgraph`.
  virtual absl::Status DefineInSubgraph(xnn_subgraph& subgraph, uint32_t flags);

  // Actually allocate buffer unless necessary.
  virtual void AllocateBufferIfNeeded();

  virtual size_t ElementSize(size_t num_elements) const {
    return num_elements * 4;
  }

  DimsType internal_dims;
  size_t internal_num_elements;

  bool is_output_tensor = false;
  bool is_sparse_tensor = false;

  absl::flat_hash_map<std::string, int> metadata;

  // The same tensor can be used in multiple subgraphs, this is mapping from
  // subgraph to a per-subgraph id.
  absl::flat_hash_map<xnn_subgraph_t, uint32_t> map_subgraph_to_tensor_id;
};

std::ostream& operator<<(std::ostream& os, const Tensor& tensor);

// Channelwise Quantized.
struct QCTensor : public Tensor {
  // in_dims[dim_scale_] == dims of scale data.
  QCTensor(DimsType in_dims, size_t dim_scale_,
           xnn_datatype datatype_ = xnn_datatype_qcint8,
           bool is_sparse_ = false)
      : Tensor(std::move(in_dims), datatype_, is_sparse_),
        dim_scale(dim_scale_) {
    ABSL_CHECK_LT(dim_scale, 4);
    if (datatype == xnn_datatype_qcint4) {
      zero_point = 8;
    } else {
      zero_point = 0;
    }
  }

  void AllocateBufferIfNeeded() override;
  size_t ElementSize(size_t num_elements) const override {
    switch (datatype) {
      case xnn_datatype_qcint8:
        return num_elements;
      case xnn_datatype_qcint4:
        return (num_elements + 1) / 2;
      default:
        ABSL_LOG(FATAL) << "Unsupported datatype: " << datatype;
        return 0;
    }
  }

  virtual absl::Status LoadFromFile(absl::string_view quantized_weight_filename,
                                    absl::string_view scale_filename,
                                    bool use_mmap, bool exact_match);
  // Append kQuantizedScaleSuffix to use as scale filename.
  absl::Status LoadFromFile(absl::string_view file_path, bool use_mmap,
                            bool exact_match) override {
    return LoadFromFile(file_path,
                        absl::StrCat(file_path, kQuantizedScaleSuffix),
                        use_mmap, exact_match);
  }

  absl::Status DumpToFile(absl::string_view file_path) override;

  absl::Status DefineWeight(xnn_subgraph& subgraph, uint32_t flags) override;

  std::shared_ptr<Tensor> Transpose() override;

  absl::StatusOr<std::shared_ptr<Tensor>> ConvertToF32() override;

  std::shared_ptr<Tensor> Slice(size_t index, size_t offset) override;

  std::shared_ptr<float> scale_data;
  // Index of the dimension to scale.
  size_t dim_scale;
  int32_t zero_point;

 private:
  friend std::ostream& operator<<(std::ostream& os, const QCTensor& tensor);
};

std::ostream& operator<<(std::ostream& os, const QCTensor& tensor);

// Interface to access weights. The interface allows e.g. benchmark test to
// return random-initialized weights content, without preparing real weights.
class WeightAccessor {
 public:
  virtual ~WeightAccessor() = default;

  // Load a static weight tensor according to tensor name. The loader tries the
  // best to ensure the dimensions match expected dimension.
  virtual absl::StatusOr<std::shared_ptr<Tensor>> LoadWeight(
      absl::string_view, Tensor::DimsType, size_t dim_scale_if_any) const = 0;
  absl::StatusOr<std::shared_ptr<Tensor>> LoadWeight(
      absl::string_view filename_prefix, Tensor::DimsType expected_dims) const {
    return LoadWeight(filename_prefix, std::move(expected_dims), 0);
  }

  // Load weight, then transpose before return.
  virtual absl::StatusOr<std::shared_ptr<Tensor>> LoadTransposedWeight(
      absl::string_view, Tensor::DimsType, size_t dim_scale_if_any) const = 0;
};

// May be attached to an LLM graph as a side input to override how weights are
// accessed.
using WeightAccessorProvider = std::function<std::unique_ptr<WeightAccessor>()>;

}  // namespace xnn_utils
}  // namespace mediapipe::tasks::genai

#endif  // MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_XNN_UTILS_XNN_TENSOR_H_