//===- LoweringPatterns.h - Vector rewrite patterns --------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H #define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" namespace mlir { class RewritePatternSet; namespace vector { //===----------------------------------------------------------------------===// // Lowering pattern populate functions //===----------------------------------------------------------------------===// /// Populate the pattern set with the following patterns: /// /// [OuterProductOpLowering] /// Progressively lower a `vector.outerproduct` to linearized /// `vector.extract` + `vector.fma` + `vector.insert`. /// /// [ContractionOpLowering] /// Progressive lowering of ContractionOp. /// One: /// %x = vector.contract with at least one free/batch dimension /// is replaced by: /// %a = vector.contract with one less free/batch dimension /// %b = vector.contract with one less free/batch dimension /// /// [ContractionOpToMatmulOpLowering] /// Progressively lower a `vector.contract` with row-major matmul semantics to /// linearized `vector.shape_cast` + `vector.matmul` on the way to /// `llvm.matrix.multiply`. /// /// [ContractionOpToDotLowering] /// Progressively lower a `vector.contract` with row-major matmul semantics to /// linearized `vector.extract` + `vector.reduce` + `vector.insert`. /// /// [ContractionOpToOuterProductOpLowering] /// Progressively lower a `vector.contract` with row-major matmul semantics to /// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`. void populateVectorContractLoweringPatterns( RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit = 1, bool disableOuterProductLowering = false); /// Populate the pattern set with the following patterns: /// /// [OuterProductOpLowering] /// Progressively lower a `vector.outerproduct` to linearized /// `vector.extract` + `vector.fma` + `vector.insert`. void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Collect a set of patterns to convert vector.multi_reduction op into /// a sequence of vector.reduction ops. The patterns comprise: /// /// [InnerOuterDimReductionConversion] /// Rewrites vector.multi_reduction such that all reduction dimensions are /// either innermost or outermost, by adding the proper vector.transpose /// operations. /// /// [ReduceMultiDimReductionRank] /// Once in innermost or outermost reduction /// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction, /// by introducing vector.shape_cast ops to collapse + multi-reduce + expand /// back. /// /// [TwoDimMultiReductionToElementWise] /// Once in 2-D vector.multi_reduction form, with an **outermost** reduction /// dimension, unroll the outer dimension to obtain a sequence of 1-D vector /// ops. This also has an opportunity for tree-reduction (in the future). /// /// [TwoDimMultiReductionToReduction] /// Once in 2-D vector.multi_reduction form, with an **innermost** reduction /// dimension, unroll the outer dimension to obtain a sequence of extract + /// vector.reduction + insert. This can further lower to horizontal reduction /// ops. /// /// [OneDimMultiReductionToTwoDim] /// For cases that reduce to 1-D vector<k> reduction (and are thus missing /// either a parallel or a reduction), we lift them back up to 2-D with a simple /// vector.shape_cast to vector<1xk> so that the other patterns can kick in, /// thus fully exiting out of the vector.multi_reduction abstraction. void populateVectorMultiReductionLoweringPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit = 1); /// Populate the pattern set with the following patterns: /// /// [TransferReadToVectorLoadLowering] /// Progressive lowering of BroadcastOp to ExtractOp + InsertOp + lower-D /// BroadcastOp until dim 1. void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Populate the pattern set with the following patterns: /// /// [CreateMaskOp] /// Progressive lowering of CreateMaskOp to lower-D CreateMaskOp until dim 1. /// /// [ConstantMaskOp] /// Progressive lowering of ConstantMaskOp to lower-D ConstantMaskOp until /// dim 1. void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Collects patterns that lower scalar vector transfer ops to memref loads and /// stores when beneficial. If `allowMultipleUses` is set to true, the patterns /// are applied to vector transfer reads with any number of uses. Otherwise, /// only vector transfer reads with a single use will be lowered. void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses); /// Populate the pattern set with the following patterns: /// /// [ShapeCastOp2DDownCastRewritePattern] /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D /// vectors progressively. /// /// [ShapeCastOp2DUpCastRewritePattern] /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D /// vectors progressively. /// /// [ShapeCastOpRewritePattern] /// Reference lowering to fully unrolled sequences of single element ExtractOp + /// InsertOp. Note that applying this pattern can almost always be considered a /// performance bug. void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Populate the pattern set with the following patterns: /// /// [TransposeOpLowering] /// /// [TransposeOp2DToShuffleLowering] /// void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit = 1); /// Populate the pattern set with the following patterns: /// /// [TransferReadToVectorLoadLowering] /// Progressive lowering of transfer_read.This pattern supports lowering of /// `vector.transfer_read` to a combination of `vector.load` and /// `vector.broadcast` /// /// [TransferWriteToVectorStoreLowering] /// Progressive lowering of transfer_write. This pattern supports lowering of /// `vector.transfer_write` to `vector.store` /// /// [VectorLoadToMemrefLoadLowering] /// Replace a 0-d vector.load with a memref.load + vector.broadcast. /// /// [VectorStoreToMemrefStoreLowering] /// Replace a 0-d vector.store with a vector.extractelement + memref.store. /// /// These patterns lower transfer ops to simpler ops like `vector.load`, /// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank /// of a most `maxTransferRank` are lowered. This is useful when combined with /// VectorToSCF, which reduces the rank of vector transfer ops. void populateVectorTransferLoweringPatterns( RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank = std::nullopt, PatternBenefit benefit = 1); /// Collect a set of transfer read/write lowering patterns that simplify the /// permutation map (e.g., converting it to a minor identity map) by inserting /// broadcasts and transposes. More specifically: /// /// [TransferReadPermutationLowering] /// Lower transfer_read op with permutation into a transfer_read with a /// permutation map composed of leading zeros followed by a minor identity + /// vector.transpose op. /// Ex: /// vector.transfer_read ... /// permutation_map: (d0, d1, d2) -> (0, d1) /// into: /// %v = vector.transfer_read ... /// permutation_map: (d0, d1, d2) -> (d1, 0) /// vector.transpose %v, [1, 0] /// /// vector.transfer_read ... /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) /// into: /// %v = vector.transfer_read ... /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) /// vector.transpose %v, [0, 1, 3, 2, 4] /// Note that an alternative is to transform it to linalg.transpose + /// vector.transfer_read to do the transpose in memory instead. /// /// [TransferWritePermutationLowering] /// Lower transfer_write op with permutation into a transfer_write with a /// minor identity permutation map. (transfer_write ops cannot have broadcasts.) /// Ex: /// vector.transfer_write %v ... /// permutation_map: (d0, d1, d2) -> (d2, d0, d1) /// into: /// %tmp = vector.transpose %v, [2, 0, 1] /// vector.transfer_write %tmp ... /// permutation_map: (d0, d1, d2) -> (d0, d1, d2) /// /// vector.transfer_write %v ... /// permutation_map: (d0, d1, d2, d3) -> (d3, d2) /// into: /// %tmp = vector.transpose %v, [1, 0] /// %v = vector.transfer_write %tmp ... /// permutation_map: (d0, d1, d2, d3) -> (d2, d3) /// /// [TransferOpReduceRank] /// Lower transfer_read op with broadcast in the leading dimensions into /// transfer_read of lower rank + vector.broadcast. /// Ex: vector.transfer_read ... /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) /// into: /// %v = vector.transfer_read ... /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) /// vector.broadcast %v void populateVectorTransferPermutationMapLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Populate the pattern set with the following patterns: /// /// [ScanToArithOps] /// Convert vector.scan op into arith ops and vector.insert_strided_slice / /// vector.extract_strided_slice. void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Populate the pattern set with the following patterns: /// /// [FlattenGather] /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the /// outermost dimension. /// /// [Gather1DToConditionalLoads] /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these /// loads/extracts are made conditional using `scf.if` ops. void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Populates instances of `MaskOpRewritePattern` to lower masked operations /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and /// not its nested `MaskableOpInterface`. void populateVectorMaskLoweringPatternsForSideEffectingOps( RewritePatternSet &patterns); /// Populate the pattern set with the following patterns: /// /// [VectorMaskedLoadOpConverter] /// Turns vector.maskedload to scf.if + memref.load /// /// [VectorMaskedStoreOpConverter] /// Turns vector.maskedstore to scf.if + memref.store void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Populate the pattern set with the following patterns: /// /// [UnrollInterleaveOp] /// A one-shot unrolling of InterleaveOp to (one or more) ExtractOp + /// InterleaveOp (of `targetRank`) + InsertOp. void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank = 1, PatternBenefit benefit = 1); void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Populates the pattern set with the following patterns: /// /// [UnrollBitCastOp] /// A one-shot unrolling of BitCastOp to (one or more) ExtractOp + /// BitCastOp (of `targetRank`) + InsertOp. void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank = 1, PatternBenefit benefit = 1); } // namespace vector } // namespace mlir #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H