//===- WrapInZeroTripCheck.cpp - Loop transforms to add zero-trip-check ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" usingnamespacemlir; /// Create zero-trip-check around a `while` op and return the new loop op in the /// check. The while loop is rotated to avoid evaluating the condition twice. /// /// Given an example below: /// /// scf.while (%arg0 = %init) : (i32) -> i64 { /// %val = .., %arg0 : i64 /// %cond = arith.cmpi .., %arg0 : i32 /// scf.condition(%cond) %val : i64 /// } do { /// ^bb0(%arg1: i64): /// %next = .., %arg1 : i32 /// scf.yield %next : i32 /// } /// /// First clone before block to the front of the loop: /// /// %pre_val = .., %init : i64 /// %pre_cond = arith.cmpi .., %init : i32 /// scf.while (%arg0 = %init) : (i32) -> i64 { /// %val = .., %arg0 : i64 /// %cond = arith.cmpi .., %arg0 : i32 /// scf.condition(%cond) %val : i64 /// } do { /// ^bb0(%arg1: i64): /// %next = .., %arg1 : i32 /// scf.yield %next : i32 /// } /// /// Create `if` op with the condition, rotate and move the loop into the else /// branch: /// /// %pre_val = .., %init : i64 /// %pre_cond = arith.cmpi .., %init : i32 /// scf.if %pre_cond -> i64 { /// %res = scf.while (%arg1 = %va0) : (i64) -> i64 { /// // Original after block /// %next = .., %arg1 : i32 /// // Original before block /// %val = .., %next : i64 /// %cond = arith.cmpi .., %next : i32 /// scf.condition(%cond) %val : i64 /// } do { /// ^bb0(%arg2: i64): /// %scf.yield %arg2 : i32 /// } /// scf.yield %res : i64 /// } else { /// scf.yield %pre_val : i64 /// } FailureOr<scf::WhileOp> mlir::scf::wrapWhileLoopInZeroTripCheck( scf::WhileOp whileOp, RewriterBase &rewriter, bool forceCreateCheck) { … }