Skip to content

Commit 192cd68

Browse files
authored
Add checks before hoisting out in loop pipelining (#90872)
Currently, during a loop pipelining transformation, operations may be hoisted out without any checks on the loop bounds, which leads to incorrect transformations and unexpected behaviour. The following [issue ](#90870) describes the problem more extensively, including an example. The proposed fix adds some check in the loop bounds before and applies the maximum hoisting.
1 parent 1721c14 commit 192cd68

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ loopScheduling(scf::ForOp forOp,
261261
return 1;
262262
};
263263

264+
std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
265+
std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
264266
DenseMap<Operation *, unsigned> opCycles;
265267
std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
266268
for (Operation &op : forOp.getBody()->getOperations()) {
@@ -271,7 +273,14 @@ loopScheduling(scf::ForOp forOp,
271273
Operation *def = operand.getDefiningOp();
272274
if (!def)
273275
continue;
274-
earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
276+
if (ubConstant && lbConstant) {
277+
unsigned ubInt = ubConstant.value();
278+
unsigned lbInt = lbConstant.value();
279+
auto minLatency = std::min(ubInt - lbInt - 1, getLatency(def));
280+
earlyCycle = std::max(earlyCycle, opCycles[def] + minLatency);
281+
} else {
282+
earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
283+
}
275284
}
276285
opCycles[&op] = earlyCycle;
277286
wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);

mlir/test/Dialect/SCF/transform-ops.mlir

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,60 @@ module attributes {transform.with_named_sequence} {
300300
transform.yield
301301
}
302302
}
303+
304+
305+
// -----
306+
307+
// CHECK-LABEL: func.func @loop_pipeline
308+
func.func @loop_pipeline(%arg0: memref<4x16xf32>, %arg1: vector<16xf32>) -> vector<16xf32> {
309+
%c0 = arith.constant 0 : index
310+
%c1 = arith.constant 1 : index
311+
%cst = arith.constant 0.000000e+00 : f32
312+
%c3 = arith.constant 3 : index
313+
// CHECK: vector.transfer_read
314+
// CHECK: vector.transfer_read
315+
// CHECK: vector.transfer_read
316+
// CHECK: arith.addf
317+
// CHECK: arith.addf
318+
// CHECK: arith.addf
319+
%0 = scf.for %arg2 = %c0 to %c3 step %c1 iter_args(%arg3 = %arg1) -> (vector<16xf32>) {
320+
%1 = vector.transfer_read %arg0[%arg2, %c0], %cst {in_bounds = [true]} : memref<4x16xf32>, vector<16xf32>
321+
%2 = arith.addf %1, %arg3 : vector<16xf32>
322+
scf.yield %2 : vector<16xf32>
323+
}
324+
return %0 : vector<16xf32>
325+
}
326+
module attributes {transform.with_named_sequence} {
327+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
328+
%0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.for">
329+
%1 = transform.loop.pipeline %0 {iteration_interval = 1 : i64, read_latency = 5 : i64, scheduling_type = "full-loops"} : (!transform.op<"scf.for">) -> !transform.any_op
330+
transform.yield
331+
}
332+
}
333+
334+
335+
// -----
336+
337+
// CHECK-LABEL: func.func @loop_pipeline_lb_gt_0
338+
func.func @loop_pipeline_lb_gt_0(%arg0: memref<4x16xf32>, %arg1: vector<16xf32>) -> vector<16xf32> {
339+
%c1 = arith.constant 1 : index
340+
%cst = arith.constant 0.000000e+00 : f32
341+
%c3 = arith.constant 3 : index
342+
// CHECK: vector.transfer_read
343+
// CHECK: vector.transfer_read
344+
// CHECK: arith.addf
345+
// CHECK: arith.addf
346+
%0 = scf.for %arg2 = %c1 to %c3 step %c1 iter_args(%arg3 = %arg1) -> (vector<16xf32>) {
347+
%1 = vector.transfer_read %arg0[%arg2, %c1], %cst {in_bounds = [true]} : memref<4x16xf32>, vector<16xf32>
348+
%2 = arith.addf %1, %arg3 : vector<16xf32>
349+
scf.yield %2 : vector<16xf32>
350+
}
351+
return %0 : vector<16xf32>
352+
}
353+
module attributes {transform.with_named_sequence} {
354+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
355+
%0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.for">
356+
%1 = transform.loop.pipeline %0 {iteration_interval = 1 : i64, read_latency = 5 : i64, scheduling_type = "full-loops"} : (!transform.op<"scf.for">) -> !transform.any_op
357+
transform.yield
358+
}
359+
}

0 commit comments

Comments
 (0)