Skip to content

[MLIR][SCF] Add canonicalization pattern to fold away iter args of scf.forall #90189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,10 @@ def ForallOp : SCF_Op<"forall", [

// Declare the shared_outs as inits/outs to DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }

/// Returns operations within scf.forall.in_parallel whose destination
/// operand is the block argument `bbArg`.
SmallVector<Operation*> getCombiningOps(BlockArgument bbArg);
}];
}

Expand Down
188 changes: 187 additions & 1 deletion mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,19 @@ InParallelOp ForallOp::getTerminator() {
return cast<InParallelOp>(getBody()->getTerminator());
}

SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
SmallVector<Operation *> storeOps;
InParallelOp inParallelOp = getTerminator();
for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hoping that ParallelCombiningOpInterface is implemented by the ops in the "combining" region (e.g., parallel_insert_slice), but its actually implemented by the op that has the region (in_parallel). Then we wouldn't have to hard-code parallel_insert_slice here. But nvm, we can fix that later...

if (auto parallelInsertSliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
storeOps.push_back(parallelInsertSliceOp);
}
}
return storeOps;
}

std::optional<Value> ForallOp::getSingleInductionVar() {
if (getRank() != 1)
return std::nullopt;
Expand Down Expand Up @@ -1509,6 +1522,179 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
}
};

/// The following canonicalization pattern folds the iter arguments of
/// scf.forall op if :-
/// 1. The corresponding result has zero uses.
/// 2. The iter argument is NOT being modified within the loop body.
/// uses.
///
/// Example of first case :-
/// INPUT:
/// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
/// {
/// ...
/// <SOME USE OF %arg0>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know the semantics of things yet if you use a shared_outs in anything apart from the store_op in the scf.forall.in_parallel. I'd just check for a single use in the scf.forall.in_parallel (and as said in the other comment, just check for the use being a tensor.insert_in_parallel)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reg context:
From the .td description :-

The actions of the in_parallel terminators specify how to combine the partial results
of all parallel invocations into a full value, in some unspecified order.
The “destination” of each such op must be a shared_out block argument of the scf.forall op.

So, my understanding is that "each such op" will have shared_out block argument and it is always supposed to be in the "Destination" operand.

check that the only use of the iter_args is in tensor.insert_in_parallel ops within the scf.forall.in_parallel (from previous comment)

I don't think we should constrain the use of the iter_args to be ONLY that. Ideally the use would be that a slice of iter_arg is being extracted, performed some computation on and then stored back into the same iter_arg.

What I'm doing instead is -> as per your comment about having an API defined in scf.forall that'll return a unique tensor.parallel_insert_slice - I've added that and making use of that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still dont think the semantics of the operation is well defined when a shared-outs is used anywhere outside of the scf.forall.in_parallel region. So for now, unless you need it (in which case i'd like to know more), better to narrow the usage for cases where the shared outs is used only once and in the scf.forall.in_parallel region.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, my understanding is that "each such op" will have shared_out block argument and it is always supposed to be in the "Destination" operand.

That is correct. If there is no op in the terminator that uses a certain bbarg as a destination, then that bbarg can be removed. You can think of the bbarg as the buffer that is being modified in parallel. If that buffer does not appear in the terminator, then we are not modifying the buffer at all; sure, we may have other ops inside the loop body that insert into the buffer, but if that tensor value is then not fed into a parallel_insert_slice, then these insertion never become visible to the other threads. That's why it's safe to use the init arg instead of the bbarg in the loop body. The bufferization framework will privatize the buffer for each thread.

I don't think we should constrain the use of the iter_args to be ONLY that. Ideally the use would be that a slice of iter_arg is being extracted, performed some computation on and then stored back into the same iter_arg.

iter_args are typically also used in the loop body. It doesn’t matter for your canonicalization pattern. You only need to check whether there is a use in the parallel_insert_slice. Potential other uses of the iter_arg ar irrelevant.

I still dont think the semantics of the operation is well defined when a shared-outs is used anywhere outside of the scf.forall.in_parallel region.

Are you talking about the init value (operand of the script.forall) or the iter_arg (bbarg of the region)? If you mean the former, you are right, such IR does usually not make sense. It is supported (will not misbufferize), but a copy will be inserted. The latter (bbarg) is expected to appear in the loop body. The bbarg models the "future buffer of the tensor that is being updated in parallel". The terminator (parallel_insert_slice) just makes sure that the "updated" tensor is written back. Kind of like an "extract_slice, computation, insert_slice" pattern after tiling: the insert_slice doesn’t really do anything is just needed so that the computation result has a use and does not fold away.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I mean both. If you update/use the bbArg within the body of the loop and in the terminator, that seems like an inherent race condition. If one thread is reading/writing data to bbarg or init value then it is inherently reading a non-deterministic state of the result being computed in the loop?

EDIT: Ok, I was mistaken. I guess you can use the bbArg in the body cause that represents the value "before the iteration". But really its only used for destination of operations. All other uses are invalid.

Copy link
Member

