Skip to content

Commit 1ca6b44

Browse files
authored
[mlir][scf] scf.while uplifting: optimize op matching (#88813)
Instead of iterating over potential induction var uses looking for suitable `arith.addi`, try to trace it back from yield argument.
1 parent 61717c1 commit 1ca6b44

File tree

1 file changed

+14
-22
lines changed

1 file changed

+14
-22
lines changed

mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -101,38 +101,30 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
101101

102102
Block *afterBody = loop.getAfterBody();
103103
scf::YieldOp afterTerm = loop.getYieldOp();
104-
auto argNumber = inductionVar.getArgNumber();
105-
auto afterTermIndArg = afterTerm.getResults()[argNumber];
104+
unsigned argNumber = inductionVar.getArgNumber();
105+
Value afterTermIndArg = afterTerm.getResults()[argNumber];
106106

107-
auto inductionVarAfter = afterBody->getArgument(argNumber);
108-
109-
Value step;
107+
Value inductionVarAfter = afterBody->getArgument(argNumber);
110108

111109
// Find suitable `addi` op inside `after` block, one of the args must be an
112110
// Induction var passed from `before` block and second arg must be defined
113111
// outside of the loop and will be considered step value.
114112
// TODO: Add `subi` support?
115-
for (auto &use : inductionVarAfter.getUses()) {
116-
auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
117-
if (!owner)
118-
continue;
119-
120-
auto other =
121-
(inductionVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
122-
if (!dom.properlyDominates(other, loop))
123-
continue;
124-
125-
if (afterTermIndArg != owner.getResult())
126-
continue;
113+
auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
114+
if (!addOp)
115+
return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
127116

128-
step = other;
129-
break;
117+
Value step;
118+
if (addOp.getLhs() == inductionVarAfter) {
119+
step = addOp.getRhs();
120+
} else if (addOp.getRhs() == inductionVarAfter) {
121+
step = addOp.getLhs();
130122
}
131123

132-
if (!step)
133-
return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
124+
if (!step || !dom.properlyDominates(step, loop))
125+
return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
134126

135-
auto lb = loop.getInits()[argNumber];
127+
Value lb = loop.getInits()[argNumber];
136128

137129
assert(lb.getType().isIntOrIndex());
138130
assert(lb.getType() == ub.getType());

0 commit comments

Comments
 (0)