//===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements target-independent rewrites and utilities to lower the // 'vector.mask' operation. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE … namespace mlir { namespace vector { #define GEN_PASS_DEF_LOWERVECTORMASKPASS #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" } // namespace vector } // namespace mlir usingnamespacemlir; usingnamespacemlir::vector; //===----------------------------------------------------------------------===// // populateVectorMaskOpLoweringPatterns //===----------------------------------------------------------------------===// namespace { /// Progressive lowering of CreateMaskOp. /// One: /// %x = vector.create_mask %a, ... : vector<dx...> /// is replaced by: /// %l = vector.create_mask ... : vector<...> ; one lower rank /// %0 = arith.cmpi "slt", %ci, %a | /// %1 = select %0, %l, %zeroes | /// %r = vector.insert %1, %pr [i] | d-times /// %x = .... /// until a one-dimensional vector is reached. class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { … }; /// Progressive lowering of ConstantMaskOp. /// One: /// %x = vector.constant_mask [a,b] /// is replaced by: /// %z = zero-result /// %l = vector.constant_mask [b] /// %4 = vector.insert %l, %z[0] /// .. /// %x = vector.insert %l, %..[a-1] /// until a one-dimensional vector is reached. All these operations /// will be folded at LLVM IR level. class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { … }; } // namespace void mlir::vector::populateVectorMaskOpLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … } //===----------------------------------------------------------------------===// // populateVectorMaskLoweringPatternsForSideEffectingOps //===----------------------------------------------------------------------===// namespace { /// The `MaskOpRewritePattern` implements a pattern that follows a two-fold /// matching: /// 1. It matches a `vector.mask` operation. /// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested /// in the matched `vector.mask` operation. /// /// It is required that the replacement op in the pattern replaces the /// `vector.mask` operation and not the nested `MaskableOpInterface`. This /// approach allows having patterns that "stop" at every `vector.mask` operation /// and actually match the traits of its the nested `MaskableOpInterface`. template <class SourceOp> struct MaskOpRewritePattern : OpRewritePattern<MaskOp> { … }; /// Lowers a masked `vector.transfer_read` operation. struct MaskedTransferReadOpPattern : public MaskOpRewritePattern<TransferReadOp> { … }; /// Lowers a masked `vector.transfer_write` operation. struct MaskedTransferWriteOpPattern : public MaskOpRewritePattern<TransferWriteOp> { … }; /// Lowers a masked `vector.gather` operation. struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> { … }; struct LowerVectorMaskPass : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> { … }; } // namespace /// Populates instances of `MaskOpRewritePattern` to lower masked operations /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and /// not its nested `MaskableOpInterface`. void vector::populateVectorMaskLoweringPatternsForSideEffectingOps( RewritePatternSet &patterns) { … } std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() { … }