llvm/mlir/include/mlir/IR/CommonTypeConstraints.td

//===-- CommonTypeConstraints.td - Common Type Constraints--*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains commonly used type constraints.
//
//===----------------------------------------------------------------------===//

#ifndef COMMON_TYPE_CONSTRAINTS_TD
#define COMMON_TYPE_CONSTRAINTS_TD

include "mlir/IR/Constraints.td"
include "mlir/IR/DialectBase.td"

//===----------------------------------------------------------------------===//
// Common predicates
//===----------------------------------------------------------------------===//

// Whether a type is a VectorType.
// Explicitly disallow 0-D vectors for now until we have good enough coverage.
def IsVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
                            CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;

// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;

// Whether a type is a fixed-length VectorType.
def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
                                  !::llvm::cast<VectorType>($_self).isScalable()}]>;

// Whether a type is a scalable VectorType.
def IsVectorTypeWithAnyDimScalablePred
        : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
                  ::llvm::cast<VectorType>($_self).isScalable()}]>;

// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
// Examples:
// Valid:
//   - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32>
// Invalid
//   - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32>
def IsVectorTypeWithOnlyTrailingDimScalablePred : And<[
  CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
  CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
  CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">,
  CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">
]>;

// Whether a type is a VectorType and all dimensions are scalable.
def IsVectorTypeWithAllDimsScalablePred : And<[
  IsVectorTypePred,
  CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]>
]>;

// Whether a type is a TensorType.
def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">;

// Whether a type is a MemRefType.
def IsMemRefTypePred : CPred<"::llvm::isa<::mlir::MemRefType>($_self)">;

// Whether a type is an UnrankedMemRefType
def IsUnrankedMemRefTypePred
        : CPred<"::llvm::isa<::mlir::UnrankedMemRefType>($_self)">;

// Whether a type is an UnrankedTensorType
def IsUnrankedTensorTypePred
        : CPred<"::llvm::isa<::mlir::UnrankedTensorType>($_self)">;

// Whether a type is a RankedTensorType
def IsRankedTensorTypePred
        : CPred<"::llvm::isa<::mlir::RankedTensorType>($_self)">;

// Whether a type is a BaseMemRefType
def IsBaseMemRefTypePred
        : CPred<"::llvm::isa<::mlir::BaseMemRefType>($_self)">;

// Whether a type is a ShapedType.
def IsShapedTypePred : CPred<"::llvm::isa<::mlir::ShapedType>($_self)">;

// For a ShapedType, verify that it has a static shape.
def HasStaticShapePred :
        CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasStaticShape()">;

// Whether a type is a TupleType.
def IsTupleTypePred : CPred<"::llvm::isa<::mlir::TupleType>($_self)">;

// Whether a type has a ValueSemantics trait.
def HasValueSemanticsPred : CPred<"$_self.hasTrait<::mlir::ValueSemantics>()">;

//===----------------------------------------------------------------------===//
// Type definitions
//===----------------------------------------------------------------------===//

// A type, carries type constraints.
class Type<Pred condition, string descr = "",
           string cppType = "::mlir::Type"> :
    TypeConstraint<condition, descr, cppType> {
  string description = "";
  string builderCall = "";
}

// Allows providing an alternative name and summary to an existing type def.
class TypeAlias<Type t, string summary = t.summary> :
    Type<t.predicate, summary, t.cppType> {
  let description = t.description;
  let builderCall = t.builderCall;
}

// A type of a specific dialect.
class DialectType<Dialect d, Pred condition, string descr = "",
                  string cppType = "::mlir::Type"> :
    Type<condition, descr, cppType> {
  Dialect dialect = d;
}

// A variadic type constraint. It expands to zero or more of the base type. This
// class is used for supporting variadic operands/results.
class Variadic<Type type> : TypeConstraint<type.predicate,
                                           "variadic of " # type.summary,
                                           type.cppType> {
  Type baseType = type;
  int minSize = 0;
}

// A nested variadic type constraint. It expands to zero or more variadic ranges
// of the base type. This class is used for supporting variadic operands and
// results. `variadicSegmentAttrName` should correspond to the name of an
// DenseI32ArrayAttr argument that provides the sizes of the inner variadic
// operand groups.
class VariadicOfVariadic<Type type, string variadicSegmentAttrName>
    : Variadic<type> {
  string segmentAttrName = variadicSegmentAttrName;
}

