@@ -3884,14 +3884,103 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
3884
3884
return success ();
3885
3885
}
3886
3886
};
3887
+
3888
+ // / If both ranges contain same values return mappping indices from args2 to
3889
+ // / args1. Otherwise return std::nullopt.
3890
+ static std::optional<SmallVector<unsigned >> getArgsMapping (ValueRange args1,
3891
+ ValueRange args2) {
3892
+ if (args1.size () != args2.size ())
3893
+ return std::nullopt;
3894
+
3895
+ SmallVector<unsigned > ret (args1.size ());
3896
+ for (auto &&[i, arg1] : llvm::enumerate (args1)) {
3897
+ auto it = llvm::find (args2, arg1);
3898
+ if (it == args2.end ())
3899
+ return std::nullopt;
3900
+
3901
+ ret[std::distance (args2.begin (), it)] = static_cast <unsigned >(i);
3902
+ }
3903
+
3904
+ return ret;
3905
+ }
3906
+
3907
+ static bool hasDuplicates (ValueRange args) {
3908
+ llvm::SmallDenseSet<Value> set;
3909
+ for (Value arg : args) {
3910
+ if (set.contains (arg))
3911
+ return true ;
3912
+
3913
+ set.insert (arg);
3914
+ }
3915
+ return false ;
3916
+ }
3917
+
3918
+ // / If `before` block args are directly forwarded to `scf.condition`, rearrange
3919
+ // / `scf.condition` args into same order as block args. Update `after` block
3920
+ // / args and op result values accordingly.
3921
+ // / Needed to simplify `scf.while` -> `scf.for` uplifting.
3922
+ struct WhileOpAlignBeforeArgs : public OpRewritePattern <WhileOp> {
3923
+ using OpRewritePattern::OpRewritePattern;
3924
+
3925
+ LogicalResult matchAndRewrite (WhileOp loop,
3926
+ PatternRewriter &rewriter) const override {
3927
+ auto oldBefore = loop.getBeforeBody ();
3928
+ ConditionOp oldTerm = loop.getConditionOp ();
3929
+ ValueRange beforeArgs = oldBefore->getArguments ();
3930
+ ValueRange termArgs = oldTerm.getArgs ();
3931
+ if (beforeArgs == termArgs)
3932
+ return failure ();
3933
+
3934
+ if (hasDuplicates (termArgs))
3935
+ return failure ();
3936
+
3937
+ auto mapping = getArgsMapping (beforeArgs, termArgs);
3938
+ if (!mapping)
3939
+ return failure ();
3940
+
3941
+ {
3942
+ OpBuilder::InsertionGuard g (rewriter);
3943
+ rewriter.setInsertionPoint (oldTerm);
3944
+ rewriter.replaceOpWithNewOp <ConditionOp>(oldTerm, oldTerm.getCondition (),
3945
+ beforeArgs);
3946
+ }
3947
+
3948
+ auto oldAfter = loop.getAfterBody ();
3949
+
3950
+ SmallVector<Type> newResultTypes (beforeArgs.size ());
3951
+ for (auto &&[i, j] : llvm::enumerate (*mapping))
3952
+ newResultTypes[j] = loop.getResult (i).getType ();
3953
+
3954
+ auto newLoop = rewriter.create <WhileOp>(
3955
+ loop.getLoc (), newResultTypes, loop.getInits (),
3956
+ /* beforeBuilder=*/ nullptr , /* afterBuilder=*/ nullptr );
3957
+ auto newBefore = newLoop.getBeforeBody ();
3958
+ auto newAfter = newLoop.getAfterBody ();
3959
+
3960
+ SmallVector<Value> newResults (beforeArgs.size ());
3961
+ SmallVector<Value> newAfterArgs (beforeArgs.size ());
3962
+ for (auto &&[i, j] : llvm::enumerate (*mapping)) {
3963
+ newResults[i] = newLoop.getResult (j);
3964
+ newAfterArgs[i] = newAfter->getArgument (j);
3965
+ }
3966
+
3967
+ rewriter.inlineBlockBefore (oldBefore, newBefore, newBefore->begin (),
3968
+ newBefore->getArguments ());
3969
+ rewriter.inlineBlockBefore (oldAfter, newAfter, newAfter->begin (),
3970
+ newAfterArgs);
3971
+
3972
+ rewriter.replaceOp (loop, newResults);
3973
+ return success ();
3974
+ }
3975
+ };
3887
3976
} // namespace
3888
3977
3889
3978
void WhileOp::getCanonicalizationPatterns (RewritePatternSet &results,
3890
3979
MLIRContext *context) {
3891
3980
results.add <RemoveLoopInvariantArgsFromBeforeBlock,
3892
3981
RemoveLoopInvariantValueYielded, WhileConditionTruth,
3893
3982
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
3894
- WhileRemoveUnusedArgs>(context);
3983
+ WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs >(context);
3895
3984
}
3896
3985
3897
3986
// ===----------------------------------------------------------------------===//
0 commit comments