-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][scf] Extend consumer fuse to single nested scf.for
#108318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][scf] Extend consumer fuse to single nested scf.for
#108318
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: None (Yun-Fly) ChangesThis is a mirror PR of #94190 with tiny build fix. Sorry for your inconvenience. Patch is 28.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108318.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e404c01010a325..f4cf92201068ae 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1481,6 +1481,50 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
return &operand;
}
+/// Find the perfectly nested loops outside of given loop(included) sorted from
+/// outer to inner.
+///
+/// E.g.
+///
+/// ```
+/// %0 = scf.for()
+/// %1 = scf.for()
+/// %2 = scf.for()
+/// %3 = ...
+/// yield %3
+/// yield %2
+/// yield %1
+/// ```
+///
+/// This function will return three perfectly nested loops: %0 + %1 + %2, when
+/// target inner loop is %2.
+static SmallVector<scf::ForOp>
+getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
+ SmallVector<scf::ForOp> nestLoops = {loop};
+ auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
+
+ // Check if it is the ForOp that yield the result of inner loop.
+ auto isForOpYieldResultOfInnerLoop =
+ [](scf::ForOp outerLoop) -> LogicalResult {
+ Block *body = outerLoop.getBody();
+ if (!llvm::hasSingleElement(body->without_terminator()))
+ return failure();
+ auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
+ auto innerForOp = dyn_cast<scf::ForOp>(body->front());
+ if (!innerForOp)
+ return failure();
+ // All of innerForOp results should be yielded.
+ return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
+ };
+
+ while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
+ nestLoops.push_back(outerLoop);
+ outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
+ }
+ // sorted from outer to inner
+ return {nestLoops.rbegin(), nestLoops.rend()};
+}
+
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
/// tensor.insert_slice. This function makes the following assumptions :
/// 1. tensor.insert_slice has scf.yield as its only user.
@@ -1498,9 +1542,10 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
auto forOp = dyn_cast<scf::ForOp>(containingOp);
if (!forOp)
return failure();
- Value resultingValue = forOp->getResult(resultNumber);
+ scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
+ Value resultingValue = topLevelForOp->getResult(resultNumber);
- return getConsumerFromUses(resultingValue, containingOp->getBlock());
+ return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
}
/// Fetch the first untiled consumer of a scf.forall's result which is yielded
@@ -1563,59 +1608,6 @@ static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
}
}
-/// After fusing consumer into scf.for we want to modify the scf.yield operation
-/// to reflect the same by returning the values yielded by the tiled consumer.
-static void
-fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
- TilingResult &tilingResult,
- ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
- ArrayRef<BlockArgument> bbArgs) {
- scf::YieldOp oldTerminatorOp =
- cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
- unsigned totalOldResults = oldTerminatorOp->getNumResults();
- unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
- SmallVector<Value> newYieldOperands;
- newYieldOperands.reserve(totalOldResults + totalTiledResults);
- for (auto oldResult : oldTerminatorOp.getResults()) {
- newYieldOperands.push_back(oldResult);
- }
- rewriter.setInsertionPointAfter(oldTerminatorOp);
- Location loc = newForOp.getLoc();
- for (auto [tiledResult, bbArg, resultOffset, resultSize] :
- llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
- resultOffsets, resultSizes)) {
- SmallVector<OpFoldResult> strides(resultOffset.size(),
- rewriter.getIndexAttr(1));
- Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, tiledResult, bbArg, resultOffset, resultSize, strides);
- newYieldOperands.push_back(newInsertSliceOp);
- }
- rewriter.create<scf::YieldOp>(loc, newYieldOperands);
- rewriter.eraseOp(oldTerminatorOp);
-}
-
-/// After fusing consumer into scf.forall we want to yield each of the resulting
-/// values by the tiled consumer within scf.forall.in_parallel region.
-static void
-fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
- SmallVector<Value> tiledResults,
- ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
- ArrayRef<BlockArgument> bbArgs) {
- scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
- rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
- Location firstYieldOpLoc =
- (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
- for (auto [tiledResult, bbArg, resultOffset, resultSize] :
- llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
- SmallVector<OpFoldResult> strides(resultOffset.size(),
- rewriter.getIndexAttr(1));
- rewriter.create<tensor::ParallelInsertSliceOp>(
- firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
- }
-}
-
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1646,81 +1638,63 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
}
- Operation *oldLoopOp = nullptr;
- SmallVector<Value> newOuts;
- Block *oldLoopBody = nullptr;
- unsigned initSize = 0;
- unsigned rank = 1;
+ // There are two possible cases regarding `oldLoopOp` here:
+ // 1. single `scf.forall` or `scf.for`.
+ // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
+ // top-level loop is the outer-most one of these nested loops.
+ LoopLikeOpInterface innerMostLoop =
+ candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
+ SmallVector<LoopLikeOpInterface> nestedLoops;
if (isInsertSliceOp) {
- auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
- oldLoopOp = forOp;
- llvm::append_range(newOuts, forOp.getInits());
- oldLoopBody = forOp.getBody();
- initSize = forOp.getInits().size();
+ nestedLoops = llvm::map_to_vector(
+ getPerfectlyNestedLoopsOutsideOf(
+ cast<scf::ForOp>(innerMostLoop.getOperation())),
+ [](scf::ForOp forOp) {
+ return cast<LoopLikeOpInterface>(forOp.getOperation());
+ });
} else {
- auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
- oldLoopOp = forallOp;
- llvm::append_range(newOuts, forallOp.getOutputs());
- oldLoopBody = forallOp.getBody();
- initSize = forallOp.getOutputs().size();
- rank = forallOp.getRank();
+ nestedLoops = {innerMostLoop};
}
- if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+ LoopLikeOpInterface outerMostLoop = nestedLoops.front();
+
+ if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) {
return rewriter.notifyMatchFailure(
- oldLoopOp, "containing loop op should either yield just one value or "
- "have the consumer op as its first user");
+ outerMostLoop,
+ "containing loop op should either yield just one value or "
+ "have the consumer op as its first user");
}
OpBuilder::InsertionGuard g(rewriter);
// 2. Check consumer is not using scf loop's output as init.
- auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp)
+ return rewriter.notifyMatchFailure(consumerOp,
+ "consumer op is not DPS operation");
SmallVector<Value> dpsInits =
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
- if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+ if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
return rewriter.notifyMatchFailure(
consumerOp,
"consumer op taking the result of scf.for as init is not supported");
}
- newOuts.append(dpsInits);
-
- Location loc = oldLoopOp->getLoc();
+ SmallVector<Value> newInits = dpsInits;
- // 3. Create new scf loop op.
- rewriter.setInsertionPoint(consumerOp);
- Operation *newLoopOp = nullptr;
- Block *newLoopBody = nullptr;
- if (isInsertSliceOp) {
- auto forOp = cast<scf::ForOp>(oldLoopOp);
- auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
- forOp.getUpperBound(),
- forOp.getStep(), newOuts);
- newLoopOp = newForOp;
- newLoopBody = newForOp.getBody();
- } else {
- auto forallOp = cast<scf::ForallOp>(oldLoopOp);
- auto newForallOp = rewriter.create<scf::ForallOp>(
- loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
- forallOp.getMixedStep(), newOuts, forallOp.getMapping());
- newLoopOp = newForallOp;
- rewriter.eraseOp(newForallOp.getTerminator());
- newLoopBody = newForallOp.getBody();
- }
+ Location loc = outerMostLoop->getLoc();
- // 4. Move the loop body to the new op.
- unsigned oldNumArguments = oldLoopBody->getNumArguments();
- rewriter.mergeBlocks(oldLoopBody, newLoopBody,
- newLoopBody->getArguments().take_front(oldNumArguments));
+ // 3. Move the whole loop structure right before consumer Op, the dominance
+ // should be already ensured by `checkAssumptionForLoop`.
+ rewriter.moveOpBefore(outerMostLoop, consumerOp);
- // 5. Set insertion point before terminator op of the loop and create a new
+ // 4. Set insertion point before terminator op of the loop and create a new
// tensor.insert_slice. In the scf.for case this is a clone of the
// candidateSliceOp whereas in the scf.forall case this is created from the
// operands of tensor.parallel_insert_slice.
tensor::InsertSliceOp clonedInsertSliceOp;
if (auto sliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
- auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+ auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
rewriter.setInsertionPoint(newForallOp.getTerminator());
clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
@@ -1731,20 +1705,17 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
}
- // 6.a. Clone consumer op.
- auto newForOpBlockArgsForConsumerDest =
- newLoopBody->getArguments().drop_front(oldNumArguments);
- auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
- rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+ // 5.a. Clone consumer op.
+ auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
- // 6.b. Replace all uses of the loop result with the result of the cloned
+ // 5.b. Replace all uses of the loop result with the result of the cloned
// tensor.insert_slice.
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
operandToReplace.set(clonedInsertSliceOp.getResult());
});
- // 7 - Perform tiling of the cloned consumer and replace the operand at
+ // 6. Perform tiling of the cloned consumer and replace the operand at
// `operandNumber` with the source of the cloned tensor.insert_slice op.
auto ossSliceOp =
cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
@@ -1754,79 +1725,108 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
if (failed(tileAndFuseResult)) {
return failure();
}
- rewriter.replaceAllUsesWith(
- tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
- clonedInsertSliceOp.getSource());
-
- // 8 - Extract offset/sizes/strides required to create the
- // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
- SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
-
- // 9. Check all insert stride is 1.
- if (llvm::any_of(strides, [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
- return rewriter.notifyMatchFailure(
- candidateSliceOp, "containingOp's result yield with stride");
- }
+ auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
+ rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
+ clonedInsertSliceOp.getSource());
- // 10. Try to get iter domain position from input position.
- SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
- if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
- rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
- iterDomainSizes))) {
- return rewriter.notifyMatchFailure(
- clonedConsumerOp, "can't get iter domain position from input position");
- }
+ // 7. Reconstruct [nested] loop with new inits.
+ YieldTiledValuesFn newYieldValuesFn =
+ [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+ ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+ SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+ SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+ OpBuilder::InsertionGuard g(innerRewriter);
+ // 8. Set inner insertPoint right before tiled consumer op.
+ innerRewriter.setInsertionPoint(tiledConsumerOp);
- // 11. Try to fetch the offset and size for all results of the cloned
- // consumer. This would then be used to form the corresponding
- // tensor.insert_slice/parallel_insert_slice later.
- unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
- SmallVector<SmallVector<OpFoldResult>> resultOffsets(
- totalNumResultsOfConsumer);
- SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
- for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
- if (failed(clonedConsumerOp.getResultTilePosition(
- rewriter, idx, iterDomainOffsets, iterDomainSizes,
- resultOffsets[idx], resultSizes[idx]))) {
+ SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
+
+ // 9. Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
return rewriter.notifyMatchFailure(
- clonedConsumerOp,
- "can't get result domain position from iter domain position");
+ candidateSliceOp, "containingOp's result yield with stride");
}
- }
- auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
- auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
- if (isInsertSliceOp) {
- auto newForOp = cast<scf::ForOp>(newLoopOp);
- fixTerminatorSCFYield(
- rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
- newForOp.getBody()->getArguments().drop_front(1 + initSize));
- } else {
- auto newForallOp = cast<scf::ForallOp>(newLoopOp);
- fixTerminatorSCFInParallel(
- rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
- arrayRefOffsets, arrayRefSizes,
- newForallOp.getBody()->getArguments().drop_front(rank + initSize));
- }
+ // 10. Try to get iter domain position from input position.
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile(
+ rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ return rewriter.notifyMatchFailure(
+ tiledConsumerOp,
+ "can't get iter domain position from input position");
+ }
- // 12. Replace the result of scf loop and consumer op with new loop's results.
- for (auto &&[oldResult, newResult] :
- llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
- rewriter.replaceAllUsesWith(oldResult, newResult);
+ // 11. Try to fetch the offset and size for all results of the cloned
+ // consumer. This would then be used to form the corresponding
+ // tensor.insert_slice/parallel_insert_slice later.
+ unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
+ SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+ totalNumResultsOfConsumer);
+ SmallVector<SmallVector<OpFoldResult>> resultSizes(
+ totalNumResultsOfConsumer);
+ for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
+ if (failed(tiledConsumerOp.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultOffsets[idx], resultSizes[idx]))) {
+ return rewriter.notifyMatchFailure(
+ tiledConsumerOp,
+ "can't get result domain position from iter domain position");
+ }
+ }
+
+ // 12. Create `extract_slice` for `iter_args` for DPS operation if
+ // necessary.
+ if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
+ tiledConsumerOp.getOperation())) {
+ rewriter.setInsertionPoint(tiledDestStyleOp);
+ for (const auto &&[index, newRegionArg] :
+ llvm::enumerate(newRegionIterArgs)) {
+ auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
+ loc, newRegionArg, resultOffsets[index], resultSizes[index],
+ SmallVector<OpFoldResult>(resultOffsets[index].size(),
+ rewriter.getIndexAttr(1)));
+ // Make C++ 17 happy, otherwise it will throw error `captured structured
+ // bindings are a C++20 extension`.
+ auto dstNumber = index;
+ rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
+ tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
+ });
+ }
+ }
+
+ // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
+ // caller.
+ Block *block = rewriter.getInsertionPoint()->getBlock();
+ rewriter.setInsertionPoint(block->getTerminator());
+ for (const auto &&[index, result] :
+ llvm::enumerate(tiledConsumerOp->getResults())) {
+ tiledResult.push_back(result);
+ tiledOffset.emplace_back(resultOffsets[index]);
+ tiledSizes.emplace_back(resultSizes[index]);
+ }
+ return success();
+ };
+ // 14. Add new inits to [nested] loops.
+ if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
+ newYieldValuesFn))) {
+ return rewriter.notifyMatchFailure(tiledConsumerOp,
+ "unable to add new inits to nest loop");
}
- for (auto &&[oldResult, newResult] :
- llvm::zip(consumerOp->getResults(),
- newLoopOp->getResults().drop_front(initSize))) {
+ // 15. Replace the result of scf loop and consumer op with new loop's results.
+
+ for (auto &&[oldResult, newResult] : llvm::zip(
+ consumerOp->getResults(),
+ nestedLoops.front()->getResults().take_back(newInits.size()))) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}
- // 13. Need to erase the old scf loop and the cloned consumer op.
- rewriter.eraseOp(oldLoopOp);
+ // 16. Need to erase the old scf loop and the cloned consumer op.
rewriter.eraseOp(clonedConsumerOp);
return scf::SCF...
[truncated]
|
// Make C++ 17 happy, otherwise it will throw error `captured structured | ||
// bindings are a C++20 extension`. | ||
auto dstNumber = index; | ||
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compared with previous PR #94190 , the only difference is these three lines to make C++ 17 happy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a comment update. Thanks!
115e58d
to
bb8b636
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make sure all files have new line at end of file.
Thanks for kind reminder. |
This is a mirror PR of #94190 with tiny build fix.
Sorry for your inconvenience.