// An optional type constraint. It expands to either zero or one of the base
// type. This class is used for supporting optional operands/results.
class Optional<Type type> : TypeConstraint<type.predicate, type.summary,
                                           type.cppType> {
  Type baseType = type;
}

// A type that can be constructed using MLIR::Builder.
// Note that this does not "inherit" from Type because it would require
// duplicating Type subclasses for buildable and non-buildable cases to avoid
// diamond "inheritance".
// TODO: we may extend this to a more general 'Buildable' trait, making some
// Types and some Attrs buildable.
class BuildableType<code builder> {
  // The builder call to invoke (if specified) to construct the BuildableType.
  code builderCall = builder;
}

// A type that's buildable iff the type passed as an argument is buildable.
// This is intended for use by types like container types, which are only
// buildable if the type of their elements is buildable.
class SameBuildabilityAs<Type type, code builder> {
  code builderCall = !if(!empty(type.builderCall), "", builder);
}

// Any type at all.
def AnyType : Type<CPred<"true">, "any type">;

// None type
def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
                    "::mlir::NoneType">,
      BuildableType<"$_builder.getType<::mlir::NoneType>()">;

// Any type from the given list
class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
                string cppType = "::mlir::Type"> : Type<
    // Satisfy any of the allowed types' conditions.
    Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
    !if(!eq(summary, ""),
        !interleave(!foreach(t, allowedTypeList, t.summary), " or "),
        summary),
    cppType> {
  list<Type> allowedTypes = allowedTypeList;
}

// A type that satisfies the constraints of all given types.
class AllOfType<list<Type> allowedTypeList, string summary = "",
                string cppType = "::mlir::Type"> : Type<
    // Satisfy all of the allowed types' conditions.
    And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>,
    !if(!eq(summary, ""),
        !interleave(!foreach(t, allowedTypeList, t.summary), " and "),
        summary),
    cppType> {
  list<Type> allowedTypes = allowedTypeList;
}

// A type that satisfies additional predicates.
class ConfinedType<Type type, list<Pred> predicates, string summary = "",
                   string cppType = type.cppType> : Type<
    And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
    summary, cppType> {
    Type baseType = type;
    list<Pred> predicateList = predicates;
}

// Integer types.

// Any integer type irrespective of its width and signedness semantics.
def AnyInteger : Type<CPred<"::llvm::isa<::mlir::IntegerType>($_self)">, "integer",
                      "::mlir::IntegerType">;

// Any integer type (regardless of signedness semantics) of a specific width.
class AnyI<int width>
    : Type<CPred<"$_self.isInteger(" # width # ")">, width # "-bit integer"> {
  int bitwidth = width;
}

class AnyIntOfWidths<list<int> widths> :
    AnyTypeOf<!foreach(w, widths, AnyI<w>),
              !interleave(widths, "/") # "-bit integer",
              "::mlir::IntegerType">;

def AnyI1  : AnyI<1>;
def AnyI8  : AnyI<8>;
def AnyI16 : AnyI<16>;
def AnyI32 : AnyI<32>;
def AnyI64 : AnyI<64>;

// Any signless integer type irrespective of its width.
def AnySignlessInteger : Type<
  CPred<"$_self.isSignlessInteger()">, "signless integer",
        "::mlir::IntegerType">;

// Signless integer type of a specific width.
class I<int width>
    : Type<CPred<"$_self.isSignlessInteger(" # width # ")">,
                  width # "-bit signless integer", "::mlir::IntegerType">,
      BuildableType<"$_builder.getIntegerType(" # width # ")"> {
  int bitwidth = width;
}

class SignlessIntOfWidths<list<int> widths> :
    AnyTypeOf<!foreach(w, widths, I<w>),
              !interleave(widths, "/") # "-bit signless integer">;

def I1  : I<1>;
def I8  : I<8>;
def I16 : I<16>;
def I32 : I<32>;
def I64 : I<64>;
def I128 : I<128>;

// Any signed integer type irrespective of its width.
def AnySignedInteger : Type<
  CPred<"$_self.isSignedInteger()">, "signed integer">;

