@@ -940,49 +940,114 @@ mlir::scf::tileAndFuseProducerOfSlice(
940
940
LogicalResult mlir::scf::yieldReplacementForFusedProducer (
941
941
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
942
942
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
943
- MutableArrayRef<LoopLikeOpInterface> loops) {
943
+ MutableArrayRef<LoopLikeOpInterface> loops,
944
+ std::optional<ArrayRef<unsigned >> yieldResultNumber) {
944
945
if (loops.empty ())
945
946
return success ();
946
947
947
- OpResult fusableProducer = fusedProducerInfo.origProducer ;
948
- Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer ;
949
- FailureOr<Value> initValue = tensor::getOrCreateDestination (
950
- rewriter, fusableProducer.getOwner ()->getLoc (), fusableProducer);
951
- if (succeeded (initValue)) {
952
-
953
- YieldTiledValuesFn newYieldValuesFn =
954
- [&](RewriterBase &innerRewriter, Location loc, ValueRange /* ivs*/ ,
955
- ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
956
- SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
957
- SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
958
- -> LogicalResult {
959
- OpBuilder::InsertionGuard g (innerRewriter);
960
- if (auto tiledDestStyleOp =
961
- tiledAndFusedProducer
962
- .getDefiningOp <DestinationStyleOpInterface>()) {
963
- rewriter.setInsertionPoint (tiledDestStyleOp);
964
- Value newRegionArg = newRegionIterArgs.back ();
948
+ Operation *originalOwner = fusedProducerInfo.origProducer .getOwner (),
949
+ *tiledOwner = fusedProducerInfo.tiledOps [0 ];
950
+
951
+ Location loc = originalOwner->getLoc ();
952
+ // a. collect all init Value to be appended
953
+ ArrayRef<unsigned > initNumberList =
954
+ yieldResultNumber ? yieldResultNumber.value ()
955
+ : llvm::to_vector (llvm::seq<unsigned >(
956
+ 0 , originalOwner->getNumResults ()));
957
+ SmallVector<Value> initValueList;
958
+ for (const auto &resultNumber : initNumberList) {
959
+ FailureOr<Value> initValue = tensor::getOrCreateDestination (
960
+ rewriter, loc, originalOwner->getResult (resultNumber));
961
+ if (succeeded (initValue)) {
962
+ initValueList.push_back (initValue.value ());
963
+ } else {
964
+ return failure ();
965
+ }
966
+ }
967
+
968
+ YieldTiledValuesFn newYieldValuesFn =
969
+ [&](RewriterBase &innerRewriter, Location loc, ValueRange /* ivs*/ ,
970
+ ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
971
+ SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
972
+ SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
973
+ OpBuilder::InsertionGuard g (innerRewriter);
974
+
975
+ // get sliceOp tile information
976
+ SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets (),
977
+ sliceSizes = sliceOp.getMixedSizes ();
978
+
979
+ // expect all strides of sliceOp being 1
980
+ if (llvm::any_of (sliceOp.getMixedStrides (), [](OpFoldResult ofr) {
981
+ return !isConstantIntValue (ofr, 1 );
982
+ }))
983
+ return failure ();
984
+
985
+ unsigned sliceResultNumber =
986
+ fusedProducerInfo.origProducer .getResultNumber ();
987
+
988
+ auto tilableOp = cast<TilingInterface>(originalOwner);
989
+ // b. get iterDomain Offset and Sizes based on sliceOp tile
990
+ SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
991
+ // skip tensor.pack/unpack/pad, which expects single opResult
992
+ if (tilableOp->getNumResults () > 1 &&
993
+ failed (tilableOp.getIterationDomainTileFromResultTile (
994
+ rewriter, sliceResultNumber, sliceOffset, sliceSizes,
995
+ iterDomainOffset, iterDomainSizes))) {
996
+ return failure ();
997
+ }
998
+
999
+ // c. calculate offsets and sizes info of all OpResults respectively based
1000
+ // on iteration Domain Tile
1001
+ SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1002
+ for (const auto &resultNumber : initNumberList) {
1003
+ if (resultNumber == fusedProducerInfo.origProducer .getResultNumber ()) {
1004
+ offsetList.push_back (sliceOffset);
1005
+ sizesList.push_back (sliceSizes);
1006
+ } else {
1007
+ assert (!iterDomainOffset.empty () && !iterDomainSizes.empty ());
1008
+ // infer result tile according to the iteration domain tile
1009
+ SmallVector<OpFoldResult> offset, sizes;
1010
+ if (failed (tilableOp.getResultTilePosition (
1011
+ rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1012
+ offset, sizes))) {
1013
+ return failure ();
1014
+ }
1015
+ offsetList.push_back (offset);
1016
+ sizesList.push_back (sizes);
1017
+ }
1018
+ }
1019
+
1020
+ // d. create `extract_slice` for `iter_args` for DPS operation if necessary
1021
+ if (auto tiledDestStyleOp =
1022
+ dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1023
+ rewriter.setInsertionPoint (tiledDestStyleOp);
1024
+ for (const auto &&[index , newRegionArg] :
1025
+ llvm::enumerate (newRegionIterArgs)) {
965
1026
auto destSlice = rewriter.create <tensor::ExtractSliceOp>(
966
- sliceOp.getLoc (), newRegionArg, sliceOp.getMixedOffsets (),
967
- sliceOp.getMixedSizes (), sliceOp.getMixedStrides ());
968
- unsigned resultNumber = fusableProducer.getResultNumber ();
1027
+ loc, newRegionArg, offsetList[index ], sizesList[index ],
1028
+ SmallVector<OpFoldResult>(offsetList[index ].size (),
1029
+ rewriter.getIndexAttr (1 )));
1030
+ unsigned resultNumber = initNumberList[index ];
969
1031
rewriter.modifyOpInPlace (tiledDestStyleOp, [&]() {
970
1032
tiledDestStyleOp.getDpsInitsMutable ()[resultNumber].set (destSlice);
971
1033
});
972
1034
}
973
- Block *block = rewriter.getInsertionPoint ()->getBlock ();
974
- rewriter.setInsertionPoint (block->getTerminator ());
975
- tiledResult.push_back (fusedProducerInfo.tiledAndFusedProducer );
976
- tiledOffset.emplace_back (sliceOp.getMixedOffsets ());
977
- tiledSizes.emplace_back (sliceOp.getMixedSizes ());
978
- return success ();
979
- };
1035
+ }
980
1036
981
- return addInitOperandsToLoopNest (rewriter, loops,
982
- SmallVector<Value>{initValue.value ()},
983
- newYieldValuesFn);
984
- }
985
- return success ();
1037
+ // e. prepare tiled offset and sizes for later `insert_slice` creation by
1038
+ // caller
1039
+ Block *block = rewriter.getInsertionPoint ()->getBlock ();
1040
+ rewriter.setInsertionPoint (block->getTerminator ());
1041
+ for (const auto &&[index , resultNumber] : llvm::enumerate (initNumberList)) {
1042
+ tiledResult.push_back (tiledOwner->getResult (resultNumber));
1043
+ tiledOffset.emplace_back (offsetList[index ]);
1044
+ tiledSizes.emplace_back (sizesList[index ]);
1045
+ }
1046
+ return success ();
1047
+ };
1048
+
1049
+ return addInitOperandsToLoopNest (rewriter, loops, initValueList,
1050
+ newYieldValuesFn);
986
1051
}
987
1052
988
1053
// / Implementation of tile consumer and fuse producer greedily.
@@ -1072,14 +1137,21 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1072
1137
continue ;
1073
1138
1074
1139
if (yieldReplacement) {
1140
+ // Reconstruct and yield all opResult of fusableProducerOp by default. The
1141
+ // caller can specific which one to yield by designating optional argument
1142
+ // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
1143
+ Operation *fusableProducerOp = fusableProducer.getOwner ();
1075
1144
if (failed (yieldReplacementForFusedProducer (
1076
1145
rewriter, candidateSliceOp, fusedResult.value (), loops))) {
1077
1146
return rewriter.notifyMatchFailure (
1078
- fusableProducer.getOwner (), " failed to replacement value for this "
1079
- " oepration from within the tiled loop" );
1147
+ fusableProducerOp, " failed to replacement value for this "
1148
+ " operation from within the tiled loop" );
1149
+ }
1150
+ for (const auto &result : fusableProducerOp->getResults ()) {
1151
+ origValToResultNumber[result] =
1152
+ loops.front ()->getNumResults () -
1153
+ (fusableProducerOp->getNumResults () - result.getResultNumber ());
1080
1154
}
1081
- origValToResultNumber[fusableProducer] =
1082
- loops.front ()->getNumResults () - 1 ;
1083
1155
}
1084
1156
1085
1157
if (Operation *tiledAndFusedOp =
0 commit comments