Skip to content

Commit 7c90081

Browse files
[SCF][PIPELINE] Handle the case when values from the peeled prologue may escape out of the loop (#105755)
Previously the values in the peeled prologue that weren't treated with the `predicateFn` were passed to the loop body without any other predication. If those values are later used outside of the loop body, they may be incorrect if the num iterations is smaller than num stages - 1. We need similar masking for those, as is done in the main loop body, using already existing predicates.
1 parent 7f37932 commit 7c90081

File tree

2 files changed

+32
-14
lines changed

2 files changed

+32
-14
lines changed

mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
268268
}
269269

270270
void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
271-
// Initialize the iteration argument to the loop initiale values.
271+
// Initialize the iteration argument to the loop initial values.
272272
for (auto [arg, operand] :
273273
llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
274274
setValueMapping(arg, operand.get(), 0);
@@ -320,16 +320,26 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
320320
if (annotateFn)
321321
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
322322
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
323-
setValueMapping(op->getResult(destId), newOp->getResult(destId),
324-
i - stages[op]);
323+
Value source = newOp->getResult(destId);
325324
// If the value is a loop carried dependency update the loop argument
326-
// mapping.
327325
for (OpOperand &operand : yield->getOpOperands()) {
328326
if (operand.get() != op->getResult(destId))
329327
continue;
328+
if (predicates[predicateIdx] &&
329+
!forOp.getResult(operand.getOperandNumber()).use_empty()) {
330+
// If the value is used outside the loop, we need to make sure we
331+
// return the correct version of it.
332+
Value prevValue = valueMapping
333+
[forOp.getRegionIterArgs()[operand.getOperandNumber()]]
334+
[i - stages[op]];
335+
source = rewriter.create<arith::SelectOp>(
336+
loc, predicates[predicateIdx], source, prevValue);
337+
}
330338
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
331-
newOp->getResult(destId), i - stages[op] + 1);
339+
source, i - stages[op] + 1);
332340
}
341+
setValueMapping(op->getResult(destId), newOp->getResult(destId),
342+
i - stages[op]);
333343
}
334344
}
335345
}

mlir/test/Dialect/SCF/loop-pipelining.mlir

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -703,18 +703,26 @@ func.func @distance_1_use(%A: memref<?xf32>, %result: memref<?xf32>) {
703703
// -----
704704

705705
// NOEPILOGUE-LABEL: stage_0_value_escape(
706-
func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>) {
706+
func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub: index) {
707707
%c0 = arith.constant 0 : index
708708
%c1 = arith.constant 1 : index
709-
%c4 = arith.constant 4 : index
710709
%cf = arith.constant 1.0 : f32
711-
// NOEPILOGUE: %[[C3:.+]] = arith.constant 3 : index
712-
// NOEPILOGUE: %[[A:.+]] = arith.addf
713-
// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[A]],
714-
// NOEPILOGUE: %[[C:.+]] = arith.cmpi slt, %[[IV]], %[[C3]] : index
715-
// NOEPILOGUE: %[[S:.+]] = arith.select %[[C]], %{{.+}}, %[[ARG]] : f32
716-
// NOEPILOGUE: scf.yield %[[S]]
717-
%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
710+
// NOEPILOGUE: %[[UB:[^,]+]]: index)
711+
// NOEPILOGUE-DAG: %[[C0:.+]] = arith.constant 0 : index
712+
// NOEPILOGUE-DAG: %[[C1:.+]] = arith.constant 1 : index
713+
// NOEPILOGUE-DAG: %[[CF:.+]] = arith.constant 1.000000e+00
714+
// NOEPILOGUE: %[[CND0:.+]] = arith.cmpi sgt, %[[UB]], %[[C0]]
715+
// NOEPILOGUE: scf.if
716+
// NOEPILOGUE: %[[IF:.+]] = scf.if %[[CND0]]
717+
// NOEPILOGUE: %[[A:.+]] = arith.addf
718+
// NOEPILOGUE: scf.yield %[[A]]
719+
// NOEPILOGUE: %[[S0:.+]] = arith.select %[[CND0]], %[[IF]], %[[CF]]
720+
// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[S0]],
721+
// NOEPILOGUE: %[[UB_1:.+]] = arith.subi %[[UB]], %[[C1]] : index
722+
// NOEPILOGUE: %[[CND1:.+]] = arith.cmpi slt, %[[IV]], %[[UB_1]] : index
723+
// NOEPILOGUE: %[[S1:.+]] = arith.select %[[CND1]], %{{.+}}, %[[ARG]] : f32
724+
// NOEPILOGUE: scf.yield %[[S1]]
725+
%r = scf.for %i0 = %c0 to %ub step %c1 iter_args(%arg0 = %cf) -> (f32) {
718726
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref<?xf32>
719727
%A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
720728
memref.store %A1_elem, %result[%c0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>

0 commit comments

Comments
 (0)