Skip to content

Commit 2ffce48

Browse files
committed
extend consumer fuse to nested scf loop (v2)
1 parent a841446 commit 2ffce48

File tree

3 files changed

+311
-23
lines changed

3 files changed

+311
-23
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,10 @@ struct SCFFuseConsumerOfSliceResult {
254254
*tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
255255
SmallVector<Operation *> tiledOps;
256256
};
257+
FailureOr<scf::SCFFuseConsumerOfSliceResult>
258+
tileAndFuseConsumerOfSliceImpl(RewriterBase &rewriter,
259+
Operation *candidateSliceOp);
260+
257261
FailureOr<scf::SCFFuseConsumerOfSliceResult>
258262
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
259263

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 211 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,12 +1228,24 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
12281228
/// failure otherwise.
12291229
static FailureOr<OpOperand *> getConsumerFromUses(Value val,
12301230
Block *containingOpBlock) {
1231-
// Step 1. Check that the value has exactly one use.
1232-
if (!llvm::hasSingleElement(val.getUses()))
1231+
// Step 1. Check that the value has exactly one use except for scf.yield.
1232+
OpOperand *operand = nullptr;
1233+
for (auto &use : val.getUses()) {
1234+
Operation *user = use.getOwner();
1235+
if (isa<tensor::InsertSliceOp>(user) ||
1236+
isa<tensor::ParallelInsertSliceOp>(user))
1237+
continue;
1238+
else {
1239+
if (operand)
1240+
return failure();
1241+
else
1242+
operand = &use;
1243+
}
1244+
}
1245+
if (!operand)
12331246
return failure();
12341247
// Step 2. Get uses.
1235-
OpOperand &operand = (*val.getUses().begin());
1236-
Operation *consumerOp = operand.getOwner();
1248+
Operation *consumerOp = operand->getOwner();
12371249
// TODO: We have to init result of consumer before scf.for, use
12381250
// DestinationStyleOpInterface to get result shape from init for now.
12391251
// Add support for other op such as op has InferTypeOpInterface.
@@ -1242,7 +1254,22 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
12421254
return failure();
12431255
if (containingOpBlock != consumerOp->getBlock())
12441256
return failure();
1245-
return &operand;
1257+
return operand;
1258+
}
1259+
1260+
/// Return perfectly outer loops of given ForOp(included), sorted from
1261+
/// outer to inner.
1262+
static SmallVector<scf::ForOp> getPerfectlyOuterLoops(scf::ForOp loop) {
1263+
SmallVector<scf::ForOp> outerLoops = {loop};
1264+
auto forOp = loop->getParentOfType<scf::ForOp>();
1265+
while (forOp) {
1266+
Block &body = forOp.getRegion().front();
1267+
if (body.begin() != std::prev(body.end(), 2))
1268+
break;
1269+
outerLoops.push_back(forOp);
1270+
forOp = forOp->getParentOfType<scf::ForOp>();
1271+
}
1272+
return {outerLoops.rbegin(), outerLoops.rend()};
12461273
}
12471274

12481275
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
@@ -1262,9 +1289,10 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
12621289
auto forOp = dyn_cast<scf::ForOp>(containingOp);
12631290
if (!forOp)
12641291
return failure();
1265-
Value resultingValue = forOp->getResult(resultNumber);
1292+
scf::ForOp topLevelForOp = getPerfectlyOuterLoops(forOp).front();
1293+
Value resultingValue = topLevelForOp->getResult(resultNumber);
12661294

1267-
return getConsumerFromUses(resultingValue, containingOp->getBlock());
1295+
return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
12681296
}
12691297

