//===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains cross-dialect canonicalization patterns that cannot be // actual canonicalization patterns due to undesired additional dependencies. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" namespace mlir { #define GEN_PASS_DEF_SCFFORLOOPCANONICALIZATION #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir usingnamespacemlir; usingnamespacemlir::scf; /// A simple, conservative analysis to determine if the loop is shape /// conserving. I.e., the type of the arg-th yielded value is the same as the /// type of the corresponding basic block argument of the loop. /// Note: This function handles only simple cases. Expand as needed. static bool isShapePreserving(ForOp forOp, int64_t arg) { … } namespace { /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: /// /// ``` /// %0 = ... : tensor<?x?xf32> /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { /// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32> /// ... /// } /// ``` /// /// is folded to: /// /// ``` /// %0 = ... : tensor<?x?xf32> /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> /// ... /// } /// ``` /// /// Note: Dim ops are folded only if it can be proven that the runtime type of /// the iter arg does not change with loop iterations. template <typename OpTy> struct DimOfIterArgFolder : public OpRewritePattern<OpTy> { … }; /// Fold dim ops of loop results to dim ops of their respective init args. E.g.: /// /// ``` /// %0 = ... : tensor<?x?xf32> /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { /// ... /// } /// %1 = tensor.dim %r, %c0 : tensor<?x?xf32> /// ``` /// /// is folded to: /// /// ``` /// %0 = ... : tensor<?x?xf32> /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { /// ... /// } /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> /// ``` /// /// Note: Dim ops are folded only if it can be proven that the runtime type of /// the iter arg does not change with loop iterations. template <typename OpTy> struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> { … }; /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for /// and scf.parallel loops with a known range. template <typename OpTy> struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> { … }; struct SCFForLoopCanonicalization : public impl::SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> { … }; } // namespace void mlir::scf::populateSCFForLoopCanonicalizationPatterns( RewritePatternSet &patterns) { … } std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() { … }