Skip to content

Commit 4377bf0

Browse files
committed
yield replacement for multiple results
1 parent 3387e55 commit 4377bf0

File tree

5 files changed

+220
-50
lines changed

5 files changed

+220
-50
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,14 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
191191
/// where `%0` had other uses as well. If not reconstructed from within the loop
192192
/// body, uses of `%0` could not be replaced, making it still live and the
193193
/// fusion immaterial.
194+
///
195+
/// The @param `yieldResultNumber` decides which result would be yield. If not
196+
/// given, yield all `opResult` of fused producer.
194197
LogicalResult yieldReplacementForFusedProducer(
195198
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
196199
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
197-
MutableArrayRef<LoopLikeOpInterface> loops);
200+
MutableArrayRef<LoopLikeOpInterface> loops,
201+
std::optional<ArrayRef<unsigned>> yieldResultNumber = std::nullopt);
198202

199203
/// Transformation information returned after tile and fuse.
200204
struct SCFTileAndFuseResult {

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
115115
return failure();
116116
}]
117117
>,
118+
InterfaceMethod<
119+
/*desc=*/[{
120+
Method to return the tile of the iteration domain based
121+
on the given tile of the certain result.
122+
}],
123+
/*retType=*/"::mlir::LogicalResult",
124+
/*methodName=*/"getIterationDomainTileFromResultTile",
125+
/*args=*/(ins
126+
"OpBuilder &":$b,
127+
"unsigned":$resultNumber,
128+
"ArrayRef<OpFoldResult> ":$resultOffsets,
129+
"ArrayRef<OpFoldResult> ":$resultSizes,
130+
"SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
131+
"SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
132+
/*methodBody=*/"",
133+
/*defaultImplementation=*/[{
134+
return failure();
135+
}]
136+
>,
118137
InterfaceMethod<
119138
/*desc=*/[{
120139
Method to generate the code that produces a tile of the result.

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -215,26 +215,39 @@ struct LinalgOpTilingInterface
215215
return success();
216216
}
217217

218-
FailureOr<TilingResult>
219-
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
220-
ArrayRef<OpFoldResult> offsets,
221-
ArrayRef<OpFoldResult> sizes) const {
218+
LogicalResult getIterationDomainTileFromResultTile(
219+
Operation *op, OpBuilder &b, unsigned resultNumber,
220+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
221+
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
222+
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
222223
auto linalgOp = cast<LinalgOp>(op);
223224

224-
// Check that the indexing map used for the output is a projected
225+
// Check that the indexing map used for the operand is a projected
225226
// permutation. This could be relaxed with a more general approach that can
226-
// map the offsets and sizes from the result to iteration space tiles
227+
// map the offsets and sizes from the operand to iteration space tiles
227228
// (filling in full extent for dimensions not used to access the result).
228229
AffineMap indexingMap =
229230
linalgOp.getIndexingMapMatchingResult(op->getResult(resultNumber));
230231
if (!indexingMap.isProjectedPermutation()) {
231-
return op->emitOpError(
232-
"unhandled tiled implementation generation when result is not "
233-
"accessed using a permuted projection");
232+
return op->emitError()
233+
<< "unhandled get iter domain position when operand is not "
234+
"accessed using a permuted projection";
234235
}
235-
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
236+
236237
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
237-
mappedOffsets, mappedSizes);
238+
iterDomainOffsets, iterDomainSizes);
239+
return success();
240+
}
241+
242+
FailureOr<TilingResult>
243+
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
244+
ArrayRef<OpFoldResult> offsets,
245+
ArrayRef<OpFoldResult> sizes) const {
246+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
247+
if (failed(getIterationDomainTileFromResultTile(
248+
op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
249+
return failure();
250+
}
238251
auto tilingInterfaceOp = cast<TilingInterface>(op);
239252
FailureOr<TilingResult> tilingResult =
240253
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 110 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -940,49 +940,114 @@ mlir::scf::tileAndFuseProducerOfSlice(
940940
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
941941
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
942942
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
943-
MutableArrayRef<LoopLikeOpInterface> loops) {
943+
MutableArrayRef<LoopLikeOpInterface> loops,
944+
std::optional<ArrayRef<unsigned>> yieldResultNumber) {
944945
if (loops.empty())
945946
return success();
946947

947-
OpResult fusableProducer = fusedProducerInfo.origProducer;
948-
Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
949-
FailureOr<Value> initValue = tensor::getOrCreateDestination(
950-
rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
951-
if (succeeded(initValue)) {
952-
953-
YieldTiledValuesFn newYieldValuesFn =
954-
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
955-
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
956-
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
957-
SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
958-
-> LogicalResult {
959-
OpBuilder::InsertionGuard g(innerRewriter);
960-
if (auto tiledDestStyleOp =
961-
tiledAndFusedProducer
962-
.getDefiningOp<DestinationStyleOpInterface>()) {
963-
rewriter.setInsertionPoint(tiledDestStyleOp);
964-
Value newRegionArg = newRegionIterArgs.back();
948+
Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
949+
*tiledOwner = fusedProducerInfo.tiledOps[0];
950+
951+
Location loc = originalOwner->getLoc();
952+
// a. collect all init Value to be appended
953+
ArrayRef<unsigned> initNumberList =
954+
yieldResultNumber ? yieldResultNumber.value()
955+
: llvm::to_vector(llvm::seq<unsigned>(
956+
0, originalOwner->getNumResults()));
957+
SmallVector<Value> initValueList;
958+
for (const auto &resultNumber : initNumberList) {
959+
FailureOr<Value> initValue = tensor::getOrCreateDestination(
960+
rewriter, loc, originalOwner->getResult(resultNumber));
961+
if (succeeded(initValue)) {
962+
initValueList.push_back(initValue.value());
963+
} else {
964+
return failure();
965+
}
966+
}
967+
968+
YieldTiledValuesFn newYieldValuesFn =
969+
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
970+
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
971+
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
972+
SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
973+
OpBuilder::InsertionGuard g(innerRewriter);
974+
975+
// get sliceOp tile information
976+
SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
977+
sliceSizes = sliceOp.getMixedSizes();
978+
979+
// expect all strides of sliceOp being 1
980+
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
981+
return !isConstantIntValue(ofr, 1);
982+
}))
983+
return failure();
984+
985+
unsigned sliceResultNumber =
986+
fusedProducerInfo.origProducer.getResultNumber();
987+
988+
auto tilableOp = cast<TilingInterface>(originalOwner);
989+
// b. get iterDomain Offset and Sizes based on sliceOp tile
990+
SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
991+
// skip tensor.pack/unpack/pad, which expects single opResult
992+
if (tilableOp->getNumResults() > 1 &&
993+
failed(tilableOp.getIterationDomainTileFromResultTile(
994+
rewriter, sliceResultNumber, sliceOffset, sliceSizes,
995+
iterDomainOffset, iterDomainSizes))) {
996+
return failure();
997+
}
998+
999+
// c. calculate offsets and sizes info of all OpResults respectively based
1000+
// on iteration Domain Tile
1001+
SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1002+
for (const auto &resultNumber : initNumberList) {
1003+
if (resultNumber == fusedProducerInfo.origProducer.getResultNumber()) {
1004+
offsetList.push_back(sliceOffset);
1005+
sizesList.push_back(sliceSizes);
1006+
} else {
1007+
assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1008+
// infer result tile according to the iteration domain tile
1009+
SmallVector<OpFoldResult> offset, sizes;
1010+
if (failed(tilableOp.getResultTilePosition(
1011+
rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1012+
offset, sizes))) {
1013+
return failure();
1014+
}
1015+
offsetList.push_back(offset);
1016+
sizesList.push_back(sizes);
1017+
}
1018+
}
1019+
1020+
// d. create `extract_slice` for `iter_args` for DPS operation if necessary
1021+
if (auto tiledDestStyleOp =
1022+
dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1023+
rewriter.setInsertionPoint(tiledDestStyleOp);
1024+
for (const auto &&[index, newRegionArg] :
1025+
llvm::enumerate(newRegionIterArgs)) {
9651026
auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
966-
sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
967-
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
968-
unsigned resultNumber = fusableProducer.getResultNumber();
1027+
loc, newRegionArg, offsetList[index], sizesList[index],
1028+
SmallVector<OpFoldResult>(offsetList[index].size(),
1029+
rewriter.getIndexAttr(1)));
1030+
unsigned resultNumber = initNumberList[index];
9691031
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
9701032
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
9711033
});
9721034
}
973-
Block *block = rewriter.getInsertionPoint()->getBlock();
974-
rewriter.setInsertionPoint(block->getTerminator());
975-
tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
976-
tiledOffset.emplace_back(sliceOp.getMixedOffsets());
977-
tiledSizes.emplace_back(sliceOp.getMixedSizes());
978-
return success();
979-
};
1035+
}
9801036

981-
return addInitOperandsToLoopNest(rewriter, loops,
982-
SmallVector<Value>{initValue.value()},
983-
newYieldValuesFn);
984-
}
985-
return success();
1037+
// e. prepare tiled offset and sizes for later `insert_slice` creation by
1038+
// caller
1039+
Block *block = rewriter.getInsertionPoint()->getBlock();
1040+
rewriter.setInsertionPoint(block->getTerminator());
1041+
for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1042+
tiledResult.push_back(tiledOwner->getResult(resultNumber));
1043+
tiledOffset.emplace_back(offsetList[index]);
1044+
tiledSizes.emplace_back(sizesList[index]);
1045+
}
1046+
return success();
1047+
};
1048+
1049+
return addInitOperandsToLoopNest(rewriter, loops, initValueList,
1050+
newYieldValuesFn);
9861051
}
9871052

