Skip to content

Commit 4235f25

Browse files
committed
yield replacement for multiple results
1 parent ca478bc commit 4235f25

File tree

5 files changed

+222
-45
lines changed

5 files changed

+222
-45
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,14 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
190190
/// where `%0` had other uses as well. If not reconstructed from within the loop
191191
/// body, uses of `%0` could not be replaced, making it still live and the
192192
/// fusion immaterial.
193+
///
194+
/// The @param `yieldResultNumber` decides which result would be yield. If not
195+
/// given, yield all `opResult` of fused producer.
193196
LogicalResult yieldReplacementForFusedProducer(
194197
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
195198
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
196-
MutableArrayRef<LoopLikeOpInterface> loops);
199+
MutableArrayRef<LoopLikeOpInterface> loops,
200+
std::optional<ArrayRef<unsigned>> yieldResultNumber = std::nullopt);
197201

198202
/// Transformation information returned after tile and fuse.
199203
struct SCFTileAndFuseResult {

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
9696
return failure();
9797
}]
9898
>,
99+
InterfaceMethod<
100+
/*desc=*/[{
101+
Method to return tile offset and size of Iteration Domain
102+
based on the given tile info from the certain result.
103+
}],
104+
/*retType=*/"LogicalResult",
105+
/*methodName=*/"getIterationDomainTileFromResultTile",
106+
/*args=*/(ins
107+
"OpBuilder &":$b,
108+
"unsigned":$resultNumber,
109+
"ArrayRef<OpFoldResult> ":$resultOffsets,
110+
"ArrayRef<OpFoldResult> ":$resultSizes,
111+
"SmallVector<OpFoldResult> &":$iterDomainOffsets,
112+
"SmallVector<OpFoldResult> &":$iterDomainSizes),
113+
/*methodBody=*/"",
114+
/*defaultImplementation=*/[{
115+
return failure();
116+
}]
117+
>,
99118
InterfaceMethod<
100119
/*desc=*/[{
101120
Method to generate the code that produces a tile of the result.

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,11 @@ struct LinalgOpTilingInterface
160160
return success();
161161
}
162162

163-
FailureOr<TilingResult>
164-
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
165-
ArrayRef<OpFoldResult> offsets,
166-
ArrayRef<OpFoldResult> sizes) const {
163+
LogicalResult getIterationDomainTileFromResultTile(
164+
Operation *op, OpBuilder &b, unsigned resultNumber,
165+
ArrayRef<OpFoldResult> resultOffsets, ArrayRef<OpFoldResult> resultSizes,
166+
SmallVector<OpFoldResult> &iterDomainOffsets,
167+
SmallVector<OpFoldResult> &iterDomainSizes) const {
167168
auto linalgOp = cast<LinalgOp>(op);
168169

169170
// Check that the indexing map used for the output is a projected
@@ -193,8 +194,27 @@ struct LinalgOpTilingInterface
193194
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
194195
unsigned dimPosition =
195196
cast<AffineDimExpr>(resultExpr.value()).getPosition();
196-
iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
197-
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
197+
iterationTileOffsets[dimPosition] = resultOffsets[resultExpr.index()];
198+
iterationTileSizes[dimPosition] = resultSizes[resultExpr.index()];
199+
}
200+
201+
iterDomainOffsets = iterationTileOffsets;
202+
iterDomainSizes = iterationTileSizes;
203+
204+
return success();
205+
}
206+
207+
FailureOr<TilingResult>
208+
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
209+
ArrayRef<OpFoldResult> offsets,
210+
ArrayRef<OpFoldResult> sizes) const {
211+
auto tilingInterfaceOp = cast<TilingInterface>(op);
212+
213+
SmallVector<OpFoldResult> iterationTileOffsets, iterationTileSizes;
214+
if (failed(tilingInterfaceOp.getIterationDomainTileFromResultTile(
215+
b, resultNumber, offsets, sizes, iterationTileOffsets,
216+
iterationTileSizes))) {
217+
return failure();
198218
}
199219

200220
FailureOr<TilingResult> tilingResult =

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 110 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -939,49 +939,114 @@ mlir::scf::tileAndFuseProducerOfSlice(
939939
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
940940
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
941941
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
942-
MutableArrayRef<LoopLikeOpInterface> loops) {
942+
MutableArrayRef<LoopLikeOpInterface> loops,
943+
std::optional<ArrayRef<unsigned>> yieldResultNumber) {
943944
if (loops.empty())
944945
return success();
945946

946-
OpResult fusableProducer = fusedProducerInfo.origProducer;
947-
Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
948-
FailureOr<Value> initValue = tensor::getOrCreateDestination(
949-
rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
950-
if (succeeded(initValue)) {
951-
952-
YieldTiledValuesFn newYieldValuesFn =
953-
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
954-
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
955-
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
956-
SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
957-
-> LogicalResult {
958-
OpBuilder::InsertionGuard g(innerRewriter);
959-
if (auto tiledDestStyleOp =
960-
tiledAndFusedProducer
961-
.getDefiningOp<DestinationStyleOpInterface>()) {
962-
rewriter.setInsertionPoint(tiledDestStyleOp);
963-
Value newRegionArg = newRegionIterArgs.back();
947+
Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
948+
*tiledOwner = fusedProducerInfo.tiledOps[0];
949+
950+
Location loc = originalOwner->getLoc();
951+
// a. collect all init Value to be appended
952+
ArrayRef<unsigned> initNumberList =
953+
yieldResultNumber ? yieldResultNumber.value()
954+
: llvm::to_vector(llvm::seq<unsigned>(
955+
0, originalOwner->getNumResults()));
956+
SmallVector<Value> initValueList;
957+
for (const auto &resultNumber : initNumberList) {
958+
FailureOr<Value> initValue = tensor::getOrCreateDestination(
959+
rewriter, loc, originalOwner->getResult(resultNumber));
960+
if (succeeded(initValue)) {
961+
initValueList.push_back(initValue.value());
962+
} else {
963+
return failure();
964+
}
965+
}
966+
967+
YieldTiledValuesFn newYieldValuesFn =
968+
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
969+
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
970+
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
971+
SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
972+
OpBuilder::InsertionGuard g(innerRewriter);
973+
974+
// get sliceOp tile information
975+
SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
976+
sliceSizes = sliceOp.getMixedSizes();
977+
978+
// expect all strides of sliceOp being 1
979+
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
980+
return !isConstantIntValue(ofr, 1);
981+
}))
982+
return failure();
983+
984+
unsigned sliceResultNumber =
985+
fusedProducerInfo.origProducer.getResultNumber();
986+
987+
auto tilableOp = cast<TilingInterface>(originalOwner);
988+
// b. get iterDomain Offset and Sizes based on sliceOp tile
989+
SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
990+
// skip tensor.pack/unpack/pad, which expects single opResult
991+
if (tilableOp->getNumResults() > 1 &&
992+
failed(tilableOp.getIterationDomainTileFromResultTile(
993+
rewriter, sliceResultNumber, sliceOffset, sliceSizes,
994+
iterDomainOffset, iterDomainSizes))) {
995+
return failure();
996+
}
997+
998+
// c. calculate offsets and sizes info of all OpResults respectively based
999+
// on iteration Domain Tile
1000+
SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1001+
for (const auto &resultNumber : initNumberList) {
1002+
if (resultNumber == fusedProducerInfo.origProducer.getResultNumber()) {
1003+
offsetList.push_back(sliceOffset);
1004+
sizesList.push_back(sliceSizes);
1005+
} else {
1006+
assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1007+
// infer result tile according to the iteration domain tile
1008+
SmallVector<OpFoldResult> offset, sizes;
1009+
if (failed(tilableOp.getResultTilePosition(
1010+
rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1011+
offset, sizes))) {
1012+
return failure();
1013+
}
1014+
offsetList.push_back(offset);
1015+
sizesList.push_back(sizes);
1016+
}
1017+
}
1018+
1019+
// d. create `extract_slice` for `iter_args` for DPS operation if necessary
1020+
if (auto tiledDestStyleOp =
1021+
dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1022+
rewriter.setInsertionPoint(tiledDestStyleOp);
1023+
for (const auto &&[index, newRegionArg] :
1024+
llvm::enumerate(newRegionIterArgs)) {
9641025
auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
965-
sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
966-
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
967-
unsigned resultNumber = fusableProducer.getResultNumber();
1026+
loc, newRegionArg, offsetList[index], sizesList[index],
1027+
SmallVector<OpFoldResult>(offsetList[index].size(),
1028+
rewriter.getIndexAttr(1)));
1029+
unsigned resultNumber = initNumberList[index];
9681030
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
9691031
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
9701032
});
9711033
}
972-
Block *block = rewriter.getInsertionPoint()->getBlock();
973-
rewriter.setInsertionPoint(block->getTerminator());
974-
tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
975-
tiledOffset.emplace_back(sliceOp.getMixedOffsets());
976-
tiledSizes.emplace_back(sliceOp.getMixedSizes());
977-
return success();
978-
};
1034+
}
9791035

