@@ -101,38 +101,30 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
101
101
102
102
Block *afterBody = loop.getAfterBody ();
103
103
scf::YieldOp afterTerm = loop.getYieldOp ();
104
- auto argNumber = inductionVar.getArgNumber ();
105
- auto afterTermIndArg = afterTerm.getResults ()[argNumber];
104
+ unsigned argNumber = inductionVar.getArgNumber ();
105
+ Value afterTermIndArg = afterTerm.getResults ()[argNumber];
106
106
107
- auto inductionVarAfter = afterBody->getArgument (argNumber);
108
-
109
- Value step;
107
+ Value inductionVarAfter = afterBody->getArgument (argNumber);
110
108
111
109
// Find suitable `addi` op inside `after` block, one of the args must be an
112
110
// Induction var passed from `before` block and second arg must be defined
113
111
// outside of the loop and will be considered step value.
114
112
// TODO: Add `subi` support?
115
- for (auto &use : inductionVarAfter.getUses ()) {
116
- auto owner = dyn_cast<arith::AddIOp>(use.getOwner ());
117
- if (!owner)
118
- continue ;
119
-
120
- auto other =
121
- (inductionVarAfter == owner.getLhs () ? owner.getRhs () : owner.getLhs ());
122
- if (!dom.properlyDominates (other, loop))
123
- continue ;
124
-
125
- if (afterTermIndArg != owner.getResult ())
126
- continue ;
113
+ auto addOp = afterTermIndArg.getDefiningOp <arith::AddIOp>();
114
+ if (!addOp)
115
+ return rewriter.notifyMatchFailure (loop, " Didn't found suitable 'addi' op" );
127
116
128
- step = other;
129
- break ;
117
+ Value step;
118
+ if (addOp.getLhs () == inductionVarAfter) {
119
+ step = addOp.getRhs ();
120
+ } else if (addOp.getRhs () == inductionVarAfter) {
121
+ step = addOp.getLhs ();
130
122
}
131
123
132
- if (!step)
133
- return rewriter.notifyMatchFailure (loop, " Didn't found suitable 'addi' op " );
124
+ if (!step || !dom. properlyDominates (step, loop) )
125
+ return rewriter.notifyMatchFailure (loop, " Invalid 'addi' form " );
134
126
135
- auto lb = loop.getInits ()[argNumber];
127
+ Value lb = loop.getInits ()[argNumber];
136
128
137
129
assert (lb.getType ().isIntOrIndex ());
138
130
assert (lb.getType () == ub.getType ());
0 commit comments