@@ -939,49 +939,114 @@ mlir::scf::tileAndFuseProducerOfSlice(
939
939
LogicalResult mlir::scf::yieldReplacementForFusedProducer (
940
940
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
941
941
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
942
- MutableArrayRef<LoopLikeOpInterface> loops) {
942
+ MutableArrayRef<LoopLikeOpInterface> loops,
943
+ std::optional<ArrayRef<unsigned >> yieldResultNumber) {
943
944
if (loops.empty ())
944
945
return success ();
945
946
946
- OpResult fusableProducer = fusedProducerInfo.origProducer ;
947
- Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer ;
948
- FailureOr<Value> initValue = tensor::getOrCreateDestination (
949
- rewriter, fusableProducer.getOwner ()->getLoc (), fusableProducer);
950
- if (succeeded (initValue)) {
951
-
952
- YieldTiledValuesFn newYieldValuesFn =
953
- [&](RewriterBase &innerRewriter, Location loc, ValueRange /* ivs*/ ,
954
- ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
955
- SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
956
- SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
957
- -> LogicalResult {
958
- OpBuilder::InsertionGuard g (innerRewriter);
959
- if (auto tiledDestStyleOp =
960
- tiledAndFusedProducer
961
- .getDefiningOp <DestinationStyleOpInterface>()) {
962
- rewriter.setInsertionPoint (tiledDestStyleOp);
963
- Value newRegionArg = newRegionIterArgs.back ();
947
+ Operation *originalOwner = fusedProducerInfo.origProducer .getOwner (),
948
+ *tiledOwner = fusedProducerInfo.tiledOps [0 ];
949
+
950
+ Location loc = originalOwner->getLoc ();
951
+ // a. collect all init Value to be appended
952
+ ArrayRef<unsigned > initNumberList =
953
+ yieldResultNumber ? yieldResultNumber.value ()
954
+ : llvm::to_vector (llvm::seq<unsigned >(
955
+ 0 , originalOwner->getNumResults ()));
956
+ SmallVector<Value> initValueList;
957
+ for (const auto &resultNumber : initNumberList) {
958
+ FailureOr<Value> initValue = tensor::getOrCreateDestination (
959
+ rewriter, loc, originalOwner->getResult (resultNumber));
960
+ if (succeeded (initValue)) {
961
+ initValueList.push_back (initValue.value ());
962
+ } else {
963
+ return failure ();
964
+ }
965
+ }
966
+
967
+ YieldTiledValuesFn newYieldValuesFn =
968
+ [&](RewriterBase &innerRewriter, Location loc, ValueRange /* ivs*/ ,
969
+ ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
970
+ SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
971
+ SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
972
+ OpBuilder::InsertionGuard g (innerRewriter);
973
+
974
+ // get sliceOp tile information
975
+ SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets (),
976
+ sliceSizes = sliceOp.getMixedSizes ();
977
+
978
+ // expect all strides of sliceOp being 1
979
+ if (llvm::any_of (sliceOp.getMixedStrides (), [](OpFoldResult ofr) {
980
+ return !isConstantIntValue (ofr, 1 );
981
+ }))
982
+ return failure ();
983
+
984
+ unsigned sliceResultNumber =
985
+ fusedProducerInfo.origProducer .getResultNumber ();
986
+
987
+ auto tilableOp = cast<TilingInterface>(originalOwner);
988
+ // b. get iterDomain Offset and Sizes based on sliceOp tile
989
+ SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
990
+ // skip tensor.pack/unpack/pad, which expects single opResult
991
+ if (tilableOp->getNumResults () > 1 &&
992
+ failed (tilableOp.getIterationDomainTileFromResultTile (
993
+ rewriter, sliceResultNumber, sliceOffset, sliceSizes,
994
+ iterDomainOffset, iterDomainSizes))) {
995
+ return failure ();
996
+ }
997
+
998
+ // c. calculate offsets and sizes info of all OpResults respectively based
999
+ // on iteration Domain Tile
1000
+ SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1001
+ for (const auto &resultNumber : initNumberList) {
1002
+ if (resultNumber == fusedProducerInfo.origProducer .getResultNumber ()) {
1003
+ offsetList.push_back (sliceOffset);
1004
+ sizesList.push_back (sliceSizes);
1005
+ } else {
1006
+ assert (!iterDomainOffset.empty () && !iterDomainSizes.empty ());
1007
+ // infer result tile according to the iteration domain tile
1008
+ SmallVector<OpFoldResult> offset, sizes;
1009
+ if (failed (tilableOp.getResultTilePosition (
1010
+ rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1011
+ offset, sizes))) {
1012
+ return failure ();
1013
+ }
1014
+ offsetList.push_back (offset);
1015
+ sizesList.push_back (sizes);
1016
+ }
1017
+ }
1018
+
1019
+ // d. create `extract_slice` for `iter_args` for DPS operation if necessary
1020
+ if (auto tiledDestStyleOp =
1021
+ dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1022
+ rewriter.setInsertionPoint (tiledDestStyleOp);
1023
+ for (const auto &&[index , newRegionArg] :
1024
+ llvm::enumerate (newRegionIterArgs)) {
964
1025
auto destSlice = rewriter.create <tensor::ExtractSliceOp>(
965
- sliceOp.getLoc (), newRegionArg, sliceOp.getMixedOffsets (),
966
- sliceOp.getMixedSizes (), sliceOp.getMixedStrides ());
967
- unsigned resultNumber = fusableProducer.getResultNumber ();
1026
+ loc, newRegionArg, offsetList[index ], sizesList[index ],
1027
+ SmallVector<OpFoldResult>(offsetList[index ].size (),
1028
+ rewriter.getIndexAttr (1 )));
1029
+ unsigned resultNumber = initNumberList[index ];
968
1030
rewriter.modifyOpInPlace (tiledDestStyleOp, [&]() {
969
1031
tiledDestStyleOp.getDpsInitsMutable ()[resultNumber].set (destSlice);
970
1032
});
971
1033
}
972
- Block *block = rewriter.getInsertionPoint ()->getBlock ();
973
- rewriter.setInsertionPoint (block->getTerminator ());
974
- tiledResult.push_back (fusedProducerInfo.tiledAndFusedProducer );
975
- tiledOffset.emplace_back (sliceOp.getMixedOffsets ());
976
- tiledSizes.emplace_back (sliceOp.getMixedSizes ());
977
- return success ();
978
- };
1034
+ }
979
1035
980
- return addInitOperandsToLoopNest (rewriter, loops,
981
- SmallVector<Value>{initValue.value ()},
982
- newYieldValuesFn);
983
- }
984
- return success ();
1036
+ // e. prepare tiled offset and sizes for later `insert_slice` creation by
1037
+ // caller
1038
+ Block *block = rewriter.getInsertionPoint ()->getBlock ();
1039
+ rewriter.setInsertionPoint (block->getTerminator ());
1040
+ for (const auto &&[index , resultNumber] : llvm::enumerate (initNumberList)) {
1041
+ tiledResult.push_back (tiledOwner->getResult (resultNumber));
1042
+ tiledOffset.emplace_back (offsetList[index ]);
1043
+ tiledSizes.emplace_back (sizesList[index ]);
1044
+ }
1045
+ return success ();
1046
+ };
1047
+
1048
+ return addInitOperandsToLoopNest (rewriter, loops, initValueList,
1049
+ newYieldValuesFn);
985
1050
}
986
1051
987
1052
// / Implementation of tile consumer and fuse producer greedily.
@@ -1071,14 +1136,21 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1071
1136
continue ;
1072
1137
1073
1138
if (yieldReplacement) {
1139
+ // Reconstruct and yield all opResult of fusableProducerOp by default. The
1140
+ // caller can specific which one to yield by designating optional argument
1141
+ // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
1142
+ Operation *fusableProducerOp = fusableProducer.getOwner ();
1074
1143
if (failed (yieldReplacementForFusedProducer (
1075
1144
rewriter, candidateSliceOp, fusedResult.value (), loops))) {
1076
1145
return rewriter.notifyMatchFailure (
1077
- fusableProducer.getOwner (), " failed to replacement value for this "
1078
- " oepration from within the tiled loop" );
1146
+ fusableProducerOp, " failed to replacement value for this "
1147
+ " operation from within the tiled loop" );
1148
+ }
1149
+ for (const auto &result : fusableProducerOp->getResults ()) {
1150
+ origValToResultNumber[result] =
1151
+ loops.front ()->getNumResults () -
1152
+ (fusableProducerOp->getNumResults () - result.getResultNumber ());
1079
1153
}
1080
- origValToResultNumber[fusableProducer] =
1081
- loops.front ()->getNumResults () - 1 ;
1082
1154
}
1083
1155
1084
1156
if (Operation *tiledAndFusedOp =
0 commit comments