Skip to content

Commit 93c4229

Browse files
[mlir][TilingInterface] NFC code changes separated out from introduction of scf::tileUsingSCFForallop. (#67081)
This patch contains NFC changes that are precursor to the introduction of `scf::tileUsingSCFForallOp` method introduced in #67083.
1 parent b163e52 commit 93c4229

File tree

8 files changed

+148
-127
lines changed

8 files changed

+148
-127
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ struct SCFTilingResult {
6060
/// of the last op.
6161
SmallVector<Operation *> tiledOps;
6262
/// The `scf.for` operations that iterate over the tiles.
63-
SmallVector<scf::ForOp> loops;
63+
SmallVector<Operation *> loops;
6464
/// Values to use as replacements for the untiled op. Is the same size as the
6565
/// number of results of the untiled op.
6666
SmallVector<Value> replacements;
@@ -160,7 +160,7 @@ struct SCFTileAndFuseResult {
160160
/// generated operation.
161161
llvm::SetVector<Operation *> tiledAndFusedOps;
162162
/// The `scf.for` operations that iterate over the tiles.
163-
SmallVector<scf::ForOp> loops;
163+
SmallVector<Operation *> loops;
164164
/// The replacement values to use for the tiled and fused operations.
165165
llvm::DenseMap<Value, Value> replacements;
166166
};

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -434,16 +434,12 @@ static LogicalResult applyTilingToAll(
434434
SmallVector<Operation *> opsToReplace{target};
435435
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
436436
for (Operation *toReplace : opsToReplace) {
437-
SmallVector<Value> replacements;
438-
replacements.reserve(toReplace->getNumResults());
439-
for (OpResult res : toReplace->getResults()) {
440-
auto it = tiledResults->replacements.find(res);
441-
if (it == tiledResults->replacements.end())
442-
replacements.push_back(res);
443-
else
444-
replacements.push_back(it->getSecond());
437+
for (OpResult res : toReplace->getResults())
438+
if (auto replacement = tiledResults->replacements.lookup(res))
439+
rewriter.replaceAllUsesWith(res, replacement);
440+
if (toReplace->use_empty()) {
441+
rewriter.eraseOp(toReplace);
445442
}
446-
rewriter.replaceOp(toReplace, replacements);
447443
}
448444

449445
// Report back the relevant handles to the transform op.

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 70 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,30 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
5555
return filledVector;
5656
}
5757

58+
/// Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
59+
template <typename SrcOpTy>
60+
static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) {
61+
return llvm::to_vector(
62+
llvm::map_range(ops, [](auto op) -> Operation * { return op; }));
63+
}
64+
template <typename SrcOpTy>
65+
static SmallVector<Operation *>
66+
getAsOperations(const SmallVector<SrcOpTy> &ops) {
67+
return getAsOperations(ArrayRef<SrcOpTy>(ops));
68+
}
69+
70+
/// Convert a list of `Operation *` to a list of `DstOpTy.
71+
template <typename DstOpTy>
72+
static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
73+
return llvm::to_vector(
74+
llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); }));
75+
}
76+
template <typename DstOpTy>
77+
static SmallVector<DstOpTy>
78+
castToTypedOperations(const SmallVector<Operation *> &ops) {
79+
return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
80+
}
81+
5882
//===----------------------------------------------------------------------===//
5983
// tileUsingSCFForOp implementation.
6084
//===----------------------------------------------------------------------===//
@@ -77,10 +101,9 @@ static bool tileDividesIterationDomain(Range loopRange) {
77101
/// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
78102
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
79103
Range loopRange, Value iv,
80-
Value tileSize) {
81-
std::optional<int64_t> ts = getConstantIntValue(tileSize);
82-
if (ts && ts.value() == 1)
83-
return getAsOpFoldResult(tileSize);
104+
OpFoldResult tileSize) {
105+
if (isConstantIntValue(tileSize, 1))
106+
return tileSize;
84107

85108
if (tileDividesIterationDomain(
86109
Range{loopRange.offset, loopRange.size, tileSize}))
@@ -296,8 +319,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
296319
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
297320
}
298321

