//===- SCF.cpp - Structured Control Flow Operations -----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" usingnamespacemlir; usingnamespacemlir::scf; #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc" //===----------------------------------------------------------------------===// // SCFDialect Dialect Interfaces //===----------------------------------------------------------------------===// namespace { struct SCFInlinerInterface : public DialectInlinerInterface { … }; } // namespace //===----------------------------------------------------------------------===// // SCFDialect //===----------------------------------------------------------------------===// void SCFDialect::initialize() { … } /// Default callback for IfOp builders. Inserts a yield without arguments. void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) { … } /// Verifies that the first block of the given `region` is terminated by a /// TerminatorTy. Reports errors on the given operation if it is not the case. template <typename TerminatorTy> static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage) { … } //===----------------------------------------------------------------------===// // ExecuteRegionOp //===----------------------------------------------------------------------===// /// Replaces the given op with the contents of the given single-block region, /// using the operands of the block terminator to replace operation results. static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs = { … } /// /// (ssa-id `=`)? `execute_region` `->` function-result-type `{` /// block+ /// `}` /// /// Example: /// scf.execute_region -> i32 { /// %idx = load %rI[%i] : memref<128xi32> /// return %idx : i32 /// } /// ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, OperationState &result) { … } void ExecuteRegionOp::print(OpAsmPrinter &p) { … } LogicalResult ExecuteRegionOp::verify() { … } // Inline an ExecuteRegionOp if it only contains one block. // "test.foo"() : () -> () // %v = scf.execute_region -> i64 { // %x = "test.val"() : () -> i64 // scf.yield %x : i64 // } // "test.bar"(%v) : (i64) -> () // // becomes // // "test.foo"() : () -> () // %x = "test.val"() : () -> i64 // "test.bar"(%x) : (i64) -> () // struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { … }; // Inline an ExecuteRegionOp if its parent can contain multiple blocks. // TODO generalize the conditions for operations which can be inlined into. // func @func_execute_region_elim() { // "test.foo"() : () -> () // %v = scf.execute_region -> i64 { // %c = "test.cmp"() : () -> i1 // cf.cond_br %c, ^bb2, ^bb3 // ^bb2: // %x = "test.val1"() : () -> i64 // cf.br ^bb4(%x : i64) // ^bb3: // %y = "test.val2"() : () -> i64 // cf.br ^bb4(%y : i64) // ^bb4(%z : i64): // scf.yield %z : i64 // } // "test.bar"(%v) : (i64) -> () // return // } // // becomes // // func @func_execute_region_elim() { // "test.foo"() : () -> () // %c = "test.cmp"() : () -> i1 // cf.cond_br %c, ^bb1, ^bb2 // ^bb1: // pred: ^bb0 // %x = "test.val1"() : () -> i64 // cf.br ^bb3(%x : i64) // ^bb2: // pred: ^bb0 // %y = "test.val2"() : () -> i64 // cf.br ^bb3(%y : i64) // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2 // "test.bar"(%z) : (i64) -> () // return // } // struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { … }; void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } void ExecuteRegionOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { … } //===----------------------------------------------------------------------===// // ConditionOp //===----------------------------------------------------------------------===// MutableOperandRange ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { … } void ConditionOp::getSuccessorRegions( ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> ®ions) { … } //===----------------------------------------------------------------------===// // ForOp //===----------------------------------------------------------------------===// void ForOp::build(OpBuilder &builder, OperationState &result, Value lb, Value ub, Value step, ValueRange initArgs, BodyBuilderFn bodyBuilder) { … } LogicalResult ForOp::verify() { … } LogicalResult ForOp::verifyRegions() { … } std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() { … } std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() { … } std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() { … } std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() { … } std::optional<ResultRange> ForOp::getLoopResults() { … } /// Promotes the loop body of a forOp to its containing block if the forOp /// it can be determined that the loop has a single iteration. LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) { … } /// Prints the initialization list in the form of /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>) /// where 'inner' values are assumed to be region arguments and 'outer' values /// are regular SSA values. static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix = "") { … } void ForOp::print(OpAsmPrinter &p) { … } ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { … } SmallVector<Region *> ForOp::getLoopRegions() { … } Block::BlockArgListType ForOp::getRegionIterArgs() { … } MutableArrayRef<OpOperand> ForOp::getInitsMutable() { … } FailureOr<LoopLikeOpInterface> ForOp::replaceWithAdditionalYields(RewriterBase &rewriter, ValueRange newInitOperands, bool replaceInitOperandUsesInLoop, const NewYieldValuesFn &newYieldValuesFn) { … } ForOp mlir::scf::getForInductionVarOwner(Value val) { … } OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) { … } void ForOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { … } SmallVector<Region *> ForallOp::getLoopRegions() { … } /// Promotes the loop body of a forallOp to its containing block if it can be /// determined that the loop has a single iteration. LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { … } Block::BlockArgListType ForallOp::getRegionIterArgs() { … } MutableArrayRef<OpOperand> ForallOp::getInitsMutable() { … } /// Promotes the loop body of a scf::ForallOp to its containing block. void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { … } LoopNest mlir::scf::buildLoopNest( OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder) { … } LoopNest mlir::scf::buildLoopNest( OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) { … } SmallVector<Value> mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn) { … } namespace { // Fold away ForOp iter arguments when: // 1) The op yields the iter arguments. // 2) The iter arguments have no use and the corresponding outer region // iterators (inputs) are yielded. // 3) The iter arguments have no use and the corresponding (operation) results // have no use. // // These arguments must be defined outside of // the ForOp region and can just be forwarded after simplifying the op inits, // yields and returns. // // The implementation uses `inlineBlockBefore` to steal the content of the // original ForOp and avoid cloning. struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { … }; /// Util function that tries to compute a constant diff between u and l. /// Returns std::nullopt when the difference between two AffineValueMap is /// dynamic. static std::optional<int64_t> computeConstDiff(Value l, Value u) { … } /// Rewriting pattern that erases loops that are known not to iterate, replaces /// single-iteration loops with their bodies, and removes empty loops that /// iterate at least once and only return values defined outside of the loop. struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> { … }; /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for: /// /// ``` /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32> /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) /// -> (tensor<?x?xf32>) { /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32> /// scf.yield %2 : tensor<?x?xf32> /// } /// use_of(%1) /// ``` /// /// folds into: /// /// ``` /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0) /// -> (tensor<32x1024xf32>) { /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32> /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32> /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32> /// scf.yield %4 : tensor<32x1024xf32> /// } /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32> /// use_of(%1) /// ``` struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> { … }; } // namespace void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } std::optional<APInt> ForOp::getConstantStep() { … } std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() { … } Speculation::Speculatability ForOp::getSpeculatability() { … } //===----------------------------------------------------------------------===// // ForallOp //===----------------------------------------------------------------------===// LogicalResult ForallOp::verify() { … } void ForallOp::print(OpAsmPrinter &p) { … } ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) { … } // Builder that takes loop bounds. void ForallOp::build( mlir::OpBuilder &b, mlir::OperationState &result, ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, ArrayRef<OpFoldResult> steps, ValueRange outputs, std::optional<ArrayAttr> mapping, function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { … } // Builder that takes loop bounds. void ForallOp::build( mlir::OpBuilder &b, mlir::OperationState &result, ArrayRef<OpFoldResult> ubs, ValueRange outputs, std::optional<ArrayAttr> mapping, function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { … } // Checks if the lbs are zeros and steps are ones. bool ForallOp::isNormalized() { … } // The ensureTerminator method generated by SingleBlockImplicitTerminator is // unaware of the fact that our terminator also needs a region to be // well-formed. We override it here to ensure that we do the right thing. void ForallOp::ensureTerminator(Region ®ion, OpBuilder &builder, Location loc) { … } InParallelOp ForallOp::getTerminator() { … } SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) { … } std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() { … } // Get lower bounds as OpFoldResult. std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() { … } // Get upper bounds as OpFoldResult. std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() { … } // Get steps as OpFoldResult. std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() { … } ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) { … } namespace { /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t). struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> { … }; class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> { … }; /// The following canonicalization pattern folds the iter arguments of /// scf.forall op if :- /// 1. The corresponding result has zero uses. /// 2. The iter argument is NOT being modified within the loop body. /// uses. /// /// Example of first case :- /// INPUT: /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c) /// { /// ... /// <SOME USE OF %arg0> /// <SOME USE OF %arg1> /// <SOME USE OF %arg2> /// ... /// scf.forall.in_parallel { /// <STORE OP WITH DESTINATION %arg1> /// <STORE OP WITH DESTINATION %arg0> /// <STORE OP WITH DESTINATION %arg2> /// } /// } /// return %res#1 /// /// OUTPUT: /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b) /// { /// ... /// <SOME USE OF %a> /// <SOME USE OF %new_arg0> /// <SOME USE OF %c> /// ... /// scf.forall.in_parallel { /// <STORE OP WITH DESTINATION %new_arg0> /// } /// } /// return %res /// /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the /// scf.forall is replaced by their corresponding operands. /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body /// of the scf.forall besides within scf.forall.in_parallel terminator, /// this canonicalization remains valid. For more details, please refer /// to : /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124 /// 3. TODO(avarma): Generalize it for other store ops. Currently it /// handles tensor.parallel_insert_slice ops only. /// /// Example of second case :- /// INPUT: /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b) /// { /// ... /// <SOME USE OF %arg0> /// <SOME USE OF %arg1> /// ... /// scf.forall.in_parallel { /// <STORE OP WITH DESTINATION %arg1> /// } /// } /// return %res#0, %res#1 /// /// OUTPUT: /// %res = scf.forall ... shared_outs(%new_arg0 = %b) /// { /// ... /// <SOME USE OF %a> /// <SOME USE OF %new_arg0> /// ... /// scf.forall.in_parallel { /// <STORE OP WITH DESTINATION %new_arg0> /// } /// } /// return %a, %res struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> { … }; struct ForallOpSingleOrZeroIterationDimsFolder : public OpRewritePattern<ForallOp> { … }; struct FoldTensorCastOfOutputIntoForallOp : public OpRewritePattern<scf::ForallOp> { … }; } // namespace void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } /// Given the region at `index`, or the parent operation if `index` is None, /// return the successor regions. These are the regions that may be selected /// during the flow of control. `operands` is a set of optional attributes that /// correspond to a constant value for each operand, or null if that operand is /// not a constant. void ForallOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { … } //===----------------------------------------------------------------------===// // InParallelOp //===----------------------------------------------------------------------===// // Build a InParallelOp with mixed static and dynamic entries. void InParallelOp::build(OpBuilder &b, OperationState &result) { … } LogicalResult InParallelOp::verify() { … } void InParallelOp::print(OpAsmPrinter &p) { … } ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) { … } OpResult InParallelOp::getParentResult(int64_t idx) { … } SmallVector<BlockArgument> InParallelOp::getDests() { … } llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() { … } //===----------------------------------------------------------------------===// // IfOp //===----------------------------------------------------------------------===// bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) { … } LogicalResult IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, IfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { … } void IfOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, Value cond) { … } void IfOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, Value cond, bool addThenBlock, bool addElseBlock) { … } void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, bool withElseRegion) { … } void IfOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, Value cond, bool withElseRegion) { … } void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, function_ref<void(OpBuilder &, Location)> thenBuilder, function_ref<void(OpBuilder &, Location)> elseBuilder) { … } LogicalResult IfOp::verify() { … } ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { … } void IfOp::print(OpAsmPrinter &p) { … } void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { … } void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> ®ions) { … } LogicalResult IfOp::fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { … } void IfOp::getRegionInvocationBounds( ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &invocationBounds) { … } namespace { // Pattern to remove unused IfOp results. struct RemoveUnusedResults : public OpRewritePattern<IfOp> { … }; struct RemoveStaticCondition : public OpRewritePattern<IfOp> { … }; /// Hoist any yielded results whose operands are defined outside /// the if, to a select instruction. struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> { … }; /// Allow the true region of an if to assume the condition is true /// and vice versa. For example: /// /// scf.if %cmp { /// print(%cmp) /// } /// /// becomes /// /// scf.if %cmp { /// print(true) /// } /// struct ConditionPropagation : public OpRewritePattern<IfOp> { … }; /// Remove any statements from an if that are equivalent to the condition /// or its negation. For example: /// /// %res:2 = scf.if %cmp { /// yield something(), true /// } else { /// yield something2(), false /// } /// print(%res#1) /// /// becomes /// %res = scf.if %cmp { /// yield something() /// } else { /// yield something2() /// } /// print(%cmp) /// /// Additionally if both branches yield the same value, replace all uses /// of the result with the yielded value. /// /// %res:2 = scf.if %cmp { /// yield something(), %arg1 /// } else { /// yield something2(), %arg1 /// } /// print(%res#1) /// /// becomes /// %res = scf.if %cmp { /// yield something() /// } else { /// yield something2() /// } /// print(%arg1) /// struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> { … }; /// Merge any consecutive scf.if's with the same condition. /// /// scf.if %cond { /// firstCodeTrue();... /// } else { /// firstCodeFalse();... /// } /// %res = scf.if %cond { /// secondCodeTrue();... /// } else { /// secondCodeFalse();... /// } /// /// becomes /// %res = scf.if %cmp { /// firstCodeTrue();... /// secondCodeTrue();... /// } else { /// firstCodeFalse();... /// secondCodeFalse();... /// } struct CombineIfs : public OpRewritePattern<IfOp> { … }; /// Pattern to remove an empty else branch. struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> { … }; /// Convert nested `if`s into `arith.andi` + single `if`. /// /// scf.if %arg0 { /// scf.if %arg1 { /// ... /// scf.yield /// } /// scf.yield /// } /// becomes /// /// %0 = arith.andi %arg0, %arg1 /// scf.if %0 { /// ... /// scf.yield /// } struct CombineNestedIfs : public OpRewritePattern<IfOp> { … }; } // namespace void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } Block *IfOp::thenBlock() { … } YieldOp IfOp::thenYield() { … } Block *IfOp::elseBlock() { … } YieldOp IfOp::elseYield() { … } //===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// void ParallelOp::build( OpBuilder &builder, OperationState &result, ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps, ValueRange initVals, function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn) { … } void ParallelOp::build( OpBuilder &builder, OperationState &result, ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps, function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { … } LogicalResult ParallelOp::verify() { … } ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { … } void ParallelOp::print(OpAsmPrinter &p) { … } SmallVector<Region *> ParallelOp::getLoopRegions() { … } std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() { … } std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() { … } std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() { … } std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() { … } ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) { … } namespace { // Collapse loop dimensions that perform a single iteration. struct ParallelOpSingleOrZeroIterationDimsFolder : public OpRewritePattern<ParallelOp> { … }; struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> { … }; } // namespace void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } /// Given the region at `index`, or the parent operation if `index` is None, /// return the successor regions. These are the regions that may be selected /// during the flow of control. `operands` is a set of optional attributes that /// correspond to a constant value for each operand, or null if that operand is /// not a constant. void ParallelOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { … } //===----------------------------------------------------------------------===// // ReduceOp //===----------------------------------------------------------------------===// void ReduceOp::build(OpBuilder &builder, OperationState &result) { … } void ReduceOp::build(OpBuilder &builder, OperationState &result, ValueRange operands) { … } LogicalResult ReduceOp::verifyRegions() { … } MutableOperandRange ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) { … } //===----------------------------------------------------------------------===// // ReduceReturnOp //===----------------------------------------------------------------------===// LogicalResult ReduceReturnOp::verify() { … } //===----------------------------------------------------------------------===// // WhileOp //===----------------------------------------------------------------------===// void WhileOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, TypeRange resultTypes, ValueRange inits, BodyBuilderFn beforeBuilder, BodyBuilderFn afterBuilder) { … } ConditionOp WhileOp::getConditionOp() { … } YieldOp WhileOp::getYieldOp() { … } std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() { … } Block::BlockArgListType WhileOp::getBeforeArguments() { … } Block::BlockArgListType WhileOp::getAfterArguments() { … } Block::BlockArgListType WhileOp::getRegionIterArgs() { … } OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) { … } void WhileOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { … } SmallVector<Region *> WhileOp::getLoopRegions() { … } /// Parses a `while` op. /// /// op ::= `scf.while` assignments `:` function-type region `do` region /// `attributes` attribute-dict /// initializer ::= /* empty */ | `(` assignment-list `)` /// assignment-list ::= assignment | assignment `,` assignment-list /// assignment ::= ssa-value `=` ssa-value ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) { … } /// Prints a `while` op. void scf::WhileOp::print(OpAsmPrinter &p) { … } /// Verifies that two ranges of types match, i.e. have the same number of /// entries and that types are pairwise equals. Reports errors on the given /// operation in case of mismatch. template <typename OpTy> static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message) { … } LogicalResult scf::WhileOp::verify() { … } namespace { /// Replace uses of the condition within the do block with true, since otherwise /// the block would not be evaluated. /// /// scf.while (..) : (i1, ...) -> ... { /// %condition = call @evaluate_condition() : () -> i1 /// scf.condition(%condition) %condition : i1, ... /// } do { /// ^bb0(%arg0: i1, ...): /// use(%arg0) /// ... /// /// becomes /// scf.while (..) : (i1, ...) -> ... { /// %condition = call @evaluate_condition() : () -> i1 /// scf.condition(%condition) %condition : i1, ... /// } do { /// ^bb0(%arg0: i1, ...): /// use(%true) /// ... struct WhileConditionTruth : public OpRewritePattern<WhileOp> { … }; /// Remove loop invariant arguments from `before` block of scf.while. /// A before block argument is considered loop invariant if :- /// 1. i-th yield operand is equal to the i-th while operand. /// 2. i-th yield operand is k-th after block argument which is (k+1)-th /// condition operand AND this (k+1)-th condition operand is equal to i-th /// iter argument/while operand. /// For the arguments which are removed, their uses inside scf.while /// are replaced with their corresponding initial value. /// /// Eg: /// INPUT :- /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b, /// ..., %argN_before = %N) /// { /// ... /// scf.condition(%cond) %arg1_before, %arg0_before, /// %arg2_before, %arg0_before, ... /// } do { /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, /// ..., %argK_after): /// ... /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN /// } /// /// OUTPUT :- /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before = /// %N) /// { /// ... /// scf.condition(%cond) %b, %a, %arg2_before, %a, ... /// } do { /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, /// ..., %argK_after): /// ... /// scf.yield %arg1_after, ..., %argN /// } /// /// EXPLANATION: /// We iterate over each yield operand. /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand /// %arg0_before, which in turn is the 0-th iter argument. So we /// remove 0-th before block argument and yield operand, and replace /// all uses of the 0-th before block argument with its initial value /// %a. /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial /// value. So we remove this operand and the corresponding before /// block argument and replace all uses of 1-th before block argument /// with %b. struct RemoveLoopInvariantArgsFromBeforeBlock : public OpRewritePattern<WhileOp> { … }; /// Remove loop invariant value from result (condition op) of scf.while. /// A value is considered loop invariant if the final value yielded by /// scf.condition is defined outside of the `before` block. We remove the /// corresponding argument in `after` block and replace the use with the value. /// We also replace the use of the corresponding result of scf.while with the /// value. /// /// Eg: /// INPUT :- /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ..., /// %argN_before = %N) { /// ... /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ... /// } do { /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after): /// ... /// some_func(%arg1_after) /// ... /// scf.yield %arg0_after, %arg2_after, ..., %argN_after /// } /// /// OUTPUT :- /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) { /// ... /// scf.condition(%cond) %arg0, %arg1, ..., %argM /// } do { /// ^bb0(%arg0, %arg3, ..., %argM): /// ... /// some_func(%a) /// ... /// scf.yield %arg0, %b, ..., %argN /// } /// /// EXPLANATION: /// 1. The 1-th and 2-th operand of scf.condition are defined outside the /// before block of scf.while, so they get removed. /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are /// replaced by %b. /// 3. The corresponding after block argument %arg1_after's uses are /// replaced by %a and %arg2_after's uses are replaced by %b. struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> { … }; /// Remove WhileOp results that are also unused in 'after' block. /// /// %0:2 = scf.while () : () -> (i32, i64) { /// %condition = "test.condition"() : () -> i1 /// %v1 = "test.get_some_value"() : () -> i32 /// %v2 = "test.get_some_value"() : () -> i64 /// scf.condition(%condition) %v1, %v2 : i32, i64 /// } do { /// ^bb0(%arg0: i32, %arg1: i64): /// "test.use"(%arg0) : (i32) -> () /// scf.yield /// } /// return %0#0 : i32 /// /// becomes /// %0 = scf.while () : () -> (i32) { /// %condition = "test.condition"() : () -> i1 /// %v1 = "test.get_some_value"() : () -> i32 /// %v2 = "test.get_some_value"() : () -> i64 /// scf.condition(%condition) %v1 : i32 /// } do { /// ^bb0(%arg0: i32): /// "test.use"(%arg0) : (i32) -> () /// scf.yield /// } /// return %0 : i32 struct WhileUnusedResult : public OpRewritePattern<WhileOp> { … }; /// Replace operations equivalent to the condition in the do block with true, /// since otherwise the block would not be evaluated. /// /// scf.while (..) : (i32, ...) -> ... { /// %z = ... : i32 /// %condition = cmpi pred %z, %a /// scf.condition(%condition) %z : i32, ... /// } do { /// ^bb0(%arg0: i32, ...): /// %condition2 = cmpi pred %arg0, %a /// use(%condition2) /// ... /// /// becomes /// scf.while (..) : (i32, ...) -> ... { /// %z = ... : i32 /// %condition = cmpi pred %z, %a /// scf.condition(%condition) %z : i32, ... /// } do { /// ^bb0(%arg0: i32, ...): /// use(%true) /// ... struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> { … }; /// Remove unused init/yield args. struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> { … }; /// Remove duplicated ConditionOp args. struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> { … }; /// If both ranges contain same values return mappping indices from args2 to /// args1. Otherwise return std::nullopt. static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1, ValueRange args2) { … } static bool hasDuplicates(ValueRange args) { … } /// If `before` block args are directly forwarded to `scf.condition`, rearrange /// `scf.condition` args into same order as block args. Update `after` block /// args and op result values accordingly. /// Needed to simplify `scf.while` -> `scf.for` uplifting. struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { … }; } // namespace void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // IndexSwitchOp //===----------------------------------------------------------------------===// /// Parse the case regions and values. static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) { … } /// Print the case regions and values. static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions) { … } LogicalResult scf::IndexSwitchOp::verify() { … } unsigned scf::IndexSwitchOp::getNumCases() { … } Block &scf::IndexSwitchOp::getDefaultBlock() { … } Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) { … } void IndexSwitchOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) { … } void IndexSwitchOp::getEntrySuccessorRegions( ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &successors) { … } void IndexSwitchOp::getRegionInvocationBounds( ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { … } struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> { … }; void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"