//===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===// // /// Part of the LLVM Project, under the Apache License v2.0 with LLVM /// Exceptions. See https://llvm.org/LICENSE.txt for license information. /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements target-independent rewrites and utilities to lower the // 'vector.multi_reduction' operation. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace vector { #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" } // namespace vector } // namespace mlir #define DEBUG_TYPE … usingnamespacemlir; namespace { /// This file implements the following transformations as composable atomic /// patterns. /// Converts vector.multi_reduction into inner-most/outer-most reduction form /// by using vector.transpose class InnerOuterDimReductionConversion : public OpRewritePattern<vector::MultiDimReductionOp> { … }; /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction /// dimensions are either inner most or outer most. class ReduceMultiDimReductionRank : public OpRewritePattern<vector::MultiDimReductionOp> { … }; /// Unrolls vector.multi_reduction with outermost reductions /// and combines results struct TwoDimMultiReductionToElementWise : public OpRewritePattern<vector::MultiDimReductionOp> { … }; /// Converts 2d vector.multi_reduction with inner most reduction dimension into /// a sequence of vector.reduction ops. struct TwoDimMultiReductionToReduction : public OpRewritePattern<vector::MultiDimReductionOp> { … }; /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d /// form with both a single parallel and reduction dimension. /// This is achieved with a simple vector.shape_cast that inserts a leading 1. /// The case with a single parallel dimension is a noop and folds away /// separately. struct OneDimMultiReductionToTwoDim : public OpRewritePattern<vector::MultiDimReductionOp> { … }; struct LowerVectorMultiReductionPass : public vector::impl::LowerVectorMultiReductionBase< LowerVectorMultiReductionPass> { … }; } // namespace void mlir::vector::populateVectorMultiReductionLoweringPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit) { … } std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass( vector::VectorMultiReductionLowering option) { … }