-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][SCF] Add canonicalization pattern to fold away iter args of scf.forall #90189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][SCF] Add canonicalization pattern to fold away iter args of scf.forall #90189
Conversation
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: Abhishek Varma (Abhishek-Varma) Changes-- This commit adds a canonicalization pattern to fold away iter args of scf.forall if :- Signed-off-by: Abhishek Varma <[email protected]> Full diff: https://github.com/llvm/llvm-project/pull/90189.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index 423e1c3e1e042c..6e5d80078e8022 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRIR
MLIRLoopLikeInterface
MLIRSideEffectInterfaces
+ MLIRSubsetOpInterface
MLIRTensorDialect
MLIRValueBoundsOpInterface
)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 7a1aafc9f1c2f9..355cfc8b3ee626 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -1509,6 +1510,203 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
}
};
+/// The following canonicalization pattern folds the iter arguments of
+/// scf.forall op if :-
+/// 1. The corresponding result has zero uses.
+/// 2. The iter argument is NOT being modified within the loop body.
+/// uses.
+///
+/// Example of first case :-
+/// INPUT:
+/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
+/// {
+/// ...
+/// <SOME USE OF %arg0>
+/// <SOME USE OF %arg1>
+/// <SOME USE OF %arg2>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %arg1>
+/// <STORE OP WITH DESTINATION %arg0>
+/// <STORE OP WITH DESTINATION %arg2>
+/// }
+/// }
+/// return %res#1
+///
+/// OUTPUT:
+/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
+/// {
+/// ...
+/// <SOME USE OF %a>
+/// <SOME USE OF %new_arg0>
+/// <SOME USE OF %c>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %new_arg0>
+/// }
+/// }
+/// return %res
+///
+/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
+/// scf.forall is replaced by their corresponding operands.
+/// 2. The canonicalization assumes that there are no <STORE OP WITH
+/// DESTINATION *> ops within the body of the scf.forall except within
+/// scf.forall.in_parallel terminator.
+/// 3. The order of the <STORE OP WITH DESTINATION *> can be arbitrary
+/// within scf.forall.in_parallel - the code below takes care of this
+/// by traversing the uses of the corresponding iter arg.
+///
+/// Example of second case :-
+/// INPUT:
+/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
+/// {
+/// ...
+/// <SOME USE OF %arg0>
+/// <SOME USE OF %arg1>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %arg1>
+/// }
+/// }
+/// return %res#0, %res#1
+///
+/// OUTPUT:
+/// %res = scf.forall ... shared_outs(%new_arg0 = %b)
+/// {
+/// ...
+/// <SOME USE OF %a>
+/// <SOME USE OF %new_arg0>
+/// ...
+/// scf.forall.in_parallel {
+/// <STORE OP WITH DESTINATION %new_arg0>
+/// }
+/// }
+/// return %a, %res
+struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
+ using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+ /// Utility function that checks if a candidate value satisifies any of the
+ /// conditions (see above doc comment) to make it viable for folding away.
+ static bool isCandidateValueToDelete(Value result, BlockArgument blockArg) {
+ if (result.use_empty()) {
+ return true;
+ }
+ Value::user_range users = blockArg.getUsers();
+ return llvm::all_of(users, [&](Operation *user) {
+ return !isa<SubsetInsertionOpInterface>(user);
+ });
+ }
+
+ LogicalResult matchAndRewrite(ForallOp forallOp,
+ PatternRewriter &rewriter) const final {
+ scf::InParallelOp terminatorOp = forallOp.getTerminator();
+ SmallVector<Operation *> yieldingOps = llvm::map_to_vector(
+ terminatorOp.getYieldingOps(), [](Operation &op) { return &op; });
+
+ // The following check should indeed be part of SCF::ForallOp::verify.
+ SmallVector<SubsetInsertionOpInterface> subsetInsertionOpInterfaceOps;
+ for (Operation *op : yieldingOps) {
+ if (auto subsetInsertionOpInterfaceOp =
+ dyn_cast<SubsetInsertionOpInterface>(op)) {
+ subsetInsertionOpInterfaceOps.push_back(subsetInsertionOpInterfaceOp);
+ continue;
+ }
+ return failure();
+ }
+
+ // Step 1: For a given i-th result of scf.forall, check the following :-
+ // a. If it has any use.
+ // b. If the corresponding iter argument is being modified within
+ // the loop.
+ //
+ // Based on the check we maintain the following :-
+ // a. `resultToDelete` - i-th result of scf.forall that'll be
+ // deleted.
+ // b. `resultToReplace` - i-th result of the old scf.forall
+ // whose uses will be replaced by the new scf.forall.
+ // c. `newOuts` - the shared_outs' operand of the new scf.forall
+ // corresponding to the i-th result with at least one use.
+ // d. `mapping` - mapping the old iter block argument of scf.forall
+ // with the corresponding shared_outs' operand. This will be
+ // used when creating a new scf.forall op.
+ SmallVector<OpResult> resultToDelete;
+ SmallVector<Value> resultToReplace;
+ SmallVector<Value> newOuts;
+ IRMapping mapping;
+ for (OpResult result : forallOp.getResults()) {
+ OpOperand *opOperand = forallOp.getTiedOpOperand(result);
+ BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+ if (isCandidateValueToDelete(result, blockArg)) {
+ resultToDelete.push_back(result);
+ mapping.map(blockArg, opOperand->get());
+ } else {
+ resultToReplace.push_back(result);
+ newOuts.push_back(opOperand->get());
+ }
+ }
+
+ // Return early if all results of scf.forall has at least one use and being
+ // modified within the loop.
+ if (resultToDelete.empty()) {
+ return failure();
+ }
+
+ // Step 2: For the the i-th result, do the following :-
+ // a. Fetch the corresponding BlockArgument.
+ // b. Look for an op within scf.forall.in_parallel whose destination
+ // operand is the BlockArgument fetched in step a.
+ // c. Remove the operation fetched in b.
+ // d. For any use of the BlockArgument in the body of the scf.forall
+ // replace it with the corresponding Output value.
+ for (OpResult result : resultToDelete) {
+ OpOperand *opOperand = forallOp.getTiedOpOperand(result);
+ BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+ Value::user_range users = blockArg.getUsers();
+ Operation *terminatorOperationToDelete = nullptr;
+ for (Operation *user : users) {
+ if (auto subsetInsertionOpInterfaceOp =
+ dyn_cast<SubsetInsertionOpInterface>(user)) {
+ if (subsetInsertionOpInterfaceOp.getDestinationOperand().get() ==
+ blockArg) {
+ terminatorOperationToDelete = subsetInsertionOpInterfaceOp;
+ break;
+ }
+ }
+ }
+ if (terminatorOperationToDelete)
+ rewriter.eraseOp(terminatorOperationToDelete);
+ }
+
+ // Step 3. Create a new scf.forall op with the new shared_outs' operands
+ // fetched earlier
+ auto newforallOp = rewriter.create<scf::ForallOp>(
+ forallOp.getLoc(), forallOp.getMixedLowerBound(),
+ forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
+ forallOp.getMapping());
+
+ // Step 4. Clone the region of the old scf.forall into the newly created
+ // scf.forall using the IRMapping formed in Step 1.
+ newforallOp.getBodyRegion().getBlocks().clear();
+ rewriter.cloneRegionBefore(forallOp.getRegion(), newforallOp.getRegion(),
+ newforallOp.getRegion().begin(), mapping);
+
+ // Step 5. Replace the uses of result of old scf.forall with that of the new
+ // scf.forall.
+ for (auto &&[oldResult, newResult] :
+ llvm::zip(resultToReplace, newforallOp->getResults())) {
+ rewriter.replaceAllUsesWith(oldResult, newResult);
+ }
+ // Step 6. Replace the uses of those values that either has no use or are
+ // not being modified within the loop with the corresponding
+ // OpOperand.
+ for (OpResult oldResult : resultToDelete) {
+ rewriter.replaceAllUsesWith(oldResult,
+ forallOp.getTiedOpOperand(oldResult)->get());
+ }
+ return success();
+ }
+};
+
struct ForallOpSingleOrZeroIterationDimsFolder
: public OpRewritePattern<ForallOp> {
using OpRewritePattern<ForallOp>::OpRewritePattern;
@@ -1667,7 +1865,7 @@ struct FoldTensorCastOfOutputIntoForallOp
void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
- ForallOpControlOperandsFolder,
+ ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
ForallOpSingleOrZeroIterationDimsFolder>(context);
}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index b4c9ed4db94e0e..9b379ad15f1ecf 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1735,6 +1735,86 @@ func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
// -----
+#map = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+module {
+ func.func @fold_iter_args_not_being_modified_within_scfforall(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 4.200000e+01 : f32
+ %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
+ %dim = tensor.dim %arg1, %c0 : tensor<?xf32>
+ %1 = affine.apply #map()[%dim, %arg0]
+ %2:2 = scf.forall (%arg3) in (%1) shared_outs(%arg4 = %arg1, %arg5 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
+ %3 = affine.apply #map1(%arg3)[%arg0]
+ %4 = affine.min #map2(%arg3)[%dim, %arg0]
+ %extracted_slice0 = tensor.extract_slice %arg4[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+ %extracted_slice1 = tensor.extract_slice %arg5[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+ %5 = linalg.elemwise_unary ins(%extracted_slice0 : tensor<?xf32>) outs(%extracted_slice1 : tensor<?xf32>) -> tensor<?xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %5 into %arg5[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ return %2#0, %2#1 : tensor<?xf32>, tensor<?xf32>
+ }
+}
+// CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
+// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK-SAME: shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
+// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
+// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
+// CHECK: %[[ELEM:.*]] = linalg.elemwise_unary ins(%[[OPERAND0]] : tensor<?xf32>) outs(%[[OPERAND1]] : tensor<?xf32>) -> tensor<?xf32>
+// CHECK: scf.forall.in_parallel {
+// CHECK-NEXT: tensor.parallel_insert_slice %[[ELEM]] into %[[ITER_ARG_5]]
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[ARG1]], %[[RESULT]]
+
+// -----
+
+#map = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+module {
+ func.func @fold_iter_args_with_no_use_of_result_scfforall(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>) -> tensor<?xf32> {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+ %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
+ %dim = tensor.dim %arg1, %c0 : tensor<?xf32>
+ %1 = affine.apply #map()[%dim, %arg0]
+ %2:3 = scf.forall (%arg4) in (%1) shared_outs(%arg5 = %arg1, %arg6 = %arg2, %arg7 = %arg3) -> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
+ %3 = affine.apply #map1(%arg4)[%arg0]
+ %4 = affine.min #map2(%arg4)[%dim, %arg0]
+ %extracted_slice = tensor.extract_slice %arg5[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg6[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+ %extracted_slice_1 = tensor.extract_slice %arg7[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+ %extracted_slice_2 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+ %5 = linalg.elemwise_unary ins(%extracted_slice : tensor<?xf32>) outs(%extracted_slice_1 : tensor<?xf32>) -> tensor<?xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %5 into %arg6[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ tensor.parallel_insert_slice %extracted_slice into %arg5[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ tensor.parallel_insert_slice %extracted_slice_0 into %arg7[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ return %2#1 : tensor<?xf32>
+ }
+}
+// CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
+// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK-SAME: shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
+// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
+// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
+// CHECK: %[[ELEM:.*]] = linalg.elemwise_unary ins(%[[OPERAND0]] : tensor<?xf32>) outs(%[[OPERAND1]] : tensor<?xf32>) -> tensor<?xf32>
+// CHECK: scf.forall.in_parallel {
+// CHECK-NEXT: tensor.parallel_insert_slice %[[ELEM]] into %[[ITER_ARG_6]]
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[RESULT]]
+
+// -----
+
func.func @index_switch_fold() -> (f32, f32) {
%switch_cst = arith.constant 1: index
%0 = scf.index_switch %switch_cst -> f32
|
mlir/lib/Dialect/SCF/IR/SCF.cpp
Outdated
} | ||
Value::user_range users = blockArg.getUsers(); | ||
return llvm::all_of(users, [&](Operation *user) { | ||
return !isa<SubsetInsertionOpInterface>(user); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We currently do not require subset insertion ops to implement the SubsetInsertionOpInterface
. There may be ops that are inserting at a subset but do not implement the interface. Two examples:
- Unregistered ops
tensor.cast
. It does not implement theSubsetInsertionOpInterface
, but the result of the cast may be passed into a subset insertion op.
So checking for !isa<SubsetInsertionOpInterface>
is not sufficient. (You could check that all users implement the subset extraction op interface. Ops currently cannot be extraction and insertion ops at the same time.)
But I think there's a simpler solution: if a shared_outs bbArg is not used as the "destination" of an op inside of the scf.forall.in_parallel
terminator, it should be safe to use the init value inside the loop instead. Can you give that a try?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @matthias-springer . Thanks for the review and suggestion! I'm writing below a few corner case which I feel might arise. Please let me know where am I going wrong with my understanding and what's the best way to deal with the same and I shall do that. :)
There may be ops that are inserting at a subset but do not implement the interface.
Oh okay. I wasn't aware of this.
Regarding tensor.cast
- the check I've added would only return true if at least one SubsetInsertionOpInterface
is found for the bbArg. Even if tensor.cast
's result may be passed into a SubsetInsertionOpInterface
, it won't have the same bbArg (the element type would be different), so it would be okay.
But I think there's a simpler solution: if a shared_outs bbArg is not used as the "destination" of an op inside of the scf.forall.in_parallel terminator, it should be safe to use the init value inside the loop instead. Can you give that a try?
That is definitely correct and a simpler check, but I was trying to address the following case as well :-
scf.forall ... shared_outs(%arg0 = %a)
{
...
<SOME USE OF %arg0>
...
%x = tensor.insert_slice <some_val> into %arg0 (or some unregistered op that inserts a value into a subset)
...
<some use of %x>
...
scf.forall.in_parallel {
<STORE OP WITH DESTINATION %arg0 or some other bbArg> (currently `tensor.parallel_insert_slice` and `tensor.insert_slice` do that)
}
}
Therefore, if I'm only checking within scf.forall.in_parallel
, it won't cater to the above case.
So, three things unknown to me at this point are :-
- How to deal with the unregistered ops which might appear anywhere in the loop body and is inserting a value in the bbArg?
- Is there any other way to check if a value is used as a "destination" of an op besides
SubsetInsertionOpInterface
? - Why do we not require subset insertion ops to implement the
SubsetInsertionOpInterface
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont know if we can get to a general solution right away. Its hard to handle the generality of all of possible "insertion-like" ops. I would go the other way. To start with, I would check that the only use of the iter_args is in tensor.insert_in_parallel
ops within the scf.forall.in_parallel
and only then drop the result (and corresponding iter_arg and tensor.insert_in_parallel
. We dont have any semantics for any other case right now and not worth generalizing in a vaccum. Lets start small but well-defined and go from there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How to deal with the unregistered ops which might appear anywhere in the loop body and is inserting a value in the bbArg?
If you just look for the iter_arg being used as a destination in the terminator, it does not matter what the remaining loop body looks like. If there is an insertion into a tensor that is defined outside of the loop, then One-Shot Bufferize will allocate a thread-local buffer copy.
Is there any other way to check if a value is used as a "destination" of an op besides SubsetInsertionOpInterface ?
I think there is a special interface for ops that can appear in the in_parallel terminator region. You could query that interface. Hopefully the destination can be queried from it. If not, you can add an interface method for that. For the moment you could also just hard-code the implementation to parallel_insert_slice because that's the only terminator that we support at the moment anyway.
Why do we not require subset insertion ops to implement the SubsetInsertionOpInterface ?
One reason is that we could not handle unregistered ops correctly. Maybe there's a way to support that safely… it’s kind of like the MemoryEffectsOpInterface: if an op does not implement that interface, it does not mean that there is no side effect; it just means that we don’t know the side effects.
Assuming we only support parallel_insert_slice, does that handle all cases that you were thinking of? (I think it will work in the example that you posted.)
/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c) | ||
/// { | ||
/// ... | ||
/// <SOME USE OF %arg0> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont know the semantics of things yet if you use a shared_outs in anything apart from the store_op
in the scf.forall.in_parallel
. I'd just check for a single use in the scf.forall.in_parallel
(and as said in the other comment, just check for the use being a tensor.insert_in_parallel
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reg context:
From the .td
description :-
The actions of the in_parallel terminators specify how to combine the partial results
of all parallel invocations into a full value, in some unspecified order.
The “destination” of each such op must be a shared_out block argument of the scf.forall op.
So, my understanding is that "each such op" will have shared_out block argument
and it is always supposed to be in the "Destination" operand.
check that the only use of the iter_args is in tensor.insert_in_parallel ops within the scf.forall.in_parallel (from previous comment)
I don't think we should constrain the use of the iter_args to be ONLY that. Ideally the use would be that a slice of iter_arg is being extracted, performed some computation on and then stored back into the same iter_arg.
What I'm doing instead is -> as per your comment about having an API defined in scf.forall
that'll return a unique tensor.parallel_insert_slice
- I've added that and making use of that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still dont think the semantics of the operation is well defined when a shared-outs is used anywhere outside of the scf.forall.in_parallel
region. So for now, unless you need it (in which case i'd like to know more), better to narrow the usage for cases where the shared outs is used only once and in the scf.forall.in_parallel
region.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, my understanding is that "each such op" will have shared_out block argument and it is always supposed to be in the "Destination" operand.
That is correct. If there is no op in the terminator that uses a certain bbarg as a destination, then that bbarg can be removed. You can think of the bbarg as the buffer that is being modified in parallel. If that buffer does not appear in the terminator, then we are not modifying the buffer at all; sure, we may have other ops inside the loop body that insert into the buffer, but if that tensor value is then not fed into a parallel_insert_slice, then these insertion never become visible to the other threads. That's why it's safe to use the init arg instead of the bbarg in the loop body. The bufferization framework will privatize the buffer for each thread.
I don't think we should constrain the use of the iter_args to be ONLY that. Ideally the use would be that a slice of iter_arg is being extracted, performed some computation on and then stored back into the same iter_arg.
iter_args are typically also used in the loop body. It doesn’t matter for your canonicalization pattern. You only need to check whether there is a use in the parallel_insert_slice. Potential other uses of the iter_arg ar irrelevant.
I still dont think the semantics of the operation is well defined when a shared-outs is used anywhere outside of the scf.forall.in_parallel region.
Are you talking about the init value (operand of the script.forall) or the iter_arg (bbarg of the region)? If you mean the former, you are right, such IR does usually not make sense. It is supported (will not misbufferize), but a copy will be inserted. The latter (bbarg) is expected to appear in the loop body. The bbarg models the "future buffer of the tensor that is being updated in parallel". The terminator (parallel_insert_slice) just makes sure that the "updated" tensor is written back. Kind of like an "extract_slice, computation, insert_slice" pattern after tiling: the insert_slice doesn’t really do anything is just needed so that the computation result has a use and does not fold away.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I mean both. If you update/use the bbArg within the body of the loop and in the terminator, that seems like an inherent race condition. If one thread is reading/writing data to bbarg or init value then it is inherently reading a non-deterministic state of the result being computed in the loop?
EDIT: Ok, I was mistaken. I guess you can use the bbArg in the body cause that represents the value "before the iteration". But really its only used for destination of operations. All other uses are invalid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kind of. Another legitimate use case is to extract a slice from the bbarg and then pass the slice into destination-style op (such as a linalg.matmul). (tensor.extract_slice is not a destination-style op and the bbarg is not directly used as a destination.)
If two loop iterations extract overlapping slices, there would indeed be a race condition. But our tiling algorithm never generates such IR.
Long story short, it is possible to build racy IR with scf.forall. When we designed the scf.forall op, we tried to design it in such a way that race conditions are not possible; we didn’t even have the shared_outs in the beginning. But proving that two different iterations of the loop are independent (which is needed in the bufferization to decide if a buffer must be privatized for each thread) was too difficult/brittle. So what we ended up with is a more explicit operation that gives the user more control, but when used incorrectly may have race conditions.
Hi @MaheshRavishankar @matthias-springer - I've made changes to this revision as per the review comments. Can you please re-review ? |
/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c) | ||
/// { | ||
/// ... | ||
/// <SOME USE OF %arg0> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still dont think the semantics of the operation is well defined when a shared-outs is used anywhere outside of the scf.forall.in_parallel
region. So for now, unless you need it (in which case i'd like to know more), better to narrow the usage for cases where the shared outs is used only once and in the scf.forall.in_parallel
region.
bb3973a
to
1daa475
Compare
/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c) | ||
/// { | ||
/// ... | ||
/// <SOME USE OF %arg0> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I mean both. If you update/use the bbArg within the body of the loop and in the terminator, that seems like an inherent race condition. If one thread is reading/writing data to bbarg or init value then it is inherently reading a non-deterministic state of the result being computed in the loop?
EDIT: Ok, I was mistaken. I guess you can use the bbArg in the body cause that represents the value "before the iteration". But really its only used for destination of operations. All other uses are invalid.
1daa475
to
8f37eeb
Compare
Hi @matthias-springer @MaheshRavishankar Thank you for your review comments! They were really helpful and helped clear out certain concepts pertaining to A few notes to help understand the current stance as per the discussion thread above :-
Please re-review/approve whenever you both get a chance. Thanks! |
09501e1
to
a564a40
Compare
SmallVector<Operation *> ForallOp::getStoreOpUser(BlockArgument bbArg) { | ||
SmallVector<Operation *> storeOps; | ||
InParallelOp inParallelOp = getTerminator(); | ||
for (Operation &yieldOp : inParallelOp.getYieldingOps()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was hoping that ParallelCombiningOpInterface
is implemented by the ops in the "combining" region (e.g., parallel_insert_slice
), but its actually implemented by the op that has the region (in_parallel
). Then we wouldn't have to hard-code parallel_insert_slice
here. But nvm, we can fix that later...
a564a40
to
5d4cc21
Compare
…f.forall -- This commit adds a canonicalization pattern to fold away iter args of scf.forall if :- a. The corresponding tied result has no use. b. It is not being modified within the loop. Signed-off-by: Abhishek Varma <[email protected]>
5d4cc21
to
b0180f1
Compare
- With this work #351, we change the pack-peel pipeline to always peel the first and the last iteration and fuse the unpack op into the inner forall loop for matmul only and matmul-elementwise dispatch. - With this change, we have made progress on codegen for matmul-transpose-b and matmul-elementwise example dispatches to a reasonable state. - A minor issue is with [upstream change](llvm/llvm-project#90189), there is redundant data allocation and copy after bufferization ([example IR](https://gist.github.com/yzhang93/55f448368db32cccd2af31c730cc878a#file-gistfile1-txt-L342)). So currently I have to disable the canonicalization pass [here](https://github.com/nod-ai/iree-amd-aie/pull/392/files#diff-42f0d0bb098689f25ee68e8f05ec6c2eaa89ce41a6394d7d99c1d1c912943b38L256), but in the long term we may want to fix this issue.
- With this work #351, we change the pack-peel pipeline to always peel the first and the last iteration and fuse the unpack op into the inner forall loop for matmul only and matmul-elementwise dispatch. - With this change, we have made progress on codegen for matmul-transpose-b and matmul-elementwise example dispatches to a reasonable state. - A minor issue is with [upstream change](llvm/llvm-project#90189), there is redundant data allocation and copy after bufferization ([example IR](https://gist.github.com/yzhang93/55f448368db32cccd2af31c730cc878a#file-gistfile1-txt-L342)). So currently I have to disable the canonicalization pass [here](https://github.com/nod-ai/iree-amd-aie/pull/392/files#diff-42f0d0bb098689f25ee68e8f05ec6c2eaa89ce41a6394d7d99c1d1c912943b38L256), but in the long term we may want to fix this issue.
-- This commit adds a canonicalization pattern to fold away iter args of scf.forall if :-
a. The corresponding tied result has no use.
b. It is not being modified within the loop.
Signed-off-by: Abhishek Varma [email protected]