//===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H #include <optional> #include <utility> #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc" namespace mlir { class ConversionTarget; class RewritePatternSet; class TypeConverter; namespace arith { class AndIOp; class NarrowTypeEmulationConverter; class TruncIOp; } // namespace arith namespace vector { struct VectorTransformsOptions; /// Options that control the vector unrolling. struct UnrollVectorOptions { … }; /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul /// semantics to a contraction with MMT semantics (matrix matrix multiplication /// with the RHS transposed). This specific form is meant to have the vector /// operands are organized such that the reduction dimension is contiguous. /// Example: /// ``` /// vector.contract {indexing_maps = [affine_map<(m, n, k) -> (m, k)>, /// affine_map<(m, n, k) -> (n, k)>, /// affine_map<(m, n, k) -> (m, n)>], /// iterator_types = ["parallel", "parallel", "reduction"], /// kind = #vector.kind<add>} %a, %b, %c : ... /// ``` /// /// The `constraint` predicate is used to decide which `vector.contraction` ops /// to filter out. void populateVectorContractCanonicalizeMatmulToMMT( RewritePatternSet &patterns, std::function<LogicalResult(vector::ContractionOp)> constraint = [](vector::ContractionOp) { … }; /// Collect patterns to convert reduction op to vector.contract and fold /// transpose/broadcast ops into the contract. void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Populate `patterns` with the following patterns. /// /// - VectorTransferFullPartialRewriter /// /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds /// masking) fast path and a slow path. /// /// Example (a 2-D vector.transfer_read): /// ``` /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> /// ``` /// is transformed into: /// ``` /// %1:3 = scf.if (%inBounds) { /// // fast path, direct cast /// memref.cast %A: memref<A...> to compatibleMemRefType /// scf.yield %view : compatibleMemRefType, index, index /// } else { /// // slow path, not in-bounds vector.transfer or linalg.copy. /// memref.cast %alloc: memref<B...> to compatibleMemRefType /// scf.yield %4 : compatibleMemRefType, index, index // } /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} /// ``` /// where `alloc` is a top of the function alloca'ed buffer of one vector. /// /// Preconditions: /// 1. `xferOp.permutation_map()` must be a minor identity map /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` /// must be equal. This will be relaxed in the future but requires /// rank-reducing subviews. void populateVectorTransferFullPartialPatterns( RewritePatternSet &patterns, const VectorTransformsOptions &options); /// Collect a set of patterns to reduce the rank of the operands of vector /// transfer ops to operate on the largest contigious vector. /// These patterns are useful when lowering to dialects with 1d vector type /// such as llvm and it will result fewer memory reads. void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Patterns that remove redundant Vector Ops by re-ordering them with /// e.g. elementwise Ops: /// ``` /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> /// %r = arith.addf %at, %bt : vector<2x4xf32> /// ``` /// gets converted to: /// ``` /// %0 = arith.addf %a, %b : vector<4x2xf32> /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32> /// ``` /// At the moment, these patterns are limited to vector.broadcast and /// vector.transpose. void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Patterns that fold chained vector reductions. These patterns assume that /// elementwise operations (e.g., `arith.addf` with vector operands) are /// cheaper than vector reduction. /// Note that these patterns change the order of reduction which may not always /// produce bit-identical results on some floating point inputs. /// /// Example: /// ``` /// %a = vector.reduction <add> %x, %acc /// %b = vector.reduction <add> %y, %a /// ``` /// is transformed into: /// ``` /// %a = arith.addf %x, %y /// %b = vector.reduction <add> %a, %acc /// ``` void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Patterns to break down vector reductions into a series of arith reductions /// over vector elements. This is intended to be simplify code with reductions /// over small vector types and avoid more specialized reduction lowering when /// possible. /// /// Example: /// ``` /// %a = vector.reduction <add> %x : vector<2xf32> into f32 /// ``` /// is transformed into: /// ``` /// %y = vector.extract %x[0] : f32 from vector<2xf32> /// %z = vector.extract %x[1] : f32 from vector<2xf32> /// %a = arith.addf %y, %z : f32 /// ``` void populateBreakDownVectorReductionPatterns( RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2, PatternBenefit benefit = 1); /// Populate `patterns` with the following patterns. /// /// [DecomposeDifferentRankInsertStridedSlice] /// ========================================== /// RewritePattern for InsertStridedSliceOp where source and destination vectors /// have different ranks. /// /// When ranks are different, InsertStridedSlice needs to extract a properly /// ranked vector from the destination vector into which to insert. This pattern /// only takes care of this extraction part and forwards the rest to /// [VectorInsertStridedSliceOpSameRankRewritePattern]. /// /// For a k-D source and n-D destination vector (k < n), we emit: /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to /// insert the k-D source. /// 2. k-D -> (n-1)-D InsertStridedSlice op /// 3. InsertOp that is the reverse of 1. /// /// [DecomposeNDExtractStridedSlice] /// ================================ /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case. void populateVectorInsertExtractStridedSliceDecompositionPatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Populate `patterns` with a pattern to breaks down 1-D extract_strided_slice /// ops into a chain of Extract ops to extract each element from the source, and /// then a chain of Insert ops to insert to the target vector. /// /// If `controlFn` is not nullptr, the pattern will only be invoked on ops that /// `controlFn` returns true. Otherwise runs on ops. void populateVectorExtractStridedSliceToExtractInsertChainPatterns( RewritePatternSet &patterns, std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr, PatternBenefit benefit = 1); /// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops /// based on the destination vector shape. Bitcasts from a lower bitwidth /// element type to a higher bitwidth one are extracted from the lower bitwidth /// based on the native destination vector shape and inserted based on the ratio /// of the bitwidths. /// /// This acts as a last resort way to break down vector.bitcast ops to smaller /// vector sizes. Because this pattern composes until it is bitcasting to a /// single element of the higher bitwidth, the is an optional control function. /// If `controlFn` is not nullptr, the pattern will only apply to ops where /// `controlFn` returns true, otherwise applies to all bitcast ops. void populateBreakDownVectorBitCastOpPatterns( RewritePatternSet &patterns, std::function<bool(BitCastOp)> controlFn = nullptr, PatternBenefit benefit = 1); /// Populate `patterns` with the following patterns. /// /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns(); /// /// [ConvertSameRankInsertStridedSliceIntoShuffle] /// ============================================== /// RewritePattern for InsertStridedSliceOp where source and destination vectors /// have the same rank. For each outermost index in the slice: /// begin end stride /// [offset : offset+size*stride : stride] /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector. /// 2. InsertStridedSlice (k-1)-D into (n-1)-D /// 3. the destination subvector is inserted back in the proper place /// 3. InsertOp that is the reverse of 1. /// /// [Convert1DExtractStridedSliceIntoShuffle] /// ========================================= /// For such cases, we can lower it to a ShuffleOp. void populateVectorInsertExtractStridedSliceTransforms( RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Collect a set of pattern to unroll vector operations to a smaller shapes. /// `options` structure controls which operations are unrolled and the target /// shape. /// `op` is unrolled to the `targetShape` as follows, for each of its operands: /// 1. the unrolled type `unrolledVectorType` and number of unrolled instances /// `numUnrolledInstances` are computed from the `targetShape`. For now it is /// assumed the unrolling factors divide the vector sizes. /// 2. ExtractStridedSlice are created to break-up the vector operands. /// 3. the original op is cloned `numUnrolledInstances` times, once for each /// result. /// 4. InsertStridedSlice are inserted to re-assemble the slices into the /// original vectore shape. /// /// Example: /// /// opA(operand0, operand1) // numUnrolledInstances = 3 /// /// operand0 operand1 /// | | /// fork fork /// <----------gather all fork ops ---------> /// /|\ /|\ /// f00 f01 f02 f10 f11 f12 /// <---------- clone op 3 times ---------> /// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) /// \ | / /// <-------------------- join -------------------------> /// /// Other local patterns then kick in iteratively (including DCE) and compose /// to combine the ExtractStridedSlice/InsertStridedSlice. void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit = 1); /// Collect a set of vector.shape_cast folding patterns. void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Collect a set of leading one dimension removal patterns. /// /// These patterns insert vector.shape_cast to remove leading one dimensions /// to expose more canonical forms of read/write/insert/extract operations. /// With them, there are more chances that we can cancel out extract-insert /// pairs or forward write-read pairs. void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Collect a set of one dimension removal patterns. /// /// These patterns insert rank-reducing memref.subview ops to remove one /// dimensions. With them, there are more chances that we can avoid /// potentially expensive vector.shape_cast operations. void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Collect a set of patterns that use vector.shape_cast to help fold unit dims. /// /// These patterns use vector.shape_cast to remove unit dims from e.g. /// arithmetic operations on Vectors. The newly inserted shape_casts will either /// cancel each other out or will be folded away when combined with other /// patterns. void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Collect a set of patterns to flatten n-D vector transfers on contiguous /// memref. /// /// These patterns insert memref.collapse_shape + vector.shape_cast patterns /// to transform multiple small n-D transfers into a larger 1-D transfer where /// the memref contiguity properties allow it. /// /// Flattening is only applied if the bitwidth of the trailing vector dimension /// is smaller or equal to `targetVectorBitwidth`. void populateFlattenVectorTransferPatterns( RewritePatternSet &patterns, unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(), PatternBenefit benefit = 1); /// Collect a set of patterns that bubble up/down bitcast ops. /// /// These patterns move vector.bitcast ops to be before insert ops or after /// extract ops where suitable. With them, bitcast will happen on smaller /// vectors and there are more chances to share extract/insert ops. void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// These patterns materialize masks for various vector ops such as transfers. void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit = 1); /// Appends patterns for emulating vector operations over narrow types with ops /// over wider types. void populateVectorNarrowTypeEmulationPatterns( arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns); /// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of /// vector operations comprising `shuffle` and `bitwise` ops. /// Warning: these patterns currently only work for little endian targets. FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter, vector::BitCastOp bitCastOp, arith::TruncIOp truncOp, vector::BroadcastOp maybeBroadcastOp); /// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of /// vector operations comprising `shuffle` and `bitwise` ops. /// Warning: these patterns currently only work for little endian targets. FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp, vector::BitCastOp bitCastOp, vector::BroadcastOp maybeBroadcastOp); /// Appends patterns for rewriting vector operations over narrow types with /// ops over wider types. /// Warning: these patterns currently only work for little endian targets. void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Appends patterns for emulating a sub-byte vector transpose. void populateVectorTransposeNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Populates patterns for ND vectors (N >= 2) linearization and sets up the /// provided ConversionTarget with the appropriate legality configuration for /// the ops to get converted properly. void populateVectorLinearizeTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth); /// Populates patterns for linearizing ND (N >= 2) vector operations to 1D /// vector shuffle operations. void populateVectorLinearizeShuffleLikeOpsPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth); } // namespace vector } // namespace mlir #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H