//===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" usingnamespacemlir; usingnamespacemlir::vector; // Helper that picks the proper sequence for inserting. static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, Value into, int64_t offset) { … } // Helper that picks the proper sequence for extracting. static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, int64_t offset) { … } /// RewritePattern for InsertStridedSliceOp where source and destination vectors /// have different ranks. /// /// When ranks are different, InsertStridedSlice needs to extract a properly /// ranked vector from the destination vector into which to insert. This pattern /// only takes care of this extraction part and forwards the rest to /// [ConvertSameRankInsertStridedSliceIntoShuffle]. /// /// For a k-D source and n-D destination vector (k < n), we emit: /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to /// insert the k-D source. /// 2. k-D -> (n-1)-D InsertStridedSlice op /// 3. InsertOp that is the reverse of 1. class DecomposeDifferentRankInsertStridedSlice : public OpRewritePattern<InsertStridedSliceOp> { … }; /// RewritePattern for InsertStridedSliceOp where source and destination vectors /// have the same rank. For each outermost index in the slice: /// begin end stride /// [offset : offset+size*stride : stride] /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector. /// 2. InsertStridedSlice (k-1)-D into (n-1)-D /// 3. the destination subvector is inserted back in the proper place /// 3. InsertOp that is the reverse of 1. class ConvertSameRankInsertStridedSliceIntoShuffle : public OpRewritePattern<InsertStridedSliceOp> { … }; /// RewritePattern for ExtractStridedSliceOp where source and destination /// vectors are 1-D. For such cases, we can lower it to a ShuffleOp. class Convert1DExtractStridedSliceIntoShuffle : public OpRewritePattern<ExtractStridedSliceOp> { … }; /// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops /// to extract each element from the source, and then a chain of Insert ops /// to insert to the target vector. class Convert1DExtractStridedSliceIntoExtractInsertChain final : public OpRewritePattern<ExtractStridedSliceOp> { … }; /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D. /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case. class DecomposeNDExtractStridedSlice : public OpRewritePattern<ExtractStridedSliceOp> { … }; void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … } void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns( RewritePatternSet &patterns, std::function<bool(ExtractStridedSliceOp)> controlFn, PatternBenefit benefit) { … } /// Populate the given list with patterns that convert from Vector to LLVM. void vector::populateVectorInsertExtractStridedSliceTransforms( RewritePatternSet &patterns, PatternBenefit benefit) { … }