12701298
/// Fetch the first untiled consumer of a scf.forall's result which is yielded
@@ -1383,8 +1411,8 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
13831411
/// Implementation of fusing consumer of a single slice by computing the
13841412
/// slice of the consumer in-place for scf loop.
13851413
FailureOr<scf::SCFFuseConsumerOfSliceResult>
1386-
mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1387-
Operation *candidateSliceOp) {
1414+
mlir::scf::tileAndFuseConsumerOfSliceImpl(RewriterBase &rewriter,
1415+
Operation *candidateSliceOp) {
13881416
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
13891417
candidateSliceOp))
13901418
return failure();
@@ -1418,22 +1446,25 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
14181446
if (isInsertSliceOp) {
14191447
auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
14201448
oldLoopOp = forOp;
1421-
llvm::append_range(newOuts, forOp.getInits());
1422-
oldLoopBody = forOp.getBody();
14231449
initSize = forOp.getInits().size();
14241450
} else {
14251451
auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
14261452
oldLoopOp = forallOp;
1427-
llvm::append_range(newOuts, forallOp.getOutputs());
1428-
oldLoopBody = forallOp.getBody();
14291453
initSize = forallOp.getOutputs().size();
14301454
rank = forallOp.getRank();
14311455
}
14321456

1433-
if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
1457+
Operation *oldTopLevelLoop = oldLoopOp;
1458+
SmallVector<scf::ForOp> oldNestedForOps, newNestedForOps;
1459+
if (isInsertSliceOp) {
1460+
oldNestedForOps = getPerfectlyOuterLoops(cast<scf::ForOp>(oldLoopOp));
1461+
oldTopLevelLoop = oldNestedForOps.front();
1462+
}
1463+
if (failed(checkAssumptionForLoop(oldTopLevelLoop, consumerOp))) {
14341464
return rewriter.notifyMatchFailure(
1435-
oldLoopOp, "containing loop op should either yield just one value or "
1436-
"have the consumer op as its first user");
1465+
oldTopLevelLoop,
1466+
"containing loop op should either yield just one value or "
1467+
"have the consumer op as its first user");
14371468
}
14381469

14391470
OpBuilder::InsertionGuard g(rewriter);
@@ -1442,28 +1473,60 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
14421473
auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
14431474
SmallVector<Value> dpsInits =
14441475
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
1445-
if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
1476+
if (llvm::is_contained(dpsInits, oldTopLevelLoop->getResult(resultNumber))) {
14461477
return rewriter.notifyMatchFailure(
14471478
consumerOp,
14481479
"consumer op taking the result of scf.for as init is not supported");
14491480
}
1450-
newOuts.append(dpsInits);
1481+
SmallVector<Value> newInitAppend = dpsInits;
14511482

14521483
Location loc = oldLoopOp->getLoc();
14531484

14541485
// 3. Create new scf loop op.
14551486
rewriter.setInsertionPoint(consumerOp);
1487+
1488+
// 3.a Create new outer scf loops if necessary
1489+
bool isNestedForOps = isInsertSliceOp && oldNestedForOps.size() > 1;
1490+
if (isNestedForOps) {
1491+
for (auto &&[index, forOp] :
1492+
llvm::enumerate(MutableArrayRef(oldNestedForOps).drop_back())) {
1493+
SmallVector<Value> newInits;
1494+
newInits = llvm::to_vector(forOp.getInits());
1495+
newInits.append(newInitAppend.begin(), newInitAppend.end());
1496+
auto newLoop = rewriter.create<scf::ForOp>(
1497+
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1498+
forOp.getStep(), newInits);
1499+
newInitAppend = llvm::map_to_vector(
1500+
newLoop.getRegionIterArgs().take_back(newInitAppend.size()),
1501+
[](BlockArgument bArg) -> Value { return bArg; });
1502+
rewriter.mergeBlocks(
1503+
forOp.getBody(), newLoop.getBody(),
1504+
newLoop.getBody()->getArguments().take_front(initSize + 1));
1505+
rewriter.replaceOp(
1506+
forOp, newLoop->getResults().take_front(forOp->getNumResults()));
1507+
newNestedForOps.push_back(newLoop);
1508+
rewriter.setInsertionPointAfter(oldNestedForOps[index + 1]);
1509+
}
1510+
}
1511+
1512+
// 3.b Create new inner most scf loop
14561513
Operation *newLoopOp = nullptr;
14571514
Block *newLoopBody = nullptr;
14581515
if (isInsertSliceOp) {
14591516
auto forOp = cast<scf::ForOp>(oldLoopOp);
1517+
llvm::append_range(newOuts, forOp.getInits());
1518+
newOuts.append(newInitAppend);
1519+
oldLoopBody = forOp.getBody();
14601520
auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
14611521
forOp.getUpperBound(),
14621522
forOp.getStep(), newOuts);
14631523
newLoopOp = newForOp;
14641524
newLoopBody = newForOp.getBody();
14651525
} else {
14661526
auto forallOp = cast<scf::ForallOp>(oldLoopOp);
1527+
llvm::append_range(newOuts, forallOp.getOutputs());
1528+
newOuts.append(newInitAppend);
1529+
oldLoopBody = forallOp.getBody();
14671530
auto newForallOp = rewriter.create<scf::ForallOp>(
14681531
loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
14691532
forallOp.getMixedStep(), newOuts, forallOp.getMapping());
@@ -1577,28 +1640,155 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
15771640
newForallOp.getBody()->getArguments().drop_front(rank + initSize));
15781641
}
15791642

