@@ -1481,13 +1481,11 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
1481
1481
return &operand;
1482
1482
}
1483
1483
1484
- // / Recursively find the outer nest loops of given loop(included) while the
1485
- // / predict function succeed, sorted from outer to inner.
1484
+ // / Recursively find the outer nest loops of given loop(included) sorted from
1485
+ // / outer to inner.
1486
1486
// /
1487
1487
// / @param loop: target loop, note that this loop will be also included. I.e.
1488
1488
// / if no other nest loops were found, just return itself.
1489
- // / @param pred: predict function, the termination condition of recursive
1490
- // / process.
1491
1489
// / @return Outer Nest Loops: nest loops outside given target loop(included).
1492
1490
// /
1493
1491
// / E.g.
@@ -1498,36 +1496,37 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
1498
1496
// / %2 = scf.for()
1499
1497
// / ```
1500
1498
// /
1501
- // / If `%2 = scf.for` is given without specific prediction function, this
1502
- // / function will return three nest loops: %0 + %1 + %2.
1503
- static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile (
1504
- LoopLikeOpInterface loop,
1505
- const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
1499
+ // / This function will return three nest loops: %0 + %1 + %2.
1500
+ static SmallVector<LoopLikeOpInterface>
1501
+ getPerfectlyOuterNestedLoops (LoopLikeOpInterface loop) {
1506
1502
SmallVector<LoopLikeOpInterface> nestLoops = {loop};
1507
1503
auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp ());
1508
- while (outerLoop && succeeded (pred (outerLoop))) {
1504
+
1505
+ // / Check if it is the ForOp that yield the result of inner loop.
1506
+ auto isForOpYieldResultOfInnerLoop =
1507
+ [](LoopLikeOpInterface outerLoop) -> LogicalResult {
1508
+ auto forOp = dyn_cast<scf::ForOp>(outerLoop.getOperation ());
1509
+ if (!forOp)
1510
+ return failure ();
1511
+ Block *body = forOp.getBody ();
1512
+ if (!llvm::hasSingleElement (body->without_terminator ()))
1513
+ return failure ();
1514
+ auto yieldOp = cast<scf::YieldOp>(body->getTerminator ());
1515
+ auto innerForOp = dyn_cast<scf::ForOp>(body->front ());
1516
+ if (!innerForOp)
1517
+ return failure ();
1518
+ // If any of the innerForOp results are not yielded.
1519
+ return success (innerForOp->getNumResults () == yieldOp->getNumOperands ());
1520
+ };
1521
+
1522
+ while (outerLoop && succeeded (isForOpYieldResultOfInnerLoop (outerLoop))) {
1509
1523
nestLoops.push_back (outerLoop);
1510
1524
outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp ());
1511
1525
}
1512
1526
// sorted from outer to inner
1513
1527
return {nestLoops.rbegin (), nestLoops.rend ()};
1514
1528
}
1515
1529
1516
- // / Check if it is the ForOp that yield the result of inner loop
1517
- static LogicalResult isForOpYieldResultOfInnerLoop (LoopLikeOpInterface loop) {
1518
- if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation ())) {
1519
- Block::OpListType &opsInLoopBody = forOp.getBody ()->getOperations ();
1520
- for (auto &&[index , op] : llvm::enumerate (opsInLoopBody)) {
1521
- // If the orderIndex of inner loop is the last second one before the
1522
- // yieldOp of ForOp, the given loop must yield the result of inner loop.
1523
- if (isa<LoopLikeOpInterface>(op)) {
1524
- return success ((index + 2 ) == opsInLoopBody.size ());
1525
- }
1526
- }
1527
- }
1528
- return failure ();
1529
- }
1530
-
1531
1530
// / Fetch the untiled consumer of a scf.for's result which is yielded by a
1532
1531
// / tensor.insert_slice. This function makes the following assumptions :
1533
1532
// / 1. tensor.insert_slice has scf.yield as its only user.
@@ -1546,7 +1545,7 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
1546
1545
if (!forOp)
1547
1546
return failure ();
1548
1547
LoopLikeOpInterface topLevelForOp =
1549
- getOuterNestLoopsWhile (forOp, isForOpYieldResultOfInnerLoop ).front ();
1548
+ getPerfectlyOuterNestedLoops (forOp).front ();
1550
1549
Value resultingValue = topLevelForOp->getResult (resultNumber);
1551
1550
1552
1551
return getConsumerFromUses (resultingValue, topLevelForOp->getBlock ());
@@ -1650,8 +1649,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1650
1649
candidateSliceOp->getParentOfType <LoopLikeOpInterface>();
1651
1650
SmallVector<LoopLikeOpInterface> nestedLoops;
1652
1651
if (isInsertSliceOp) {
1653
- nestedLoops =
1654
- getOuterNestLoopsWhile (innerMostLoop, isForOpYieldResultOfInnerLoop);
1652
+ nestedLoops = getPerfectlyOuterNestedLoops (innerMostLoop);
1655
1653
} else {
1656
1654
nestedLoops = {innerMostLoop};
1657
1655
}
0 commit comments