980-
return addInitOperandsToLoopNest(rewriter, loops,
981-
SmallVector<Value>{initValue.value()},
982-
newYieldValuesFn);
983-
}
984-
return success();
1036+
// e. prepare tiled offset and sizes for later `insert_slice` creation by
1037+
// caller
1038+
Block *block = rewriter.getInsertionPoint()->getBlock();
1039+
rewriter.setInsertionPoint(block->getTerminator());
1040+
for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1041+
tiledResult.push_back(tiledOwner->getResult(resultNumber));
1042+
tiledOffset.emplace_back(offsetList[index]);
1043+
tiledSizes.emplace_back(sizesList[index]);
1044+
}
1045+
return success();
1046+
};
1047+
1048+
return addInitOperandsToLoopNest(rewriter, loops, initValueList,
1049+
newYieldValuesFn);
9851050
}
9861051

9871052
/// Implementation of tile consumer and fuse producer greedily.
@@ -1071,14 +1136,21 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
10711136
continue;
10721137

10731138
if (yieldReplacement) {
1139+
// Reconstruct and yield all opResult of fusableProducerOp by default. The
1140+
// caller can specific which one to yield by designating optional argument
1141+
// named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
1142+
Operation *fusableProducerOp = fusableProducer.getOwner();
10741143
if (failed(yieldReplacementForFusedProducer(
10751144
rewriter, candidateSliceOp, fusedResult.value(), loops))) {
10761145
return rewriter.notifyMatchFailure(
1077-
fusableProducer.getOwner(), "failed to replacement value for this "
1078-
"oepration from within the tiled loop");
1146+
fusableProducerOp, "failed to replacement value for this "
1147+
"operation from within the tiled loop");
1148+
}
1149+
for (const auto &result : fusableProducerOp->getResults()) {
1150+
origValToResultNumber[result] =
1151+
loops.front()->getNumResults() -
1152+
(fusableProducerOp->getNumResults() - result.getResultNumber());
10791153
}
1080-
origValToResultNumber[fusableProducer] =
1081-
loops.front()->getNumResults() - 1;
10821154
}
10831155

