@@ -1228,12 +1228,24 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1228
1228
// / failure otherwise.
1229
1229
static FailureOr<OpOperand *> getConsumerFromUses (Value val,
1230
1230
Block *containingOpBlock) {
1231
- // Step 1. Check that the value has exactly one use.
1232
- if (!llvm::hasSingleElement (val.getUses ()))
1231
+ // Step 1. Check that the value has exactly one use except for scf.yield.
1232
+ OpOperand *operand = nullptr ;
1233
+ for (auto &use : val.getUses ()) {
1234
+ Operation *user = use.getOwner ();
1235
+ if (isa<tensor::InsertSliceOp>(user) ||
1236
+ isa<tensor::ParallelInsertSliceOp>(user))
1237
+ continue ;
1238
+ else {
1239
+ if (operand)
1240
+ return failure ();
1241
+ else
1242
+ operand = &use;
1243
+ }
1244
+ }
1245
+ if (!operand)
1233
1246
return failure ();
1234
1247
// Step 2. Get uses.
1235
- OpOperand &operand = (*val.getUses ().begin ());
1236
- Operation *consumerOp = operand.getOwner ();
1248
+ Operation *consumerOp = operand->getOwner ();
1237
1249
// TODO: We have to init result of consumer before scf.for, use
1238
1250
// DestinationStyleOpInterface to get result shape from init for now.
1239
1251
// Add support for other op such as op has InferTypeOpInterface.
@@ -1242,7 +1254,22 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
1242
1254
return failure ();
1243
1255
if (containingOpBlock != consumerOp->getBlock ())
1244
1256
return failure ();
1245
- return &operand;
1257
+ return operand;
1258
+ }
1259
+
1260
+ // / Return perfectly outer loops of given ForOp(included), sorted from
1261
+ // / outer to inner.
1262
+ static SmallVector<scf::ForOp> getPerfectlyOuterLoops (scf::ForOp loop) {
1263
+ SmallVector<scf::ForOp> outerLoops = {loop};
1264
+ auto forOp = loop->getParentOfType <scf::ForOp>();
1265
+ while (forOp) {
1266
+ Block &body = forOp.getRegion ().front ();
1267
+ if (body.begin () != std::prev (body.end (), 2 ))
1268
+ break ;
1269
+ outerLoops.push_back (forOp);
1270
+ forOp = forOp->getParentOfType <scf::ForOp>();
1271
+ }
1272
+ return {outerLoops.rbegin (), outerLoops.rend ()};
1246
1273
}
1247
1274
1248
1275
// / Fetch the untiled consumer of a scf.for's result which is yielded by a
@@ -1262,9 +1289,10 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
1262
1289
auto forOp = dyn_cast<scf::ForOp>(containingOp);
1263
1290
if (!forOp)
1264
1291
return failure ();
1265
- Value resultingValue = forOp->getResult (resultNumber);
1292
+ scf::ForOp topLevelForOp = getPerfectlyOuterLoops (forOp).front ();
1293
+ Value resultingValue = topLevelForOp->getResult (resultNumber);
1266
1294
1267
- return getConsumerFromUses (resultingValue, containingOp ->getBlock ());
1295
+ return getConsumerFromUses (resultingValue, topLevelForOp ->getBlock ());
1268
1296
}
1269
1297
1270
1298
// / Fetch the first untiled consumer of a scf.forall's result which is yielded
@@ -1383,8 +1411,8 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
1383
1411
// / Implementation of fusing consumer of a single slice by computing the
1384
1412
// / slice of the consumer in-place for scf loop.
1385
1413
FailureOr<scf::SCFFuseConsumerOfSliceResult>
1386
- mlir::scf::tileAndFuseConsumerOfSlice (RewriterBase &rewriter,
1387
- Operation *candidateSliceOp) {
1414
+ mlir::scf::tileAndFuseConsumerOfSliceImpl (RewriterBase &rewriter,
1415
+ Operation *candidateSliceOp) {
1388
1416
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1389
1417
candidateSliceOp))
1390
1418
return failure ();
@@ -1418,22 +1446,25 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1418
1446
if (isInsertSliceOp) {
1419
1447
auto forOp = candidateSliceOp->getParentOfType <scf::ForOp>();
1420
1448
oldLoopOp = forOp;
1421
- llvm::append_range (newOuts, forOp.getInits ());
1422
- oldLoopBody = forOp.getBody ();
1423
1449
initSize = forOp.getInits ().size ();
1424
1450
} else {
1425
1451
auto forallOp = candidateSliceOp->getParentOfType <scf::ForallOp>();
1426
1452
oldLoopOp = forallOp;
1427
- llvm::append_range (newOuts, forallOp.getOutputs ());
1428
- oldLoopBody = forallOp.getBody ();
1429
1453
initSize = forallOp.getOutputs ().size ();
1430
1454
rank = forallOp.getRank ();
1431
1455
}
1432
1456
1433
- if (failed (checkAssumptionForLoop (oldLoopOp, consumerOp))) {
1457
+ Operation *oldTopLevelLoop = oldLoopOp;
1458
+ SmallVector<scf::ForOp> oldNestedForOps, newNestedForOps;
1459
+ if (isInsertSliceOp) {
1460
+ oldNestedForOps = getPerfectlyOuterLoops (cast<scf::ForOp>(oldLoopOp));
1461
+ oldTopLevelLoop = oldNestedForOps.front ();
1462
+ }
1463
+ if (failed (checkAssumptionForLoop (oldTopLevelLoop, consumerOp))) {
1434
1464
return rewriter.notifyMatchFailure (
1435
- oldLoopOp, " containing loop op should either yield just one value or "
1436
- " have the consumer op as its first user" );
1465
+ oldTopLevelLoop,
1466
+ " containing loop op should either yield just one value or "
1467
+ " have the consumer op as its first user" );
1437
1468
}
1438
1469
1439
1470
OpBuilder::InsertionGuard g (rewriter);
@@ -1442,28 +1473,60 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1442
1473
auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
1443
1474
SmallVector<Value> dpsInits =
1444
1475
llvm::map_to_vector (dstOp.getDpsInits (), [](Value v) { return v; });
1445
- if (llvm::is_contained (dpsInits, oldLoopOp ->getResult (resultNumber))) {
1476
+ if (llvm::is_contained (dpsInits, oldTopLevelLoop ->getResult (resultNumber))) {
1446
1477
return rewriter.notifyMatchFailure (
1447
1478
consumerOp,
1448
1479
" consumer op taking the result of scf.for as init is not supported" );
1449
1480
}
1450
- newOuts. append ( dpsInits) ;
1481
+ SmallVector<Value> newInitAppend = dpsInits;
1451
1482
1452
1483
Location loc = oldLoopOp->getLoc ();
1453
1484
1454
1485
// 3. Create new scf loop op.
1455
1486
rewriter.setInsertionPoint (consumerOp);
1487
+
1488
+ // 3.a Create new outer scf loops if necessary
1489
+ bool isNestedForOps = isInsertSliceOp && oldNestedForOps.size () > 1 ;
1490
+ if (isNestedForOps) {
1491
+ for (auto &&[index , forOp] :
1492
+ llvm::enumerate (MutableArrayRef (oldNestedForOps).drop_back ())) {
1493
+ SmallVector<Value> newInits;
1494
+ newInits = llvm::to_vector (forOp.getInits ());
1495
+ newInits.append (newInitAppend.begin (), newInitAppend.end ());
1496
+ auto newLoop = rewriter.create <scf::ForOp>(
1497
+ forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
1498
+ forOp.getStep (), newInits);
1499
+ newInitAppend = llvm::map_to_vector (
1500
+ newLoop.getRegionIterArgs ().take_back (newInitAppend.size ()),
1501
+ [](BlockArgument bArg) -> Value { return bArg; });
1502
+ rewriter.mergeBlocks (
1503
+ forOp.getBody (), newLoop.getBody (),
1504
+ newLoop.getBody ()->getArguments ().take_front (initSize + 1 ));
1505
+ rewriter.replaceOp (
1506
+ forOp, newLoop->getResults ().take_front (forOp->getNumResults ()));
1507
+ newNestedForOps.push_back (newLoop);
1508
+ rewriter.setInsertionPointAfter (oldNestedForOps[index + 1 ]);
1509
+ }
1510
+ }
1511
+
1512
+ // 3.b Create new inner most scf loop
1456
1513
Operation *newLoopOp = nullptr ;
1457
1514
Block *newLoopBody = nullptr ;
1458
1515
if (isInsertSliceOp) {
1459
1516
auto forOp = cast<scf::ForOp>(oldLoopOp);
1517
+ llvm::append_range (newOuts, forOp.getInits ());
1518
+ newOuts.append (newInitAppend);
1519
+ oldLoopBody = forOp.getBody ();
1460
1520
auto newForOp = rewriter.create <scf::ForOp>(loc, forOp.getLowerBound (),
1461
1521
forOp.getUpperBound (),
1462
1522
forOp.getStep (), newOuts);
1463
1523
newLoopOp = newForOp;
1464
1524
newLoopBody = newForOp.getBody ();
1465
1525
} else {
1466
1526
auto forallOp = cast<scf::ForallOp>(oldLoopOp);
1527
+ llvm::append_range (newOuts, forallOp.getOutputs ());
1528
+ newOuts.append (newInitAppend);
1529
+ oldLoopBody = forallOp.getBody ();
1467
1530
auto newForallOp = rewriter.create <scf::ForallOp>(
1468
1531
loc, forallOp.getMixedLowerBound (), forallOp.getMixedUpperBound (),
1469
1532
forallOp.getMixedStep (), newOuts, forallOp.getMapping ());
@@ -1577,28 +1640,155 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1577
1640
newForallOp.getBody ()->getArguments ().drop_front (rank + initSize));
1578
1641
}
1579
1642
1580
- // 12. Replace the result of scf loop and consumer op with new loop's results.
1643
+ // 12. Restore outer loops from inner to outer
1644
+ if (isNestedForOps) {
1645
+ newNestedForOps.push_back (cast<scf::ForOp>(newLoopOp));
1646
+ for (auto [outerLoop, innerLoop] :
1647
+ llvm::zip_equal (MutableArrayRef (newNestedForOps).drop_back (),
1648
+ MutableArrayRef (newNestedForOps).drop_front ())) {
1649
+ auto forOp = cast<scf::ForOp>(outerLoop);
1650
+ auto outerLoopYield =
1651
+ cast<scf::YieldOp>(forOp.getBody ()->getTerminator ());
1652
+ SmallVector<Value> newYields =
1653
+ llvm::to_vector (outerLoopYield.getOperands ());
1654
+ ValueRange additionalYields =
1655
+ innerLoop->getResults ().take_back (newInitAppend.size ());
1656
+ newYields.append (additionalYields.begin (), additionalYields.end ());
1657
+ rewriter.setInsertionPoint (outerLoopYield);
1658
+ rewriter.replaceOpWithNewOp <scf::YieldOp>(outerLoopYield, newYields);
1659
+ }
1660
+ }
1661
+
1662
+ // 13. Replace the result of scf loop and consumer op with new loop's results.
1581
1663
for (auto &&[oldResult, newResult] :
1582
1664
llvm::zip_first (oldLoopOp->getResults (), newLoopOp->getResults ())) {
1583
1665
rewriter.replaceAllUsesWith (oldResult, newResult);
1584
1666
}
1585
1667
1668
+ Operation *newTopLevelLoop =
1669
+ isNestedForOps ? newNestedForOps.front () : newLoopOp;
1586
1670
for (auto &&[oldResult, newResult] :
1587
1671
llvm::zip (consumerOp->getResults (),
1588
- newLoopOp ->getResults ().drop_front (initSize))) {
1672
+ newTopLevelLoop ->getResults ().drop_front (initSize))) {
1589
1673
rewriter.replaceAllUsesWith (oldResult, newResult);
1590
1674
}
1591
1675
1592
- // 13 . Need to erase the old scf loop and the cloned consumer op.
1676
+ // 14 . Need to erase the old scf loop and the cloned consumer op.
1593
1677
rewriter.eraseOp (oldLoopOp);
1594
1678
rewriter.eraseOp (clonedConsumerOp);
1595
1679
1680
+ // 15. Need to erase the cloned insertSliceOp and unused extractSliceOp in
1681
+ // avoid of complex domination analysis
1682
+ assert (clonedInsertSliceOp->hasOneUse ());
1683
+ auto unUsedExtractOp =
1684
+ cast<tensor::ExtractSliceOp>((*clonedInsertSliceOp->getUsers ().begin ()));
1685
+ rewriter.eraseOp (unUsedExtractOp);
1686
+ rewriter.eraseOp (clonedInsertSliceOp);
1687
+
1596
1688
return scf::SCFFuseConsumerOfSliceResult{
1597
1689
consumerOpOperand,
1598
1690
&(tileAndFuseResult->tiledOps [0 ]->getOpOperand (operandNumber)),
1599
1691
tileAndFuseResult->tiledOps };
1600
1692
}
1601
1693
1694
+ // / Get the result of top-level loop which yields the target InsertSliceOp. E.g
1695
+ // / ```
1696
+ // / %1 = scf.for
1697
+ // / %2 = scf.for
1698
+ // / %3 = scf.for
1699
+ // / ...
1700
+ // / %4 = insert
1701
+ // / yield %4
1702
+ // / %5 = insert %3
1703
+ // / yield %5
1704
+ // / yield %2
1705
+ // / ```
1706
+ // / @param targetSliceOp: %4 = insert
1707
+ // / @param insertSliceOpChain: chain of all related insert sliceOp
1708
+ // / @return resultValue: %1
1709
+ static FailureOr<Value> getResultOfTopLevelLoopYieldInsertSliceOp (
1710
+ Operation *targetSliceOp,
1711
+ SmallVectorImpl<OffsetSizeAndStrideOpInterface> &insertSliceOpChain,
1712
+ int curDepth = 0 , int maxDepth = 5 ) {
1713
+ assert (isa<OffsetSizeAndStrideOpInterface>(targetSliceOp));
1714
+ // Control recursive time in avoid of stack overflow
1715
+ if (curDepth > maxDepth)
1716
+ return failure ();
1717
+
1718
+ insertSliceOpChain.push_back (
1719
+ cast<OffsetSizeAndStrideOpInterface>(targetSliceOp));
1720
+ Value resultOfLoop;
1721
+ if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(targetSliceOp)) {
1722
+ Value destValue = sliceOp.getDest ();
1723
+ auto iterArg = cast<BlockArgument>(destValue);
1724
+ auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner ()->getParentOp ());
1725
+ if (!forallOp)
1726
+ return failure ();
1727
+ resultOfLoop = forallOp.getTiedOpResult (forallOp.getTiedOpOperand (iterArg));
1728
+ } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(targetSliceOp)) {
1729
+ Value resultValue = sliceOp.getResult ();
1730
+ for (auto &useOperand : resultValue.getUses ()) {
1731
+ if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner ())) {
1732
+ if (llvm::detail::isPresent (resultOfLoop))
1733
+ return failure ();
1734
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp ());
1735
+ if (!forOp)
1736
+ return failure ();
1737
+ resultOfLoop = forOp->getResult (useOperand.getOperandNumber ());
1738
+ }
1739
+ }
1740
+ }
1741
+
1742
+ if (!llvm::detail::isPresent (resultOfLoop))
1743
+ return failure ();
1744
+
1745
+ while (true ) {
1746
+ bool walkThroughOuterLoop = false ;
1747
+ for (OpOperand &useOperand : resultOfLoop.getUses ()) {
1748
+ if (auto sliceOp =
1749
+ dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner ())) {
1750
+ return getResultOfTopLevelLoopYieldInsertSliceOp (
1751
+ sliceOp, insertSliceOpChain, curDepth + 1 );
1752
+ } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner ())) {
1753
+ // walk through outer loop
1754
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp ());
1755
+ if (!forOp)
1756
+ return failure ();
1757
+ resultOfLoop = forOp->getResult (useOperand.getOperandNumber ());
1758
+ walkThroughOuterLoop = true ;
1759
+ break ;
1760
+ }
1761
+ }
1762
+ if (!walkThroughOuterLoop)
1763
+ break ;
1764
+ }
1765
+ return resultOfLoop;
1766
+ }
1767
+
1768
+ // / Fusing real consumer of a single slice even within complex nested loops via
1769
+ // / multiple application of `tileAndFuseConsumerOfSliceImpl`.
1770
+ FailureOr<scf::SCFFuseConsumerOfSliceResult>
1771
+ mlir::scf::tileAndFuseConsumerOfSlice (RewriterBase &rewriter,
1772
+ Operation *candidateSliceOp) {
1773
+ SmallVector<OffsetSizeAndStrideOpInterface> sliceOpChain;
1774
+ if (failed (getResultOfTopLevelLoopYieldInsertSliceOp (candidateSliceOp,
1775
+ sliceOpChain)))
1776
+ return failure ();
1777
+
1778
+ FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResult;
1779
+ // reverse from outer to inner
1780
+ std::reverse (sliceOpChain.begin (), sliceOpChain.end ());
1781
+ // multiple application of `tileAndFuseConsumerOfSliceImpl`
1782
+ for (auto &sliceOp : sliceOpChain) {
1783
+ fuseConsumerResult = tileAndFuseConsumerOfSliceImpl (rewriter, sliceOp);
1784
+ if (failed (fuseConsumerResult)) {
1785
+ return rewriter.notifyMatchFailure (sliceOp,
1786
+ " could not fuse consumer of sliceOp" );
1787
+ }
1788
+ }
1789
+ return fuseConsumerResult;
1790
+ }
1791
+
1602
1792
// ===----------------------------------------------------------------------===//
1603
1793
// lowerToLoopsUsingSCFForOp implementation.
1604
1794
// ===----------------------------------------------------------------------===//
0 commit comments