Skip to content

Commit 38956c2

Browse files
committed
rename getPerfectlyOuterNestedLoops
1 parent 2062960 commit 38956c2

File tree

1 file changed

+26
-28
lines changed

1 file changed

+26
-28
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,13 +1481,11 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
14811481
return &operand;
14821482
}
14831483

1484-
/// Recursively find the outer nest loops of given loop(included) while the
1485-
/// predict function succeed, sorted from outer to inner.
1484+
/// Recursively find the outer nest loops of given loop(included) sorted from
1485+
/// outer to inner.
14861486
///
14871487
/// @param loop: target loop, note that this loop will be also included. I.e.
14881488
/// if no other nest loops were found, just return itself.
1489-
/// @param pred: predict function, the termination condition of recursive
1490-
/// process.
14911489
/// @return Outer Nest Loops: nest loops outside given target loop(included).
14921490
///
14931491
/// E.g.
@@ -1498,36 +1496,37 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
14981496
/// %2 = scf.for()
14991497
/// ```
15001498
///
1501-
/// If `%2 = scf.for` is given without specific prediction function, this
1502-
/// function will return three nest loops: %0 + %1 + %2.
1503-
static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
1504-
LoopLikeOpInterface loop,
1505-
const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
1499+
/// This function will return three nest loops: %0 + %1 + %2.
1500+
static SmallVector<LoopLikeOpInterface>
1501+
getPerfectlyOuterNestedLoops(LoopLikeOpInterface loop) {
15061502
SmallVector<LoopLikeOpInterface> nestLoops = {loop};
15071503
auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp());
1508-
while (outerLoop && succeeded(pred(outerLoop))) {
1504+
1505+
/// Check if it is the ForOp that yield the result of inner loop.
1506+
auto isForOpYieldResultOfInnerLoop =
1507+
[](LoopLikeOpInterface outerLoop) -> LogicalResult {
1508+
auto forOp = dyn_cast<scf::ForOp>(outerLoop.getOperation());
1509+
if (!forOp)
1510+
return failure();
1511+
Block *body = forOp.getBody();
1512+
if (!llvm::hasSingleElement(body->without_terminator()))
1513+
return failure();
1514+
auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
1515+
auto innerForOp = dyn_cast<scf::ForOp>(body->front());
1516+
if (!innerForOp)
1517+
return failure();
1518+
// If any of the innerForOp results are not yielded.
1519+
return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1520+
};
1521+
1522+
while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
15091523
nestLoops.push_back(outerLoop);
15101524
outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp());
15111525
}
15121526
// sorted from outer to inner
15131527
return {nestLoops.rbegin(), nestLoops.rend()};
15141528
}
15151529

1516-
/// Check if it is the ForOp that yield the result of inner loop
1517-
static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) {
1518-
if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
1519-
Block::OpListType &opsInLoopBody = forOp.getBody()->getOperations();
1520-
for (auto &&[index, op] : llvm::enumerate(opsInLoopBody)) {
1521-
// If the orderIndex of inner loop is the last second one before the
1522-
// yieldOp of ForOp, the given loop must yield the result of inner loop.
1523-
if (isa<LoopLikeOpInterface>(op)) {
1524-
return success((index + 2) == opsInLoopBody.size());
1525-
}
1526-
}
1527-
}
1528-
return failure();
1529-
}
1530-
15311530
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
15321531
/// tensor.insert_slice. This function makes the following assumptions :
15331532
/// 1. tensor.insert_slice has scf.yield as its only user.
@@ -1546,7 +1545,7 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
15461545
if (!forOp)
15471546
return failure();
15481547
LoopLikeOpInterface topLevelForOp =
1549-
getOuterNestLoopsWhile(forOp, isForOpYieldResultOfInnerLoop).front();
1548+
getPerfectlyOuterNestedLoops(forOp).front();
15501549
Value resultingValue = topLevelForOp->getResult(resultNumber);
15511550

15521551
return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
@@ -1650,8 +1649,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
16501649
candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
16511650
SmallVector<LoopLikeOpInterface> nestedLoops;
16521651
if (isInsertSliceOp) {
1653-
nestedLoops =
1654-
getOuterNestLoopsWhile(innerMostLoop, isForOpYieldResultOfInnerLoop);
1652+
nestedLoops = getPerfectlyOuterNestedLoops(innerMostLoop);
16551653
} else {
16561654
nestedLoops = {innerMostLoop};
16571655
}

0 commit comments

Comments
 (0)