
//===- BuiltinTypes.h - MLIR Builtin Type Classes ---------------*- C++ -*-===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception


#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Support/ADTExtras.h"

namespace llvm {
class BitVector;
struct fltSemantics;
} // namespace llvm

// Tablegen Interface Declarations

namespace mlir {
class AffineExpr;
class AffineMap;
class FloatType;
class IndexType;
class IntegerType;
class MemRefType;
class RankedTensorType;
class StringAttr;
class TypeRange;

namespace detail {
struct FunctionTypeStorage;
struct IntegerTypeStorage;
struct TupleTypeStorage;
} // namespace detail

/// Type trait indicating that the type has value semantics.
template <typename ConcreteType>
class ValueSemantics
    : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};

// FloatType

class FloatType : public Type {};

// TensorType

/// Tensor types represent multi-dimensional arrays, and have two variants:
/// RankedTensorType and UnrankedTensorType.
/// Note: This class attaches the ShapedType trait to act as a mixin to
///       provide many useful utility functions. This inheritance has no effect
///       on derived tensor types.
class TensorType : public Type, public ShapedType::Trait<TensorType> {};

// BaseMemRefType

/// This class provides a shared interface for ranked and unranked memref types.
/// Note: This class attaches the ShapedType trait to act as a mixin to
///       provide many useful utility functions. This inheritance has no effect
///       on derived memref types.
class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {};

} // namespace mlir

// Tablegen Type Declarations

#include "mlir/IR/BuiltinTypes.h.inc"

namespace mlir {
#include "mlir/IR/BuiltinTypeConstraints.h.inc"

// MemRefType

/// This is a builder type that keeps local references to arguments. Arguments
/// that are passed into the builder must outlive the builder.
class MemRefType::Builder {};

// RankedTensorType

/// This is a builder type that keeps local references to arguments. Arguments
/// that are passed into the builder must outlive the builder.
class RankedTensorType::Builder {};

// VectorType

/// This is a builder type that keeps local references to arguments. Arguments
/// that are passed into the builder must outlive the builder.
class VectorType::Builder {};

/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
/// `originalShape` with some `1` entries erased, return the set of indices
/// that specifies which of the entries of `originalShape` are dropped to obtain
/// `reducedShape`. The returned mask can be applied as a projection to
/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
/// which dimensions must be kept when e.g. compute MemRef strides under
/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
/// obtained by dropping only `1` entries in `originalShape`.
/// If `matchDynamic` is true, then dynamic dims in `originalShape` and
/// `reducedShape` will be considered matching with non-dynamic dims, unless
/// the non-dynamic dim is from `originalShape` and equal to 1. For example,
/// in ([1, 3, ?], [?, 5]), the mask would be {1, 0, 0}, since 3 and 5 will
/// match with the corresponding dynamic dims.
computeRankReductionMask(ArrayRef<int64_t> originalShape,
                         ArrayRef<int64_t> reducedShape,
                         bool matchDynamic = false);

/// Enum that captures information related to verifier error conditions on
/// slice insert/extract type of ops.
enum class SliceVerificationResult {};

/// Check if `originalType` can be rank reduced to `candidateReducedType` type
/// by dropping some dimensions with static size `1`.
/// Return `SliceVerificationResult::Success` on success or an appropriate error
/// code.
SliceVerificationResult isRankReducedType(ShapedType originalType,
                                          ShapedType candidateReducedType);

// Deferred Method Definitions

inline bool BaseMemRefType::classof(Type type) {}

inline bool BaseMemRefType::isValidElementType(Type type) {}

inline bool FloatType::classof(Type type) {}

inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) {}

inline FloatType FloatType::getFloat6E2M3FN(MLIRContext *ctx) {}

inline FloatType FloatType::getFloat6E3M2FN(MLIRContext *ctx) {}

inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {}

inline FloatType FloatType::getFloat8E4M3(MLIRContext *ctx) {}

inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) {}

inline FloatType FloatType::getFloat8E5M2FNUZ(MLIRContext *ctx) {}

inline FloatType FloatType::getFloat8E4M3FNUZ(MLIRContext *ctx) {}

inline FloatType FloatType::getFloat8E4M3B11FNUZ(MLIRContext *ctx) {}

inline FloatType FloatType::getFloat8E3M4(MLIRContext *ctx) {}

inline FloatType FloatType::getBF16(MLIRContext *ctx) {}

inline FloatType FloatType::getF16(MLIRContext *ctx) {}

inline FloatType FloatType::getTF32(MLIRContext *ctx) {}

inline FloatType FloatType::getF32(MLIRContext *ctx) {}

inline FloatType FloatType::getF64(MLIRContext *ctx) {}

inline FloatType FloatType::getF80(MLIRContext *ctx) {}

inline FloatType FloatType::getF128(MLIRContext *ctx) {}

inline bool TensorType::classof(Type type) {}

// Type Utilities

/// Returns the strides of the MemRef if the layout map is in strided form.
/// MemRefs with a layout map in strided form include:
///   1. empty or identity layout map, in which case the stride information is
///      the canonical form computed from sizes;
///   2. a StridedLayoutAttr layout;
///   3. any other layout that be converted into a single affine map layout of
///      the form `K + k0 * d0 + ... kn * dn`, where K and ki's are constants or
///      symbols.
/// A stride specification is a list of integer values that are either static
/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
/// the distance in the number of elements between successive entries along a
/// particular dimension.
LogicalResult getStridesAndOffset(MemRefType t,
                                  SmallVectorImpl<int64_t> &strides,
                                  int64_t &offset);

/// Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl<int64_t>,
/// int64_t) that will assert if the logical result is not succeeded.
std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset(MemRefType t);

/// Return a version of `t` with identity layout if it can be determined
/// statically that the layout is the canonical contiguous strided layout.
/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
/// `t` with simplified layout.
MemRefType canonicalizeStridedLayout(MemRefType t);

/// Given MemRef `sizes` that are either static or dynamic, returns the
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
/// once a dynamic dimension is encountered, all canonical strides become
/// dynamic and need to be encoded with a different symbol.
/// For canonical strides expressions, the offset is always 0 and the fastest
/// varying stride is always `1`.
/// Examples:
///   - memref<3x4x5xf32> has canonical stride expression
///         `20*exprs[0] + 5*exprs[1] + exprs[2]`.
///   - memref<3x?x5xf32> has canonical stride expression
///         `s0*exprs[0] + 5*exprs[1] + exprs[2]`.
///   - memref<3x4x?xf32> has canonical stride expression
///         `s1*exprs[0] + s0*exprs[1] + exprs[2]`.
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
                                          ArrayRef<AffineExpr> exprs,
                                          MLIRContext *context);

/// Return the result of makeCanonicalStrudedLayoutExpr for the common case
/// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
                                          MLIRContext *context);

/// Return "true" if the layout for `t` is compatible with strided semantics.
bool isStrided(MemRefType t);

/// Return "true" if the last dimension of the given type has a static unit
/// stride. Also return "true" for types with no strides.
bool isLastMemrefDimUnitStride(MemRefType type);

/// Return "true" if the last N dimensions of the given type are contiguous.
/// Examples:
///   - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
///   considering both _all_ and _only_ the trailing 3 dims,
///   - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
///   considering the trailing 3 dims.
bool trailingNDimsContiguous(MemRefType type, int64_t n);

} // namespace mlir