-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][SCF] Add canonicalization pattern to fold away iter args of scf.forall #90189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1415,6 +1415,19 @@ InParallelOp ForallOp::getTerminator() { | |
return cast<InParallelOp>(getBody()->getTerminator()); | ||
} | ||
|
||
SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) { | ||
SmallVector<Operation *> storeOps; | ||
InParallelOp inParallelOp = getTerminator(); | ||
for (Operation &yieldOp : inParallelOp.getYieldingOps()) { | ||
if (auto parallelInsertSliceOp = | ||
Abhishek-Varma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp); | ||
parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) { | ||
storeOps.push_back(parallelInsertSliceOp); | ||
} | ||
} | ||
return storeOps; | ||
} | ||
|
||
std::optional<Value> ForallOp::getSingleInductionVar() { | ||
if (getRank() != 1) | ||
return std::nullopt; | ||
|
@@ -1509,6 +1522,179 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> { | |
} | ||
}; | ||
|
||
/// The following canonicalization pattern folds the iter arguments of | ||
/// scf.forall op if :- | ||
/// 1. The corresponding result has zero uses. | ||
/// 2. The iter argument is NOT being modified within the loop body. | ||
/// uses. | ||
/// | ||
/// Example of first case :- | ||
/// INPUT: | ||
/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c) | ||
/// { | ||
/// ... | ||
/// <SOME USE OF %arg0> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont know the semantics of things yet if you use a shared_outs in anything apart from the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reg context:
So, my understanding is that "each such op" will have
I don't think we should constrain the use of the iter_args to be ONLY that. Ideally the use would be that a slice of iter_arg is being extracted, performed some computation on and then stored back into the same iter_arg. What I'm doing instead is -> as per your comment about having an API defined in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still dont think the semantics of the operation is well defined when a shared-outs is used anywhere outside of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That is correct. If there is no op in the terminator that uses a certain bbarg as a destination, then that bbarg can be removed. You can think of the bbarg as the buffer that is being modified in parallel. If that buffer does not appear in the terminator, then we are not modifying the buffer at all; sure, we may have other ops inside the loop body that insert into the buffer, but if that tensor value is then not fed into a parallel_insert_slice, then these insertion never become visible to the other threads. That's why it's safe to use the init arg instead of the bbarg in the loop body. The bufferization framework will privatize the buffer for each thread.
iter_args are typically also used in the loop body. It doesn’t matter for your canonicalization pattern. You only need to check whether there is a use in the parallel_insert_slice. Potential other uses of the iter_arg ar irrelevant.
Are you talking about the init value (operand of the script.forall) or the iter_arg (bbarg of the region)? If you mean the former, you are right, such IR does usually not make sense. It is supported (will not misbufferize), but a copy will be inserted. The latter (bbarg) is expected to appear in the loop body. The bbarg models the "future buffer of the tensor that is being updated in parallel". The terminator (parallel_insert_slice) just makes sure that the "updated" tensor is written back. Kind of like an "extract_slice, computation, insert_slice" pattern after tiling: the insert_slice doesn’t really do anything is just needed so that the computation result has a use and does not fold away. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I mean both. If you update/use the bbArg within the body of the loop and in the terminator, that seems like an inherent race condition. If one thread is reading/writing data to bbarg or init value then it is inherently reading a non-deterministic state of the result being computed in the loop? EDIT: Ok, I was mistaken. I guess you can use the bbArg in the body cause that represents the value "before the iteration". But really its only used for destination of operations. All other uses are invalid. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Kind of. Another legitimate use case is to extract a slice from the bbarg and then pass the slice into destination-style op (such as a linalg.matmul). (tensor.extract_slice is not a destination-style op and the bbarg is not directly used as a destination.) If two loop iterations extract overlapping slices, there would indeed be a race condition. But our tiling algorithm never generates such IR. Long story short, it is possible to build racy IR with scf.forall. When we designed the scf.forall op, we tried to design it in such a way that race conditions are not possible; we didn’t even have the shared_outs in the beginning. But proving that two different iterations of the loop are independent (which is needed in the bufferization to decide if a buffer must be privatized for each thread) was too difficult/brittle. So what we ended up with is a more explicit operation that gives the user more control, but when used incorrectly may have race conditions. |
||
/// <SOME USE OF %arg1> | ||
/// <SOME USE OF %arg2> | ||
/// ... | ||
/// scf.forall.in_parallel { | ||
/// <STORE OP WITH DESTINATION %arg1> | ||
/// <STORE OP WITH DESTINATION %arg0> | ||
/// <STORE OP WITH DESTINATION %arg2> | ||
/// } | ||
/// } | ||
/// return %res#1 | ||
/// | ||
/// OUTPUT: | ||
/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b) | ||
/// { | ||
/// ... | ||
/// <SOME USE OF %a> | ||
/// <SOME USE OF %new_arg0> | ||
/// <SOME USE OF %c> | ||
/// ... | ||
/// scf.forall.in_parallel { | ||
/// <STORE OP WITH DESTINATION %new_arg0> | ||
/// } | ||
/// } | ||
/// return %res | ||
/// | ||
/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the | ||
/// scf.forall is replaced by their corresponding operands. | ||
/// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body | ||
/// of the scf.forall besides within scf.forall.in_parallel terminator, | ||
/// this canonicalization remains valid. For more details, please refer | ||
/// to : | ||
/// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124 | ||
/// 3. TODO(avarma): Generalize it for other store ops. Currently it | ||
/// handles tensor.parallel_insert_slice ops only. | ||
/// | ||
/// Example of second case :- | ||
/// INPUT: | ||
/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b) | ||
/// { | ||
/// ... | ||
/// <SOME USE OF %arg0> | ||
/// <SOME USE OF %arg1> | ||
/// ... | ||
/// scf.forall.in_parallel { | ||
/// <STORE OP WITH DESTINATION %arg1> | ||
/// } | ||
/// } | ||
/// return %res#0, %res#1 | ||
/// | ||
/// OUTPUT: | ||
/// %res = scf.forall ... shared_outs(%new_arg0 = %b) | ||
/// { | ||
/// ... | ||
/// <SOME USE OF %a> | ||
/// <SOME USE OF %new_arg0> | ||
/// ... | ||
/// scf.forall.in_parallel { | ||
/// <STORE OP WITH DESTINATION %new_arg0> | ||
/// } | ||
/// } | ||
/// return %a, %res | ||
struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> { | ||
using OpRewritePattern<ForallOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(ForallOp forallOp, | ||
PatternRewriter &rewriter) const final { | ||
// Step 1: For a given i-th result of scf.forall, check the following :- | ||
// a. If it has any use. | ||
// b. If the corresponding iter argument is being modified within | ||
// the loop, i.e. has at least one store op with the iter arg as | ||
// its destination operand. For this we use | ||
// ForallOp::getCombiningOps(iter_arg). | ||
// | ||
// Based on the check we maintain the following :- | ||
// a. `resultToDelete` - i-th result of scf.forall that'll be | ||
// deleted. | ||
// b. `resultToReplace` - i-th result of the old scf.forall | ||
// whose uses will be replaced by the new scf.forall. | ||
// c. `newOuts` - the shared_outs' operand of the new scf.forall | ||
// corresponding to the i-th result with at least one use. | ||
SetVector<OpResult> resultToDelete; | ||
SmallVector<Value> resultToReplace; | ||
SmallVector<Value> newOuts; | ||
for (OpResult result : forallOp.getResults()) { | ||
OpOperand *opOperand = forallOp.getTiedOpOperand(result); | ||
BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand); | ||
if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) { | ||
resultToDelete.insert(result); | ||
} else { | ||
resultToReplace.push_back(result); | ||
newOuts.push_back(opOperand->get()); | ||
} | ||
} | ||
|
||
// Return early if all results of scf.forall have at least one use and being | ||
// modified within the loop. | ||
if (resultToDelete.empty()) | ||
return failure(); | ||
|
||
// Step 2: For the the i-th result, do the following :- | ||
// a. Fetch the corresponding BlockArgument. | ||
// b. Look for store ops (currently tensor.parallel_insert_slice) | ||
// with the BlockArgument as its destination operand. | ||
// c. Remove the operations fetched in b. | ||
for (OpResult result : resultToDelete) { | ||
OpOperand *opOperand = forallOp.getTiedOpOperand(result); | ||
BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand); | ||
SmallVector<Operation *> combiningOps = | ||
forallOp.getCombiningOps(blockArg); | ||
for (Operation *combiningOp : combiningOps) | ||
rewriter.eraseOp(combiningOp); | ||
} | ||
|
||
// Step 3. Create a new scf.forall op with the new shared_outs' operands | ||
// fetched earlier | ||
auto newForallOp = rewriter.create<scf::ForallOp>( | ||
Abhishek-Varma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
forallOp.getLoc(), forallOp.getMixedLowerBound(), | ||
forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts, | ||
forallOp.getMapping(), | ||
/*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {}); | ||
|
||
// Step 4. Merge the block of the old scf.forall into the newly created | ||
// scf.forall using the new set of arguments. | ||
Block *loopBody = forallOp.getBody(); | ||
Block *newLoopBody = newForallOp.getBody(); | ||
ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments(); | ||
// Form initial new bbArg list with just the control operands of the new | ||
// scf.forall op. | ||
SmallVector<Value> newBlockArgs = | ||
llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()), | ||
Abhishek-Varma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
[](BlockArgument b) -> Value { return b; }); | ||
Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs(); | ||
unsigned index = 0; | ||
// Take the new corresponding bbArg if the old bbArg was used as a | ||
// destination in the in_parallel op. For all other bbArgs, use the | ||
// corresponding init_arg from the old scf.forall op. | ||
for (OpResult result : forallOp.getResults()) { | ||
if (resultToDelete.count(result)) { | ||
newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get()); | ||
} else { | ||
newBlockArgs.push_back(newSharedOutsArgs[index++]); | ||
} | ||
} | ||
rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs); | ||
|
||
// Step 5. Replace the uses of result of old scf.forall with that of the new | ||
// scf.forall. | ||
for (auto &&[oldResult, newResult] : | ||
llvm::zip(resultToReplace, newForallOp->getResults())) | ||
rewriter.replaceAllUsesWith(oldResult, newResult); | ||
|
||
// Step 6. Replace the uses of those values that either has no use or are | ||
// not being modified within the loop with the corresponding | ||
// OpOperand. | ||
for (OpResult oldResult : resultToDelete) | ||
rewriter.replaceAllUsesWith(oldResult, | ||
forallOp.getTiedOpOperand(oldResult)->get()); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct ForallOpSingleOrZeroIterationDimsFolder | ||
: public OpRewritePattern<ForallOp> { | ||
using OpRewritePattern<ForallOp>::OpRewritePattern; | ||
|
@@ -1667,7 +1853,7 @@ struct FoldTensorCastOfOutputIntoForallOp | |
void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||
MLIRContext *context) { | ||
results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp, | ||
ForallOpControlOperandsFolder, | ||
ForallOpControlOperandsFolder, ForallOpIterArgsFolder, | ||
ForallOpSingleOrZeroIterationDimsFolder>(context); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was hoping that
ParallelCombiningOpInterface
is implemented by the ops in the "combining" region (e.g.,parallel_insert_slice
), but its actually implemented by the op that has the region (in_parallel
). Then we wouldn't have to hard-codeparallel_insert_slice
here. But nvm, we can fix that later...