299-
scf::SCFTilingResult tilingResult;
300322
SmallVector<OpFoldResult> offsets, sizes;
323+
SmallVector<scf::ForOp> forLoops;
301324
{
302325
// If there is an interchange specified, permute the iteration domain and
303326
// the tile sizes.
@@ -320,8 +343,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
320343
// 3. Materialize an empty loop nest that iterates over the tiles. These
321344
// loops for now do not return any values even if the original operation has
322345
// results.
323-
tilingResult.loops = generateTileLoopNest(
324-
rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
346+
forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
347+
tileSizeVector, offsets, sizes);
325348

326349
if (!interchangeVector.empty()) {
327350
auto inversePermutation = invertPermutationVector(interchangeVector);
@@ -331,30 +354,30 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
331354
}
332355

333356
LLVM_DEBUG({
334-
if (!tilingResult.loops.empty()) {
357+
if (!forLoops.empty()) {
335358
llvm::dbgs() << "LoopNest shell :\n";
336-
tilingResult.loops.front().dump();
359+
forLoops.front().dump();
337360
llvm::dbgs() << "\n";
338361
}
339362
});
340363

341364
// 4. Generate the tiled implementation within the inner most loop.
342-
if (!tilingResult.loops.empty())
343-
rewriter.setInsertionPoint(
344-
tilingResult.loops.back().getBody()->getTerminator());
365+
if (!forLoops.empty())
366+
rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator());
345367
FailureOr<TilingResult> tiledImplementation =
346368
op.getTiledImplementation(rewriter, offsets, sizes);
347-
tilingResult.tiledOps.append(tiledImplementation->tiledOps);
369+
348370
if (op->getNumResults() == 0) {
349-
// nothing more to do.
350-
return tilingResult;
371+
return scf::SCFTilingResult{
372+
tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
351373
}
352374

353375
// If loops are empty, the tiled op is used as the replacement for the untiled
354376
// op.
355-
if (tilingResult.loops.empty()) {
356-
tilingResult.replacements = tiledImplementation->tiledValues;
357-
return tilingResult;
377+
if (forLoops.empty()) {
378+
return scf::SCFTilingResult{tiledImplementation->tiledOps,
379+
getAsOperations(forLoops),
380+
tiledImplementation->tiledValues};
358381
}
359382

360383
// 5. Yield all the results of the tiled operation. The surrounding loop
@@ -378,18 +401,18 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
378401
destinationTensors)))
379402
return rewriter.notifyMatchFailure(op, "failed to get destinations");
380403

381-
tilingResult.replacements = yieldTiledValues(
404+
SmallVector<Value> replacements = yieldTiledValues(
382405
rewriter, destinationTensors, tiledImplementation.value(),
383-
resultOffsetsList, resultSizesList, tilingResult.loops);
384-
406+
resultOffsetsList, resultSizesList, forLoops);
385407
LLVM_DEBUG({
386-
if (!tilingResult.loops.empty()) {
408+
if (!forLoops.empty()) {
387409
llvm::dbgs() << "After tiled implementation :\n";
388-
tilingResult.loops.front().dump();
410+
forLoops.front().dump();
389411
llvm::dbgs() << "\n";
390412
}
391413
});
392-
return tilingResult;
414+
return scf::SCFTilingResult{tiledImplementation->tiledOps,
415+
getAsOperations(forLoops), replacements};
393416
}
394417

395418
FailureOr<scf::SCFReductionTilingResult>
@@ -467,6 +490,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
467490
results.mergeOp = mergeOp;
468491
return results;
469492
}
493+
470494
//===----------------------------------------------------------------------===//
471495
// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
472496
//===----------------------------------------------------------------------===//
@@ -637,28 +661,31 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
637661
}
638662

639663
// 1. First tile the consumer.
640-
scf::SCFTileAndFuseResult tileAndFuseResult;
664+
SmallVector<scf::ForOp> forLoops;
665+
SetVector<Operation *> fusedProducers, tiledAndFusedOps;
666+
DenseMap<Value, Value> replacements;
641667
llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
642668
{
643669
FailureOr<scf::SCFTilingResult> tilingResult =
644670
tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
645671
if (failed(tilingResult))
646672
return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
647673
for (auto *tiledOp : tilingResult->tiledOps)
648-
tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
649-
tileAndFuseResult.loops = std::move(tilingResult->loops);
650-
for (const auto &result : llvm::enumerate(
651-
llvm::zip(consumer->getResults(), tilingResult->replacements))) {
652-
tileAndFuseResult.replacements[std::get<0>(result.value())] =
653-
std::get<1>(result.value());
674+
tiledAndFusedOps.insert(tiledOp);
675+
forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops);
676+
for (auto [index, origValue, replacement] :
677+
llvm::enumerate(consumer->getResults(), tilingResult->replacements)) {
678+
replacements[origValue] = replacement;
654679
yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
655-
result.index())] = result.index();
680+
index)] = index;
656681
}
657682
}
658683

659684
// If there are no loops generated, fusion is immaterial.
660-
if (tileAndFuseResult.loops.empty())
661-
return tileAndFuseResult;
685+
if (forLoops.empty()) {
686+
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
687+
getAsOperations(forLoops), replacements};
688+
}
662689

663690
// 2. Typically, the operands of the tiled operation are slices of the
664691
// operands of the untiled operation. These are expressed in IR using
@@ -675,7 +702,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
675702
};
676703

677704
std::deque<tensor::ExtractSliceOp> candidates;
678-
addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
705+
addCandidateSlices(tiledAndFusedOps.back(), candidates);
679706
OpBuilder::InsertionGuard g(rewriter);
680707
while (!candidates.empty()) {
681708
// Traverse the slices in BFS fashion.
@@ -685,19 +712,20 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
685712
// The operands of the fused producer might themselved be slices of
686713
// values produced by operations that implement the `TilingInterface`.
687714
// Add these operations to the worklist.
688-
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
689-
tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
690-
tileAndFuseResult.loops);
691-
if (!fusedProducer)
715+
std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
716+
tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
717+
if (!fusedResult)
692718
continue;
693719

