Skip to content

Commit e8f07cd

Browse files
authored
[MLIR][SCF] Define -scf-rotate-while pass (llvm#99850)
Define SCF dialect patterns rotating `scf.while` loops leveraging existing `mlir::scf::wrapWhileLoopInZeroTripCheck`. `forceCreateCheck` is always `false` as the pattern would lead to an infinite recursion otherwise. This pattern rotates `scf.while` ops, mutating them from "while" loops to "do-while" loops. A guard checking the condition for the first iteration is inserted. Note this guard can be optimized away if the compiler can prove the loop will be executed at least once. Using this pattern, the following while loop: ```mlir scf.while (%arg0 = %init) : (i32) -> i64 { %val = .., %arg0 : i64 %cond = arith.cmpi .., %arg0 : i32 scf.condition(%cond) %val : i64 } do { ^bb0(%arg1: i64): %next = .., %arg1 : i32 scf.yield %next : i32 } ``` Can be transformed into: ``` mlir %pre_val = .., %init : i64 %pre_cond = arith.cmpi .., %init : i32 scf.if %pre_cond -> i64 { %res = scf.while (%arg1 = %va0) : (i64) -> i64 { // Original after block %next = .., %arg1 : i32 // Original before block %val = .., %next : i64 %cond = arith.cmpi .., %next : i32 scf.condition(%cond) %val : i64 } do { ^bb0(%arg2: i64): %scf.yield %arg2 : i32 } scf.yield %res : i64 } else { scf.yield %pre_val : i64 } ``` The test pass for `wrapWhileLoopInZeroTripCheck` has been modified to use the new pattern when `forceCreateCheck=false`. --------- Signed-off-by: Victor Perez <[email protected]>
1 parent 78e7ec3 commit e8f07cd

File tree

6 files changed

+70
-9
lines changed

6 files changed

+70
-9
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
8585
/// * `after` block containing arith.addi
8686
void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
8787

88+
/// Populate patterns to rotate `scf.while` ops, constructing `do-while` loops
89+
/// from `while` loops.
90+
void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns);
8891
} // namespace scf
8992
} // namespace mlir
9093

mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,11 @@ FailureOr<ForOp> pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
228228
/// } else {
229229
/// scf.yield %pre_val : i64
230230
/// }
231+
///
232+
/// Failure mechanism is not implemented for this function, so it currently
233+
/// always returns a `WhileOp` operation: a new one if the transformation took
234+
/// place or the input `whileOp` if the loop was already in a `do-while` form
235+
/// and `forceCreateCheck` is `false`.
231236
FailureOr<WhileOp> wrapWhileLoopInZeroTripCheck(WhileOp whileOp,
232237
RewriterBase &rewriter,
233238
bool forceCreateCheck = false);

mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
1313
ParallelLoopCollapsing.cpp
1414
ParallelLoopFusion.cpp
1515
ParallelLoopTiling.cpp
16+
RotateWhileLoop.cpp
1617
StructuralTypeConversions.cpp
1718
TileUsingInterface.cpp
1819
WrapInZeroTripCheck.cpp
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//===- RotateWhileLoop.cpp - scf.while loop rotation ----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Rotates `scf.while` loops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
14+
15+
#include "mlir/Dialect/SCF/IR/SCF.h"
16+
17+
using namespace mlir;
18+
19+
namespace {
20+
struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
21+
using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
22+
23+
LogicalResult matchAndRewrite(scf::WhileOp whileOp,
24+
PatternRewriter &rewriter) const final {
25+
// Setting this option would lead to infinite recursion on a greedy driver
26+
// as 'do-while' loops wouldn't be skipped.
27+
constexpr bool forceCreateCheck = false;
28+
FailureOr<scf::WhileOp> result =
29+
scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter, forceCreateCheck);
30+
// scf::wrapWhileLoopInZeroTripCheck hasn't yet implemented a failure
31+
// mechanism. 'do-while' loops are simply returned unmodified. In order to
32+
// stop recursion, we check input and output operations differ.
33+
return success(succeeded(result) && *result != whileOp);
34+
}
35+
};
36+
} // namespace
37+
38+
namespace mlir {
39+
namespace scf {
40+
void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns) {
41+
patterns.add<RotateWhileLoopPattern>(patterns.getContext());
42+
}
43+
} // namespace scf
44+
} // namespace mlir

mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func.func @wrap_while_loop_in_zero_trip_check(%bound : i32) -> i32 {
2020
// CHECK-SAME: %[[BOUND:.*]]: i32) -> i32 {
2121
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
2222
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
23-
// CHECK-DAG: %[[PRE_COND:.*]] = arith.cmpi slt, %[[C0]], %[[BOUND]] : i32
23+
// CHECK-DAG: %[[PRE_COND:.*]] = arith.cmpi sgt, %[[BOUND]], %[[C0]] : i32
2424
// CHECK-DAG: %[[PRE_INV:.*]] = arith.addi %[[BOUND]], %[[C5]] : i32
2525
// CHECK: %[[IF:.*]]:2 = scf.if %[[PRE_COND]] -> (i32, i32) {
2626
// CHECK: %[[WHILE:.*]]:2 = scf.while (

mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- TestWrapInZeroTripCheck.cpp -- Passes to test SCF zero-trip-check --===//
1+
//===- TestSCFWrapInZeroTripCheck.cpp -- Pass to test SCF zero-trip-check -===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -13,9 +13,11 @@
1313

1414
#include "mlir/Dialect/Func/IR/FuncOps.h"
1515
#include "mlir/Dialect/SCF/IR/SCF.h"
16+
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
1617
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
1718
#include "mlir/IR/PatternMatch.h"
1819
#include "mlir/Pass/Pass.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1921

2022
using namespace mlir;
2123

@@ -46,13 +48,19 @@ struct TestWrapWhileLoopInZeroTripCheckPass
4648
func::FuncOp func = getOperation();
4749
MLIRContext *context = &getContext();
4850
IRRewriter rewriter(context);
49-
func.walk([&](scf::WhileOp op) {
50-
FailureOr<scf::WhileOp> result =
51-
scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
52-
// Ignore not implemented failure in tests. The expected output should
53-
// catch problems (e.g. transformation doesn't happen).
54-
(void)result;
55-
});
51+
if (forceCreateCheck) {
52+
func.walk([&](scf::WhileOp op) {
53+
FailureOr<scf::WhileOp> result =
54+
scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
55+
// Ignore not implemented failure in tests. The expected output should
56+
// catch problems (e.g. transformation doesn't happen).
57+
(void)result;
58+
});
59+
} else {
60+
RewritePatternSet patterns(context);
61+
scf::populateSCFRotateWhileLoopPatterns(patterns);
62+
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
63+
}
5664
}
5765

5866
Option<bool> forceCreateCheck{

0 commit comments

Comments
 (0)