llvm/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h

//===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
#define MLIR_DIALECT_QUANT_IR_QUANTTYPES_H

#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/MathExtras.h"

namespace mlir {
namespace quant {
namespace detail {

struct QuantizedTypeStorage;
struct AnyQuantizedTypeStorage;
struct UniformQuantizedTypeStorage;
struct UniformQuantizedPerAxisTypeStorage;
struct CalibratedQuantizedTypeStorage;

} // namespace detail

/// Enumeration of bit-mapped flags related to quantized types.
namespace QuantizationFlags {
enum FlagValue {};
} // namespace QuantizationFlags

/// Base class for all quantized types known to this dialect.
/// All quantized types have:
///   - storageType: The (narrower) numeric type that is being used to
///     approximate some expressed type.
///   - expressedType: The type that is being approximated.
///
/// The base class provides generic support for manipulating the types based
/// on these fields.
class QuantizedType : public Type {};

/// A quantized type that maps storage to/from expressed types in an
/// unspecified way.
///
/// Typical syntax:
///   quant.any<i8:f32>
///   quant.any<i8>
///   quant.any<i8<-16,15>>
///
/// Note that for the any type, the expressed type is optional.
class AnyQuantizedType
    : public Type::TypeBase<AnyQuantizedType, QuantizedType,
                            detail::AnyQuantizedTypeStorage> {};

/// Represents a family of uniform, quantized types.
///
/// Each instance of this type expresses a mapping between real values (most
/// often expressed in floating point f32) and quantized values (either fixed
/// point or affine).
///
/// The relationship is:
///     real_value = scale * (quantized_value - zero_point)
///
/// It is used as part of high level graph transformations that have the goal
/// of re-expressing parts of a computation in terms of this common form for
/// more efficient execution at runtime. In addition, it is designed to be
/// expressive enough to facilitate lowering to precise types and operations
/// in target hardware.
///
/// As a high-level type, focused on intermediate passes, this type holds
/// opinions consistent with high-level usage. If lowering math kernels below
/// the high level arithmetic ops (i.e. to LLVM IR or hardware specific
/// instruction sets), it is expected that the information expressed here
/// will be used to drive low level codegen and target specific type selection,
/// but this type will likely be erased in the process.
///
/// Syntax synopsis:
///   Per-layer, all parameters expressed:
///     !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}>
///   Per-layer, optional parameters omitted:
///     !quant<uniform[StorageType]{Scale}>
///
///   StorageType: 'i'|'u' NumBits
///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
///   Scale: A legal double value
///   ZeroPoint: An integer value
class UniformQuantizedType
    : public Type::TypeBase<UniformQuantizedType, QuantizedType,
                            detail::UniformQuantizedTypeStorage> {};

/// Represents per-axis (also known as per-channel quantization).
///
/// Syntax synopsis:
///   Per-axis, all parameters expressed:
///     !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}>
///   Per-axis, optional parameters omitted:
///     !quant<uniform[StorageType]{Scale}>
///
///   StorageType: 'i'|'u' NumBits
///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
///   QuantizedDim: An integer value
///   QuantParams: (Scale ':' ZeroPoint)+
///   Scale: A legal double value
///   ZeroPoint: An integer value
class UniformQuantizedPerAxisType
    : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
                            detail::UniformQuantizedPerAxisTypeStorage> {};

/// A quantized type that infers its range from given min/max values.
///
/// Typical syntax:
///   quant.calibrated<f32<-0.922,0.981>>
class CalibratedQuantizedType
    : public Type::TypeBase<CalibratedQuantizedType, QuantizedType,
                            detail::CalibratedQuantizedTypeStorage> {};

} // namespace quant
} // namespace mlir

#endif // MLIR_DIALECT_QUANT_IR_QUANTTYPES_H