Skip to content

Commit 0c21dfd

Browse files
authored
[mlir][spirv] SCFToSPIRV: fix WhileOp block args types conversion (#68588)
WhileOp before/after block args types weren't converted, resulting in invalid IR.
1 parent 475d687 commit 0c21dfd

File tree

2 files changed

+65
-2
lines changed

2 files changed

+65
-2
lines changed

mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,11 +344,16 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
344344
auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
345345
loopOp.addEntryAndMergeBlock();
346346

347-
OpBuilder::InsertionGuard guard(rewriter);
348-
349347
Region &beforeRegion = whileOp.getBefore();
350348
Region &afterRegion = whileOp.getAfter();
351349

350+
if (failed(rewriter.convertRegionTypes(&beforeRegion, typeConverter)) ||
351+
failed(rewriter.convertRegionTypes(&afterRegion, typeConverter)))
352+
return rewriter.notifyMatchFailure(whileOp,
353+
"Failed to convert region types");
354+
355+
OpBuilder::InsertionGuard guard(rewriter);
356+
352357
Block &entryBlock = *loopOp.getEntryBlock();
353358
Block &beforeBlock = beforeRegion.front();
354359
Block &afterBlock = afterRegion.front();

mlir/test/Conversion/SCFToSPIRV/while.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,62 @@ func.func @while_loop2(%arg0: f32) -> i64 {
6969
return %res : i64
7070
}
7171

72+
// -----
73+
74+
// CHECK-LABEL: @while_loop_before_typeconv
75+
func.func @while_loop_before_typeconv(%arg0: index) -> i64 {
76+
// CHECK-SAME: (%[[ARG:.*]]: i32)
77+
// CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<i64, Function>
78+
// CHECK: spirv.mlir.loop {
79+
// CHECK: spirv.Branch ^[[HEADER:.*]](%[[ARG]] : i32)
80+
// CHECK: ^[[HEADER]](%[[INDVAR1:.*]]: i32):
81+
// CHECK: spirv.BranchConditional %{{.*}}, ^[[BODY:.*]](%{{.*}} : i64), ^[[MERGE:.*]]
82+
// CHECK: ^[[BODY]](%[[INDVAR2:.*]]: i64):
83+
// CHECK: spirv.Branch ^[[HEADER]](%{{.*}} : i32)
84+
// CHECK: ^[[MERGE]]:
85+
// CHECK: spirv.mlir.merge
86+
// CHECK: }
87+
%res = scf.while (%arg1 = %arg0) : (index) -> i64 {
88+
%shared = "foo.shared_compute"(%arg1) : (index) -> i64
89+
%condition = "foo.evaluate_condition"(%arg1, %shared) : (index, i64) -> i1
90+
scf.condition(%condition) %shared : i64
91+
} do {
92+
^bb0(%arg2: i64):
93+
%res = "foo.payload"(%arg2) : (i64) -> index
94+
scf.yield %res : index
95+
}
96+
// CHECK: %[[OUT:.*]] = spirv.Load "Function" %[[VAR]] : i64
97+
// CHECK: spirv.ReturnValue %[[OUT]] : i64
98+
return %res : i64
99+
}
100+
101+
// -----
102+
103+
// CHECK-LABEL: @while_loop_after_typeconv
104+
func.func @while_loop_after_typeconv(%arg0: f32) -> index {
105+
// CHECK-SAME: (%[[ARG:.*]]: f32)
106+
// CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<i32, Function>
107+
// CHECK: spirv.mlir.loop {
108+
// CHECK: spirv.Branch ^[[HEADER:.*]](%[[ARG]] : f32)
109+
// CHECK: ^[[HEADER]](%[[INDVAR1:.*]]: f32):
110+
// CHECK: spirv.BranchConditional %{{.*}}, ^[[BODY:.*]](%{{.*}} : i32), ^[[MERGE:.*]]
111+
// CHECK: ^[[BODY]](%[[INDVAR2:.*]]: i32):
112+
// CHECK: spirv.Branch ^[[HEADER]](%{{.*}} : f32)
113+
// CHECK: ^[[MERGE]]:
114+
// CHECK: spirv.mlir.merge
115+
// CHECK: }
116+
%res = scf.while (%arg1 = %arg0) : (f32) -> index {
117+
%shared = "foo.shared_compute"(%arg1) : (f32) -> index
118+
%condition = "foo.evaluate_condition"(%arg1, %shared) : (f32, index) -> i1
119+
scf.condition(%condition) %shared : index
120+
} do {
121+
^bb0(%arg2: index):
122+
%res = "foo.payload"(%arg2) : (index) -> f32
123+
scf.yield %res : f32
124+
}
125+
// CHECK: %[[OUT:.*]] = spirv.Load "Function" %[[VAR]] : i32
126+
// CHECK: spirv.ReturnValue %[[OUT]] : i32
127+
return %res : index
128+
}
129+
72130
} // end module

0 commit comments

Comments
 (0)