llvm/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp

//===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms `quant.dcast` and `quant.qcast` into lower-level ops.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/Dialect/Quant/Transforms/Passes.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
namespace quant {

#define GEN_PASS_DEF_LOWERQUANTOPS
#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"

namespace {

// If 'inputType' is a tensor, return its element type. If it is a scalar,
// return it as is.
Type getScalarType(Type inputType) {}

// Return the shape of an input value as a list of attributes (static dimensions)
// and values (dynamic dimensions). If 'input' is a scalar, an empty list is
// returned. If 'input' is a tensor, its shape is returned.
SmallVector<OpFoldResult>
getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) {}

// If 'referenceType' is a scalar, return 'elementType' as is. If
// 'referenceType' is a tensor, return another tensor with the same shape and
// elements of type 'elementType'.
Type getScalarOrTensorType(Type elementType, Type referenceType) {}

// Return a constant with the given value. If 'referenceType' is a tensor, a
// tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a
// scalar, 'referenceShape' is ignored and a scalar constant is returned.
Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
                                Type referenceType,
                                ArrayRef<OpFoldResult> referenceShape) {}

// Reshape an unranked tensor into a 1D ranked tensor.
//
// - input
//   Unranked tensor.
//
// Return values:
//
// - flatInput
//   1D ranked, dynamically shaped tensor.
//
// - inputShape
//   1D extent tensor containing the shape of the original unranked input.
//
std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
                                              Value input) {}

// Reshape an unranked tensor into a 3D ranked tensor where the central
// dimension of the result tensor corresponds to dimension 'axis' of the input
// tensor.
//
// - input
//   Unranked tensor.
//
// - axis
//   Index of the input dimension around which other input dimiensions will be
//   collapsed.
//
// - axisSize
//   Size of input dimension 'axis'.
//
// Return values:
//
// - flatInput
//   3D ranked tensor of shape [?, axisSize, ?].
//
// - inputShape
//   1D extent tensor containing the shape of the original unranked input.
//
std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
                                                        Location loc,
                                                        Value input,
                                                        int64_t axis,
                                                        int64_t axisSize) {}

// Reshape an input tensor into its original unranked shape.
//
// - input
//   Ranked tensor.
//
// - inputShape
//   1D extent tensor.
//
Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
                                 Value inputShape) {}

// Create a tensor constant containing all scales in a per-channel quantized
// type. Example:
//
//   !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
//
// produces
//
//   %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
//
Value materializePerChannelScales(OpBuilder &builder, Location loc,
                                  UniformQuantizedPerAxisType quantizedType) {}

// Create a tensor constant containing all zero points in a per-channel
// quantized type. Example:
//
//   !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
//
// produces
//
//   %cst = arith.constant dense<[10, 20]> : tensor<2xi8>
//
Value materializePerChannelZeroPoints(
    OpBuilder &builder, Location loc,
    UniformQuantizedPerAxisType quantizedType) {}

// Clamp the given scalar or tensor input using the storage bounds encoded in
// the given quantized type, if present.
//
// - input
//   Scalar or ranked tensor input. The element type must match the storage type
//   of 'quantizedType'.
//
// - inputShape
//   If 'input' is a tensor, combination of attributes/values representing its
//   static/dynamic dimensions. If 'input' is a scalar, empty list.
//
// - quantizedType
//   Per-axis or per-channel quantized type.
Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
                          ArrayRef<OpFoldResult> inputShape,
                          QuantizedType quantizedType) {}

// Emit op 'arith.fptosi' or 'arith.fptoui'.
Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
                            Type resultType, bool isSigned) {}

// Emit op 'arith.sitofp' or 'arith.uitofp'.
Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
                            Type resultType, bool isSigned) {}

// Quantize a scalar or ranked tensor value. The stored value is clamped using 
// the storage bounds encoded in the given quantized type.
//
// See function 'convertRanked()' below for a description of the arguments.
Value quantizeValue(OpBuilder &builder, Location loc, Value input,
                    ArrayRef<OpFoldResult> inputShape, Value scale,
                    Value zeroPoint, QuantizedType quantizedType) {}

// Dequantize a scalar or ranked tensor input.
//
// See function 'convertRanked()' below for a description of the arguments.
Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
                      ArrayRef<OpFoldResult> inputShape, Value scale,
                      Value zeroPoint, QuantizedType quantizedType) {}

// Convert a scalar or ranked tensor input with the given scale and zero point
// values.
//
// - input
//   Scalar or ranked tensor value.
//
// - inputShape
//   If 'input' is a tensor, combination or attributes/values representing its
//   static/dynamic dimensions. If 'input' is a scalar, empty list.
//
// - scale
//   Scale as a floating-point scalar value.
//
// - zeroPoint
//   Zero point as an integer scalar value.
//
// - quantizedType
//   Scalar quantized type of the result ('quant.qcast') or of the input
//   ('quant.dcast').
//
Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
                    Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
                    Value zeroPoint, QuantizedType quantizedType) {}

// Convert an operation using per-layer quantization with a scalar or ranked
// tensor input.
//
// - op
//   'quant.dcast' or 'quant.qcast' op.
//
// - input
//   Scalar or ranked tensor.
//
// - quantizedType
//   Per-layer quantized type.
//
Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
                            Value input, UniformQuantizedType quantizedType) {}

// Convert an operation using per-layer quantization.
//
// - op
//   'quant.dcast' or 'quant.qcast' op.
//
// - input
//   Scalar, ranked tensor, or unranked tensor.
//
// - quantizedType
//   Per-layer quantized type.
//
Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
                      Value input, UniformQuantizedType quantizedType) {}

// Convert an operation using per-channel quantization and a scalar or ranked
// tensor as an input.
//
// - op
//   'quant.dcast' or 'quant.qcast' op.
//
// - input
//   Scalar or ranked tensor.
//
// - quantizedType
//   Per-channel quantized type.
//
Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
                              Value input,
                              UniformQuantizedPerAxisType quantizedType,
                              int64_t channelAxis) {}

// Convert an operation using per-channel quantization.
//
// - op
//   'quant.dcast' or 'quant.qcast' op.
//
// - input
//   Scalar, ranked tensor, or unranked tensor.
//
// - quantizedType
//   Per-channel quantized type.
//
Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
                        Value input,
                        UniformQuantizedPerAxisType quantizedType) {}

// Convert a quantization operation.
//
// - op
//   'quant.dcast' or 'quant.qcast' op.
//
// - input
//   Scalar, ranked tensor, or unranked tensor. The element type matches
//   the storage type (quant.dcast) or expressed type (quant.qcast) of
//   'quantizedType'.
//
// - quantizedType
//   Per-layer or per-channel quantized type.
//
Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
                       Value input, Type quantizedType) {}

// Lowering pattern for 'quant.dcast'
struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {};

// Lowering pattern for 'quant.qcast'
struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {};

struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {};

} // namespace

void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) {}

} // namespace quant
} // namespace mlir