694720
if (Operation *tiledAndFusedOp =
695-
fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
696-
tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp);
721+
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
722+
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
723+
tiledAndFusedOps.insert(tiledAndFusedOp);
697724
addCandidateSlices(tiledAndFusedOp, candidates);
698725
}
699726
}
700-
return tileAndFuseResult;
727+
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
728+
getAsOperations(forLoops), replacements};
701729
}
702730

703731
//===----------------------------------------------------------------------===//

mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) ->
88
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
99
%init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1010
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
11-
%gemm = linalg.matmul {__internal_linalg_transform__ = "fusion"}
11+
%gemm = linalg.matmul {__internal_transform__ = "fusion"}
1212
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
1313
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
1414
return %gemm : tensor<?x?xf32>
@@ -47,7 +47,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
4747
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
4848
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
4949
%generic = linalg.generic {
50-
__internal_linalg_transform__ = "fusion",
50+
__internal_transform__ = "fusion",
5151
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
5252
iterator_types = ["parallel", "parallel"]}
5353
ins(%gemm, %arg2 : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) {
@@ -97,7 +97,7 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %r
9797
%d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
9898
%init1 = tensor.empty(%d0, %d2) : tensor<?x?xf32>
9999
%fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
100-
%gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_fusion"}
100+
%gemm1 = linalg.matmul {__internal_transform__ = "gemm_fusion"}
101101
ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
102102
return %gemm1 : tensor<?x?xf32>
103103
}
@@ -147,7 +147,7 @@ func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32
147147
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
148148
%init1 = tensor.empty(%d1, %d0) : tensor<?x?xf32>
149149
%transpose = linalg.generic {
150-
__internal_linalg_transform__ = "fusion",
150+
__internal_transform__ = "fusion",
151151
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
152152
iterator_types = ["parallel", "parallel"]}
153153
ins(%gemm : tensor<?x?xf32>) outs(%init1 : tensor<?x?xf32>) {
@@ -198,7 +198,7 @@ func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?
198198
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
199199
outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
200200
%3 = linalg.generic {
201-
__internal_linalg_transform__ = "gemm_interchange_fusion",
201+
__internal_transform__ = "gemm_interchange_fusion",
202202
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
203203
iterator_types = ["parallel", "parallel"]}
204204
ins(%2 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
@@ -249,7 +249,7 @@ func.func @matmul_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
249249
affine_map<(d0, d1) -> (d0, d1)>,
250250
affine_map<(d0, d1) -> (d0, d1)>],
251251
iterator_types = ["parallel", "parallel"],
252-
__internal_linalg_transform__ = "gemm_plus_gemm_fusion"}
252+
__internal_transform__ = "gemm_plus_gemm_fusion"}
253253
ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
254254
outs(%5 : tensor<?x?xf32>) {
255255
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
@@ -302,7 +302,7 @@ func.func @matmul_plus_transpose_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x
302302
affine_map<(d0, d1) -> (d1, d0)>,
303303
affine_map<(d0, d1) -> (d0, d1)>],
304304
iterator_types = ["parallel", "parallel"],
305-
__internal_linalg_transform__ = "gemm_plus_gemm_fusion"}
305+
__internal_transform__ = "gemm_plus_gemm_fusion"}
306306
ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
307307
outs(%5 : tensor<?x?xf32>) {
308308
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
@@ -352,7 +352,7 @@ func.func @matmul_sequence_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>
352352
%1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
353353
outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N1] * [N1, N2]
354354
%2 = linalg.matmul
355-
{__internal_linalg_transform__ = "gemm_sequence_fusion"}
355+
{__internal_transform__ = "gemm_sequence_fusion"}
356356
ins(%1, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>)
357357
outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N2] * [N2, N3]
358358
return %2 : tensor<?x?xf32>
@@ -425,7 +425,7 @@ func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
425425
linalg.yield %10, %9 : f32, f32
426426
} -> (tensor<30xf32>, tensor<30x3xf32>)
427427
%6 = linalg.generic {
428-
__internal_linalg_transform__ = "reduction_sequence_fusion",
428+
__internal_transform__ = "reduction_sequence_fusion",
429429
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
430430
affine_map<(d0, d1) -> (d0, d1)>],
431431
iterator_types = ["parallel", "parallel"]}

mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?
1313
ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
1414
%d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
1515
%fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
16-
%gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_sequence_fusion_and_yield"}
16+
%gemm1 = linalg.matmul {__internal_transform__ = "gemm_sequence_fusion_and_yield"}
1717
ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
1818
return %gemm0, %gemm1 : tensor<?x?xf32>, tensor<?x?xf32>
1919
}

0 commit comments

Comments
 (0)