//===- ExtractAddressCmoputations.cpp - Extract address computations -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // /// This transformation pass rewrites loading/storing from/to a memref with /// offsets into loading/storing from/to a subview and without any offset on /// the instruction itself. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/PatternMatch.h" usingnamespacemlir; namespace { //===----------------------------------------------------------------------===// // Helper functions for the `load base[off0...]` // => `load (subview base[off0...])[0...]` pattern. //===----------------------------------------------------------------------===// // Matches getFailureOrSrcMemRef specs for LoadOp. // \see LoadStoreLikeOpRewriter. static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) { … } // Matches rebuildOpFromAddressAndIndices specs for LoadOp. // \see LoadStoreLikeOpRewriter. static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter, memref::LoadOp loadOp, Value srcMemRef, ArrayRef<Value> indices) { … } // Matches getViewSizeForEachDim specs for LoadOp. // \see LoadStoreLikeOpRewriter. static SmallVector<OpFoldResult> getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) { … } //===----------------------------------------------------------------------===// // Helper functions for the `store val, base[off0...]` // => `store val, (subview base[off0...])[0...]` pattern. //===----------------------------------------------------------------------===// // Matches getFailureOrSrcMemRef specs for StoreOp. // \see LoadStoreLikeOpRewriter. static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) { … } // Matches rebuildOpFromAddressAndIndices specs for StoreOp. // \see LoadStoreLikeOpRewriter. static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter, memref::StoreOp storeOp, Value srcMemRef, ArrayRef<Value> indices) { … } // Matches getViewSizeForEachDim specs for StoreOp. // \see LoadStoreLikeOpRewriter. static SmallVector<OpFoldResult> getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) { … } //===----------------------------------------------------------------------===// // Helper functions for the `ldmatrix base[off0...]` // => `ldmatrix (subview base[off0...])[0...]` pattern. //===----------------------------------------------------------------------===// // Matches getFailureOrSrcMemRef specs for LdMatrixOp. // \see LoadStoreLikeOpRewriter. static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) { … } // Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp. // \see LoadStoreLikeOpRewriter. static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter, nvgpu::LdMatrixOp ldMatrixOp, Value srcMemRef, ArrayRef<Value> indices) { … } //===----------------------------------------------------------------------===// // Helper functions for the `transfer_read base[off0...]` // => `transfer_read (subview base[off0...])[0...]` pattern. //===----------------------------------------------------------------------===// // Matches getFailureOrSrcMemRef specs for TransferReadOp. // \see LoadStoreLikeOpRewriter. template <typename TransferLikeOp> static FailureOr<Value> getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) { … } // Matches rebuildOpFromAddressAndIndices specs for TransferReadOp. // \see LoadStoreLikeOpRewriter. static vector::TransferReadOp rebuildTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp transferReadOp, Value srcMemRef, ArrayRef<Value> indices) { … } //===----------------------------------------------------------------------===// // Helper functions for the `transfer_write base[off0...]` // => `transfer_write (subview base[off0...])[0...]` pattern. //===----------------------------------------------------------------------===// // Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp. // \see LoadStoreLikeOpRewriter. static vector::TransferWriteOp rebuildTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp transferWriteOp, Value srcMemRef, ArrayRef<Value> indices) { … } //===----------------------------------------------------------------------===// // Generic helper functions used as default implementation in // LoadStoreLikeOpRewriter. //===----------------------------------------------------------------------===// /// Helper function to get the src memref. /// It uses the already defined getFailureOrSrcMemRef but asserts /// that the source is a memref. template <typename LoadStoreLikeOp, FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)> static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) { … } /// Helper function to get the sizes of the resulting view. /// This function gets the sizes of the source memref then substracts the /// offsets used within \p loadStoreLikeOp. This gives the maximal (for /// inbound) sizes for the view. /// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp. template <typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)> static SmallVector<OpFoldResult> getGenericOpViewSizeForEachDim(RewriterBase &rewriter, LoadStoreLikeOp loadStoreLikeOp) { … } /// Rewrite a store/load-like op so that all its indices are zeros. /// E.g., %ld = memref.load %base[%off0]...[%offN] /// => /// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1] /// %ld = memref.load %new_base[0,..,0] : /// memref<1x..x1xTy, strided<[1,..,1], offset: ?>> /// /// `getSrcMemRef` returns the source memref for the given load-like operation. /// /// `getViewSizeForEachDim` returns the sizes of view that is going to feed /// new operation. This must return one size per dimension of the view. /// The sizes of the view needs to be at least as big as what is actually /// going to be accessed. Use the provided `loadStoreOp` to get the right /// sizes. /// /// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new /// LoadStoreLikeOp that reads from srcMemRef[indices]. /// The returned operation will be used to replace loadStoreOp. template <typename LoadStoreLikeOp, FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp), LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)( RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/, Value /*srcMemRef*/, ArrayRef<Value> /*indices*/), SmallVector<OpFoldResult> (*getViewSizeForEachDim)( RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) = getGenericOpViewSizeForEachDim< LoadStoreLikeOp, getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>> struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> { using OpRewritePattern<LoadStoreLikeOp>::OpRewritePattern; LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp, PatternRewriter &rewriter) const override { … } }; } // namespace void memref::populateExtractAddressComputationsPatterns( RewritePatternSet &patterns) { … }