@@ -910,61 +910,98 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
910
910
unsigned numTargetOuts = target.getNumResults ();
911
911
unsigned numSourceOuts = source.getNumResults ();
912
912
913
- OperandRange targetOuts = target.getOutputs ();
914
- OperandRange sourceOuts = source.getOutputs ();
915
-
916
913
// Create fused shared_outs.
917
914
SmallVector<Value> fusedOuts;
918
- fusedOuts.reserve (numTargetOuts + numSourceOuts);
919
- fusedOuts.append (targetOuts.begin (), targetOuts.end ());
920
- fusedOuts.append (sourceOuts.begin (), sourceOuts.end ());
915
+ llvm::append_range (fusedOuts, target.getOutputs ());
916
+ llvm::append_range (fusedOuts, source.getOutputs ());
921
917
922
- // Create a new scf:: forall op after the source loop.
918
+ // Create a new scf. forall op after the source loop.
923
919
rewriter.setInsertionPointAfter (source);
924
920
scf::ForallOp fusedLoop = rewriter.create <scf::ForallOp>(
925
921
source.getLoc (), source.getMixedLowerBound (), source.getMixedUpperBound (),
926
922
source.getMixedStep (), fusedOuts, source.getMapping ());
927
923
928
924
// Map control operands.
929
- IRMapping fusedMapping ;
930
- fusedMapping .map (target.getInductionVars (), fusedLoop.getInductionVars ());
931
- fusedMapping .map (source.getInductionVars (), fusedLoop.getInductionVars ());
925
+ IRMapping mapping ;
926
+ mapping .map (target.getInductionVars (), fusedLoop.getInductionVars ());
927
+ mapping .map (source.getInductionVars (), fusedLoop.getInductionVars ());
932
928
933
929
// Map shared outs.
934
- fusedMapping.map (target.getRegionIterArgs (),
935
- fusedLoop.getRegionIterArgs ().slice (0 , numTargetOuts));
936
- fusedMapping.map (
937
- source.getRegionIterArgs (),
938
- fusedLoop.getRegionIterArgs ().slice (numTargetOuts, numSourceOuts));
930
+ mapping.map (target.getRegionIterArgs (),
931
+ fusedLoop.getRegionIterArgs ().take_front (numTargetOuts));
932
+ mapping.map (source.getRegionIterArgs (),
933
+ fusedLoop.getRegionIterArgs ().take_back (numSourceOuts));
939
934
940
935
// Append everything except the terminator into the fused operation.
941
936
rewriter.setInsertionPointToStart (fusedLoop.getBody ());
942
937
for (Operation &op : target.getBody ()->without_terminator ())
943
- rewriter.clone (op, fusedMapping );
938
+ rewriter.clone (op, mapping );
944
939
for (Operation &op : source.getBody ()->without_terminator ())
945
- rewriter.clone (op, fusedMapping );
940
+ rewriter.clone (op, mapping );
946
941
947
942
// Fuse the old terminator in_parallel ops into the new one.
948
943
scf::InParallelOp targetTerm = target.getTerminator ();
949
944
scf::InParallelOp sourceTerm = source.getTerminator ();
950
945
scf::InParallelOp fusedTerm = fusedLoop.getTerminator ();
951
-
952
946
rewriter.setInsertionPointToStart (fusedTerm.getBody ());
953
947
for (Operation &op : targetTerm.getYieldingOps ())
954
- rewriter.clone (op, fusedMapping );
948
+ rewriter.clone (op, mapping );
955
949
for (Operation &op : sourceTerm.getYieldingOps ())
956
- rewriter.clone (op, fusedMapping);
957
-
958
- // Replace all uses of the old loops with the fused loop.
959
- rewriter.replaceAllUsesWith (target.getResults (),
960
- fusedLoop.getResults ().slice (0 , numTargetOuts));
961
- rewriter.replaceAllUsesWith (
962
- source.getResults (),
963
- fusedLoop.getResults ().slice (numTargetOuts, numSourceOuts));
964
-
965
- // Erase the old loops.
966
- rewriter.eraseOp (target);
967
- rewriter.eraseOp (source);
950
+ rewriter.clone (op, mapping);
951
+
952
+ // Replace old loops by substituting their uses by results of the fused loop.
953
+ rewriter.replaceOp (target, fusedLoop.getResults ().take_front (numTargetOuts));
954
+ rewriter.replaceOp (source, fusedLoop.getResults ().take_back (numSourceOuts));
955
+
956
+ return fusedLoop;
957
+ }
958
+
959
+ scf::ForOp mlir::fuseIndependentSiblingForLoops (scf::ForOp target,
960
+ scf::ForOp source,
961
+ RewriterBase &rewriter) {
962
+ unsigned numTargetOuts = target.getNumResults ();
963
+ unsigned numSourceOuts = source.getNumResults ();
964
+
965
+ // Create fused init_args, with target's init_args before source's init_args.
966
+ SmallVector<Value> fusedInitArgs;
967
+ llvm::append_range (fusedInitArgs, target.getInitArgs ());
968
+ llvm::append_range (fusedInitArgs, source.getInitArgs ());
969
+
970
+ // Create a new scf.for op after the source loop (with scf.yield terminator
971
+ // (without arguments) only in case its init_args is empty).
972
+ rewriter.setInsertionPointAfter (source);
973
+ scf::ForOp fusedLoop = rewriter.create <scf::ForOp>(
974
+ source.getLoc (), source.getLowerBound (), source.getUpperBound (),
975
+ source.getStep (), fusedInitArgs);
976
+
977
+ // Map original induction variables and operands to those of the fused loop.
978
+ IRMapping mapping;
979
+ mapping.map (target.getInductionVar (), fusedLoop.getInductionVar ());
980
+ mapping.map (target.getRegionIterArgs (),
981
+ fusedLoop.getRegionIterArgs ().take_front (numTargetOuts));
982
+ mapping.map (source.getInductionVar (), fusedLoop.getInductionVar ());
983
+ mapping.map (source.getRegionIterArgs (),
984
+ fusedLoop.getRegionIterArgs ().take_back (numSourceOuts));
985
+
986
+ // Merge target's body into the new (fused) for loop and then source's body.
987
+ rewriter.setInsertionPointToStart (fusedLoop.getBody ());
988
+ for (Operation &op : target.getBody ()->without_terminator ())
989
+ rewriter.clone (op, mapping);
990
+ for (Operation &op : source.getBody ()->without_terminator ())
991
+ rewriter.clone (op, mapping);
992
+
993
+ // Build fused yield results by appropriately mapping original yield operands.
994
+ SmallVector<Value> yieldResults;
995
+ for (Value operand : target.getBody ()->getTerminator ()->getOperands ())
996
+ yieldResults.push_back (mapping.lookupOrDefault (operand));
997
+ for (Value operand : source.getBody ()->getTerminator ()->getOperands ())
998
+ yieldResults.push_back (mapping.lookupOrDefault (operand));
999
+ if (!yieldResults.empty ())
1000
+ rewriter.create <scf::YieldOp>(source.getLoc (), yieldResults);
1001
+
1002
+ // Replace old loops by substituting their uses by results of the fused loop.
1003
+ rewriter.replaceOp (target, fusedLoop.getResults ().take_front (numTargetOuts));
1004
+ rewriter.replaceOp (source, fusedLoop.getResults ().take_back (numSourceOuts));
968
1005
969
1006
return fusedLoop;
970
1007
}
0 commit comments