Skip to content

Commit 7ef08ea

Browse files
authored
[mlir][scf] Extend option to yield replacement for multiple results case (#93144)
This patch extends the functionality of yielding replacement for multiple results case and adds another optional argument called `yieldResultNumber` indicating which result(s) need yield. If not given, all of results will be yield by default.
1 parent 4169338 commit 7ef08ea

File tree

5 files changed

+242
-46
lines changed

5 files changed

+242
-46
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,14 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
191191
/// where `%0` had other uses as well. If not reconstructed from within the loop
192192
/// body, uses of `%0` could not be replaced, making it still live and the
193193
/// fusion immaterial.
194+
///
195+
/// The @param `yieldResultNumber` decides which result would be yield. If not
196+
/// given, yield all `opResult` of fused producer.
194197
LogicalResult yieldReplacementForFusedProducer(
195198
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
196199
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
197-
MutableArrayRef<LoopLikeOpInterface> loops);
200+
MutableArrayRef<LoopLikeOpInterface> loops,
201+
ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{});
198202

199203
/// Transformation information returned after tile and fuse.
200204
struct SCFTileAndFuseResult {

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
5151
For an operation to be "tiled and fused" with its (already tiled) consumer,
5252
an operation has to implement the following additional method (see
5353
description below):
54-
- `generateResultTileValue
54+
- `generateResultTileValue`
55+
- `getIterationDomainTileFromResultTile`
5556

5657
For an operation to be "tiled and fused" with its (already tiled) producer,
5758
an operation has to implement the following additional methods (see
@@ -302,6 +303,41 @@ def TilingInterface : OpInterface<"TilingInterface"> {
302303
return failure();
303304
}]
304305
>,
306+
InterfaceMethod<
307+
/*desc=*/[{
308+
Method to return the tile of the iteration domain based
309+
on the given tile of the certain result.
310+
311+
This method is required to allow operations to be "tiled and fused"
312+
with an (already tiled) consumer. Given a tile of an result,
313+
returns the tile of the iteration space that uses this tile.
314+
- `resultNumber` is the result of the producer used by the consumer.
315+
- `offsets` is the offset of the slice of the producer result used by
316+
the tiled implementation of the consumer.
317+
- `sizes` is the size of the slice of the producer result used by the
318+
consumer.
319+
If fusion of the producer with the consumer is not legal for the
320+
result, or if this mapping cannot be computed, the implementation
321+
should return a failure.
322+
323+
For most cases `generateResultTileValue` could be a implemented using
324+
`getIterationDomainTileFromResultTile` + `getTiledImplementation`
325+
methods.
326+
}],
327+
/*retType=*/"::mlir::LogicalResult",
328+
/*methodName=*/"getIterationDomainTileFromResultTile",
329+
/*args=*/(ins
330+
"OpBuilder &":$b,
331+
"unsigned":$resultNumber,
332+
"ArrayRef<OpFoldResult> ":$offsets,
333+
"ArrayRef<OpFoldResult> ":$sizes,
334+
"SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
335+
"SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
336+
/*methodBody=*/"",
337+
/*defaultImplementation=*/[{
338+
return failure();
339+
}]
340+
>,
305341
InterfaceMethod<
306342
/*desc=*/[{
307343
Generates the scalar implementation of the operation.

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,11 @@ struct LinalgOpTilingInterface
215215
return success();
216216
}
217217

218-
FailureOr<TilingResult>
219-
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
220-
ArrayRef<OpFoldResult> offsets,
221-
ArrayRef<OpFoldResult> sizes) const {
218+
LogicalResult getIterationDomainTileFromResultTile(
219+
Operation *op, OpBuilder &b, unsigned resultNumber,
220+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
221+
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
222+
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
222223
auto linalgOp = cast<LinalgOp>(op);
223224

224225
// Check that the indexing map used for the output is a projected
@@ -232,9 +233,21 @@ struct LinalgOpTilingInterface
232233
"unhandled tiled implementation generation when result is not "
233234
"accessed using a permuted projection");
234235
}
235-
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
236+
236237
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
237-
mappedOffsets, mappedSizes);
238+
iterDomainOffsets, iterDomainSizes);
239+
return success();
240+
}
241+
242+
FailureOr<TilingResult>
243+
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
244+
ArrayRef<OpFoldResult> offsets,
245+
ArrayRef<OpFoldResult> sizes) const {
246+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
247+
if (failed(getIterationDomainTileFromResultTile(
248+
op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
249+
return failure();
250+
}
238251
auto tilingInterfaceOp = cast<TilingInterface>(op);
239252
FailureOr<TilingResult> tilingResult =
240253
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 119 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -953,49 +953,122 @@ mlir::scf::tileAndFuseProducerOfSlice(
953953
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
954954
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
955955
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
956-
MutableArrayRef<LoopLikeOpInterface> loops) {
956+
MutableArrayRef<LoopLikeOpInterface> loops,
957+
ArrayRef<unsigned> yieldResultNumber) {
957958
if (loops.empty())
958959
return success();
959960

960-
OpResult fusableProducer = fusedProducerInfo.origProducer;
961-
Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
962-
FailureOr<Value> initValue = tensor::getOrCreateDestination(
963-
rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
964-
if (succeeded(initValue)) {
965-
966-
YieldTiledValuesFn newYieldValuesFn =
967-
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
968-
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
969-
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
970-
SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
971-
-> LogicalResult {
972-
OpBuilder::InsertionGuard g(innerRewriter);
973-
if (auto tiledDestStyleOp =
974-
tiledAndFusedProducer
975-
.getDefiningOp<DestinationStyleOpInterface>()) {
976-
rewriter.setInsertionPoint(tiledDestStyleOp);
977-
Value newRegionArg = newRegionIterArgs.back();
961+
Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
962+
*tiledOwner = fusedProducerInfo.tiledOps[0];
963+
964+
Location loc = originalOwner->getLoc();
965+
// a. collect all init Value to be appended
966+
SmallVector<unsigned> initNumberList =
967+
yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
968+
0, originalOwner->getNumResults()))
969+
: llvm::to_vector(yieldResultNumber);
970+
SmallVector<Value> initValueList;
971+
for (const auto &resultNumber : initNumberList) {
972+
FailureOr<Value> initValue = tensor::getOrCreateDestination(
973+
rewriter, loc, originalOwner->getResult(resultNumber));
974+
if (succeeded(initValue)) {
975+
initValueList.push_back(initValue.value());
976+
} else {
977+
return failure();
978+
}
979+
}
980+
981+
YieldTiledValuesFn newYieldValuesFn =
982+
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
983+
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
984+
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
985+
SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
986+
OpBuilder::InsertionGuard g(innerRewriter);
987+
988+
// get sliceOp tile information
989+
SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
990+
sliceSizes = sliceOp.getMixedSizes();
991+
992+
// expect all strides of sliceOp being 1
993+
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
994+
return !isConstantIntValue(ofr, 1);
995+
}))
996+
return failure();
997+
998+
unsigned sliceResultNumber =
999+
fusedProducerInfo.origProducer.getResultNumber();
1000+
1001+
auto tilableOp = cast<TilingInterface>(originalOwner);
1002+
// b. get iterDomain Offset and Sizes based on sliceOp tile
1003+
SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1004+
// skip tensor.pack/unpack/pad, which expects single opResult
1005+
if (tilableOp->getNumResults() > 1 &&
1006+
failed(tilableOp.getIterationDomainTileFromResultTile(
1007+
rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1008+
iterDomainOffset, iterDomainSizes))) {
1009+
// In theory, it is unnecessary to raise an error here. Actually although
1010+
// it fails to reconstruct the result tensor, it should not broke current
1011+
// fusion anyway. The reason why we must return failure currently is that
1012+
// the callback function `newYieldValuesFn` will be called after new init
1013+
// operand(s) has already been appended. It will take more refactoring to
1014+
// make sure the init operands are added consistently in the future. For
1015+
// more details, please refer to:
1016+
// https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1017+
return failure();
1018+
}
1019+
1020+
// c. calculate offsets and sizes info of all OpResults respectively based
1021+
// on iteration Domain Tile
1022+
SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1023+
for (const auto &resultNumber : initNumberList) {
1024+
if (resultNumber == sliceResultNumber) {
1025+
offsetList.push_back(sliceOffset);
1026+
sizesList.push_back(sliceSizes);
1027+
} else {
1028+
assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1029+
// infer result tile according to the iteration domain tile
1030+
SmallVector<OpFoldResult> offset, sizes;
1031+
if (failed(tilableOp.getResultTilePosition(
1032+
rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1033+
offset, sizes))) {
1034+
return failure();
1035+
}
1036+
offsetList.push_back(offset);
1037+
sizesList.push_back(sizes);
1038+
}
1039+
}
1040+
1041+
// d. create `extract_slice` for `iter_args` for DPS operation if necessary
1042+
if (auto tiledDestStyleOp =
1043+
dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1044+
rewriter.setInsertionPoint(tiledDestStyleOp);
1045+
for (const auto &&[index, newRegionArg] :
1046+
llvm::enumerate(newRegionIterArgs)) {
9781047
auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
979-
sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
980-
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
981-
unsigned resultNumber = fusableProducer.getResultNumber();
1048+
loc, newRegionArg, offsetList[index], sizesList[index],
1049+
SmallVector<OpFoldResult>(offsetList[index].size(),
1050+
rewriter.getIndexAttr(1)));
1051+
unsigned resultNumber = initNumberList[index];
9821052
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
9831053
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
9841054
});
9851055
}
986-
Block *block = rewriter.getInsertionPoint()->getBlock();
987-
rewriter.setInsertionPoint(block->getTerminator());
988-
tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
989-
tiledOffset.emplace_back(sliceOp.getMixedOffsets());
990-
tiledSizes.emplace_back(sliceOp.getMixedSizes());
991-
return success();
992-
};
1056+
}
9931057

994-
return addInitOperandsToLoopNest(rewriter, loops,
995-
SmallVector<Value>{initValue.value()},
996-
newYieldValuesFn);
997-
}
998-
return success();
1058+
// e. prepare tiled offset and sizes for later `insert_slice` creation by
1059+
// caller
1060+
Block *block = rewriter.getInsertionPoint()->getBlock();
1061+
rewriter.setInsertionPoint(block->getTerminator());
1062+
for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1063+
tiledResult.push_back(tiledOwner->getResult(resultNumber));
1064+
tiledOffset.emplace_back(offsetList[index]);
1065+
tiledSizes.emplace_back(sizesList[index]);
1066+
}
1067+
return success();
1068+
};
1069+
1070+
return addInitOperandsToLoopNest(rewriter, loops, initValueList,
1071+
newYieldValuesFn);
9991072
}
10001073

10011074
/// Implementation of tile consumer and fuse producer greedily.
@@ -1085,14 +1158,22 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
10851158
continue;
10861159

10871160
if (yieldReplacement) {
1161+
// Reconstruct and yield all opResult of fusableProducerOp by default. The
1162+
// caller can specific which one to yield by designating optional argument
1163+
// named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
1164+
Operation *fusableProducerOp = fusableProducer.getOwner();
10881165
if (failed(yieldReplacementForFusedProducer(
10891166
rewriter, candidateSliceOp, fusedResult.value(), loops))) {
10901167
return rewriter.notifyMatchFailure(
1091-
fusableProducer.getOwner(), "failed to replacement value for this "
1092-
"oepration from within the tiled loop");
1168+
fusableProducerOp, "failed to replacement value for this "
1169+
"operation from within the tiled loop");
1170+
}
1171+
for (auto [index, result] :
1172+
llvm::enumerate(fusableProducerOp->getResults())) {
1173+
origValToResultNumber[result] = loops.front()->getNumResults() -
1174+
fusableProducerOp->getNumResults() +
1175+
index;
10931176
}
1094-
origValToResultNumber[fusableProducer] =
1095-
loops.front()->getNumResults() - 1;
10961177
}
10971178

10981179
if (Operation *tiledAndFusedOp =

mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,65 @@ module attributes {transform.with_named_sequence} {
5858
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
5959
// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]]
6060
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0
61+
62+
// -----
63+
64+
func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>,
65+
%rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>,
66+
%rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>)
67+
-> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) {
68+
%out0, %out1 = linalg.generic {
69+
indexing_maps = [affine_map<(i, j) -> (i, j)>,
70+
affine_map<(i, j) -> (i, j)>,
71+
affine_map<(i, j) -> (i, j)>,
72+
affine_map<(i, j) -> (j, i)>],
73+
iterator_types = ["parallel", "parallel"]
74+
}
75+
ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>)
76+
outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) {
77+
^bb0(%0: f32, %1: f32, %2: f32, %3: f32):
78+
%4 = arith.mulf %0, %1 : f32
79+
%5 = arith.addf %0, %1 : f32
80+
linalg.yield %4, %5: f32, f32
81+
} -> (tensor<32x32xf32>, tensor<32x32xf32>)
82+
83+
%out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32>
84+
85+
return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>
86+
}
87+
88+
module attributes {transform.with_named_sequence} {
89+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
90+
%add = transform.structured.match ops{["linalg.add"]} in %arg0
91+
: (!transform.any_op) -> !transform.any_op
92+
%a, %b = transform.test.fuse_and_yield %add [16]
93+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
94+
transform.yield
95+
}
96+
}
97+
// CHECK: func.func @multiple_outputs_fusion_yield_all(
98+
// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
99+
// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
100+
// CHECK-SAME: %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
101+
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
102+
// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
103+
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>)
104+
// CHECK: %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] =
105+
// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
106+
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
107+
// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0]
108+
// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
109+
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]]
110+
// CHECK: %[[GENERIC_TILE:.+]]:2 = linalg.generic
111+
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
112+
// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
113+
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0]
114+
// CHECK-DAG: %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
115+
// CHECK: %[[ADD_TILE:.+]] = linalg.add
116+
// CHECK-SAME: ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] :
117+
// CHECK-SAME: outs(%[[INIT2_TILE]] :
118+
// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0]
119+
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0]
120+
// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]]
121+
// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]]
122+
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0

0 commit comments

Comments
 (0)