9881053
/// Implementation of tile consumer and fuse producer greedily.
@@ -1072,14 +1137,21 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
10721137
continue;
10731138

10741139
if (yieldReplacement) {
1140+
// Reconstruct and yield all opResult of fusableProducerOp by default. The
1141+
// caller can specific which one to yield by designating optional argument
1142+
// named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
1143+
Operation *fusableProducerOp = fusableProducer.getOwner();
10751144
if (failed(yieldReplacementForFusedProducer(
10761145
rewriter, candidateSliceOp, fusedResult.value(), loops))) {
10771146
return rewriter.notifyMatchFailure(
1078-
fusableProducer.getOwner(), "failed to replacement value for this "
1079-
"oepration from within the tiled loop");
1147+
fusableProducerOp, "failed to replacement value for this "
1148+
"operation from within the tiled loop");
1149+
}
1150+
for (const auto &result : fusableProducerOp->getResults()) {
1151+
origValToResultNumber[result] =
1152+
loops.front()->getNumResults() -
1153+
(fusableProducerOp->getNumResults() - result.getResultNumber());
10801154
}
1081-
origValToResultNumber[fusableProducer] =
1082-
loops.front()->getNumResults() - 1;
10831155
}
10841156

10851157
if (Operation *tiledAndFusedOp =

mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,65 @@ module attributes {transform.with_named_sequence} {
5858
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
5959
// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]]
6060
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0
61+
62+
// -----
63+
64+
func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>,
65+
%rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>,
66+
%rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>)
67+
-> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) {
68+
%out0, %out1 = linalg.generic {
69+
indexing_maps = [affine_map<(i, j) -> (i, j)>,
70+
affine_map<(i, j) -> (i, j)>,
71+
affine_map<(i, j) -> (i, j)>,
72+
affine_map<(i, j) -> (j, i)>],
73+
iterator_types = ["parallel", "parallel"]
74+
}
75+
ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>)
76+
outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) {
77+
^bb0(%0: f32, %1: f32, %2: f32, %3: f32):
78+
%4 = arith.mulf %0, %1 : f32
79+
%5 = arith.addf %0, %1 : f32
80+
linalg.yield %4, %5: f32, f32
81+
} -> (tensor<32x32xf32>, tensor<32x32xf32>)
82+
83+
%out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32>
84+
85+
return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>
86+
}
87+
88+
module attributes {transform.with_named_sequence} {
89+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
90+
%add = transform.structured.match ops{["linalg.add"]} in %arg0
91+
: (!transform.any_op) -> !transform.any_op
92+
%a, %b = transform.test.fuse_and_yield %add [16]
93+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
94+
transform.yield
95+
}
96+
}
97+
// CHECK: func.func @multiple_outputs_fusion_yield_all(
98+
// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
99+
// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
100+
// CHECK-SAME: %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
101+
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
102+
// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
103+
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>)
104+
// CHECK: %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] =
105+
// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
106+
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
107+
// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0]
108+
// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
109+
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]]
110+
// CHECK: %[[GENERIC_TILE:.+]]:2 = linalg.generic
111+
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
112+
// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
113+
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0]
114+
// CHECK-DAG: %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
115+
// CHECK: %[[ADD_TILE:.+]] = linalg.add
116+
// CHECK-SAME: ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] :
117+
// CHECK-SAME: outs(%[[INIT2_TILE]] :
118+
// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0]
119+
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0]
120+
// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]]
121+
// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]]
122+
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0

0 commit comments

Comments
 (0)