10841156
if (Operation *tiledAndFusedOp =

mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,65 @@ module attributes {transform.with_named_sequence} {
5858
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
5959
// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]]
6060
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0
61+
62+
// -----
63+
64+
func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>,
65+
%rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>,
66+
%rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>)
67+
-> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) {
68+
%out0, %out1 = linalg.generic {
69+
indexing_maps = [affine_map<(i, j) -> (i, j)>,
70+
affine_map<(i, j) -> (i, j)>,
71+
affine_map<(i, j) -> (i, j)>,
72+
affine_map<(i, j) -> (j, i)>],
73+
iterator_types = ["parallel", "parallel"]
74+
}
75+
ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>)
76+
outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) {
77+
^bb0(%0: f32, %1: f32, %2: f32, %3: f32):
78+
%4 = arith.mulf %0, %1 : f32
79+
%5 = arith.addf %0, %1 : f32
80+
linalg.yield %4, %5: f32, f32
81+
} -> (tensor<32x32xf32>, tensor<32x32xf32>)
82+
83+
%out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32>
84+
85+
return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>
86+
}
87+
88+
module attributes {transform.with_named_sequence} {
89+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
90+
%add = transform.structured.match ops{["linalg.add"]} in %arg0
91+
: (!transform.any_op) -> !transform.any_op
92+
%a, %b = transform.test.fuse_and_yield %add [16]
93+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
94+
transform.yield
95+
}
96+
}
97+
// CHECK: func.func @multiple_outputs_fusion_yield_all(
98+
// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
99+
// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
100+
// CHECK-SAME: %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
101+
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
102+
// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
103+
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>)
104+
// CHECK: %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] =
105+
// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
106+
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
107+
// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0]
108+
// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
109+
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]]
110+
// CHECK: %[[GENERIC_TILE:.+]]:2 = linalg.generic
111+
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
112+
// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
113+
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0]
114+
// CHECK-DAG: %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
115+
// CHECK: %[[ADD_TILE:.+]] = linalg.add
116+
// CHECK-SAME: ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] :
117+
// CHECK-SAME: outs(%[[INIT2_TILE]] :
118+
// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0]
119+
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0]
120+
// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]]
121+
// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]]
122+
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0

0 commit comments

Comments
 (0)