//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" usingnamespacemlir; usingnamespacemlir::memref; /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *MemRefDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { … } //===----------------------------------------------------------------------===// // Common canonicalization pattern support logic //===----------------------------------------------------------------------===// /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref.cast /// into the root operation directly. LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { … } /// Return an unranked/ranked tensor type for the given unranked/ranked memref /// type. Type mlir::memref::getTensorTypeFromMemRefType(Type type) { … } OpFoldResult memref::getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim) { … } SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder, Location loc, Value value) { … } //===----------------------------------------------------------------------===// // Utility functions for propagating static information //===----------------------------------------------------------------------===// /// Helper function that infers the constant values from a list of \p values, /// a \p memRefTy, and another helper function \p getAttributes. /// The inferred constant values replace the related `OpFoldResult` in /// \p values. /// /// \note This function shouldn't be used directly, instead, use the /// `getConstifiedMixedXXX` methods from the related operations. /// /// \p getAttributes retuns a list of potentially constant values, as determined /// by \p isDynamic, from the given \p memRefTy. The returned list must have as /// many elements as \p values or be empty. /// /// E.g., consider the following example: /// ``` /// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] : /// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>> /// ``` /// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`. /// Now using this helper function with: /// - `values == [2, %dyn_stride]`, /// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>` /// - `getAttributes == getConstantStrides` (i.e., a wrapper around /// `getStridesAndOffset`), and /// - `isDynamic == ShapedType::isDynamic` /// Will yield: `values == [2, 1]` static void constifyIndexValues( SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy, MLIRContext *ctxt, llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes, llvm::function_ref<bool(int64_t)> isDynamic) { … } /// Wrapper around `getShape` that conforms to the function signature /// expected for `getAttributes` in `constifyIndexValues`. static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) { … } /// Wrapper around `getStridesAndOffset` that returns only the offset and /// conforms to the function signature expected for `getAttributes` in /// `constifyIndexValues`. static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) { … } /// Wrapper around `getStridesAndOffset` that returns only the strides and /// conforms to the function signature expected for `getAttributes` in /// `constifyIndexValues`. static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) { … } //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// void AllocOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { … } void AllocaOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { … } template <typename AllocLikeOp> static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { … } LogicalResult AllocOp::verify() { … } LogicalResult AllocaOp::verify() { … } namespace { /// Fold constant dimensions into an alloc like operation. template <typename AllocLikeOp> struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { … }; /// Fold alloc operations with no users or only store and dealloc uses. template <typename T> struct SimplifyDeadAlloc : public OpRewritePattern<T> { … }; } // namespace void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // ReallocOp //===----------------------------------------------------------------------===// LogicalResult ReallocOp::verify() { … } void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // AllocaScopeOp //===----------------------------------------------------------------------===// void AllocaScopeOp::print(OpAsmPrinter &p) { … } ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { … } void AllocaScopeOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { … } /// Given an operation, return whether this op is guaranteed to /// allocate an AutomaticAllocationScopeResource static bool isGuaranteedAutomaticAllocation(Operation *op) { … } /// Given an operation, return whether this op itself could /// allocate an AutomaticAllocationScopeResource. Note that /// this will not check whether an operation contained within /// the op can allocate. static bool isOpItselfPotentialAutomaticAllocation(Operation *op) { … } /// Return whether this op is the last non terminating op /// in a region. That is to say, it is in a one-block region /// and is only followed by a terminator. This prevents /// extending the lifetime of allocations. static bool lastNonTerminatorInRegion(Operation *op) { … } /// Inline an AllocaScopeOp if either the direct parent is an allocation scope /// or it contains no allocation. struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> { … }; /// Move allocations into an allocation scope, if it is legal to /// move them (e.g. their operands are available at the location /// the op would be moved to). struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> { … }; void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // AssumeAlignmentOp //===----------------------------------------------------------------------===// LogicalResult AssumeAlignmentOp::verify() { … } //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { … } /// Determines whether MemRef_CastOp casts to a more dynamic version of the /// source memref. This is useful to fold a memref.cast into a consuming op /// and implement canonicalization patterns for ops in different dialects that /// may consume the results of memref.cast operations. Such foldable memref.cast /// operations are typically inserted as `view` and `subview` ops are /// canonicalized, to preserve the type compatibility of their uses. /// /// Returns true when all conditions are met: /// 1. source and result are ranked memrefs with strided semantics and same /// element type and rank. /// 2. each of the source's size, offset or stride has more static information /// than the corresponding result's size, offset or stride. /// /// Example 1: /// ```mlir /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> /// %2 = consumer %1 ... : memref<?x?xf32> ... /// ``` /// /// may fold into: /// /// ```mlir /// %2 = consumer %0 ... : memref<8x16xf32> ... /// ``` /// /// Example 2: /// ``` /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> /// to memref<?x?xf32> /// consumer %1 : memref<?x?xf32> ... /// ``` /// /// may fold into: /// /// ``` /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> /// ``` bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { … } bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … } OpFoldResult CastOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// namespace { /// If the source/target of a CopyOp is a CastOp that does not modify the shape /// and element type, the cast can be skipped. Such CastOps only cast the layout /// of the type. struct FoldCopyOfCast : public OpRewritePattern<CopyOp> { … }; /// Fold memref.copy(%x, %x). struct FoldSelfCopy : public OpRewritePattern<CopyOp> { … }; struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> { … }; } // namespace void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } LogicalResult CopyOp::fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { … } //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// LogicalResult DeallocOp::fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { … } //===----------------------------------------------------------------------===// // DimOp //===----------------------------------------------------------------------===// void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { … } void DimOp::build(OpBuilder &builder, OperationState &result, Value source, int64_t index) { … } std::optional<int64_t> DimOp::getConstantIndex() { … } Speculation::Speculatability DimOp::getSpeculatability() { … } /// Return a map with key being elements in `vals` and data being number of /// occurences of it. Use std::map, since the `vals` here are strides and the /// dynamic stride value is the same as the tombstone value for /// `DenseMap<int64_t>`. static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { … } /// Given the `originalType` and a `candidateReducedType` whose shape is assumed /// to be a subset of `originalType` with some `1` entries erased, return the /// set of indices that specifies which of the entries of `originalShape` are /// dropped to obtain `reducedShape`. /// This accounts for cases where there are multiple unit-dims, but only a /// subset of those are dropped. For MemRefTypes these can be disambiguated /// using the strides. If a dimension is dropped the stride must be dropped too. static FailureOr<llvm::SmallBitVector> computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef<OpFoldResult> sizes) { … } llvm::SmallBitVector SubViewOp::getDroppedDims() { … } OpFoldResult DimOp::fold(FoldAdaptor adaptor) { … } namespace { /// Fold dim of a memref reshape operation to a load into the reshape's shape /// operand. struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { … }; } // namespace void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } // --------------------------------------------------------------------------- // DmaStartOp // --------------------------------------------------------------------------- void DmaStartOp::build(OpBuilder &builder, OperationState &result, Value srcMemRef, ValueRange srcIndices, Value destMemRef, ValueRange destIndices, Value numElements, Value tagMemRef, ValueRange tagIndices, Value stride, Value elementsPerStride) { … } void DmaStartOp::print(OpAsmPrinter &p) { … } // Parse DmaStartOp. // Ex: // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, // %tag[%index], %stride, %num_elt_per_stride : // : memref<3076 x f32, 0>, // memref<1024 x f32, 2>, // memref<1 x i32> // ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { … } LogicalResult DmaStartOp::verify() { … } LogicalResult DmaStartOp::fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { … } // --------------------------------------------------------------------------- // DmaWaitOp // --------------------------------------------------------------------------- LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { … } LogicalResult DmaWaitOp::verify() { … } //===----------------------------------------------------------------------===// // ExtractAlignedPointerAsIndexOp //===----------------------------------------------------------------------===// void ExtractAlignedPointerAsIndexOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { … } //===----------------------------------------------------------------------===// // ExtractStridedMetadataOp //===----------------------------------------------------------------------===// /// The number and type of the results are inferred from the /// shape of the source. LogicalResult ExtractStridedMetadataOp::inferReturnTypes( MLIRContext *context, std::optional<Location> location, ExtractStridedMetadataOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { … } void ExtractStridedMetadataOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { … } /// Helper function to perform the replacement of all constant uses of `values` /// by a materialized constant extracted from `maybeConstants`. /// `values` and `maybeConstants` are expected to have the same size. template <typename Container> static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef<OpFoldResult> maybeConstants) { … } LogicalResult ExtractStridedMetadataOp::fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { … } SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() { … } SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedStrides() { … } OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() { … } //===----------------------------------------------------------------------===// // GenericAtomicRMWOp //===----------------------------------------------------------------------===// void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, Value memref, ValueRange ivs) { … } LogicalResult GenericAtomicRMWOp::verify() { … } ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { … } void GenericAtomicRMWOp::print(OpAsmPrinter &p) { … } //===----------------------------------------------------------------------===// // AtomicYieldOp //===----------------------------------------------------------------------===// LogicalResult AtomicYieldOp::verify() { … } //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue) { … } static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue) { … } LogicalResult GlobalOp::verify() { … } ElementsAttr GlobalOp::getConstantInitValue() { … } //===----------------------------------------------------------------------===// // GetGlobalOp //===----------------------------------------------------------------------===// LogicalResult GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … } //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// LogicalResult LoadOp::verify() { … } OpFoldResult LoadOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // MemorySpaceCastOp //===----------------------------------------------------------------------===// void MemorySpaceCastOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { … } bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … } OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // PrefetchOp //===----------------------------------------------------------------------===// void PrefetchOp::print(OpAsmPrinter &p) { … } ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { … } LogicalResult PrefetchOp::verify() { … } LogicalResult PrefetchOp::fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { … } //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// OpFoldResult RankOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // ReinterpretCastOp //===----------------------------------------------------------------------===// void ReinterpretCastOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { … } /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, /// `staticSizes` and `staticStrides` are automatically filled with /// source-memref-rank sentinel values that encode dynamic entries. void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, OpFoldResult offset, ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides, ArrayRef<NamedAttribute> attrs) { … } void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, int64_t offset, ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<NamedAttribute> attrs) { … } void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, Value offset, ValueRange sizes, ValueRange strides, ArrayRef<NamedAttribute> attrs) { … } // TODO: ponder whether we want to allow missing trailing sizes/strides that are // completed automatically, like we have for subview and extract_slice. LogicalResult ReinterpretCastOp::verify() { … } OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) { … } SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() { … } SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() { … } OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() { … } namespace { /// Replace the sequence: /// ``` /// base, offset, sizes, strides = extract_strided_metadata src /// dst = reinterpret_cast base to offset, sizes, strides /// ``` /// With /// /// ``` /// dst = memref.cast src /// ``` /// /// Note: The cast operation is only inserted when the type of dst and src /// are not the same. E.g., when going from <4xf32> to <?xf32>. /// /// This pattern also matches when the offset, sizes, and strides don't come /// directly from the `extract_strided_metadata`'s results but it can be /// statically proven that they would hold the same values. /// /// For instance, the following sequence would be replaced: /// ``` /// base, offset, sizes, strides = /// extract_strided_metadata memref : memref<3x4xty> /// dst = reinterpret_cast base to 0, [3, 4], strides /// ``` /// Because we know (thanks to the type of the input memref) that variable /// `offset` and `sizes` will respectively hold 0 and [3, 4]. /// /// Similarly, the following sequence would be replaced: /// ``` /// c0 = arith.constant 0 /// c4 = arith.constant 4 /// base, offset, sizes, strides = /// extract_strided_metadata memref : memref<3x4xty> /// dst = reinterpret_cast base to c0, [3, c4], strides /// ``` /// Because we know that `offset`and `c0` will hold 0 /// and `c4` will hold 4. struct ReinterpretCastOpExtractStridedMetadataFolder : public OpRewritePattern<ReinterpretCastOp> { … }; } // namespace void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // Reassociative reshape ops //===----------------------------------------------------------------------===// void CollapseShapeOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { … } void ExpandShapeOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { … } LogicalResult ExpandShapeOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) { … } /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp /// result and operand. Layout maps are verified separately. /// /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are /// allowed in a reassocation group. static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape, ArrayRef<ReassociationIndices> reassociation, bool allowMultipleDynamicDimsPerGroup) { … } SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { … } SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { … } SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { … } SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { … } /// Compute the layout map after expanding a given source MemRef type with the /// specified reassociation indices. static FailureOr<StridedLayoutAttr> computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape, ArrayRef<ReassociationIndices> reassociation) { … } FailureOr<MemRefType> ExpandShapeOp::computeExpandedType( MemRefType srcType, ArrayRef<int64_t> resultShape, ArrayRef<ReassociationIndices> reassociation) { … } FailureOr<SmallVector<OpFoldResult>> ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc, MemRefType expandedType, ArrayRef<ReassociationIndices> reassociation, ArrayRef<OpFoldResult> inputShape) { … } void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, Type resultType, Value src, ArrayRef<ReassociationIndices> reassociation, ArrayRef<OpFoldResult> outputShape) { … } void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, Type resultType, Value src, ArrayRef<ReassociationIndices> reassociation) { … } void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, ArrayRef<int64_t> resultShape, Value src, ArrayRef<ReassociationIndices> reassociation) { … } void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, ArrayRef<int64_t> resultShape, Value src, ArrayRef<ReassociationIndices> reassociation, ArrayRef<OpFoldResult> outputShape) { … } LogicalResult ExpandShapeOp::verify() { … } void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } /// Compute the layout map after collapsing a given source MemRef type with the /// specified reassociation indices. /// /// Note: All collapsed dims in a reassociation group must be contiguous. It is /// not possible to check this by inspecting a MemRefType in the general case. /// If non-contiguity cannot be checked statically, the collapse is assumed to /// be valid (and thus accepted by this function) unless `strict = true`. static FailureOr<StridedLayoutAttr> computeCollapsedLayoutMap(MemRefType srcType, ArrayRef<ReassociationIndices> reassociation, bool strict = false) { … } bool CollapseShapeOp::isGuaranteedCollapsible( MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) { … } MemRefType CollapseShapeOp::computeCollapsedType( MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) { … } void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, ArrayRef<ReassociationIndices> reassociation, ArrayRef<NamedAttribute> attrs) { … } LogicalResult CollapseShapeOp::verify() { … } struct CollapseShapeOpMemRefCastFolder : public OpRewritePattern<CollapseShapeOp> { … }; void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) { … } OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// void ReshapeOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { … } LogicalResult ReshapeOp::verify() { … } //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// LogicalResult StoreOp::verify() { … } LogicalResult StoreOp::fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { … } //===----------------------------------------------------------------------===// // SubViewOp //===----------------------------------------------------------------------===// void SubViewOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { … } /// A subview result type can be fully inferred from the source type and the /// static representation of offsets, sizes and strides. Special sentinels /// encode the dynamic case. Type SubViewOp::inferResultType(MemRefType sourceMemRefType, ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) { … } Type SubViewOp::inferResultType(MemRefType sourceMemRefType, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) { … } Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType, ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) { … } Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) { … } // Build a SubViewOp with mixed static and dynamic entries and custom result // type. If the type passed is nullptr, it is inferred. void SubViewOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides, ArrayRef<NamedAttribute> attrs) { … } // Build a SubViewOp with mixed static and dynamic entries and inferred result // type. void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides, ArrayRef<NamedAttribute> attrs) { … } // Build a SubViewOp with static entries and inferred result type. void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<NamedAttribute> attrs) { … } // Build a SubViewOp with dynamic entries and custom result type. If the // type passed is nullptr, it is inferred. void SubViewOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<NamedAttribute> attrs) { … } // Build a SubViewOp with dynamic entries and custom result type. If the type // passed is nullptr, it is inferred. void SubViewOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, ArrayRef<NamedAttribute> attrs) { … } // Build a SubViewOp with dynamic entries and inferred result type. void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, ArrayRef<NamedAttribute> attrs) { … } /// For ViewLikeOpInterface. Value SubViewOp::getViewSource() { … } /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same /// static value). static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { … } /// Return true if `t1` and `t2` have equal strides (both dynamic or of same /// static value). Dimensions of `t1` may be dropped in `t2`; these must be /// marked as dropped in `droppedDims`. static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims) { … } static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType) { … } /// Verifier for SubViewOp. LogicalResult SubViewOp::verify() { … } raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { … } /// Return the list of Range (i.e. offset, size, stride). Each Range /// entry contains either the dynamic value or a ConstantIndexOp constructed /// with `b` at location `loc`. SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc) { … } /// Compute the canonical result type of a SubViewOp. Call `inferResultType` /// to deduce the result type for the given `sourceType`. Additionally, reduce /// the rank of the inferred result type if `currentResultType` is lower rank /// than `currentSourceType`. Use this signature if `sourceType` is updated /// together with the result type. In this case, it is important to compute /// the dropped dimensions using `currentSourceType` whose strides align with /// `currentResultType`. static MemRefType getCanonicalSubViewResultType( MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { … } Value mlir::memref::createCanonicalRankReducingSubViewOp( OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) { … } FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value, ArrayRef<int64_t> desiredShape) { … } /// Helper method to check if a `subview` operation is trivially a no-op. This /// is the case if the all offsets are zero, all strides are 1, and the source /// shape is same as the size of the subview. In such cases, the subview can /// be folded into its source. static bool isTrivialSubViewOp(SubViewOp subViewOp) { … } namespace { /// Pattern to rewrite a subview op with MemRefCast arguments. /// This essentially pushes memref.cast past its consuming subview when /// `canFoldIntoConsumerOp` is true. /// /// Example: /// ``` /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : /// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>> /// ``` /// is rewritten into: /// ``` /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> /// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to /// memref<3x4xf32, strided<[?, 1], offset: ?>> /// ``` class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { … }; /// Canonicalize subview ops that are no-ops. When the source shape is not /// same as a result shape due to use of `affine_map`. class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> { … }; } // namespace /// Return the canonical type of the result of a subview. struct SubViewReturnTypeCanonicalizer { … }; /// A canonicalizer wrapper to replace SubViewOps. struct SubViewCanonicalizer { … }; void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// void TransposeOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { … } /// Build a strided memref type by applying `permutationMap` to `memRefType`. static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap) { … } void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, AffineMapAttr permutation, ArrayRef<NamedAttribute> attrs) { … } // transpose $in $permutation attr-dict : type($in) `to` type(results) void TransposeOp::print(OpAsmPrinter &p) { … } ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { … } LogicalResult TransposeOp::verify() { … } OpFoldResult TransposeOp::fold(FoldAdaptor) { … } //===----------------------------------------------------------------------===// // ViewOp //===----------------------------------------------------------------------===// void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { … } LogicalResult ViewOp::verify() { … } Value ViewOp::getViewSource() { … } namespace { struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { … }; struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { … }; } // namespace void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // AtomicRMWOp //===----------------------------------------------------------------------===// LogicalResult AtomicRMWOp::verify() { … } OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"