//===- ExpandStridedMetadata.cpp - Simplify this operation -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // /// The pass expands memref operations that modify the metadata of a memref /// (sizes, offset, strides) into a sequence of easier to analyze constructs. /// In particular, this pass transforms operations into explicit sequence of /// operations that model the effect of this operation on the different /// metadata. This pass uses affine constructs to materialize these effects. //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include <optional> namespace mlir { namespace memref { #define GEN_PASS_DEF_EXPANDSTRIDEDMETADATA #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" } // namespace memref } // namespace mlir usingnamespacemlir; usingnamespacemlir::affine; namespace { struct StridedMetadata { … }; /// From `subview(memref, subOffset, subSizes, subStrides))` compute /// /// \verbatim /// baseBuffer, baseOffset, baseSizes, baseStrides = /// extract_strided_metadata(memref) /// strides#i = baseStrides#i * subStrides#i /// offset = baseOffset + sum(subOffset#i * baseStrides#i) /// sizes = subSizes /// \endverbatim /// /// and return {baseBuffer, offset, sizes, strides} static FailureOr<StridedMetadata> resolveSubviewStridedMetadata(RewriterBase &rewriter, memref::SubViewOp subview) { … } /// Replace `dst = subview(memref, subOffset, subSizes, subStrides))` /// With /// /// \verbatim /// baseBuffer, baseOffset, baseSizes, baseStrides = /// extract_strided_metadata(memref) /// strides#i = baseStrides#i * subSizes#i /// offset = baseOffset + sum(subOffset#i * baseStrides#i) /// sizes = subSizes /// dst = reinterpret_cast baseBuffer, offset, sizes, strides /// \endverbatim /// /// In other words, get rid of the subview in that expression and canonicalize /// on its effects on the offset, the sizes, and the strides using affine.apply. struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> { … }; /// Pattern to replace `extract_strided_metadata(subview)` /// With /// /// \verbatim /// baseBuffer, baseOffset, baseSizes, baseStrides = /// extract_strided_metadata(memref) /// strides#i = baseStrides#i * subSizes#i /// offset = baseOffset + sum(subOffset#i * baseStrides#i) /// sizes = subSizes /// \verbatim /// /// with `baseBuffer`, `offset`, `sizes` and `strides` being /// the replacements for the original `extract_strided_metadata`. struct ExtractStridedMetadataOpSubviewFolder : OpRewritePattern<memref::ExtractStridedMetadataOp> { … }; /// Compute the expanded sizes of the given \p expandShape for the /// \p groupId-th reassociation group. /// \p origSizes hold the sizes of the source shape as values. /// This is used to compute the new sizes in cases of dynamic shapes. /// /// sizes#i = /// baseSizes#groupId / product(expandShapeSizes#j, /// for j in group excluding reassIdx#i) /// Where reassIdx#i is the reassociation index at index i in \p groupId. /// /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() /// /// TODO: Move this utility function directly within ExpandShapeOp. For now, /// this is not possible because this function uses the Affine dialect and the /// MemRef dialect cannot depend on the Affine dialect. static SmallVector<OpFoldResult> getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder, ArrayRef<OpFoldResult> origSizes, unsigned groupId) { … } /// Compute the expanded strides of the given \p expandShape for the /// \p groupId-th reassociation group. /// \p origStrides and \p origSizes hold respectively the strides and sizes /// of the source shape as values. /// This is used to compute the strides in cases of dynamic shapes and/or /// dynamic stride for this reassociation group. /// /// strides#i = /// origStrides#reassDim * product(expandShapeSizes#j, for j in /// reassIdx#i+1..reassIdx#i+group.size-1) /// /// Where reassIdx#i is the reassociation index for at index i in \p groupId /// and expandShapeSizes#j is either: /// - The constant size at dimension j, derived directly from the result type of /// the expand_shape op, or /// - An affine expression: baseSizes#reassDim / product of all constant sizes /// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic /// element.) /// /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() /// /// TODO: Move this utility function directly within ExpandShapeOp. For now, /// this is not possible because this function uses the Affine dialect and the /// MemRef dialect cannot depend on the Affine dialect. SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape, OpBuilder &builder, ArrayRef<OpFoldResult> origSizes, ArrayRef<OpFoldResult> origStrides, unsigned groupId) { … } /// Produce an OpFoldResult object with \p builder at \p loc representing /// `prod(valueOrConstant#i, for i in {indices})`, /// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false, /// values[i] otherwise. /// /// \pre for all index in indices: index < values.size() /// \pre for all index in indices: index < maybeConstants.size() static OpFoldResult getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc, ArrayRef<int64_t> maybeConstants, ArrayRef<OpFoldResult> values, llvm::function_ref<bool(int64_t)> isDynamic) { … } /// Compute the collapsed size of the given \p collpaseShape for the /// \p groupId-th reassociation group. /// \p origSizes hold the sizes of the source shape as values. /// This is used to compute the new sizes in cases of dynamic shapes. /// /// Conceptually this helper function computes: /// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`. /// /// \post result.size() == 1, in other words, each group collapse to one /// dimension. /// /// TODO: Move this utility function directly within CollapseShapeOp. For now, /// this is not possible because this function uses the Affine dialect and the /// MemRef dialect cannot depend on the Affine dialect. static SmallVector<OpFoldResult> getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder, ArrayRef<OpFoldResult> origSizes, unsigned groupId) { … } /// Compute the collapsed stride of the given \p collpaseShape for the /// \p groupId-th reassociation group. /// \p origStrides and \p origSizes hold respectively the strides and sizes /// of the source shape as values. /// This is used to compute the strides in cases of dynamic shapes and/or /// dynamic stride for this reassociation group. /// /// Conceptually this helper function returns the stride of the inner most /// dimension of that group in the original shape. /// /// \post result.size() == 1, in other words, each group collapse to one /// dimension. static SmallVector<OpFoldResult> getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, ArrayRef<OpFoldResult> origSizes, ArrayRef<OpFoldResult> origStrides, unsigned groupId) { … } /// From `reshape_like(memref, subSizes, subStrides))` compute /// /// \verbatim /// baseBuffer, baseOffset, baseSizes, baseStrides = /// extract_strided_metadata(memref) /// strides#i = baseStrides#i * subStrides#i /// sizes = subSizes /// \endverbatim /// /// and return {baseBuffer, baseOffset, sizes, strides} template <typename ReassociativeReshapeLikeOp> static FailureOr<StridedMetadata> resolveReshapeStridedMetadata( RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape, function_ref<SmallVector<OpFoldResult>( ReassociativeReshapeLikeOp, OpBuilder &, ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/)> getReshapedSizes, function_ref<SmallVector<OpFoldResult>( ReassociativeReshapeLikeOp, OpBuilder &, ArrayRef<OpFoldResult> /*origSizes*/, ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)> getReshapedStrides) { … } /// Replace `baseBuffer, offset, sizes, strides = /// extract_strided_metadata(reshapeLike(memref))` /// With /// /// \verbatim /// baseBuffer, offset, baseSizes, baseStrides = /// extract_strided_metadata(memref) /// sizes = getReshapedSizes(reshapeLike) /// strides = getReshapedStrides(reshapeLike) /// \endverbatim /// /// /// Notice that `baseBuffer` and `offset` are unchanged. /// /// In other words, get rid of the expand_shape in that expression and /// materialize its effects on the sizes and the strides using affine apply. template <typename ReassociativeReshapeLikeOp, SmallVector<OpFoldResult> (*getReshapedSizes)( ReassociativeReshapeLikeOp, OpBuilder &, ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/), SmallVector<OpFoldResult> (*getReshapedStrides)( ReassociativeReshapeLikeOp, OpBuilder &, ArrayRef<OpFoldResult> /*origSizes*/, ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)> struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> { … }; /// Pattern to replace `extract_strided_metadata(collapse_shape)` /// With /// /// \verbatim /// baseBuffer, baseOffset, baseSizes, baseStrides = /// extract_strided_metadata(memref) /// strides#i = baseStrides#i * subSizes#i /// offset = baseOffset + sum(subOffset#i * baseStrides#i) /// sizes = subSizes /// \verbatim /// /// with `baseBuffer`, `offset`, `sizes` and `strides` being /// the replacements for the original `extract_strided_metadata`. struct ExtractStridedMetadataOpCollapseShapeFolder : OpRewritePattern<memref::ExtractStridedMetadataOp> { … }; /// Pattern to replace `extract_strided_metadata(expand_shape)` /// with the results of computing the sizes and strides on the expanded shape /// and dividing up dimensions into static and dynamic parts as needed. struct ExtractStridedMetadataOpExpandShapeFolder : OpRewritePattern<memref::ExtractStridedMetadataOp> { … }; /// Replace `base, offset, sizes, strides = /// extract_strided_metadata(allocLikeOp)` /// /// With /// /// ``` /// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref<eltTy> /// offset = 0 /// sizes = allocSizes /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1}) /// ``` /// /// The transformation only applies if the allocLikeOp has been normalized. /// In other words, the affine_map must be an identity. template <typename AllocLikeOp> struct ExtractStridedMetadataOpAllocFolder : public OpRewritePattern<memref::ExtractStridedMetadataOp> { … }; /// Replace `base, offset, sizes, strides = /// extract_strided_metadata(get_global)` /// /// With /// /// ``` /// base = reinterpret_cast get_global to a flat memref<eltTy> /// offset = 0 /// sizes = allocSizes /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1}) /// ``` /// /// It is expected that the memref.get_global op has static shapes /// and identity affine_map for the layout. struct ExtractStridedMetadataOpGetGlobalFolder : public OpRewritePattern<memref::ExtractStridedMetadataOp> { … }; /// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the /// source of the ViewLikeOp. class RewriteExtractAlignedPointerAsIndexOfViewLikeOp : public OpRewritePattern<memref::ExtractAlignedPointerAsIndexOp> { … }; /// Replace `base, offset, sizes, strides = /// extract_strided_metadata( /// reinterpret_cast(src, srcOffset, srcSizes, srcStrides))` /// With /// ``` /// base, ... = extract_strided_metadata(src) /// offset = srcOffset /// sizes = srcSizes /// strides = srcStrides /// ``` /// /// In other words, consume the `reinterpret_cast` and apply its effects /// on the offset, sizes, and strides. class ExtractStridedMetadataOpReinterpretCastFolder : public OpRewritePattern<memref::ExtractStridedMetadataOp> { … }; /// Replace `base, offset, sizes, strides = /// extract_strided_metadata( /// cast(src) to dstTy)` /// With /// ``` /// base, ... = extract_strided_metadata(src) /// offset = !dstTy.srcOffset.isDynamic() /// ? dstTy.srcOffset /// : extract_strided_metadata(src).offset /// sizes = for each srcSize in dstTy.srcSizes: /// !srcSize.isDynamic() /// ? srcSize // : extract_strided_metadata(src).sizes[i] /// strides = for each srcStride in dstTy.srcStrides: /// !srcStrides.isDynamic() /// ? srcStrides /// : extract_strided_metadata(src).strides[i] /// ``` /// /// In other words, consume the `cast` and apply its effects /// on the offset, sizes, and strides or compute them directly from `src`. class ExtractStridedMetadataOpCastFolder : public OpRewritePattern<memref::ExtractStridedMetadataOp> { … }; /// Replace `base, offset, sizes, strides = extract_strided_metadata( /// memory_space_cast(src) to dstTy)` /// with /// ``` /// oldBase, offset, sizes, strides = extract_strided_metadata(src) /// destBaseTy = type(oldBase) with memory space from destTy /// base = memory_space_cast(oldBase) to destBaseTy /// ``` /// /// In other words, propagate metadata extraction accross memory space casts. class ExtractStridedMetadataOpMemorySpaceCastFolder : public OpRewritePattern<memref::ExtractStridedMetadataOp> { … }; /// Replace `base, offset = /// extract_strided_metadata(extract_strided_metadata(src)#0)` /// With /// ``` /// base, ... = extract_strided_metadata(src) /// offset = 0 /// ``` class ExtractStridedMetadataOpExtractStridedMetadataFolder : public OpRewritePattern<memref::ExtractStridedMetadataOp> { … }; } // namespace void memref::populateExpandStridedMetadataPatterns( RewritePatternSet &patterns) { … } void memref::populateResolveExtractStridedMetadataPatterns( RewritePatternSet &patterns) { … } //===----------------------------------------------------------------------===// // Pass registration //===----------------------------------------------------------------------===// namespace { struct ExpandStridedMetadataPass final : public memref::impl::ExpandStridedMetadataBase< ExpandStridedMetadataPass> { … }; } // namespace void ExpandStridedMetadataPass::runOnOperation() { … } std::unique_ptr<Pass> memref::createExpandStridedMetadataPass() { … }