llvm/mlir/lib/Bindings/Python/DialectQuant.cpp

//===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir-c/Dialect/Quant.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include <cstdint>
#include <pybind11/cast.h>
#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>
#include <vector>

namespace py = pybind11;
using namespace llvm;
using namespace mlir;
using namespace mlir::python::adaptors;

static void populateDialectQuantSubmodule(const py::module &m) {
  //===-------------------------------------------------------------------===//
  // QuantizedType
  //===-------------------------------------------------------------------===//

  auto quantizedType =
      mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType);
  quantizedType.def_staticmethod(
      "default_minimum_for_integer",
      [](bool isSigned, unsigned integralWidth) {
        return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
                                                            integralWidth);
      },
      "Default minimum value for the integer with the specified signedness and "
      "bit width.",
      py::arg("is_signed"), py::arg("integral_width"));
  quantizedType.def_staticmethod(
      "default_maximum_for_integer",
      [](bool isSigned, unsigned integralWidth) {
        return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
                                                            integralWidth);
      },
      "Default maximum value for the integer with the specified signedness and "
      "bit width.",
      py::arg("is_signed"), py::arg("integral_width"));
  quantizedType.def_property_readonly(
      "expressed_type",
      [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); },
      "Type expressed by this quantized type.");
  quantizedType.def_property_readonly(
      "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); },
      "Flags of this quantized type (named accessors should be preferred to "
      "this)");
  quantizedType.def_property_readonly(
      "is_signed",
      [](MlirType type) { return mlirQuantizedTypeIsSigned(type); },
      "Signedness of this quantized type.");
  quantizedType.def_property_readonly(
      "storage_type",
      [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); },
      "Storage type backing this quantized type.");
  quantizedType.def_property_readonly(
      "storage_type_min",
      [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); },
      "The minimum value held by the storage type of this quantized type.");
  quantizedType.def_property_readonly(
      "storage_type_max",
      [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); },
      "The maximum value held by the storage type of this quantized type.");
  quantizedType.def_property_readonly(
      "storage_type_integral_width",
      [](MlirType type) {
        return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
      },
      "The bitwidth of the storage type of this quantized type.");
  quantizedType.def(
      "is_compatible_expressed_type",
      [](MlirType type, MlirType candidate) {
        return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
      },
      "Checks whether the candidate type can be expressed by this quantized "
      "type.",
      py::arg("candidate"));
  quantizedType.def_property_readonly(
      "quantized_element_type",
      [](MlirType type) {
        return mlirQuantizedTypeGetQuantizedElementType(type);
      },
      "Element type of this quantized type expressed as quantized type.");
  quantizedType.def(
      "cast_from_storage_type",
      [](MlirType type, MlirType candidate) {
        MlirType castResult =
            mlirQuantizedTypeCastFromStorageType(type, candidate);
        if (!mlirTypeIsNull(castResult))
          return castResult;
        throw py::type_error("Invalid cast.");
      },
      "Casts from a type based on the storage type of this quantized type to a "
      "corresponding type based on the quantized type. Raises TypeError if the "
      "cast is not valid.",
      py::arg("candidate"));
  quantizedType.def_staticmethod(
      "cast_to_storage_type",
      [](MlirType type) {
        MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
        if (!mlirTypeIsNull(castResult))
          return castResult;
        throw py::type_error("Invalid cast.");
      },
      "Casts from a type based on a quantized type to a corresponding type "
      "based on the storage type of this quantized type. Raises TypeError if "
      "the cast is not valid.",
      py::arg("type"));
  quantizedType.def(
      "cast_from_expressed_type",
      [](MlirType type, MlirType candidate) {
        MlirType castResult =
            mlirQuantizedTypeCastFromExpressedType(type, candidate);
        if (!mlirTypeIsNull(castResult))
          return castResult;
        throw py::type_error("Invalid cast.");
      },
      "Casts from a type based on the expressed type of this quantized type to "
      "a corresponding type based on the quantized type. Raises TypeError if "
      "the cast is not valid.",
      py::arg("candidate"));
  quantizedType.def_staticmethod(
      "cast_to_expressed_type",
      [](MlirType type) {
        MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
        if (!mlirTypeIsNull(castResult))
          return castResult;
        throw py::type_error("Invalid cast.");
      },
      "Casts from a type based on a quantized type to a corresponding type "
      "based on the expressed type of this quantized type. Raises TypeError if "
      "the cast is not valid.",
      py::arg("type"));
  quantizedType.def(
      "cast_expressed_to_storage_type",
      [](MlirType type, MlirType candidate) {
        MlirType castResult =
            mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
        if (!mlirTypeIsNull(castResult))
          return castResult;
        throw py::type_error("Invalid cast.");
      },
      "Casts from a type based on the expressed type of this quantized type to "
      "a corresponding type based on the storage type. Raises TypeError if the "
      "cast is not valid.",
      py::arg("candidate"));

  quantizedType.get_class().attr("FLAG_SIGNED") =
      mlirQuantizedTypeGetSignedFlag();

  //===-------------------------------------------------------------------===//
  // AnyQuantizedType
  //===-------------------------------------------------------------------===//

  auto anyQuantizedType =
      mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType,
                         quantizedType.get_class());
  anyQuantizedType.def_classmethod(
      "get",
      [](py::object cls, unsigned flags, MlirType storageType,
         MlirType expressedType, int64_t storageTypeMin,
         int64_t storageTypeMax) {
        return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
                                           storageTypeMin, storageTypeMax));
      },
      "Gets an instance of AnyQuantizedType in the same context as the "
      "provided storage type.",
      py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
      py::arg("expressed_type"), py::arg("storage_type_min"),
      py::arg("storage_type_max"));

  //===-------------------------------------------------------------------===//
  // UniformQuantizedType
  //===-------------------------------------------------------------------===//

  auto uniformQuantizedType = mlir_type_subclass(
      m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType,
      quantizedType.get_class());
  uniformQuantizedType.def_classmethod(
      "get",
      [](py::object cls, unsigned flags, MlirType storageType,
         MlirType expressedType, double scale, int64_t zeroPoint,
         int64_t storageTypeMin, int64_t storageTypeMax) {
        return cls(mlirUniformQuantizedTypeGet(flags, storageType,
                                               expressedType, scale, zeroPoint,
                                               storageTypeMin, storageTypeMax));
      },
      "Gets an instance of UniformQuantizedType in the same context as the "
      "provided storage type.",
      py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
      py::arg("expressed_type"), py::arg("scale"), py::arg("zero_point"),
      py::arg("storage_type_min"), py::arg("storage_type_max"));
  uniformQuantizedType.def_property_readonly(
      "scale",
      [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); },
      "The scale designates the difference between the real values "
      "corresponding to consecutive quantized values differing by 1.");
  uniformQuantizedType.def_property_readonly(
      "zero_point",
      [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); },
      "The storage value corresponding to the real value 0 in the affine "
      "equation.");
  uniformQuantizedType.def_property_readonly(
      "is_fixed_point",
      [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); },
      "Fixed point values are real numbers divided by a scale.");

  //===-------------------------------------------------------------------===//
  // UniformQuantizedPerAxisType
  //===-------------------------------------------------------------------===//
  auto uniformQuantizedPerAxisType = mlir_type_subclass(
      m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType,
      quantizedType.get_class());
  uniformQuantizedPerAxisType.def_classmethod(
      "get",
      [](py::object cls, unsigned flags, MlirType storageType,
         MlirType expressedType, std::vector<double> scales,
         std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
         int64_t storageTypeMin, int64_t storageTypeMax) {
        if (scales.size() != zeroPoints.size())
          throw py::value_error(
              "Mismatching number of scales and zero points.");
        auto nDims = static_cast<intptr_t>(scales.size());
        return cls(mlirUniformQuantizedPerAxisTypeGet(
            flags, storageType, expressedType, nDims, scales.data(),
            zeroPoints.data(), quantizedDimension, storageTypeMin,
            storageTypeMax));
      },
      "Gets an instance of UniformQuantizedPerAxisType in the same context as "
      "the provided storage type.",
      py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
      py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"),
      py::arg("quantized_dimension"), py::arg("storage_type_min"),
      py::arg("storage_type_max"));
  uniformQuantizedPerAxisType.def_property_readonly(
      "scales",
      [](MlirType type) {
        intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
        std::vector<double> scales;
        scales.reserve(nDim);
        for (intptr_t i = 0; i < nDim; ++i) {
          double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
          scales.push_back(scale);
        }
      },
      "The scales designate the difference between the real values "
      "corresponding to consecutive quantized values differing by 1. The ith "
      "scale corresponds to the ith slice in the quantized_dimension.");
  uniformQuantizedPerAxisType.def_property_readonly(
      "zero_points",
      [](MlirType type) {
        intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
        std::vector<int64_t> zeroPoints;
        zeroPoints.reserve(nDim);
        for (intptr_t i = 0; i < nDim; ++i) {
          int64_t zeroPoint =
              mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
          zeroPoints.push_back(zeroPoint);
        }
      },
      "the storage values corresponding to the real value 0 in the affine "
      "equation. The ith zero point corresponds to the ith slice in the "
      "quantized_dimension.");
  uniformQuantizedPerAxisType.def_property_readonly(
      "quantized_dimension",
      [](MlirType type) {
        return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
      },
      "Specifies the dimension of the shape that the scales and zero points "
      "correspond to.");
  uniformQuantizedPerAxisType.def_property_readonly(
      "is_fixed_point",
      [](MlirType type) {
        return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
      },
      "Fixed point values are real numbers divided by a scale.");

  //===-------------------------------------------------------------------===//
  // CalibratedQuantizedType
  //===-------------------------------------------------------------------===//

  auto calibratedQuantizedType = mlir_type_subclass(
      m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType,
      quantizedType.get_class());
  calibratedQuantizedType.def_classmethod(
      "get",
      [](py::object cls, MlirType expressedType, double min, double max) {
        return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
      },
      "Gets an instance of CalibratedQuantizedType in the same context as the "
      "provided expressed type.",
      py::arg("cls"), py::arg("expressed_type"), py::arg("min"),
      py::arg("max"));
  calibratedQuantizedType.def_property_readonly("min", [](MlirType type) {
    return mlirCalibratedQuantizedTypeGetMin(type);
  });
  calibratedQuantizedType.def_property_readonly("max", [](MlirType type) {
    return mlirCalibratedQuantizedTypeGetMax(type);
  });
}

PYBIND11_MODULE(_mlirDialectsQuant, m) {
  m.doc() = "MLIR Quantization dialect";

  populateDialectQuantSubmodule(m);
}