// Signed integer type of a specific width.
class SI<int width>
    : Type<CPred<"$_self.isSignedInteger(" # width # ")">,
                  width # "-bit signed integer", "::mlir::IntegerType">,
      BuildableType<
        "$_builder.getIntegerType(" # width # ", /*isSigned=*/true)"> {
  int bitwidth = width;
}

class SignedIntOfWidths<list<int> widths> :
    AnyTypeOf<!foreach(w, widths, SI<w>),
              !interleave(widths, "/") # "-bit signed integer">;

def SI1  : SI<1>;
def SI8  : SI<8>;
def SI16 : SI<16>;
def SI32 : SI<32>;
def SI64 : SI<64>;

// Any unsigned integer type irrespective of its width.
def AnyUnsignedInteger : Type<
  CPred<"$_self.isUnsignedInteger()">, "unsigned integer">;

// Unsigned integer type of a specific width.
class UI<int width>
    : Type<CPred<"$_self.isUnsignedInteger(" # width # ")">,
                  width # "-bit unsigned integer", "::mlir::IntegerType">,
      BuildableType<
        "$_builder.getIntegerType(" # width # ", /*isSigned=*/false)"> {
  int bitwidth = width;
}

class UnsignedIntOfWidths<list<int> widths> :
    AnyTypeOf<!foreach(w, widths, UI<w>),
              !interleave(widths, "/") # "-bit unsigned integer">;

def UI1  : UI<1>;
def UI8  : UI<8>;
def UI16 : UI<16>;
def UI32 : UI<32>;
def UI64 : UI<64>;

// Index type.
def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
                 "::mlir::IndexType">,
            BuildableType<"$_builder.getIndexType()">;

// Any signless integer type or index type.
def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
                                     "signless integer or index">;

// Floating point types.

// Any float type irrespective of its width.
def AnyFloat : Type<CPred<"::llvm::isa<::mlir::FloatType>($_self)">, "floating-point",
                    "::mlir::FloatType">;

// Float type of a specific width.
class F<int width>
    : Type<CPred<"$_self.isF" # width # "()">,
           width # "-bit float", "::mlir::FloatType">,
      BuildableType<"$_builder.getF" # width # "Type()"> {
  int bitwidth = width;
}

class FloatOfWidths<list<int> widths> :
    AnyTypeOf<!foreach(w, widths, F<w>),
              !interleave(widths, "/") # "-bit float">;

def F16 : F<16>;
def F32 : F<32>;
def F64 : F<64>;
def F80 : F<80>;
def F128 : F<128>;

def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
           BuildableType<"$_builder.getBF16Type()">;
def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
           BuildableType<"$_builder.getTF32Type()">;
def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
               BuildableType<"$_builder.getFloat8E4M3FNType()">;
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
             BuildableType<"$_builder.getFloat8E5M2Type()">;
def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
             BuildableType<"$_builder.getFloat8E4M3Type()">;
def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
                 BuildableType<"$_builder.getFloat8E4M3FNUZType()">;
def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
                 BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">;
def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
                 BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
             BuildableType<"$_builder.getFloat8E3M4Type()">;
def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
               BuildableType<"$_builder.getFloat4E2M1FNType()">;
def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
               BuildableType<"$_builder.getFloat6E2M3FNType()">;
def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
               BuildableType<"$_builder.getFloat6E3M2FNType()">;
def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
                BuildableType<"$_builder.getFloat8E8M0FNUType()">;

def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
                      "complex-type", "::mlir::ComplexType">;

class Complex<Type elType>
    : ConfinedType<AnyComplex, [
          SubstLeaves<"$_self",
                      "::llvm::cast<::mlir::ComplexType>($_self).getElementType()",
           elType.predicate>],
           "complex type with " # elType.summary # " elements",
           "::mlir::ComplexType">,
      SameBuildabilityAs<elType, "::mlir::ComplexType::get($_builder.get" # elType #
                               "Type())"> {
  Type elementType = elType;
}

class OpaqueType<string dialect, string name, string summary>
  : Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
         summary, "::mlir::OpaqueType">,
    BuildableType<"::mlir::OpaqueType::get("
                  "$_builder.getStringAttr(\"" # dialect # "\"), \""
                  # name # "\")">;

// Function Type

// Any function type.
def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
                              "function type", "::mlir::FunctionType">;

// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
                    string descr, string cppType = "::mlir::Type"> :
    // First, check the container predicate.  Then, substitute the extracted
    // element into the element type checker.
    Type<And<[containerPred,
                SubstLeaves<"$_self", !cast<string>(elementTypeCall),
                etype.predicate>]>,
         descr # " of " # etype.summary # " values", cppType>;

class ShapedContainerType<list<Type> allowedTypes,
                          Pred containerPred, string descr,
                          string cppType = "::mlir::Type"> :
    Type<And<[containerPred,
              Concat<"[](::mlir::Type elementType) { return ",
                SubstLeaves<"$_self", "elementType",
                AnyTypeOf<allowedTypes>.predicate>,
                "; }(::llvm::cast<::mlir::ShapedType>($_self).getElementType())">]>,
         descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppType>;

// Whether a shaped type is ranked.
def HasRankPred : CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasRank()">;

// Whether a shaped type has one of the specified ranks.
class HasAnyRankOfPred<list<int> ranks> : And<[
    HasRankPred,
    Or<!foreach(rank, ranks,
                CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank()
                         == }]
                      # rank>)>]>;

// Whether a shaped type has a rank greater than or equal of the specified rank.
class HasRankGreaterOrEqualPred<int rank> : And<[
    HasRankPred,
    CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank() >= }] # rank>
]>;

// Container with value semantics.
class ValueSemanticsContainerOf<list<Type> allowedTypes> :
  ShapedContainerType<allowedTypes, HasValueSemanticsPred,
  "container with value semantics">;

// Vector types.

class VectorOf<list<Type> allowedTypes> :
  ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
                      "::mlir::VectorType">;

// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
class VectorOfAnyRankOf<list<Type> allowedTypes> :
  ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
                      "::mlir::VectorType">;

class FixedVectorOf<list<Type> allowedTypes> :
  ShapedContainerType<allowedTypes, IsFixedVectorTypePred,
          "fixed-length vector", "::mlir::VectorType">;

class ScalableVectorOf<list<Type> allowedTypes> :
  ShapedContainerType<allowedTypes, IsVectorTypeWithAnyDimScalablePred,
          "scalable vector", "::mlir::VectorType">;

// Any vector with a single trailing scalable dimension, with an element type in
// the `allowedTypes` list.
//
// Note: This Similar to ScalableVectorOf, with the extra requirement that only
// the trailing dim is scalable.
class VectorWithTrailingDimScalableOf<list<Type> allowedTypes> :
  ShapedContainerType<allowedTypes, IsVectorTypeWithOnlyTrailingDimScalablePred,
          "trailing scalable vector", "::mlir::VectorType">;

// Whether the number of elements of a vector is from the given
// `allowedRanks` list
class IsVectorOfRankPred<list<int> allowedRanks> :
  And<[IsVectorTypePred,
       Or<!foreach(allowedlength, allowedRanks,
                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
                           == }]
                         # allowedlength>)>]>;

// Whether the number of elements of a fixed-length vector is from the given
// `allowedRanks` list
class IsFixedVectorOfRankPred<list<int> allowedRanks> :
  And<[IsFixedVectorTypePred,
       Or<!foreach(allowedlength, allowedRanks,
                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
                           == }]
                         # allowedlength>)>]>;

// Whether the number of elements of a scalable vector is from the given
// `allowedRanks` list
class IsScalableVectorOfRankPred<list<int> allowedRanks> :
  And<[IsVectorTypeWithAnyDimScalablePred,
       Or<!foreach(allowedlength, allowedRanks,
                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
                           == }]
                         # allowedlength>)>]>;

// Any vector where the rank is from the given `allowedRanks` list
class VectorOfRank<list<int> allowedRanks> : Type<
  IsVectorOfRankPred<allowedRanks>,
  " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;

// Any fixed-length vector where the rank is from the given `allowedRanks` list
class FixedVectorOfRank<list<int> allowedRanks> : Type<
  IsFixedVectorOfRankPred<allowedRanks>,
  " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;

// Any scalable vector where the rank is from the given `allowedRanks` list
class ScalableVectorOfRank<list<int> allowedRanks> : Type<
  IsScalableVectorOfRankPred<allowedRanks>,
  " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;

