@@ -953,49 +953,122 @@ mlir::scf::tileAndFuseProducerOfSlice(
953
953
LogicalResult mlir::scf::yieldReplacementForFusedProducer (
954
954
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
955
955
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
956
- MutableArrayRef<LoopLikeOpInterface> loops) {
956
+ MutableArrayRef<LoopLikeOpInterface> loops,
957
+ ArrayRef<unsigned > yieldResultNumber) {
957
958
if (loops.empty ())
958
959
return success ();
959
960
960
- OpResult fusableProducer = fusedProducerInfo.origProducer ;
961
- Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer ;
962
- FailureOr<Value> initValue = tensor::getOrCreateDestination (
963
- rewriter, fusableProducer.getOwner ()->getLoc (), fusableProducer);
964
- if (succeeded (initValue)) {
965
-
966
- YieldTiledValuesFn newYieldValuesFn =
967
- [&](RewriterBase &innerRewriter, Location loc, ValueRange /* ivs*/ ,
968
- ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
969
- SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
970
- SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
971
- -> LogicalResult {
972
- OpBuilder::InsertionGuard g (innerRewriter);
973
- if (auto tiledDestStyleOp =
974
- tiledAndFusedProducer
975
- .getDefiningOp <DestinationStyleOpInterface>()) {
976
- rewriter.setInsertionPoint (tiledDestStyleOp);
977
- Value newRegionArg = newRegionIterArgs.back ();
961
+ Operation *originalOwner = fusedProducerInfo.origProducer .getOwner (),
962
+ *tiledOwner = fusedProducerInfo.tiledOps [0 ];
963
+
964
+ Location loc = originalOwner->getLoc ();
965
+ // a. collect all init Value to be appended
966
+ SmallVector<unsigned > initNumberList =
967
+ yieldResultNumber.empty () ? llvm::to_vector (llvm::seq<unsigned >(
968
+ 0 , originalOwner->getNumResults ()))
969
+ : llvm::to_vector (yieldResultNumber);
970
+ SmallVector<Value> initValueList;
971
+ for (const auto &resultNumber : initNumberList) {
972
+ FailureOr<Value> initValue = tensor::getOrCreateDestination (
973
+ rewriter, loc, originalOwner->getResult (resultNumber));
974
+ if (succeeded (initValue)) {
975
+ initValueList.push_back (initValue.value ());
976
+ } else {
977
+ return failure ();
978
+ }
979
+ }
980
+
981
+ YieldTiledValuesFn newYieldValuesFn =
982
+ [&](RewriterBase &innerRewriter, Location loc, ValueRange /* ivs*/ ,
983
+ ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
984
+ SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
985
+ SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
986
+ OpBuilder::InsertionGuard g (innerRewriter);
987
+
988
+ // get sliceOp tile information
989
+ SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets (),
990
+ sliceSizes = sliceOp.getMixedSizes ();
991
+
992
+ // expect all strides of sliceOp being 1
993
+ if (llvm::any_of (sliceOp.getMixedStrides (), [](OpFoldResult ofr) {
994
+ return !isConstantIntValue (ofr, 1 );
995
+ }))
996
+ return failure ();
997
+
998
+ unsigned sliceResultNumber =
999
+ fusedProducerInfo.origProducer .getResultNumber ();
1000
+
1001
+ auto tilableOp = cast<TilingInterface>(originalOwner);
1002
+ // b. get iterDomain Offset and Sizes based on sliceOp tile
1003
+ SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1004
+ // skip tensor.pack/unpack/pad, which expects single opResult
1005
+ if (tilableOp->getNumResults () > 1 &&
1006
+ failed (tilableOp.getIterationDomainTileFromResultTile (
1007
+ rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1008
+ iterDomainOffset, iterDomainSizes))) {
1009
+ // In theory, it is unnecessary to raise an error here. Actually although
1010
+ // it fails to reconstruct the result tensor, it should not broke current
1011
+ // fusion anyway. The reason why we must return failure currently is that
1012
+ // the callback function `newYieldValuesFn` will be called after new init
1013
+ // operand(s) has already been appended. It will take more refactoring to
1014
+ // make sure the init operands are added consistently in the future. For
1015
+ // more details, please refer to:
1016
+ // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1017
+ return failure ();
1018
+ }
1019
+
1020
+ // c. calculate offsets and sizes info of all OpResults respectively based
1021
+ // on iteration Domain Tile
1022
+ SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1023
+ for (const auto &resultNumber : initNumberList) {
1024
+ if (resultNumber == sliceResultNumber) {
1025
+ offsetList.push_back (sliceOffset);
1026
+ sizesList.push_back (sliceSizes);
1027
+ } else {
1028
+ assert (!iterDomainOffset.empty () && !iterDomainSizes.empty ());
1029
+ // infer result tile according to the iteration domain tile
1030
+ SmallVector<OpFoldResult> offset, sizes;
1031
+ if (failed (tilableOp.getResultTilePosition (
1032
+ rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1033
+ offset, sizes))) {
1034
+ return failure ();
1035
+ }
1036
+ offsetList.push_back (offset);
1037
+ sizesList.push_back (sizes);
1038
+ }
1039
+ }
1040
+
1041
+ // d. create `extract_slice` for `iter_args` for DPS operation if necessary
1042
+ if (auto tiledDestStyleOp =
1043
+ dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1044
+ rewriter.setInsertionPoint (tiledDestStyleOp);
1045
+ for (const auto &&[index , newRegionArg] :
1046
+ llvm::enumerate (newRegionIterArgs)) {
978
1047
auto destSlice = rewriter.create <tensor::ExtractSliceOp>(
979
- sliceOp.getLoc (), newRegionArg, sliceOp.getMixedOffsets (),
980
- sliceOp.getMixedSizes (), sliceOp.getMixedStrides ());
981
- unsigned resultNumber = fusableProducer.getResultNumber ();
1048
+ loc, newRegionArg, offsetList[index ], sizesList[index ],
1049
+ SmallVector<OpFoldResult>(offsetList[index ].size (),
1050
+ rewriter.getIndexAttr (1 )));
1051
+ unsigned resultNumber = initNumberList[index ];
982
1052
rewriter.modifyOpInPlace (tiledDestStyleOp, [&]() {
983
1053
tiledDestStyleOp.getDpsInitsMutable ()[resultNumber].set (destSlice);
984
1054
});
985
1055
}
986
- Block *block = rewriter.getInsertionPoint ()->getBlock ();
987
- rewriter.setInsertionPoint (block->getTerminator ());
988
- tiledResult.push_back (fusedProducerInfo.tiledAndFusedProducer );
989
- tiledOffset.emplace_back (sliceOp.getMixedOffsets ());
990
- tiledSizes.emplace_back (sliceOp.getMixedSizes ());
991
- return success ();
992
- };
1056
+ }
993
1057
994
- return addInitOperandsToLoopNest (rewriter, loops,
995
- SmallVector<Value>{initValue.value ()},
996
- newYieldValuesFn);
997
- }
998
- return success ();
1058
+ // e. prepare tiled offset and sizes for later `insert_slice` creation by
1059
+ // caller
1060
+ Block *block = rewriter.getInsertionPoint ()->getBlock ();
1061
+ rewriter.setInsertionPoint (block->getTerminator ());
1062
+ for (const auto &&[index , resultNumber] : llvm::enumerate (initNumberList)) {
1063
+ tiledResult.push_back (tiledOwner->getResult (resultNumber));
1064
+ tiledOffset.emplace_back (offsetList[index ]);
1065
+ tiledSizes.emplace_back (sizesList[index ]);
1066
+ }
1067
+ return success ();
1068
+ };
1069
+
1070
+ return addInitOperandsToLoopNest (rewriter, loops, initValueList,
1071
+ newYieldValuesFn);
999
1072
}
1000
1073
1001
1074
// / Implementation of tile consumer and fuse producer greedily.
@@ -1085,14 +1158,22 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1085
1158
continue ;
1086
1159
1087
1160
if (yieldReplacement) {
1161
+ // Reconstruct and yield all opResult of fusableProducerOp by default. The
1162
+ // caller can specific which one to yield by designating optional argument
1163
+ // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
1164
+ Operation *fusableProducerOp = fusableProducer.getOwner ();
1088
1165
if (failed (yieldReplacementForFusedProducer (
1089
1166
rewriter, candidateSliceOp, fusedResult.value (), loops))) {
1090
1167
return rewriter.notifyMatchFailure (
1091
- fusableProducer.getOwner (), " failed to replacement value for this "
1092
- " oepration from within the tiled loop" );
1168
+ fusableProducerOp, " failed to replacement value for this "
1169
+ " operation from within the tiled loop" );
1170
+ }
1171
+ for (auto [index , result] :
1172
+ llvm::enumerate (fusableProducerOp->getResults ())) {
1173
+ origValToResultNumber[result] = loops.front ()->getNumResults () -
1174
+ fusableProducerOp->getNumResults () +
1175
+ index ;
1093
1176
}
1094
- origValToResultNumber[fusableProducer] =
1095
- loops.front ()->getNumResults () - 1 ;
1096
1177
}
1097
1178
1098
1179
if (Operation *tiledAndFusedOp =
0 commit comments