@@ -55,6 +55,30 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
55
55
return filledVector;
56
56
}
57
57
58
+ // / Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
59
+ template <typename SrcOpTy>
60
+ static SmallVector<Operation *> getAsOperations (ArrayRef<SrcOpTy> ops) {
61
+ return llvm::to_vector (
62
+ llvm::map_range (ops, [](auto op) -> Operation * { return op; }));
63
+ }
64
+ template <typename SrcOpTy>
65
+ static SmallVector<Operation *>
66
+ getAsOperations (const SmallVector<SrcOpTy> &ops) {
67
+ return getAsOperations (ArrayRef<SrcOpTy>(ops));
68
+ }
69
+
70
+ // / Convert a list of `Operation *` to a list of `DstOpTy.
71
+ template <typename DstOpTy>
72
+ static SmallVector<DstOpTy> castToTypedOperations (ArrayRef<Operation *> ops) {
73
+ return llvm::to_vector (
74
+ llvm::map_range (ops, [](Operation *op) { return cast<DstOpTy>(op); }));
75
+ }
76
+ template <typename DstOpTy>
77
+ static SmallVector<DstOpTy>
78
+ castToTypedOperations (const SmallVector<Operation *> &ops) {
79
+ return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
80
+ }
81
+
58
82
// ===----------------------------------------------------------------------===//
59
83
// tileUsingSCFForOp implementation.
60
84
// ===----------------------------------------------------------------------===//
@@ -77,10 +101,9 @@ static bool tileDividesIterationDomain(Range loopRange) {
77
101
// / `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
78
102
static OpFoldResult getBoundedTileSize (OpBuilder &b, Location loc,
79
103
Range loopRange, Value iv,
80
- Value tileSize) {
81
- std::optional<int64_t > ts = getConstantIntValue (tileSize);
82
- if (ts && ts.value () == 1 )
83
- return getAsOpFoldResult (tileSize);
104
+ OpFoldResult tileSize) {
105
+ if (isConstantIntValue (tileSize, 1 ))
106
+ return tileSize;
84
107
85
108
if (tileDividesIterationDomain (
86
109
Range{loopRange.offset , loopRange.size , tileSize}))
@@ -296,8 +319,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
296
319
tileSizeVector.append (numLoops - tileSizeVector.size (), zero);
297
320
}
298
321
299
- scf::SCFTilingResult tilingResult;
300
322
SmallVector<OpFoldResult> offsets, sizes;
323
+ SmallVector<scf::ForOp> forLoops;
301
324
{
302
325
// If there is an interchange specified, permute the iteration domain and
303
326
// the tile sizes.
@@ -320,8 +343,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
320
343
// 3. Materialize an empty loop nest that iterates over the tiles. These
321
344
// loops for now do not return any values even if the original operation has
322
345
// results.
323
- tilingResult. loops = generateTileLoopNest (
324
- rewriter, op. getLoc (), iterationDomain, tileSizeVector, offsets, sizes);
346
+ forLoops = generateTileLoopNest (rewriter, op. getLoc (), iterationDomain,
347
+ tileSizeVector, offsets, sizes);
325
348
326
349
if (!interchangeVector.empty ()) {
327
350
auto inversePermutation = invertPermutationVector (interchangeVector);
@@ -331,30 +354,30 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
331
354
}
332
355
333
356
LLVM_DEBUG ({
334
- if (!tilingResult. loops .empty ()) {
357
+ if (!forLoops .empty ()) {
335
358
llvm::dbgs () << " LoopNest shell :\n " ;
336
- tilingResult. loops .front ().dump ();
359
+ forLoops .front ().dump ();
337
360
llvm::dbgs () << " \n " ;
338
361
}
339
362
});
340
363
341
364
// 4. Generate the tiled implementation within the inner most loop.
342
- if (!tilingResult.loops .empty ())
343
- rewriter.setInsertionPoint (
344
- tilingResult.loops .back ().getBody ()->getTerminator ());
365
+ if (!forLoops.empty ())
366
+ rewriter.setInsertionPoint (forLoops.back ().getBody ()->getTerminator ());
345
367
FailureOr<TilingResult> tiledImplementation =
346
368
op.getTiledImplementation (rewriter, offsets, sizes);
347
- tilingResult. tiledOps . append (tiledImplementation-> tiledOps );
369
+
348
370
if (op->getNumResults () == 0 ) {
349
- // nothing more to do.
350
- return tilingResult ;
371
+ return scf::SCFTilingResult{
372
+ tiledImplementation-> tiledOps , getAsOperations (forLoops), {}} ;
351
373
}
352
374
353
375
// If loops are empty, the tiled op is used as the replacement for the untiled
354
376
// op.
355
- if (tilingResult.loops .empty ()) {
356
- tilingResult.replacements = tiledImplementation->tiledValues ;
357
- return tilingResult;
377
+ if (forLoops.empty ()) {
378
+ return scf::SCFTilingResult{tiledImplementation->tiledOps ,
379
+ getAsOperations (forLoops),
380
+ tiledImplementation->tiledValues };
358
381
}
359
382
360
383
// 5. Yield all the results of the tiled operation. The surrounding loop
@@ -378,18 +401,18 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
378
401
destinationTensors)))
379
402
return rewriter.notifyMatchFailure (op, " failed to get destinations" );
380
403
381
- tilingResult. replacements = yieldTiledValues (
404
+ SmallVector<Value> replacements = yieldTiledValues (
382
405
rewriter, destinationTensors, tiledImplementation.value (),
383
- resultOffsetsList, resultSizesList, tilingResult.loops );
384
-
406
+ resultOffsetsList, resultSizesList, forLoops);
385
407
LLVM_DEBUG ({
386
- if (!tilingResult. loops .empty ()) {
408
+ if (!forLoops .empty ()) {
387
409
llvm::dbgs () << " After tiled implementation :\n " ;
388
- tilingResult. loops .front ().dump ();
410
+ forLoops .front ().dump ();
389
411
llvm::dbgs () << " \n " ;
390
412
}
391
413
});
392
- return tilingResult;
414
+ return scf::SCFTilingResult{tiledImplementation->tiledOps ,
415
+ getAsOperations (forLoops), replacements};
393
416
}
394
417
395
418
FailureOr<scf::SCFReductionTilingResult>
@@ -467,6 +490,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
467
490
results.mergeOp = mergeOp;
468
491
return results;
469
492
}
493
+
470
494
// ===----------------------------------------------------------------------===//
471
495
// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
472
496
// ===----------------------------------------------------------------------===//
@@ -637,28 +661,31 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
637
661
}
638
662
639
663
// 1. First tile the consumer.
640
- scf::SCFTileAndFuseResult tileAndFuseResult;
664
+ SmallVector<scf::ForOp> forLoops;
665
+ SetVector<Operation *> fusedProducers, tiledAndFusedOps;
666
+ DenseMap<Value, Value> replacements;
641
667
llvm::SmallDenseMap<Value, int64_t > yieldedValueToResultNumber;
642
668
{
643
669
FailureOr<scf::SCFTilingResult> tilingResult =
644
670
tileUsingSCFForOp (rewriter, consumer, options.tilingOptions );
645
671
if (failed (tilingResult))
646
672
return rewriter.notifyMatchFailure (consumer, " failed to tile consumer" );
647
673
for (auto *tiledOp : tilingResult->tiledOps )
648
- tileAndFuseResult.tiledAndFusedOps .insert (tiledOp);
649
- tileAndFuseResult.loops = std::move (tilingResult->loops );
650
- for (const auto &result : llvm::enumerate (
651
- llvm::zip (consumer->getResults (), tilingResult->replacements ))) {
652
- tileAndFuseResult.replacements [std::get<0 >(result.value ())] =
653
- std::get<1 >(result.value ());
674
+ tiledAndFusedOps.insert (tiledOp);
675
+ forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops );
676
+ for (auto [index, origValue, replacement] :
677
+ llvm::enumerate (consumer->getResults (), tilingResult->replacements )) {
678
+ replacements[origValue] = replacement;
654
679
yieldedValueToResultNumber[tilingResult->tiledOps .back ()->getResult (
655
- result. index ()) ] = result. index () ;
680
+ index) ] = index;
656
681
}
657
682
}
658
683
659
684
// If there are no loops generated, fusion is immaterial.
660
- if (tileAndFuseResult.loops .empty ())
661
- return tileAndFuseResult;
685
+ if (forLoops.empty ()) {
686
+ return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
687
+ getAsOperations (forLoops), replacements};
688
+ }
662
689
663
690
// 2. Typically, the operands of the tiled operation are slices of the
664
691
// operands of the untiled operation. These are expressed in IR using
@@ -675,7 +702,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
675
702
};
676
703
677
704
std::deque<tensor::ExtractSliceOp> candidates;
678
- addCandidateSlices (tileAndFuseResult. tiledAndFusedOps .back (), candidates);
705
+ addCandidateSlices (tiledAndFusedOps.back (), candidates);
679
706
OpBuilder::InsertionGuard g (rewriter);
680
707
while (!candidates.empty ()) {
681
708
// Traverse the slices in BFS fashion.
@@ -685,19 +712,20 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
685
712
// The operands of the fused producer might themselved be slices of
686
713
// values produced by operations that implement the `TilingInterface`.
687
714
// Add these operations to the worklist.
688
- std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
689
- tileAndFuseProducerOfSlice (rewriter, candidateSliceOp,
690
- tileAndFuseResult.loops );
691
- if (!fusedProducer)
715
+ std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
716
+ tileAndFuseProducerOfSlice (rewriter, candidateSliceOp, forLoops);
717
+ if (!fusedResult)
692
718
continue ;
693
719
694
720
if (Operation *tiledAndFusedOp =
695
- fusedProducer->tiledAndFusedProducer .getDefiningOp ()) {
696
- tileAndFuseResult.tiledAndFusedOps .insert (tiledAndFusedOp);
721
+ fusedResult->tiledAndFusedProducer .getDefiningOp ()) {
722
+ fusedProducers.insert (fusedResult->origProducer .getDefiningOp ());
723
+ tiledAndFusedOps.insert (tiledAndFusedOp);
697
724
addCandidateSlices (tiledAndFusedOp, candidates);
698
725
}
699
726
}
700
- return tileAndFuseResult;
727
+ return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
728
+ getAsOperations (forLoops), replacements};
701
729
}
702
730
703
731
// ===----------------------------------------------------------------------===//
0 commit comments