//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the linalg dialect Fusion pass. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Dominance.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include <optional> #include <set> #define DEBUG_TYPE … usingnamespacemlir; usingnamespacemlir::linalg; /// Implements a simple high-level fusion pass on linalg structured operations. /// /// In each block, linalg ops are processed in reverse textual order. /// Given a linalg op `O`, fusion occurs by: /// 1. inspecting the linalg ops that write into the views read by `O`. There /// are 2 cases: /// a) buffer case: use the SSA value of the views and a simple alias /// analysis on subview ops to determine producer-consumer dependences; /// b) tensor case: use SSA use-def chains on extract_slice ops; /// 2. greedily fuse the linalg ops that produce the subview/extract_slice. /// 3. inspect the fused ops and determine whether they have other remaining /// LinalgOp uses. If not, then erase the original producing linalg op. /// /// More advanced use cases, analyses as well as profitability heuristics are /// left for future work. struct ShapeDimension { … }; // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps // guarantees at least one such dimension is found. If multiple candidates exist // they must agree by construction (i.e. have the same size) and we just return // the first one. static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, bool fromSubViewOpOnly = false) { … } static SmallVector<Value> getTiledOperands(LinalgOp producer) { … } /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges` /// provides the loop range information for the fused loops. The rest are /// obtained from the producer itself, since they are not tiled + fused. static LinalgOp fuse(OpBuilder &b, LinalgOp producer, const DenseMap<unsigned, Range> &fusedLoopsAndRanges) { … } /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is /// expected to be defined by a subview op or an extract_slice op. static Range getRangeFromOperandShape(OpBuilder &b, Location loc, Value shapedOperand, unsigned dim) { … } /// Fuses the producer into the loop immediately enclosing the consumer. /// This is achieved by "recomputing" the producer at the time it /// is needed just before the consumer. static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap, OpOperand &consumerOpOperand) { … } /// Walk back use-def chain through scf::For yields. /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp // TODO(ravishankarm, ntv): This can be moved into the dependence graphs // dependence tracking since the dependence tracking is similar to what is done // w.r.t to buffers. static void getProducerOfTensor(Value tensor, OpResult &opResult) { … } FailureOr<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) { … } FailureOr<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, OpOperand &consumerOpOperand) { … }