#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
usingnamespacemlir;
usingnamespacemlir::memref;
Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) { … }
LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { … }
Type mlir::memref::getTensorTypeFromMemRefType(Type type) { … }
OpFoldResult memref::getMixedSize(OpBuilder &builder, Location loc, Value value,
int64_t dim) { … }
SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
Location loc, Value value) { … }
static void constifyIndexValues(
SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
MLIRContext *ctxt,
llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
llvm::function_ref<bool(int64_t)> isDynamic) { … }
static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) { … }
static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) { … }
static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) { … }
void AllocOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
void AllocaOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
template <typename AllocLikeOp>
static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { … }
LogicalResult AllocOp::verify() { … }
LogicalResult AllocaOp::verify() { … }
namespace {
template <typename AllocLikeOp>
struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { … };
template <typename T>
struct SimplifyDeadAlloc : public OpRewritePattern<T> { … };
}
void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
LogicalResult ReallocOp::verify() { … }
void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
void AllocaScopeOp::print(OpAsmPrinter &p) { … }
ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { … }
void AllocaScopeOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { … }
static bool isGuaranteedAutomaticAllocation(Operation *op) { … }
static bool isOpItselfPotentialAutomaticAllocation(Operation *op) { … }
static bool lastNonTerminatorInRegion(Operation *op) { … }
struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> { … };
struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> { … };
void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
LogicalResult AssumeAlignmentOp::verify() { … }
void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { … }
bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { … }
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … }
OpFoldResult CastOp::fold(FoldAdaptor adaptor) { … }
namespace {
struct FoldCopyOfCast : public OpRewritePattern<CopyOp> { … };
struct FoldSelfCopy : public OpRewritePattern<CopyOp> { … };
struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> { … };
}
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
LogicalResult CopyOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) { … }
LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) { … }
void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { … }
void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
int64_t index) { … }
std::optional<int64_t> DimOp::getConstantIndex() { … }
Speculation::Speculatability DimOp::getSpeculatability() { … }
static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { … }
static FailureOr<llvm::SmallBitVector>
computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
ArrayRef<OpFoldResult> sizes) { … }
llvm::SmallBitVector SubViewOp::getDroppedDims() { … }
OpFoldResult DimOp::fold(FoldAdaptor adaptor) { … }
namespace {
struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { … };
}
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
void DmaStartOp::build(OpBuilder &builder, OperationState &result,
Value srcMemRef, ValueRange srcIndices, Value destMemRef,
ValueRange destIndices, Value numElements,
Value tagMemRef, ValueRange tagIndices, Value stride,
Value elementsPerStride) { … }
void DmaStartOp::print(OpAsmPrinter &p) { … }
ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { … }
LogicalResult DmaStartOp::verify() { … }
LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) { … }
LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) { … }
LogicalResult DmaWaitOp::verify() { … }
void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
ExtractStridedMetadataOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) { … }
void ExtractStridedMetadataOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
template <typename Container>
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
Container values,
ArrayRef<OpFoldResult> maybeConstants) { … }
LogicalResult
ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) { … }
SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() { … }
SmallVector<OpFoldResult>
ExtractStridedMetadataOp::getConstifiedMixedStrides() { … }
OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() { … }
void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
Value memref, ValueRange ivs) { … }
LogicalResult GenericAtomicRMWOp::verify() { … }
ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
OperationState &result) { … }
void GenericAtomicRMWOp::print(OpAsmPrinter &p) { … }
LogicalResult AtomicYieldOp::verify() { … }
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
TypeAttr type,
Attribute initialValue) { … }
static ParseResult
parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
Attribute &initialValue) { … }
LogicalResult GlobalOp::verify() { … }
ElementsAttr GlobalOp::getConstantInitValue() { … }
LogicalResult
GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
LogicalResult LoadOp::verify() { … }
OpFoldResult LoadOp::fold(FoldAdaptor adaptor) { … }
void MemorySpaceCastOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … }
OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) { … }
void PrefetchOp::print(OpAsmPrinter &p) { … }
ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { … }
LogicalResult PrefetchOp::verify() { … }
LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) { … }
OpFoldResult RankOp::fold(FoldAdaptor adaptor) { … }
void ReinterpretCastOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
MemRefType resultType, Value source,
OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides,
ArrayRef<NamedAttribute> attrs) { … }
void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
Value source, OpFoldResult offset,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides,
ArrayRef<NamedAttribute> attrs) { … }
void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
MemRefType resultType, Value source,
int64_t offset, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides,
ArrayRef<NamedAttribute> attrs) { … }
void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
MemRefType resultType, Value source, Value offset,
ValueRange sizes, ValueRange strides,
ArrayRef<NamedAttribute> attrs) { … }
LogicalResult ReinterpretCastOp::verify() { … }
OpFoldResult ReinterpretCastOp::fold(FoldAdaptor ) { … }
SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() { … }
SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() { … }
OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() { … }
namespace {
struct ReinterpretCastOpExtractStridedMetadataFolder
: public OpRewritePattern<ReinterpretCastOp> { … };
}
void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
void CollapseShapeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
void ExpandShapeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult ExpandShapeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) { … }
static LogicalResult
verifyCollapsedShape(Operation *op, ArrayRef<int64_t> collapsedShape,
ArrayRef<int64_t> expandedShape,
ArrayRef<ReassociationIndices> reassociation,
bool allowMultipleDynamicDimsPerGroup) { … }
SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { … }
SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { … }
SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { … }
SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { … }
static FailureOr<StridedLayoutAttr>
computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
ArrayRef<ReassociationIndices> reassociation) { … }
FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
MemRefType srcType, ArrayRef<int64_t> resultShape,
ArrayRef<ReassociationIndices> reassociation) { … }
FailureOr<SmallVector<OpFoldResult>>
ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
MemRefType expandedType,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> inputShape) { … }
void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> outputShape) { … }
void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value src,
ArrayRef<ReassociationIndices> reassociation) { … }
void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<int64_t> resultShape, Value src,
ArrayRef<ReassociationIndices> reassociation) { … }
void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<int64_t> resultShape, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> outputShape) { … }
LogicalResult ExpandShapeOp::verify() { … }
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
static FailureOr<StridedLayoutAttr>
computeCollapsedLayoutMap(MemRefType srcType,
ArrayRef<ReassociationIndices> reassociation,
bool strict = false) { … }
bool CollapseShapeOp::isGuaranteedCollapsible(
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) { … }
MemRefType CollapseShapeOp::computeCollapsedType(
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) { … }
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) { … }
LogicalResult CollapseShapeOp::verify() { … }
struct CollapseShapeOpMemRefCastFolder
: public OpRewritePattern<CollapseShapeOp> { … };
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) { … }
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) { … }
void ReshapeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult ReshapeOp::verify() { … }
LogicalResult StoreOp::verify() { … }
LogicalResult StoreOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) { … }
void SubViewOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides) { … }
Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) { … }
Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
MemRefType sourceRankedTensorType,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) { … }
Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
MemRefType sourceRankedTensorType,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) { … }
void SubViewOp::build(OpBuilder &b, OperationState &result,
MemRefType resultType, Value source,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides,
ArrayRef<NamedAttribute> attrs) { … }
void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides,
ArrayRef<NamedAttribute> attrs) { … }
void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides,
ArrayRef<NamedAttribute> attrs) { … }
void SubViewOp::build(OpBuilder &b, OperationState &result,
MemRefType resultType, Value source,
ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides,
ArrayRef<NamedAttribute> attrs) { … }
void SubViewOp::build(OpBuilder &b, OperationState &result,
MemRefType resultType, Value source, ValueRange offsets,
ValueRange sizes, ValueRange strides,
ArrayRef<NamedAttribute> attrs) { … }
void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
ValueRange offsets, ValueRange sizes, ValueRange strides,
ArrayRef<NamedAttribute> attrs) { … }
Value SubViewOp::getViewSource() { … }
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { … }
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
const llvm::SmallBitVector &droppedDims) { … }
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
Operation *op, Type expectedType) { … }
LogicalResult SubViewOp::verify() { … }
raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { … }
SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
OpBuilder &b, Location loc) { … }
static MemRefType getCanonicalSubViewResultType(
MemRefType currentResultType, MemRefType currentSourceType,
MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { … }
Value mlir::memref::createCanonicalRankReducingSubViewOp(
OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) { … }
FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
Value value,
ArrayRef<int64_t> desiredShape) { … }
static bool isTrivialSubViewOp(SubViewOp subViewOp) { … }
namespace {
class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { … };
class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> { … };
}
struct SubViewReturnTypeCanonicalizer { … };
struct SubViewCanonicalizer { … };
void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) { … }
void TransposeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
static MemRefType inferTransposeResultType(MemRefType memRefType,
AffineMap permutationMap) { … }
void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
AffineMapAttr permutation,
ArrayRef<NamedAttribute> attrs) { … }
void TransposeOp::print(OpAsmPrinter &p) { … }
ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { … }
LogicalResult TransposeOp::verify() { … }
OpFoldResult TransposeOp::fold(FoldAdaptor) { … }
void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult ViewOp::verify() { … }
Value ViewOp::getViewSource() { … }
namespace {
struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { … };
struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { … };
}
void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
LogicalResult AtomicRMWOp::verify() { … }
OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) { … }
#define GET_OP_CLASSES
#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"