Skip to content

Commit 8867182

Browse files
banach-spacelegrosbuffle
authored andcommitted
[mlir][transform] Update transform.loop.peel (reland llvm#67482)
This patch updates `transform.loop.peel` so that this Op returns two rather than one handle: * one for the peeled loop, and * one for the remainder loop. Also, following this change this Op will fail if peeling fails. This is consistent with other similar Ops that also fail if no transformation takes place. Relands llvm#67482 with an extra fix for transform_loop_ext.py
1 parent 6a535b9 commit 8867182

File tree

7 files changed

+52
-24
lines changed

7 files changed

+52
-24
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -142,23 +142,22 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
142142

143143
This operation ignores non-scf::ForOp ops and drops them in the return.
144144

145-
This operation always succeeds and returns the scf::ForOp with the
146-
postcondition: "the loop trip count is divisible by the step".
147-
This operation may return the same unmodified loop handle when peeling did
148-
not modify the IR (i.e. the loop trip count was already divisible).
145+
This operation returns two scf::ForOp Ops, with the first Op satisfying
146+
the postcondition: "the loop trip count is divisible by the step". The
147+
second loop Op contains the remaining iteration. Note that even though the
148+
Payload IR modification may be performed in-place, this operation consumes
149+
the operand handle and produces a new one.
149150

150-
Note that even though the Payload IR modification may be performed
151-
in-place, this operation consumes the operand handle and produces a new
152-
one.
151+
#### Return Modes
153152

154-
TODO: Return both the peeled loop and the remainder loop.
153+
Produces a definite failure if peeling fails.
155154
}];
156155

157156
let arguments =
158157
(ins Transform_ScfForOp:$target,
159158
DefaultValuedAttr<BoolAttr, "false">:$fail_if_already_divisible);
160-
// TODO: Return both the peeled loop and the remainder loop.
161-
let results = (outs TransformHandleTypeInterface:$transformed);
159+
let results = (outs TransformHandleTypeInterface:$peeled_loop,
160+
TransformHandleTypeInterface:$remainder_loop);
162161

163162
let assemblyFormat =
164163
"$target attr-dict `:` functional-type(operands, results)";

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,14 +226,16 @@ transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
226226
transform::ApplyToEachResultList &results,
227227
transform::TransformState &state) {
228228
scf::ForOp result;
229-
// This helper returns failure when peeling does not occur (i.e. when the IR
230-
// is not modified). This is not a failure for the op as the postcondition:
231-
// "the loop trip count is divisible by the step"
232-
// is valid.
233229
LogicalResult status =
234230
scf::peelForLoopAndSimplifyBounds(rewriter, target, result);
235-
// TODO: Return both the peeled loop and the remainder loop.
236-
results.push_back(failed(status) ? target : result);
231+
if (failed(status)) {
232+
DiagnosedSilenceableFailure diag = emitSilenceableError()
233+
<< "failed to peel";
234+
return diag;
235+
}
236+
results.push_back(target);
237+
results.push_back(result);
238+
237239
return DiagnosedSilenceableFailure::success();
238240
}
239241

mlir/python/mlir/dialects/_loop_transform_ops_ext.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,17 @@ class LoopPeelOp:
6666

6767
def __init__(
6868
self,
69-
result_type: Type,
69+
main_loop_type: Type,
70+
remainder_loop_type: Type,
7071
target: Union[Operation, Value],
7172
*,
7273
fail_if_already_divisible: Union[bool, BoolAttr] = False,
7374
ip=None,
7475
loc=None,
7576
):
7677
super().__init__(
77-
result_type,
78+
main_loop_type,
79+
remainder_loop_type,
7880
_get_op_result_or_value(target),
7981
fail_if_already_divisible=(
8082
fail_if_already_divisible

mlir/test/Dialect/Linalg/transform-op-fuse.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ transform.sequence failures(propagate) {
4848
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
4949
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
5050
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
51-
transform.loop.peel %loops#0 : (!transform.op<"scf.for">) -> !transform.any_op
51+
transform.loop.peel %loops#0 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
5252
}
5353

5454
// -----

mlir/test/Dialect/SCF/transform-ops-invalid.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,23 @@ transform.sequence failures(propagate) {
5959
// expected-error @below {{failed to outline}}
6060
transform.loop.outline %0 {func_name = "foo"} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
6161
}
62+
63+
// -----
64+
65+
func.func @test_loops_do_not_get_peeled() {
66+
%lb = arith.constant 0 : index
67+
%ub = arith.constant 40 : index
68+
%step = arith.constant 5 : index
69+
scf.for %i = %lb to %ub step %step {
70+
arith.addi %i, %i : index
71+
}
72+
return
73+
}
74+
75+
transform.sequence failures(propagate) {
76+
^bb1(%arg1: !transform.any_op):
77+
%0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
78+
%1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
79+
// expected-error @below {{failed to peel}}
80+
transform.loop.peel %1 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
81+
}

mlir/test/Dialect/SCF/transform-ops.mlir

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,18 @@ transform.sequence failures(propagate) {
8787
// CHECK-LABEL: @loop_peel_op
8888
func.func @loop_peel_op() {
8989
// CHECK: %[[C0:.+]] = arith.constant 0
90-
// CHECK: %[[C42:.+]] = arith.constant 42
90+
// CHECK: %[[C41:.+]] = arith.constant 41
9191
// CHECK: %[[C5:.+]] = arith.constant 5
9292
// CHECK: %[[C40:.+]] = arith.constant 40
9393
// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C40]] step %[[C5]]
9494
// CHECK: arith.addi
95-
// CHECK: scf.for %{{.+}} = %[[C40]] to %[[C42]] step %[[C5]]
95+
// CHECK: scf.for %{{.+}} = %[[C40]] to %[[C41]] step %[[C5]]
9696
// CHECK: arith.addi
9797
%0 = arith.constant 0 : index
98-
%1 = arith.constant 42 : index
98+
%1 = arith.constant 41 : index
9999
%2 = arith.constant 5 : index
100+
// expected-remark @below {{main loop}}
101+
// expected-remark @below {{remainder loop}}
100102
scf.for %i = %0 to %1 step %2 {
101103
arith.addi %i, %i : index
102104
}
@@ -107,7 +109,10 @@ transform.sequence failures(propagate) {
107109
^bb1(%arg1: !transform.any_op):
108110
%0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
109111
%1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
110-
transform.loop.peel %1 : (!transform.op<"scf.for">) -> !transform.any_op
112+
%main_loop, %remainder = transform.loop.peel %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">, !transform.op<"scf.for">)
113+
// Make sure
114+
transform.test_print_remark_at_operand %main_loop, "main loop" : !transform.op<"scf.for">
115+
transform.test_print_remark_at_operand %remainder, "remainder loop" : !transform.op<"scf.for">
111116
}
112117

113118
// -----

mlir/test/python/dialects/transform_loop_ext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def loopPeel():
5959
transform.OperationType.get("scf.for"),
6060
)
6161
with InsertionPoint(sequence.body):
62-
loop.LoopPeelOp(pdl.OperationType.get(), sequence.bodyTarget)
62+
loop.LoopPeelOp(transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget)
6363
transform.YieldOp()
6464
# CHECK-LABEL: TEST: loopPeel
6565
# CHECK: = transform.loop.peel %

0 commit comments

Comments
 (0)