//===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns/pass to remove usage of unit-extent dimensions // to specify broadcasting in favor of more canonical representation of the // computation // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" namespace mlir { #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS #include "mlir/Dialect/Linalg/Passes.h.inc" } // namespace mlir #define DEBUG_TYPE … usingnamespacemlir; usingnamespacemlir::linalg; namespace { /// Pattern to move init operands to ins when all the loops are parallel and /// blockArgument corresponding to init is used in the region. This is a fix-up /// when unit reduction dimensions are all folded away. In this context, it /// becomes a elementwise generic op. E.g., it converts /// /// %0 = tensor.empty() : tensor<1x1xf32> /// %1 = linalg.fill /// ins(%cst : f32) /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> /// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, /// affine_map<(d0) -> (0, d0)>], /// iterator_types = ["parallel"]} /// ins(%arg0 : tensor<1x?x1x1xf32>) /// outs(%1 : tensor<1x1xf32>) { /// ^bb0(%in: f32, %out: f32): /// %3 = arith.addf %in, %out : f32 /// linalg.yield %3 : f32 /// } -> tensor<1x1xf32> /// /// into /// /// %0 = tensor.empty() : tensor<1x1xf32> /// %1 = linalg.fill /// ins(%cst : f32) /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> /// %2 = tensor.empty() : tensor<1x1xf32> /// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, /// affine_map<(d0) -> (0, d0)>, /// affine_map<(d0) -> (0, d0)>], /// iterator_types = ["parallel"]} /// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>) /// outs(%2 : tensor<1x1xf32>) { /// ^bb0(%in: f32, %in_0: f32, %out: f32): /// %4 = arith.addf %in, %in_0 : f32 /// linalg.yield %4 : f32 /// } -> tensor<1x1xf32> struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> { … }; } // namespace //===---------------------------------------------------------------------===// // Drop loops that are unit-extents within Linalg operations. //===---------------------------------------------------------------------===// /// Implements a pass that canonicalizes the uses of unit-extent dimensions for /// broadcasting. For example, /// /// ```mlir /// #accesses = [ /// affine_map<(d0, d1) -> (0, d1)>, /// affine_map<(d0, d1) -> (d0, 0)>, /// affine_map<(d0, d1) -> (d0, d1)> /// ] /// /// #trait = { /// args_in = 2, /// args_out = 1, /// indexing_maps = #accesses, /// iterator_types = ["parallel", "parallel"], /// library_call = "some_external_fn" /// } /// /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> /// tensor<5x5xf32> /// { /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : /// tensor<5xf32> into tensor<1x5xf32> /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : /// tensor<5xf32> into tensor<5x1xf32> /// %2 = linalg.generic #trait %0, %1 { /// ^bb0(%arg2: f32, %arg3: f32): /// %3 = arith.addf %arg2, %arg3 : f32 /// linalg.yield %3 : f32 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> /// return %2 : tensor<5x5xf32> /// } /// /// would canonicalize to /// /// ```mlir /// #accesses = [ /// affine_map<(d0, d1) -> (d1)>, /// affine_map<(d0, d1) -> (d0)>, /// affine_map<(d0, d1) -> (d0, d1)> /// ] /// /// #trait = { /// args_in = 2, /// args_out = 1, /// indexing_maps = #accesses, /// iterator_types = ["parallel", "parallel"], /// library_call = "some_external_fn" /// } /// /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> /// tensor<5x5xf32> /// { /// %0 = linalg.generic #trait %arg0, %arg1 { /// ^bb0(%arg2: f32, %arg3: f32): /// %3 = arith.addf %arg2, %arg3 : f32 /// linalg.yield %3 : f32 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> /// return %0 : tensor<5x5xf32> /// } /// Update the index accesses of linalg operations having index semantics. static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet<unsigned> &unitDims, RewriterBase &rewriter) { … } /// Expand the given `value` so that the type matches the type of `origDest`. /// The `reassociation` is used when `rankReductionStrategy` is set to /// `RankReductionStrategy::ReassociativeReshape`. static Value expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, ArrayRef<ReassociationIndices> reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) { … } /// Collapse the given `value` so that the type matches the type of /// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is /// set to `RankReductionStrategy::ReassociativeReshape`. static Value collapseValue( RewriterBase &rewriter, Location loc, Value operand, ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) { … } /// Compute the modified metadata for an operands of operation /// whose unit dims are being dropped. Return the new indexing map /// to use, the shape of the operand in the replacement op /// and the `reassocation` to use to go from original operand shape /// to modified operand shape. struct UnitExtentReplacementInfo { … }; static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata( MLIRContext *context, GenericOp genericOp, OpOperand *opOperand, llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap, ArrayRef<AffineExpr> dimReplacements) { … } FailureOr<DropUnitDimsResult> linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, const ControlDropUnitDims &options) { … } namespace { struct DropUnitDims : public OpRewritePattern<GenericOp> { … }; } // namespace //===---------------------------------------------------------------------===// // Drop dimensions that are unit-extents within tensor operations. //===---------------------------------------------------------------------===// namespace { struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { … }; } // namespace namespace { /// Convert `extract_slice` operations to rank-reduced versions. struct RankReducedExtractSliceOp : public OpRewritePattern<tensor::ExtractSliceOp> { … }; /// Convert `insert_slice` operations to rank-reduced versions. /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp. template <typename InsertOpTy> struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> { … }; } // namespace /// Patterns that are used to canonicalize the use of unit-extent dims for /// broadcasting. static void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options) { … } static void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options) { … } void mlir::linalg::populateFoldUnitExtentDimsPatterns( RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) { … } void mlir::linalg::populateMoveInitOperandsToInputPattern( RewritePatternSet &patterns) { … } namespace { /// Pass that removes unit-extent dims within generic ops. struct LinalgFoldUnitExtentDimsPass : public impl::LinalgFoldUnitExtentDimsPassBase< LinalgFoldUnitExtentDimsPass> { … }; } // namespace namespace { /// Returns reassociation indices for collapsing/expanding a /// tensor of rank `rank` at position `pos`. static SmallVector<ReassociationIndices> getReassociationForReshapeAtDim(int64_t rank, int64_t pos) { … } /// Returns a collapsed `val` where the collapsing occurs at dim `pos`. /// If `pos < 0`, then don't collapse. static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val, int64_t pos) { … } /// Base class for all rank reduction patterns for contraction ops /// with unit dimensions. All patterns should convert one named op /// to another named op. Intended to reduce only one iteration space dim /// at a time. /// Reducing multiple dims will happen with recusive application of /// pattern rewrites. template <typename FromOpTy, typename ToOpTy> struct RankReduceContractionOps : OpRewritePattern<FromOpTy> { … }; /// Patterns for unbatching batched contraction ops template <typename FromOpTy, typename ToOpTy> struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> { … }; /// Patterns for reducing non-batch dimensions template <typename FromOpTy, typename ToOpTy> struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> { … }; } // namespace void mlir::linalg::populateContractionOpRankReducingPatterns( RewritePatternSet &patterns) { … }