//===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements functions concerned with optimizing transfer_read and // transfer_write ops. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Dominance.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE … #define DBGS() … usingnamespacemlir; /// Return the ancestor op in the region or nullptr if the region is not /// an ancestor of the op. static Operation *findAncestorOpInRegion(Region *region, Operation *op) { … } namespace { class TransferOptimization { … }; } // namespace /// Return true if there is a path from start operation to dest operation, /// otherwise return false. The operations have to be in the same region. bool TransferOptimization::isReachable(Operation *start, Operation *dest) { … } /// For transfer_write to overwrite fully another transfer_write must: /// 1. Access the same memref with the same indices and vector type. /// 2. Post-dominate the other transfer_write operation. /// If several candidates are available, one must be post-dominated by all the /// others since they are all post-dominating the same transfer_write. We only /// consider the transfer_write post-dominated by all the other candidates as /// this will be the first transfer_write executed after the potentially dead /// transfer_write. /// If we found such an overwriting transfer_write we know that the original /// transfer_write is dead if all reads that can be reached from the potentially /// dead transfer_write are dominated by the overwriting transfer_write. void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { … } /// A transfer_write candidate to storeToLoad forwarding must: /// 1. Access the same memref with the same indices and vector type as the /// transfer_read. /// 2. Dominate the transfer_read operation. /// If several candidates are available, one must be dominated by all the others /// since they are all dominating the same transfer_read. We only consider the /// transfer_write dominated by all the other candidates as this will be the /// last transfer_write executed before the transfer_read. /// If we found such a candidate we can do the forwarding if all the other /// potentially aliasing ops that may reach the transfer_read are post-dominated /// by the transfer_write. void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { … } /// Converts OpFoldResults to int64_t shape without unit dims. static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) { … } /// Drops unit dimensions from the input MemRefType. static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) { … } /// Creates a rank-reducing memref.subview op that drops unit dims from its /// input. Or just returns the input if it was already without unit dims. static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, Value input) { … } /// Returns the number of dims that aren't unit dims. static int getReducedRank(ArrayRef<int64_t> shape) { … } /// Trims non-scalable one dimensions from `oldType` and returns the result /// type. static VectorType trimNonScalableUnitDims(VectorType oldType) { … } // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions. static FailureOr<Value> createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, vector::CreateMaskOp op) { … } namespace { /// Rewrites `vector.transfer_read` ops where the source has unit dims, by /// inserting a memref.subview dropping those unit dims. The vector shapes are /// also reduced accordingly. class TransferReadDropUnitDimsPattern : public OpRewritePattern<vector::TransferReadOp> { … }; /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination) /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The /// vector shapes are also reduced accordingly. class TransferWriteDropUnitDimsPattern : public OpRewritePattern<vector::TransferWriteOp> { … }; } // namespace /// Creates a memref.collapse_shape collapsing all inner dimensions of the /// input starting at `firstDimToCollapse`. static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse) { … } /// Returns the new indices that collapses the inner dimensions starting from /// the `firstDimToCollapse` dimension. static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter, Location loc, ArrayRef<int64_t> shape, ValueRange indices, int64_t firstDimToCollapse) { … } namespace { /// Rewrites contiguous row-major vector.transfer_read ops by inserting /// memref.collapse_shape on the source so that the resulting /// vector.transfer_read has a 1D source. Requires the source shape to be /// already reduced i.e. without unit dims. /// /// If `targetVectorBitwidth` is provided, the flattening will only happen if /// the trailing dimension of the vector read is smaller than the provided /// bitwidth. class FlattenContiguousRowMajorTransferReadPattern : public OpRewritePattern<vector::TransferReadOp> { … }; /// Rewrites contiguous row-major vector.transfer_write ops by inserting /// memref.collapse_shape on the source so that the resulting /// vector.transfer_write has a 1D source. Requires the source shape to be /// already reduced i.e. without unit dims. /// /// If `targetVectorBitwidth` is provided, the flattening will only happen if /// the trailing dimension of the vector read is smaller than the provided /// bitwidth. class FlattenContiguousRowMajorTransferWritePattern : public OpRewritePattern<vector::TransferWriteOp> { … }; /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` /// to `memref.load` patterns. The `match` method is shared for both /// `vector.extract` and `vector.extract_element`. template <class VectorExtractOp> class RewriteScalarExtractOfTransferReadBase : public OpRewritePattern<VectorExtractOp> { … }; /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. /// /// All the users of the transfer op must be either `vector.extractelement` or /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite /// transfer ops with any number of users. Otherwise, rewrite only if the /// extract op is the single user of the transfer op. Rewriting a single /// vector load with multiple scalar loads may negatively affect performance. class RewriteScalarExtractElementOfTransferRead : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> { … }; /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`. /// /// All the users of the transfer op must be either `vector.extractelement` or /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite /// transfer ops with any number of users. Otherwise, rewrite only if the /// extract op is the single user of the transfer op. Rewriting a single /// vector load with multiple scalar loads may negatively affect performance. class RewriteScalarExtractOfTransferRead : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> { … }; /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) /// to memref.store. class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { … }; } // namespace void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp) { … } void mlir::vector::populateScalarVectorTransferLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses) { … } void mlir::vector::populateVectorTransferDropUnitDimsPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … } void mlir::vector::populateFlattenVectorTransferPatterns( RewritePatternSet &patterns, unsigned targetVectorBitwidth, PatternBenefit benefit) { … }