Skip to content

Commit a6d932b

Browse files
authored
[mlir][scf] Align scf.while before block args in canonicalizer (#76195)
If `before` block args are directly forwarded to `scf.condition` make sure they are passed in the same order. This is needed for `scf.while` uplifting #76108
1 parent cf61e34 commit a6d932b

File tree

2 files changed

+119
-1
lines changed

2 files changed

+119
-1
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3884,14 +3884,103 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
38843884
return success();
38853885
}
38863886
};
3887+
3888+
/// If both ranges contain same values return mappping indices from args2 to
3889+
/// args1. Otherwise return std::nullopt.
3890+
static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
3891+
ValueRange args2) {
3892+
if (args1.size() != args2.size())
3893+
return std::nullopt;
3894+
3895+
SmallVector<unsigned> ret(args1.size());
3896+
for (auto &&[i, arg1] : llvm::enumerate(args1)) {
3897+
auto it = llvm::find(args2, arg1);
3898+
if (it == args2.end())
3899+
return std::nullopt;
3900+
3901+
ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
3902+
}
3903+
3904+
return ret;
3905+
}
3906+
3907+
static bool hasDuplicates(ValueRange args) {
3908+
llvm::SmallDenseSet<Value> set;
3909+
for (Value arg : args) {
3910+
if (set.contains(arg))
3911+
return true;
3912+
3913+
set.insert(arg);
3914+
}
3915+
return false;
3916+
}
3917+
3918+
/// If `before` block args are directly forwarded to `scf.condition`, rearrange
3919+
/// `scf.condition` args into same order as block args. Update `after` block
3920+
/// args and op result values accordingly.
3921+
/// Needed to simplify `scf.while` -> `scf.for` uplifting.
3922+
struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
3923+
using OpRewritePattern::OpRewritePattern;
3924+
3925+
LogicalResult matchAndRewrite(WhileOp loop,
3926+
PatternRewriter &rewriter) const override {
3927+
auto oldBefore = loop.getBeforeBody();
3928+
ConditionOp oldTerm = loop.getConditionOp();
3929+
ValueRange beforeArgs = oldBefore->getArguments();
3930+
ValueRange termArgs = oldTerm.getArgs();
3931+
if (beforeArgs == termArgs)
3932+
return failure();
3933+
3934+
if (hasDuplicates(termArgs))
3935+
return failure();
3936+
3937+
auto mapping = getArgsMapping(beforeArgs, termArgs);
3938+
if (!mapping)
3939+
return failure();
3940+
3941+
{
3942+
OpBuilder::InsertionGuard g(rewriter);
3943+
rewriter.setInsertionPoint(oldTerm);
3944+
rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
3945+
beforeArgs);
3946+
}
3947+
3948+
auto oldAfter = loop.getAfterBody();
3949+
3950+
SmallVector<Type> newResultTypes(beforeArgs.size());
3951+
for (auto &&[i, j] : llvm::enumerate(*mapping))
3952+
newResultTypes[j] = loop.getResult(i).getType();
3953+
3954+
auto newLoop = rewriter.create<WhileOp>(
3955+
loop.getLoc(), newResultTypes, loop.getInits(),
3956+
/*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
3957+
auto newBefore = newLoop.getBeforeBody();
3958+
auto newAfter = newLoop.getAfterBody();
3959+
3960+
SmallVector<Value> newResults(beforeArgs.size());
3961+
SmallVector<Value> newAfterArgs(beforeArgs.size());
3962+
for (auto &&[i, j] : llvm::enumerate(*mapping)) {
3963+
newResults[i] = newLoop.getResult(j);
3964+
newAfterArgs[i] = newAfter->getArgument(j);
3965+
}
3966+
3967+
rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
3968+
newBefore->getArguments());
3969+
rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
3970+
newAfterArgs);
3971+
3972+
rewriter.replaceOp(loop, newResults);
3973+
return success();
3974+
}
3975+
};
38873976
} // namespace
38883977

38893978
void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
38903979
MLIRContext *context) {
38913980
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
38923981
RemoveLoopInvariantValueYielded, WhileConditionTruth,
38933982
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
3894-
WhileRemoveUnusedArgs>(context);
3983+
WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
38953984
}
38963985

38973986
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,35 @@ func.func @while_unused_arg2(%val0: i32) -> i32 {
11981198
// CHECK: return %[[RES]] : i32
11991199

12001200

1201+
// -----
1202+
1203+
// CHECK-LABEL: func @test_align_args
1204+
// CHECK: %[[RES:.*]]:3 = scf.while (%[[ARG0:.*]] = %{{.*}}, %[[ARG1:.*]] = %{{.*}}, %[[ARG2:.*]] = %{{.*}}) : (f32, i32, i64) -> (f32, i32, i64) {
1205+
// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG1]], %[[ARG2]] : f32, i32, i64
1206+
// CHECK: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i64):
1207+
// CHECK: %[[R1:.*]] = "test.test"(%[[ARG5]]) : (i64) -> f32
1208+
// CHECK: %[[R2:.*]] = "test.test"(%[[ARG3]]) : (f32) -> i32
1209+
// CHECK: %[[R3:.*]] = "test.test"(%[[ARG4]]) : (i32) -> i64
1210+
// CHECK: scf.yield %[[R1]], %[[R2]], %[[R3]] : f32, i32, i64
1211+
// CHECK: return %[[RES]]#2, %[[RES]]#0, %[[RES]]#1
1212+
func.func @test_align_args() -> (i64, f32, i32) {
1213+
%0 = "test.test"() : () -> (f32)
1214+
%1 = "test.test"() : () -> (i32)
1215+
%2 = "test.test"() : () -> (i64)
1216+
%3:3 = scf.while (%arg0 = %0, %arg1 = %1, %arg2 = %2) : (f32, i32, i64) -> (i64, f32, i32) {
1217+
%cond = "test.test"() : () -> (i1)
1218+
scf.condition(%cond) %arg2, %arg0, %arg1 : i64, f32, i32
1219+
} do {
1220+
^bb0(%arg3: i64, %arg4: f32, %arg5: i32):
1221+
%4 = "test.test"(%arg3) : (i64) -> (f32)
1222+
%5 = "test.test"(%arg4) : (f32) -> (i32)
1223+
%6 = "test.test"(%arg5) : (i32) -> (i64)
1224+
scf.yield %4, %5, %6 : f32, i32, i64
1225+
}
1226+
return %3#0, %3#1, %3#2 : i64, f32, i32
1227+
}
1228+
1229+
12011230
// -----
12021231

12031232
// CHECK-LABEL: @combineIfs

0 commit comments

Comments
 (0)