1580-
// 12. Replace the result of scf loop and consumer op with new loop's results.
1643+
// 12. Restore outer loops from inner to outer
1644+
if (isNestedForOps) {
1645+
newNestedForOps.push_back(cast<scf::ForOp>(newLoopOp));
1646+
for (auto [outerLoop, innerLoop] :
1647+
llvm::zip_equal(MutableArrayRef(newNestedForOps).drop_back(),
1648+
MutableArrayRef(newNestedForOps).drop_front())) {
1649+
auto forOp = cast<scf::ForOp>(outerLoop);
1650+
auto outerLoopYield =
1651+
cast<scf::YieldOp>(forOp.getBody()->getTerminator());
1652+
SmallVector<Value> newYields =
1653+
llvm::to_vector(outerLoopYield.getOperands());
1654+
ValueRange additionalYields =
1655+
innerLoop->getResults().take_back(newInitAppend.size());
1656+
newYields.append(additionalYields.begin(), additionalYields.end());
1657+
rewriter.setInsertionPoint(outerLoopYield);
1658+
rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
1659+
}
1660+
}
1661+
1662+
// 13. Replace the result of scf loop and consumer op with new loop's results.
15811663
for (auto &&[oldResult, newResult] :
15821664
llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
15831665
rewriter.replaceAllUsesWith(oldResult, newResult);
15841666
}
15851667

1668+
Operation *newTopLevelLoop =
1669+
isNestedForOps ? newNestedForOps.front() : newLoopOp;
15861670
for (auto &&[oldResult, newResult] :
15871671
llvm::zip(consumerOp->getResults(),
1588-
newLoopOp->getResults().drop_front(initSize))) {
1672+
newTopLevelLoop->getResults().drop_front(initSize))) {
15891673
rewriter.replaceAllUsesWith(oldResult, newResult);
15901674
}
15911675

1592-
// 13. Need to erase the old scf loop and the cloned consumer op.
1676+
// 14. Need to erase the old scf loop and the cloned consumer op.
15931677
rewriter.eraseOp(oldLoopOp);
15941678
rewriter.eraseOp(clonedConsumerOp);
15951679

1680+
// 15. Need to erase the cloned insertSliceOp and unused extractSliceOp in
1681+
// avoid of complex domination analysis
1682+
assert(clonedInsertSliceOp->hasOneUse());
1683+
auto unUsedExtractOp =
1684+
cast<tensor::ExtractSliceOp>((*clonedInsertSliceOp->getUsers().begin()));
1685+
rewriter.eraseOp(unUsedExtractOp);
1686+
rewriter.eraseOp(clonedInsertSliceOp);
1687+
15961688
return scf::SCFFuseConsumerOfSliceResult{
15971689
consumerOpOperand,
15981690
&(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
15991691
tileAndFuseResult->tiledOps};
16001692
}
16011693