@matthias-springer matthias-springer May 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of. Another legitimate use case is to extract a slice from the bbarg and then pass the slice into destination-style op (such as a linalg.matmul). (tensor.extract_slice is not a destination-style op and the bbarg is not directly used as a destination.)

If two loop iterations extract overlapping slices, there would indeed be a race condition. But our tiling algorithm never generates such IR.

Long story short, it is possible to build racy IR with scf.forall. When we designed the scf.forall op, we tried to design it in such a way that race conditions are not possible; we didn’t even have the shared_outs in the beginning. But proving that two different iterations of the loop are independent (which is needed in the bufferization to decide if a buffer must be privatized for each thread) was too difficult/brittle. So what we ended up with is a more explicit operation that gives the user more control, but when used incorrectly may have race conditions.

/// <SOME USE OF %arg1>
/// <SOME USE OF %arg2>
/// ...
/// scf.forall.in_parallel {
/// <STORE OP WITH DESTINATION %arg1>
/// <STORE OP WITH DESTINATION %arg0>
/// <STORE OP WITH DESTINATION %arg2>
/// }
/// }
/// return %res#1
///
/// OUTPUT:
/// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
/// {
/// ...
/// <SOME USE OF %a>
/// <SOME USE OF %new_arg0>
/// <SOME USE OF %c>
/// ...
/// scf.forall.in_parallel {
/// <STORE OP WITH DESTINATION %new_arg0>
/// }
/// }
/// return %res
///
/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
/// scf.forall is replaced by their corresponding operands.
/// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
/// of the scf.forall besides within scf.forall.in_parallel terminator,
/// this canonicalization remains valid. For more details, please refer
/// to :
/// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
/// 3. TODO(avarma): Generalize it for other store ops. Currently it
/// handles tensor.parallel_insert_slice ops only.
///
/// Example of second case :-
/// INPUT:
/// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
/// {
/// ...
/// <SOME USE OF %arg0>
/// <SOME USE OF %arg1>
/// ...
/// scf.forall.in_parallel {
/// <STORE OP WITH DESTINATION %arg1>
/// }
/// }
/// return %res#0, %res#1
///
/// OUTPUT:
/// %res = scf.forall ... shared_outs(%new_arg0 = %b)
/// {
/// ...
/// <SOME USE OF %a>
/// <SOME USE OF %new_arg0>
/// ...
/// scf.forall.in_parallel {
/// <STORE OP WITH DESTINATION %new_arg0>
/// }
/// }
/// return %a, %res
struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
using OpRewritePattern<ForallOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ForallOp forallOp,
PatternRewriter &rewriter) const final {
// Step 1: For a given i-th result of scf.forall, check the following :-
// a. If it has any use.
// b. If the corresponding iter argument is being modified within
// the loop, i.e. has at least one store op with the iter arg as
// its destination operand. For this we use
// ForallOp::getCombiningOps(iter_arg).
//
// Based on the check we maintain the following :-
// a. `resultToDelete` - i-th result of scf.forall that'll be
// deleted.
// b. `resultToReplace` - i-th result of the old scf.forall
// whose uses will be replaced by the new scf.forall.
// c. `newOuts` - the shared_outs' operand of the new scf.forall
// corresponding to the i-th result with at least one use.
SetVector<OpResult> resultToDelete;
SmallVector<Value> resultToReplace;
SmallVector<Value> newOuts;
for (OpResult result : forallOp.getResults()) {
OpOperand *opOperand = forallOp.getTiedOpOperand(result);
BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
resultToDelete.insert(result);
} else {
resultToReplace.push_back(result);
newOuts.push_back(opOperand->get());
}
}

// Return early if all results of scf.forall have at least one use and being
// modified within the loop.
if (resultToDelete.empty())
return failure();

// Step 2: For the the i-th result, do the following :-
// a. Fetch the corresponding BlockArgument.
// b. Look for store ops (currently tensor.parallel_insert_slice)
// with the BlockArgument as its destination operand.
// c. Remove the operations fetched in b.
for (OpResult result : resultToDelete) {
OpOperand *opOperand = forallOp.getTiedOpOperand(result);
BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
SmallVector<Operation *> combiningOps =
forallOp.getCombiningOps(blockArg);
for (Operation *combiningOp : combiningOps)
rewriter.eraseOp(combiningOp);
}