// Any vector where the rank is from the given `allowedRanks` list and the type
// is from the given `allowedTypes` list
class VectorOfRankAndType<list<int> allowedRanks,
                          list<Type> allowedTypes> : AllOfType<
  [VectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
  VectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
  "::mlir::VectorType">;

// Fixed-width vector where the rank is from the given `allowedRanks` list and
// the type is from the given `allowedTypes` list
class FixedVectorOfRankAndType<list<int> allowedRanks,
                          list<Type> allowedTypes> : AllOfType<
  [FixedVectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
  FixedVectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
  "::mlir::VectorType">;

// Whether the number of elements of a vector is from the given
// `allowedLengths` list
class IsVectorOfLengthPred<list<int> allowedLengths> :
  And<[IsVectorTypePred,
       Or<!foreach(allowedlength, allowedLengths,
                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
                           == }]
                         # allowedlength>)>]>;

// Whether the number of elements of a fixed-length vector is from the given
// `allowedLengths` list
class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
  And<[IsFixedVectorTypePred,
       Or<!foreach(allowedlength, allowedLengths,
                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
                           == }]
                         # allowedlength>)>]>;

// Whether the number of elements of a scalable vector is from the given
// `allowedLengths` list
class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
  And<[IsVectorTypeWithAnyDimScalablePred,
       Or<!foreach(allowedlength, allowedLengths,
                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
                           == }]
                         # allowedlength>)>]>;

// Normalizes an index so the indices in both directions have the same value.
// For example, when indexing forwards index 2 is the third element. When
// indexing in reverse the third element is -3. This helper would map both of
// these to the "normalized" index of 3. This makes the bounds checking in
// IsNthDimSizeIsOneOfPred simpler (see first CPred).
class NormalizeIndex<int value> {
  int ret = !if(!lt(value, 0),
    !sub(0, value)  /* -value if negative */,
    !add(value, 1)  /* value + 1 if positive*/);
}

// Whether the n-th dim of the shape is contained within `allowedSizes`.
// Negative values for `n` index in reverse.
//
// Examples:
// IsNthDimSizeIsOneOfPred<0, {2, 3, 4}>
//  - Accepts any shape where the first dim is 2, 3, or 4.
//    * This means shapes like: 2x8x9x5, 4, 3x1, 4x?, etc
// IsNthDimSizeIsOneOfPred<-1, {16}>
//  - Accepts any shape where the last dim is 16.
//    * This means shapes like 2x16, 16, 1x2x3x4x16, etc
// IsNthDimSizeIsOneOfPred<-2, {10, 5}>
//  - Accepts any shape where the second to last dim is 10 or 5.
//    * This means shapes like: 1x10x2, 2x1x4x5x6, 8x10x?, etc
class IsNthDimSizeIsOneOfPred<int n, list<int> allowedSizes>
  : And<[
      CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # NormalizeIndex<n>.ret>,
      CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), "
        # "::llvm::cast<::mlir::ShapedType>($_self).getDimSize("
        #   !if(!lt(n, 0),
              "::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n,
              "" # n)
        # "))">]>;

// Whether the shape of a vector matches the given `shape` list.
class IsVectorOfShape<list<int> shape>
  : CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">;

// Any vector where the number of elements is from the given
// `allowedLengths` list
class VectorOfLength<list<int> allowedLengths> : Type<
  IsVectorOfLengthPred<allowedLengths>,
  " of length " # !interleave(allowedLengths, "/"),
  "::mlir::VectorType">;

// Any fixed-length vector where the number of elements is from the given
// `allowedLengths` list
class FixedVectorOfLength<list<int> allowedLengths> : Type<
  IsFixedVectorOfLengthPred<allowedLengths>,
  " of length " # !interleave(allowedLengths, "/"),
  "::mlir::VectorType">;

// Any scalable vector where the number of elements is from the given
// `allowedLengths` list
class ScalableVectorOfLength<list<int> allowedLengths> : Type<
  IsScalableVectorOfLengthPred<allowedLengths>,
  " of length " # !interleave(allowedLengths, "/"),
  "::mlir::VectorType">;

// Any vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes`
// list
class VectorOfLengthAndType<list<int> allowedLengths,
                            list<Type> allowedTypes> : AllOfType<
  [VectorOf<allowedTypes>, VectorOfLength<allowedLengths>],
  VectorOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
  "::mlir::VectorType">;

// Any fixed-length vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
class FixedVectorOfLengthAndType<list<int> allowedLengths,
                                 list<Type> allowedTypes> : AllOfType<
  [FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
  FixedVectorOf<allowedTypes>.summary #
  FixedVectorOfLength<allowedLengths>.summary,
  "::mlir::VectorType">;

// Any scalable vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
class ScalableVectorOfLengthAndType<list<int> allowedLengths,
                                    list<Type> allowedTypes> : AllOfType<
  [ScalableVectorOf<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
  ScalableVectorOf<allowedTypes>.summary #
  ScalableVectorOfLength<allowedLengths>.summary,
  "::mlir::VectorType">;

// Any scalable vector where the rank is from the given `allowedRanks` list and
// the number of elements is from the given `allowedLengths` list and the type
// is from the given `allowedTypes` list
class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
                                           list<int> allowedLengths,
                                           list<Type> allowedTypes> : AllOfType<
  [ScalableVectorOfRank<allowedRanks>, ScalableVectorOf<allowedTypes>,
   ScalableVectorOfLength<allowedLengths>],
  ScalableVectorOfRank<allowedRanks>.summary #
  ScalableVectorOf<allowedTypes>.summary #
  ScalableVectorOfLength<allowedLengths>.summary,
  "::mlir::VectorType">;

// Any ShapedType where the size of the n-th dim is contained in `allowedSizes`.
// Negative values for `n` index in reverse.
class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
  IsNthDimSizeIsOneOfPred<n, allowedSizes>,
  " with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}",
  "::mlir::ShapedType">;

// Any scalable vector with a single trailing scalable dimensions, where the
// size of the trailing dimension is in `allowedTrailingSizes` list, and the
// type is in the `allowedTypes` list.
class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
                                           list<Type> allowedTypes> : AllOfType<
  [VectorWithTrailingDimScalableOf<allowedTypes>,
   ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>],
   VectorWithTrailingDimScalableOf<allowedTypes>.summary #
   ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
  "::mlir::VectorType">;

def AnyVector : VectorOf<[AnyType]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;

def AnyFixedVector : FixedVectorOf<[AnyType]>;

def AnyScalableVector : ScalableVectorOf<[AnyType]>;

// Shaped types.

def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
                                   "::mlir::ShapedType">;

//===----------------------------------------------------------------------===//
// Tensor types.

// Unranked tensor type whose element type is from the given `allowedTypes`
// list, and which additionally satisfies an optional list of predicates.
class UnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [],
                       string summary = "unranked tensor">
  : ShapedContainerType<
      allowedTypes, And<!listconcat([IsUnrankedTensorTypePred], preds)>,
      summary, "::mlir::UnrankedTensorType">;

// Ranked tensor type whose element type is from the given `allowedTypes` list,
// and which additionally satisfies an optional list of predicates.
class RankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [],
                     string summary = "ranked tensor">
  : ShapedContainerType<
      allowedTypes, And<!listconcat([IsRankedTensorTypePred], preds)>,
      summary, "::mlir::RankedTensorType">;

// Any tensor type whose element type is from the given `allowedTypes`
// list, and which additionally satisfies an optional list of predicates.
//
// TODO: use `Constraint` instead of `Pred`, so we can generate a better
// default summary (a la `ConfinedAttr`).
class TensorOf<
    list<Type> allowedTypes,
    list<Pred> preds = [],
    string summary = "tensor">
  : ShapedContainerType<allowedTypes,
      And<!listconcat([IsTensorTypePred], preds)>,
      summary, "::mlir::TensorType">;

def AnyTensor  : TensorOf<[AnyType]>;

def I1Tensor   : TensorOf<[I1]>;
def I8Tensor   : TensorOf<[I8]>;
def I16Tensor  : TensorOf<[I16]>;
def I32Tensor  : TensorOf<[I32]>;
def I64Tensor  : TensorOf<[I64]>;
def IndexTensor: TensorOf<[Index]>;

def BF16Tensor : TensorOf<[BF16]>;
def F16Tensor  : TensorOf<[F16]>;
def F32Tensor  : TensorOf<[F32]>;
def F64Tensor  : TensorOf<[F64]>;

class Non0RankedTensorOf<list<Type> allowedTypes>
  : TensorOf<allowedTypes, [HasRankGreaterOrEqualPred<1>],
      "non-0-ranked.tensor">;

def AnyRankedTensor : RankedTensorOf<[AnyType]>;
def AnyNon0RankedTensor  : Non0RankedTensorOf<[AnyType]>;
def AnyUnrankedTensor  : UnrankedTensorOf<[AnyType]>;

def AnyNon0RankedOrUnrankedTensor
  : AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor],
              "non-0-ranked or unranked tensor", "::mlir::TensorType">;

// Ranked tensor type with one of the specified types and ranks.
class TensorRankOf<list<Type> allowedTypes, list<int> ranks>
  : RankedTensorOf<allowedTypes,
      [HasAnyRankOfPred<ranks>],
      !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;

class 0DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [0]>;
class 1DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [1]>;
class 2DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [2]>;
class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>;
class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;

class StaticShapeTensorOf<list<Type> allowedTypes>
  : RankedTensorOf<allowedTypes, [HasStaticShapePred],
                   "statically shaped tensor">;

def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;

//===----------------------------------------------------------------------===//
// Memref type.

// Any unranked memref whose element type is from the given `allowedTypes` list.
class UnrankedMemRefOf<list<Type> allowedTypes> :
    ShapedContainerType<allowedTypes,
                        IsUnrankedMemRefTypePred, "unranked.memref",
                        "::mlir::UnrankedMemRefType">;

def AnyUnrankedMemRef : UnrankedMemRefOf<[AnyType]>;

// Any ranked memref whose element type is from the given `allowedTypes` list.
class MemRefOf<list<Type> allowedTypes> :
    ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref",
                        "::mlir::MemRefType">;

class Non0RankedMemRefOf<list<Type> allowedTypes> :
    ConfinedType<MemRefOf<allowedTypes>, [HasRankGreaterOrEqualPred<1>],
         "non-0-ranked." # MemRefOf<allowedTypes>.summary,
         "::mlir::MemRefType">;

def AnyMemRef : MemRefOf<[AnyType]>;
def AnyNon0RankedMemRef : Non0RankedMemRefOf<[AnyType]>;

// Any memref (ranked or unranked) whose element type is from the given
// `allowedTypes` list, and which additionally satisfies an optional list of
// predicates.
class RankedOrUnrankedMemRefOf<
    list<Type> allowedTypes,
    list<Pred> preds = [],
    string summary = "ranked or unranked memref">
  : ShapedContainerType<allowedTypes,
      And<!listconcat([IsBaseMemRefTypePred], preds)>,
      summary, "::mlir::BaseMemRefType">;

def AnyRankedOrUnrankedMemRef  : RankedOrUnrankedMemRefOf<[AnyType]>;
def AnyNon0RankedOrUnrankedMemRef:
    AnyTypeOf<[AnyUnrankedMemRef, AnyNon0RankedMemRef]>;

// Memref declarations handle any memref, independent of rank, size, (static or
// dynamic), layout, or memory space.
def I1MemRef  : MemRefOf<[I1]>;
def I8MemRef  : MemRefOf<[I8]>;
def I16MemRef : MemRefOf<[I16]>;
def I32MemRef : MemRefOf<[I32]>;
def I64MemRef : MemRefOf<[I64]>;

def BF16MemRef : MemRefOf<[BF16]>;
def F16MemRef  : MemRefOf<[F16]>;
def F32MemRef  : MemRefOf<[F32]>;
def F64MemRef  : MemRefOf<[F64]>;

// TODO: Have an easy way to add another constraint to a type.
class MemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
    ConfinedType<MemRefOf<allowedTypes>, [HasAnyRankOfPred<ranks>],
         !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
         MemRefOf<allowedTypes>.summary,
         "::mlir::MemRefType">;

class StaticShapeMemRefOf<list<Type> allowedTypes> :
    ConfinedType<MemRefOf<allowedTypes>, [HasStaticShapePred],
         "statically shaped " # MemRefOf<allowedTypes>.summary,
         "::mlir::MemRefType">;

def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>;

// For a MemRefType, verify that it has strides.
def HasStridesPred : CPred<[{ isStrided(::llvm::cast<::mlir::MemRefType>($_self)) }]>;

class StridedMemRefOf<list<Type> allowedTypes> :
    ConfinedType<MemRefOf<allowedTypes>, [HasStridesPred],
         "strided " # MemRefOf<allowedTypes>.summary>;

def AnyStridedMemRef : StridedMemRefOf<[AnyType]>;

class AnyStridedMemRefOfRank<int rank> :
  AllOfType<[AnyStridedMemRef, MemRefRankOf<[AnyType], [rank]>],
       AnyStridedMemRef.summary # " of rank " # rank>;

class StridedMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
    ConfinedType<MemRefOf<allowedTypes>, [HasAnyRankOfPred<ranks>],
         !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
         MemRefOf<allowedTypes>.summary>;

// This represents a generic tuple without any constraints on element type.
def AnyTuple : Type<IsTupleTypePred, "tuple", "::mlir::TupleType">;

// A container type that has other types embedded in it, but (unlike
// ContainerType) can hold elements with a mix of types. Requires a call that
// produces a list of all elements' types.
class MixedContainerType<Type etype, Pred containerPred, code elementTypesCall,
                         string descr> :
    Type<
        And<[
            containerPred,
            Concat<
                "::llvm::all_of(" # elementTypesCall # ", [](::mlir::Type t) { "
                "return t && (",
                SubstLeaves<"$_self", "t", etype.predicate>,
                "); })"
            >
        ]>,
        descr # " with any combination of " # etype.summary # " values"> {
  // The type of elements in the container.
  Type elementType = etype;

  // Call to retrieve.
  code getElementTypesCall = elementTypesCall;
}

// A Tuple that holds a mix of elements of the allowed types.
class TupleOf<list<Type> allowedTypes>
    : MixedContainerType<AnyTypeOf<allowedTypes>, IsTupleTypePred,
                         "::llvm::cast<::mlir::TupleType>($_self).getTypes()",
                         "tuple">;

// A Tuple with arbitrary nesting, where all elements are a mix of the allowed
// types.
class NestedTupleOf<list<Type> allowedTypes> :
    MixedContainerType<AnyTypeOf<allowedTypes>, IsTupleTypePred,
                       "getFlattenedTypes(::llvm::cast<::mlir::TupleType>($_self))",
                       "nested tuple">;

//===----------------------------------------------------------------------===//
// Common type constraints
//===----------------------------------------------------------------------===//
// Type constraint for types that are "like" some type or set of types T, that is
// they're either a T, a vector of Ts, or a tensor of Ts.
class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
  allowedType.predicate,
  ValueSemanticsContainerOf<[allowedType]>.predicate]>,
  name>;

// Type constraint for types that are "like" some type or set of types T, that is
// they're either a T or a mapable container of Ts.
class TypeOrValueSemanticsContainer<Type allowedType, string name>
    : TypeConstraint<Or<[
  allowedType.predicate,
  ValueSemanticsContainerOf<[allowedType]>.predicate]>,
  name>;

// Temporary constraint to allow gradual transition to supporting 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
class TypeOrContainerOfAnyRank<Type allowedType, string name> : TypeConstraint<Or<[
  allowedType.predicate, VectorOfAnyRankOf<[allowedType]>.predicate,
  TensorOf<[allowedType]>.predicate]>,
  name>;


// Type constraint for bool-like types: bools, vectors of bools, tensors of
// bools.
def BoolLike : TypeOrContainer<I1, "bool-like">;

def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;

// Type constraint for signless-integer-like types: signless integers, indices,
// vectors of signless integers or indices, tensors of signless integers.
def SignlessIntegerLike : TypeOrValueSemanticsContainer<
    AnySignlessIntegerOrIndex, "signless-integer-like">;

def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank<
    AnySignlessIntegerOrIndex,
    "signless-integer-like">;

// Type constraint for float-like types: floats, vectors or tensors thereof.
def FloatLike : TypeOrContainer<AnyFloat, "floating-point-like">;

// Type constraint for signless-integer-like or float-like types.
def SignlessIntegerOrFloatLike : TypeConstraint<Or<[
    SignlessIntegerLike.predicate, FloatLike.predicate]>,
    "signless-integer-like or floating-point-like">;

#endif // COMMON_TYPE_CONSTRAINTS_TD