1694+
/// Get the result of top-level loop which yields the target InsertSliceOp. E.g
1695+
/// ```
1696+
/// %1 = scf.for
1697+
/// %2 = scf.for
1698+
/// %3 = scf.for
1699+
/// ...
1700+
/// %4 = insert
1701+
/// yield %4
1702+
/// %5 = insert %3
1703+
/// yield %5
1704+
/// yield %2
1705+
/// ```
1706+
/// @param targetSliceOp: %4 = insert
1707+
/// @param insertSliceOpChain: chain of all related insert sliceOp
1708+
/// @return resultValue: %1
1709+
static FailureOr<Value> getResultOfTopLevelLoopYieldInsertSliceOp(
1710+
Operation *targetSliceOp,
1711+
SmallVectorImpl<OffsetSizeAndStrideOpInterface> &insertSliceOpChain,
1712+
int curDepth = 0, int maxDepth = 5) {
1713+
assert(isa<OffsetSizeAndStrideOpInterface>(targetSliceOp));
1714+
// Control recursive time in avoid of stack overflow
1715+
if (curDepth > maxDepth)
1716+
return failure();
1717+
1718+
insertSliceOpChain.push_back(
1719+
cast<OffsetSizeAndStrideOpInterface>(targetSliceOp));
1720+
Value resultOfLoop;
1721+
if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(targetSliceOp)) {
1722+
Value destValue = sliceOp.getDest();
1723+
auto iterArg = cast<BlockArgument>(destValue);
1724+
auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
1725+
if (!forallOp)
1726+
return failure();
1727+
resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
1728+
} else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(targetSliceOp)) {
1729+
Value resultValue = sliceOp.getResult();
1730+
for (auto &useOperand : resultValue.getUses()) {
1731+
if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
1732+
if (llvm::detail::isPresent(resultOfLoop))
1733+
return failure();
1734+
auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
1735+
if (!forOp)
1736+
return failure();
1737+
resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
1738+
}
1739+
}
1740+
}
1741+
1742+
if (!llvm::detail::isPresent(resultOfLoop))
1743+
return failure();
1744+
1745+
while (true) {
1746+
bool walkThroughOuterLoop = false;
1747+
for (OpOperand &useOperand : resultOfLoop.getUses()) {
1748+
if (auto sliceOp =
1749+
dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
1750+
return getResultOfTopLevelLoopYieldInsertSliceOp(
1751+
sliceOp, insertSliceOpChain, curDepth + 1);
1752+
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
1753+
// walk through outer loop
1754+
auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
1755+
if (!forOp)
1756+
return failure();
1757+
resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
1758+
walkThroughOuterLoop = true;
1759+
break;
1760+
}
1761+
}
1762+
if (!walkThroughOuterLoop)
1763+
break;
1764+
}
1765+
return resultOfLoop;
1766+
}
1767+
1768+
/// Fusing real consumer of a single slice even within complex nested loops via
1769+
/// multiple application of `tileAndFuseConsumerOfSliceImpl`.
1770+
FailureOr<scf::SCFFuseConsumerOfSliceResult>
1771+
mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1772+
Operation *candidateSliceOp) {
1773+
SmallVector<OffsetSizeAndStrideOpInterface> sliceOpChain;
1774+
if (failed(getResultOfTopLevelLoopYieldInsertSliceOp(candidateSliceOp,
1775+
sliceOpChain)))
1776+
return failure();
1777+
1778+
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResult;
1779+
// reverse from outer to inner
1780+
std::reverse(sliceOpChain.begin(), sliceOpChain.end());
1781+
// multiple application of `tileAndFuseConsumerOfSliceImpl`
1782+
for (auto &sliceOp : sliceOpChain) {
1783+
fuseConsumerResult = tileAndFuseConsumerOfSliceImpl(rewriter, sliceOp);
1784+
if (failed(fuseConsumerResult)) {
1785+
return rewriter.notifyMatchFailure(sliceOp,
1786+
"could not fuse consumer of sliceOp");
1787+
}
1788+
}
1789+
return fuseConsumerResult;
1790+
}
1791+
16021792
//===----------------------------------------------------------------------===//
16031793
// lowerToLoopsUsingSCFForOp implementation.
16041794
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)