//===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Transforms `quant.dcast` and `quant.qcast` into lower-level ops. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/Dialect/Quant/Transforms/Passes.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace quant { #define GEN_PASS_DEF_LOWERQUANTOPS #include "mlir/Dialect/Quant/Transforms/Passes.h.inc" namespace { // If 'inputType' is a tensor, return its element type. If it is a scalar, // return it as is. Type getScalarType(Type inputType) { … } // Return the shape of an input value as a list of attributes (static dimensions) // and values (dynamic dimensions). If 'input' is a scalar, an empty list is // returned. If 'input' is a tensor, its shape is returned. SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) { … } // If 'referenceType' is a scalar, return 'elementType' as is. If // 'referenceType' is a tensor, return another tensor with the same shape and // elements of type 'elementType'. Type getScalarOrTensorType(Type elementType, Type referenceType) { … } // Return a constant with the given value. If 'referenceType' is a tensor, a // tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a // scalar, 'referenceShape' is ignored and a scalar constant is returned. Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar, Type referenceType, ArrayRef<OpFoldResult> referenceShape) { … } // Reshape an unranked tensor into a 1D ranked tensor. // // - input // Unranked tensor. // // Return values: // // - flatInput // 1D ranked, dynamically shaped tensor. // // - inputShape // 1D extent tensor containing the shape of the original unranked input. // std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc, Value input) { … } // Reshape an unranked tensor into a 3D ranked tensor where the central // dimension of the result tensor corresponds to dimension 'axis' of the input // tensor. // // - input // Unranked tensor. // // - axis // Index of the input dimension around which other input dimiensions will be // collapsed. // // - axisSize // Size of input dimension 'axis'. // // Return values: // // - flatInput // 3D ranked tensor of shape [?, axisSize, ?]. // // - inputShape // 1D extent tensor containing the shape of the original unranked input. // std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input, int64_t axis, int64_t axisSize) { … } // Reshape an input tensor into its original unranked shape. // // - input // Ranked tensor. // // - inputShape // 1D extent tensor. // Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input, Value inputShape) { … } // Create a tensor constant containing all scales in a per-channel quantized // type. Example: // // !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}> // // produces // // %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32> // Value materializePerChannelScales(OpBuilder &builder, Location loc, UniformQuantizedPerAxisType quantizedType) { … } // Create a tensor constant containing all zero points in a per-channel // quantized type. Example: // // !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}> // // produces // // %cst = arith.constant dense<[10, 20]> : tensor<2xi8> // Value materializePerChannelZeroPoints( OpBuilder &builder, Location loc, UniformQuantizedPerAxisType quantizedType) { … } // Clamp the given scalar or tensor input using the storage bounds encoded in // the given quantized type, if present. // // - input // Scalar or ranked tensor input. The element type must match the storage type // of 'quantizedType'. // // - inputShape // If 'input' is a tensor, combination of attributes/values representing its // static/dynamic dimensions. If 'input' is a scalar, empty list. // // - quantizedType // Per-axis or per-channel quantized type. Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input, ArrayRef<OpFoldResult> inputShape, QuantizedType quantizedType) { … } // Emit op 'arith.fptosi' or 'arith.fptoui'. Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input, Type resultType, bool isSigned) { … } // Emit op 'arith.sitofp' or 'arith.uitofp'. Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input, Type resultType, bool isSigned) { … } // Quantize a scalar or ranked tensor value. The stored value is clamped using // the storage bounds encoded in the given quantized type. // // See function 'convertRanked()' below for a description of the arguments. Value quantizeValue(OpBuilder &builder, Location loc, Value input, ArrayRef<OpFoldResult> inputShape, Value scale, Value zeroPoint, QuantizedType quantizedType) { … } // Dequantize a scalar or ranked tensor input. // // See function 'convertRanked()' below for a description of the arguments. Value dequantizeValue(OpBuilder &builder, Location loc, Value input, ArrayRef<OpFoldResult> inputShape, Value scale, Value zeroPoint, QuantizedType quantizedType) { … } // Convert a scalar or ranked tensor input with the given scale and zero point // values. // // - input // Scalar or ranked tensor value. // // - inputShape // If 'input' is a tensor, combination or attributes/values representing its // static/dynamic dimensions. If 'input' is a scalar, empty list. // // - scale // Scale as a floating-point scalar value. // // - zeroPoint // Zero point as an integer scalar value. // // - quantizedType // Scalar quantized type of the result ('quant.qcast') or of the input // ('quant.dcast'). // Value convertRanked(OpBuilder &builder, Location loc, Operation *op, Value input, ArrayRef<OpFoldResult> inputShape, Value scale, Value zeroPoint, QuantizedType quantizedType) { … } // Convert an operation using per-layer quantization with a scalar or ranked // tensor input. // // - op // 'quant.dcast' or 'quant.qcast' op. // // - input // Scalar or ranked tensor. // // - quantizedType // Per-layer quantized type. // Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op, Value input, UniformQuantizedType quantizedType) { … } // Convert an operation using per-layer quantization. // // - op // 'quant.dcast' or 'quant.qcast' op. // // - input // Scalar, ranked tensor, or unranked tensor. // // - quantizedType // Per-layer quantized type. // Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op, Value input, UniformQuantizedType quantizedType) { … } // Convert an operation using per-channel quantization and a scalar or ranked // tensor as an input. // // - op // 'quant.dcast' or 'quant.qcast' op. // // - input // Scalar or ranked tensor. // // - quantizedType // Per-channel quantized type. // Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, Value input, UniformQuantizedPerAxisType quantizedType, int64_t channelAxis) { … } // Convert an operation using per-channel quantization. // // - op // 'quant.dcast' or 'quant.qcast' op. // // - input // Scalar, ranked tensor, or unranked tensor. // // - quantizedType // Per-channel quantized type. // Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op, Value input, UniformQuantizedPerAxisType quantizedType) { … } // Convert a quantization operation. // // - op // 'quant.dcast' or 'quant.qcast' op. // // - input // Scalar, ranked tensor, or unranked tensor. The element type matches // the storage type (quant.dcast) or expressed type (quant.qcast) of // 'quantizedType'. // // - quantizedType // Per-layer or per-channel quantized type. // Value convertQuantized(OpBuilder &builder, Location loc, Operation *op, Value input, Type quantizedType) { … } // Lowering pattern for 'quant.dcast' struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> { … }; // Lowering pattern for 'quant.qcast' struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> { … }; struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> { … }; } // namespace void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) { … } } // namespace quant } // namespace mlir