// Step 3. Create a new scf.forall op with the new shared_outs' operands
// fetched earlier
auto newForallOp = rewriter.create<scf::ForallOp>(
forallOp.getLoc(), forallOp.getMixedLowerBound(),
forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
forallOp.getMapping(),
/*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});

// Step 4. Merge the block of the old scf.forall into the newly created
// scf.forall using the new set of arguments.
Block *loopBody = forallOp.getBody();
Block *newLoopBody = newForallOp.getBody();
ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
// Form initial new bbArg list with just the control operands of the new
// scf.forall op.
SmallVector<Value> newBlockArgs =
llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
[](BlockArgument b) -> Value { return b; });
Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
unsigned index = 0;
// Take the new corresponding bbArg if the old bbArg was used as a
// destination in the in_parallel op. For all other bbArgs, use the
// corresponding init_arg from the old scf.forall op.
for (OpResult result : forallOp.getResults()) {
if (resultToDelete.count(result)) {
newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
} else {
newBlockArgs.push_back(newSharedOutsArgs[index++]);
}
}
rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);

// Step 5. Replace the uses of result of old scf.forall with that of the new
// scf.forall.
for (auto &&[oldResult, newResult] :
llvm::zip(resultToReplace, newForallOp->getResults()))
rewriter.replaceAllUsesWith(oldResult, newResult);

// Step 6. Replace the uses of those values that either has no use or are
// not being modified within the loop with the corresponding
// OpOperand.
for (OpResult oldResult : resultToDelete)
rewriter.replaceAllUsesWith(oldResult,
forallOp.getTiedOpOperand(oldResult)->get());
return success();
}
};

struct ForallOpSingleOrZeroIterationDimsFolder
: public OpRewritePattern<ForallOp> {
using OpRewritePattern<ForallOp>::OpRewritePattern;
Expand Down Expand Up @@ -1667,7 +1853,7 @@ struct FoldTensorCastOfOutputIntoForallOp
void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
ForallOpControlOperandsFolder,
ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
ForallOpSingleOrZeroIterationDimsFolder>(context);
}

Expand Down
81 changes: 81 additions & 0 deletions mlir/test/Dialect/SCF/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1735,6 +1735,87 @@ func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(

// -----

#map = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
module {
func.func @fold_iter_args_not_being_modified_within_scfforall(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 4.200000e+01 : f32
%0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
%dim = tensor.dim %arg1, %c0 : tensor<?xf32>
%1 = affine.apply #map()[%dim, %arg0]
%2:2 = scf.forall (%arg3) in (%1) shared_outs(%arg4 = %arg1, %arg5 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
%3 = affine.apply #map1(%arg3)[%arg0]
%4 = affine.min #map2(%arg3)[%dim, %arg0]
%extracted_slice0 = tensor.extract_slice %arg4[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
%extracted_slice1 = tensor.extract_slice %arg5[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
%5 = linalg.elemwise_unary ins(%extracted_slice0 : tensor<?xf32>) outs(%extracted_slice1 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %5 into %arg5[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
}
}
return %2#0, %2#1 : tensor<?xf32>, tensor<?xf32>
}
}
// CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
// CHECK: %[[RESULT:.*]] = scf.forall
// CHECK-SAME: shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
// CHECK: %[[ELEM:.*]] = linalg.elemwise_unary ins(%[[OPERAND0]] : tensor<?xf32>) outs(%[[OPERAND1]] : tensor<?xf32>) -> tensor<?xf32>
// CHECK: scf.forall.in_parallel {
// CHECK-NEXT: tensor.parallel_insert_slice %[[ELEM]] into %[[ITER_ARG_5]]
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return %[[ARG1]], %[[RESULT]]

// -----

#map = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
module {
func.func @fold_iter_args_with_no_use_of_result_scfforall(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>) -> tensor<?xf32> {
%cst = arith.constant 4.200000e+01 : f32
%c0 = arith.constant 0 : index
%0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
%dim = tensor.dim %arg1, %c0 : tensor<?xf32>
%1 = affine.apply #map()[%dim, %arg0]
%2:3 = scf.forall (%arg4) in (%1) shared_outs(%arg5 = %arg1, %arg6 = %arg2, %arg7 = %arg3) -> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
%3 = affine.apply #map1(%arg4)[%arg0]
%4 = affine.min #map2(%arg4)[%dim, %arg0]
%extracted_slice = tensor.extract_slice %arg5[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
%extracted_slice_0 = tensor.extract_slice %arg6[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
%extracted_slice_1 = tensor.extract_slice %arg7[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
%extracted_slice_2 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
%5 = linalg.elemwise_unary ins(%extracted_slice : tensor<?xf32>) outs(%extracted_slice_1 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %5 into %arg6[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %extracted_slice into %arg5[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %extracted_slice_0 into %arg7[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %5 into %arg7[%4] [%3] [1] : tensor<?xf32> into tensor<?xf32>
}
}
return %2#1 : tensor<?xf32>
}
}
// CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: %[[RESULT:.*]] = scf.forall
// CHECK-SAME: shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
// CHECK: %[[ELEM:.*]] = linalg.elemwise_unary ins(%[[OPERAND0]] : tensor<?xf32>) outs(%[[OPERAND1]] : tensor<?xf32>) -> tensor<?xf32>
// CHECK: scf.forall.in_parallel {
// CHECK-NEXT: tensor.parallel_insert_slice %[[ELEM]] into %[[ITER_ARG_6]]
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return %[[RESULT]]

// -----

func.func @index_switch_fold() -> (f32, f32) {
%switch_cst = arith.constant 1: index
%0 = scf.index_switch %switch_cst -> f32
Expand Down
Loading