-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][scf] Extend consumer fuse to single nested scf.for
#94190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: None (Yun-Fly) ChangesHi, based on early discussion in this thread. This patch aims to extend new feature of fusing consumer to more complex nested loop structure. E.g.
What's New in this PR:
NOTE that: this PR DOES NOT deal with the refactor of The resulting IR will finally appear like below:
Looking forward to your suggestion and review, thanks. Patch is 46.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94190.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..9dd730e64a030 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -1103,98 +1104,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
// tileAndFuseConsumerUsingSCF implementation.
//===----------------------------------------------------------------------===//
-/// A utility function that checks whether the only use of the result of a
-/// tensor.insert_slice op is in a scf.yield op.
-static LogicalResult
-checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
- Value result = candidateSliceOp.getResult();
- Value::use_range uses = result.getUses();
- if (!llvm::hasSingleElement(uses)) {
- LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
- return failure();
- }
- OpOperand &operandUse = (*uses.begin());
- Operation *userOp = operandUse.getOwner();
- if (!isa<scf::YieldOp>(userOp)) {
- LLVM_DEBUG(llvm::dbgs()
- << "Expected scf.yield to be the only user, but got -> "
- << (*userOp));
- return failure();
- }
- if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
- LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
- "be in the same block\n");
- return failure();
- }
- return success();
-}
-
-/// Fetches the OpOperand of the only user (and use) of the value `val` which
-/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
-/// failure otherwise.
-static FailureOr<OpOperand *> getConsumerFromUses(Value val,
- Block *containingOpBlock) {
- // Step 1. Check that the value has exactly one use.
- if (!llvm::hasSingleElement(val.getUses()))
- return failure();
- // Step 2. Get uses.
- OpOperand &operand = (*val.getUses().begin());
- Operation *consumerOp = operand.getOwner();
- // TODO: We have to init result of consumer before scf.for, use
- // DestinationStyleOpInterface to get result shape from init for now.
- // Add support for other op such as op has InferTypeOpInterface.
- if (!isa<TilingInterface>(consumerOp) ||
- !isa<DestinationStyleOpInterface>(consumerOp))
- return failure();
- if (containingOpBlock != consumerOp->getBlock())
- return failure();
- return &operand;
-}
-
-/// Fetch the untiled consumer of a scf.for's result which is yielded by a
-/// tensor.insert_slice. This function makes the following assumptions :
-/// 1. tensor.insert_slice has scf.yield as its only user.
-/// 2. scf.for's corresponding result has only one use.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
- if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
- return failure();
- Value sliceResult = candidateSliceOp.getResult();
- // Step 1. Fetch the corresponding output.
- OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
- unsigned resultNumber = yieldOpOperand.getOperandNumber();
- // Step 2. Check containing op is scf.for.
- Operation *containingOp = candidateSliceOp->getParentOp();
- auto forOp = dyn_cast<scf::ForOp>(containingOp);
- if (!forOp)
- return failure();
- Value resultingValue = forOp->getResult(resultNumber);
-
- return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
-/// Fetch the first untiled consumer of a scf.forall's result which is yielded
-/// by a tensor.parallel_insert_slice.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
- // Step 1. Fetch the corresponding output
- Value sliceDest = candidateSliceOp.getDest();
- auto iterArg = dyn_cast<BlockArgument>(sliceDest);
- if (!iterArg)
- return failure();
- Operation *containingOp = iterArg.getOwner()->getParentOp();
- if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
- return failure();
- // Step 2. Check that the containing op is scf.forall.
- auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
- if (!forallOp)
- return failure();
- Value resultingValue =
- forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
-
- return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
/// This utility currently checks whether the loop either :-
/// 1. Yields exactly one result.
/// 2. Has consumer op as its first user and other users to be in the same
@@ -1220,31 +1129,116 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
return success();
}
-/// A utility to fetch an untiled consumer of
-/// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
- if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(insertSlice);
- } else if (auto parallelInsertSlice =
- dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(parallelInsertSlice);
- } else {
+// Traverse and collect all outer loops of given sliceOp, sorted by
+// outer-to-inner. If `untilLoop` found, stop walk through in advance.
+static SmallVector<LoopLikeOpInterface> getOuterLoopsOfSliceOp(
+ OffsetSizeAndStrideOpInterface sliceOp,
+ std::optional<LoopLikeOpInterface> untilLoop = std::nullopt) {
+ SmallVector<LoopLikeOpInterface> outerLoops;
+ auto forOp = sliceOp->getParentOfType<LoopLikeOpInterface>();
+ while (forOp) {
+ outerLoops.push_back(forOp);
+ if (untilLoop.has_value() && *untilLoop == forOp)
+ break;
+ forOp = forOp->getParentOfType<LoopLikeOpInterface>();
+ }
+ return {outerLoops.rbegin(), outerLoops.rend()};
+}
+
+// Get the Result of top-level Loop which yield the target InsertSliceOp. E.g
+// ```
+// %1 = scf.for
+// %2 = scf.for
+// %3 = scf.for
+// ...
+// %4 = insert
+// yield %4
+// %5 = insert %3
+// yield %5
+// yield %2
+// ```
+// @param targetSliceOp: %4 = insert
+// @return Result Value: %1
+// Collected insertSliceOp List during walk including targetSliceOp:
+// %4 = insert and %5 = insert %3
+static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+getResultOfTopLevelLoopYieldInsertSliceOp(
+ OffsetSizeAndStrideOpInterface targetSliceOp, int curDepth = 0,
+ int maxDepth = 5) {
+ // control recursive time in avoid of stack overflow
+ if (curDepth > maxDepth)
+ return failure();
+
+ SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
+ candidateSliceOpList.push_back(targetSliceOp);
+ Value resultOfLoop;
+ if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(
+ targetSliceOp.getOperation())) {
+ Value destValue = sliceOp.getDest();
+ auto iterArg = cast<BlockArgument>(destValue);
+ auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
+ if (!forallOp)
+ return failure();
+ resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+ } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(
+ targetSliceOp.getOperation())) {
+ Value resultValue = sliceOp.getResult();
+ for (auto &useOperand : resultValue.getUses()) {
+ if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+ if (llvm::detail::isPresent(resultOfLoop))
+ return failure();
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+ if (!forOp)
+ return failure();
+ resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+ }
+ }
+ }
+
+ if (!llvm::detail::isPresent(resultOfLoop))
return failure();
+
+ while (true) {
+ bool walkThroughOuterLoop = false;
+ for (auto &useOperand : resultOfLoop.getUses()) {
+ if (auto sliceOp =
+ dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
+ auto resultAndSliceOpsPair =
+ getResultOfTopLevelLoopYieldInsertSliceOp(sliceOp, curDepth + 1);
+ if (failed(resultAndSliceOpsPair))
+ return failure();
+ candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(),
+ (*resultAndSliceOpsPair).second.end());
+ return std::make_pair((*resultAndSliceOpsPair).first,
+ candidateSliceOpList);
+ } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+ // walk through outer loop
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+ if (!forOp)
+ return failure();
+ resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+ walkThroughOuterLoop = true;
+ break;
+ }
+ }
+ if (!walkThroughOuterLoop)
+ break;
}
+ return std::make_pair(resultOfLoop, candidateSliceOpList);
}
/// After fusing consumer into scf.for we want to modify the scf.yield operation
/// to reflect the same by returning the values yielded by the tiled consumer.
static void
fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
- TilingResult &tilingResult,
- ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+ ResultRange tilingResult,
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes,
ArrayRef<BlockArgument> bbArgs) {
scf::YieldOp oldTerminatorOp =
cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
unsigned totalOldResults = oldTerminatorOp->getNumResults();
- unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+ unsigned totalTiledResults = tilingResult.size();
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(totalOldResults + totalTiledResults);
for (auto oldResult : oldTerminatorOp.getResults()) {
@@ -1253,8 +1247,7 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
rewriter.setInsertionPointAfter(oldTerminatorOp);
Location loc = newForOp.getLoc();
for (auto [tiledResult, bbArg, resultOffset, resultSize] :
- llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
- resultOffsets, resultSizes)) {
+ llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
SmallVector<OpFoldResult> strides(resultOffset.size(),
rewriter.getIndexAttr(1));
Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
@@ -1267,18 +1260,17 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
/// After fusing consumer into scf.forall we want to yield each of the resulting
/// values by the tiled consumer within scf.forall.in_parallel region.
-static void
-fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
- SmallVector<Value> tiledResults,
- ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
- ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
- ArrayRef<BlockArgument> bbArgs) {
+static void fixTerminatorSCFInParallel(
+ RewriterBase &rewriter, scf::ForallOp newForallOp, ResultRange tilingResult,
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+ ArrayRef<BlockArgument> bbArgs) {
scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
Location firstYieldOpLoc =
(*(newTerminatorOp.getYieldingOps().begin())).getLoc();
for (auto [tiledResult, bbArg, resultOffset, resultSize] :
- llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+ llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
SmallVector<OpFoldResult> strides(resultOffset.size(),
rewriter.getIndexAttr(1));
rewriter.create<tensor::ParallelInsertSliceOp>(
@@ -1286,6 +1278,180 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
}
}
+// If the top level loop of nested loop structure is scf.forall, need to create
+// additional tensor.extract_slice for its new appended `shared_outs` in order
+// to pass correct local memory for inner loops. E.g.
+//
+// scf.forall shared_outs(%o1=..., %o2=...) {
+// %local_o1 = extract_slice %o1
+// // fix new appended `shared_out` %o2
+// %local_o2 = extract_slice %o2
+// scf.for init_args(%init1=%local_o1, %init2=%local_o2) {
+// ...
+// }
+// ...
+// }
+static void
+fixSharedOutSCFForall(RewriterBase &rewriter, scf::ForallOp outerLoop,
+ LoopLikeOpInterface innerLoop,
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+ unsigned newInitSize,
+ SmallVector<tensor::ExtractSliceOp> &newExtractOps) {
+ rewriter.setInsertionPoint(innerLoop);
+ Location Loc = outerLoop.getLoc();
+ MutableArrayRef<BlockArgument> bbArgs = outerLoop.getBody()->getArguments();
+
+ SmallVector<tensor::ExtractSliceOp> newOps;
+ newOps.reserve(resultOffsets.size());
+ for (auto [bbArg, offset, sizes] : llvm::zip_equal(
+ bbArgs.take_back(newInitSize), resultOffsets, resultSizes)) {
+ SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+ auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+ Loc, bbArg, offset, sizes, strides);
+ newOps.push_back(newExtractOp);
+ }
+ newExtractOps = newOps;
+}
+
+// If outerMost loop of nested loop structure is `scf.forall`, need to deal with
+// DpsInit of tiled consumer
+static void fixDpsInitsOfTiledConsumer(
+ RewriterBase &rewriter, Operation *tiledConsumer,
+ ArrayRef<BlockArgument> bbArgs,
+ SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &resultSizes) {
+ rewriter.setInsertionPoint(tiledConsumer);
+ Location Loc = tiledConsumer->getLoc();
+ for (auto &&[bbArg, offset, sizes, dpsInit] :
+ llvm::zip_equal(bbArgs, resultOffsets, resultSizes,
+ cast<DestinationStyleOpInterface>(tiledConsumer)
+ .getDpsInitsMutable())) {
+ SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+ auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+ Loc, bbArg, offset, sizes, strides);
+ dpsInit.set(newExtractOp.getResult());
+ }
+}
+
+// compute all results tile by given SliceOp along operand
+static LogicalResult computeAllResultTileForOpGivenOperandSliceOp(
+ RewriterBase &rewriter, TilingInterface tilableOp, unsigned operandNumber,
+ OffsetSizeAndStrideOpInterface ossSliceOp,
+ SmallVector<SmallVector<OpFoldResult>> &allResultOffsets,
+ SmallVector<SmallVector<OpFoldResult>> &allResultSizes) {
+ // 1. check all stride all 1
+ if (llvm::any_of(ossSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(ossSliceOp, "ossSliceOp has stride");
+ }
+ // 2. compute iteration domain Tile from input position
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ if (failed(tilableOp.getIterationDomainTileFromOperandTile(
+ rewriter, operandNumber, ossSliceOp.getMixedOffsets(),
+ ossSliceOp.getMixedSizes(), iterDomainOffsets, iterDomainSizes))) {
+ return rewriter.notifyMatchFailure(
+ tilableOp, "can't get iter domain position from input position");
+ }
+ unsigned totalNumResultsOfConsumer = tilableOp->getNumResults();
+ SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+ totalNumResultsOfConsumer);
+ SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+ // 3. compute result Tile by resultNumber
+ for (auto [idx, v] : llvm::enumerate(tilableOp->getResults())) {
+ if (failed(tilableOp.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultOffsets[idx], resultSizes[idx]))) {
+ return rewriter.notifyMatchFailure(
+ tilableOp,
+ "can't get result domain position from iter domain position");
+ }
+ }
+ allResultOffsets = resultOffsets;
+ allResultSizes = resultSizes;
+ return success();
+}
+
+// Considering multi-level tensor.*SliceOp maybe based on different
+// coordination, this utility computes the real OFFSET coordinated on ROOT
+// SliceOp. E.g
+// %0 = insert_slice %1 into %2[OFFSET1] [SIZE1]
+// %3 = insert_slice %4 into %5[OFFSET2] [SIZE2]
+//
+// where the coordination can be illustrated as follow:
+//
+// %3 ----------------------------------
+// | | |
+// | OFFSET2 | OFFSET1 |
+// | ------ %0 |
+// | |
+// | |
+// |------------------ %1 ------ |
+// | | SIZE1 |
+// | | |
+// | | |
+// | | ------- |
+// |
+//
+// The real OFFSET of %1 coordinated on %3 is actually `OFFSET1` + `OFFSET2`
+static FailureOr<SmallVector<OpFoldResult>>
+computeRealOffsetsCoordinatedRootSliceOp(
+ RewriterBase &rewriter, Location loc,
+ OffsetSizeAndStrideOpInterface candidateSliceOp,
+ MutableArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceOpList) {
+ if (llvm::any_of(candidateSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return rewriter.notifyMatchFailure(candidateSliceOp,
+ "candidateSliceOp has stride");
+ }
+ SmallVector<OpFoldResult> realOffsets = candidateSliceOp.getMixedOffsets();
+ // real offsets equals to accumulative offsets of outer candidates
+ for (auto iter = candidateSliceOpList.rbegin(); *iter != candidateSliceOp;
+ iter++) {
+ // assert each outer candidate slice has no stride
+ if (llvm::any_of(iter->getMixedStrides(), [](OpFoldResult stride) {
+ return !isConstantIntValue(stride, 1);
+ })) {
+ return failure();
+ }
+ for (auto &&[ofr1, ofr2] :
+ llvm::zip_equal(realOffsets, iter->getMixedOffsets())) {
+ using AVE = affine::AffineValueExpr;
+ affine::AffineBuilder ab(rewriter, loc);
+ AffineExpr dim0, dim1, sym;
+ bindDims(rewriter.getContext(), dim0, dim1);
+ bindSymbols(rewriter.getContext(), sym);
+ auto aveOffset1 = AVE(dim0).bind(ofr1), aveOffset2 = AVE(dim1).bind(ofr2);
+ ofr1 = ab.add(aveOffset1, aveOffset2);
+ }
+ }
+ return realOffsets;
+}
+
+// Get the first tilable user of given Value and check its domination at the
+// same time
+static FailureOr<OpOperand *>
+getTilableConsumerOperandFirstUseVal(Value val, Operation *loopOp) {
+ for (auto &useOfval : val.getUses()) {
+ Operation *consumerOp = useOfval.getOwner();
+ // 1. Check whether consumerOp is tilable
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp))
+ continue;
+ // 2. check stay in same block with loopOp
+ if (loopOp->getBlock() != consumerOp->getBlock())
+ continue;
+ // 3. check no other user before it
+ if (failed(checkAssumptionForLoop(loopOp, consumerOp))) {
+ continue;
+ }
+ return &useOfval;
+ }
+ return failure();
+}
+
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1297,10 +1463,29 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
bool is...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Yun-Fly - thanks for starting on this!
A few starter nit comments from my end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to look deeper. This is changing things much more than I would expect.
1ab45c1
to
74e3119
Compare
74e3119
to
9c04ad4
Compare
CI issue has been solved, ready for review :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to me like this should be done by multiple application of existing transformations and not by creating a new custom transformation that unrolls both in C++
I agree with Nicolas' comment here. This is the way tile and fuse is supposed to work. You start with
First tile the consumer
then you fuse
So you are fusing the operation into an "immediately created" tiled loop nest. The more general case you are looking for can be done through repeated application. |
Hi, @nicolasvasilache @MaheshRavishankar , try to reply both in one thread.
Could you detail more about how to apply multiple existing transformations by an example?
Again, this patch is the extension of already merged PR involving producer-to-consumer fusion as well. CC: @ZhennanQin. |
Lets start with your example above. I think your input is
You can first tile the
You can fuse the
Now you can apply the same two steps again for the second level of tiling and use scf.for instead. Doesnt that give you what you are looking for? |
@MaheshRavishankar Thanks for you explanation! I can get both your points now.
I agree with you if input starts from this way, which couples tiling and fusion step by step, recursively call
With current implementation, although it is possible to fuse
Compared with fusing consumer from outer to inner step by step with multiple application, the overall logic of this patch can be simplified into three steps for your review:
As you may see, only |
9c04ad4
to
ec9640c
Compare
Hi, @MaheshRavishankar @nicolasvasilache. I have refactored the overall implementation as you advice using multiple application of existing transform. To solve the problem what I mentioned in above thread, some of previous code have to stay, say
In this PR, the original This version maybe much friendly to you with less changes. This PR is quite important for further development. Please help to go on review. Thanks! |
Hi, @MaheshRavishankar @Abhishek-Varma @nicolasvasilache. Sincerely looking forward to your new comments! |
ec9640c
to
2ffce48
Compare
74437c4
to
e11eae4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes. I really request we decouple the start with any loop nest and tile and fuse into it part of the changes here and go more incrementally. The end goals seems to be mixed in with all changes that is adding complexity from the get go. Could we start with just adding support consumer fusion with a single nested scf.for bfore generalizing it. That itself has enough complexity.
return failure(); | ||
// Step 1. Check that the value has exactly one use excluding `insertSliceOp` | ||
// or `ParallelInsertSliceOp`. | ||
OpOperand *operand = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah that is annoying... this is just dead code. Maybe we should figure out how to remove those extract_slice
and insert_slice
.
return failure(); | ||
// Step 1. Check that the value has exactly one use excluding `insertSliceOp` | ||
// or `ParallelInsertSliceOp`. | ||
OpOperand *operand = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this be easier if we just erase the candidateSliceOp
on the first fusion (which probably has some uses which are extract_slices
) so that we dont have to make this more complicated?
I see. I will try to furtherly decouple current changes to support single nested |
0884a18
to
0b9355b
Compare
Hi, @MaheshRavishankar, I have furtherly cleaned up the irrelevant code to merely support single nested IMO, this patch fallback to focus on how do we reconstruct nested |
scf.for
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I just have one last minor comment after which this can land.
A follow up to this would be to change the consumer fusion code to use the addInitOperandsToLoopNest
method cause that already accounts for lot of the complexity here of adding new inits to the tile loop nest.
0b9355b
to
9ff4abc
Compare
9ff4abc
to
38956c2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thanks for cleaning this up! I left one more comment on the use of rewriter. Please address before landing, but this looks good to me.
@Yun-Fly Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/3524 Here is the relevant piece of the build log for the reference
|
…94190)" This reverts commit 2d4bdfb. A build breakage is reported at: https://lab.llvm.org/buildbot/#/builders/138/builds/3524
I've reverted your patch because of a build failure reported at: https://lab.llvm.org/buildbot/#/builders/138/builds/3524 Meanwhile, in my environment with clang-16.0.6 as the host compiler, I see:
|
Sorry for that. I have reopened the new mirror PR(#108318) with build fix. |
Refactor current consumer fusion based on `addInitOperandsToLoopNest` to support single nested `scf.for`, E.g. ``` %0 = scf.for() { %1 = scf.for() { tiledProducer } yield %1 } %2 = consumer ins(%0) ``` Compared with #94190, this PR fix build failure by making C++17 happy.
Hi, based on early discussion in this thread. This patch aims to extend new feature of fusing consumer to more complex nested loop structure. E.g.
What's New in this PR:
scf.for
andscf.forall
.insert_slice
orparallel_insert_slice
.NOTE that: this PR DOES NOT deal with the refactor of
getTiledImplementation
we have talked before but just focuses on the functionality enhancement, BTW, in above example, you can also find that the similar issue related to unmatched semantic between tiled operand and assumption of currentgetTiledImplementation
even ondpsInits
. To unblock this necessary patch, I temporarily follow the method as @MaheshRavishankar suggested, using dummyinsert_slice
to align those gap.The resulting IR will finally appear like below:
Looking forward to your suggestion and review, thanks.