Skip to content

Commit 0b665c3

Browse files
sabaumaSpenser Bauman
and
Spenser Bauman
authored
[mlir][scf] Implement conversion from scf.forall to scf.parallel (#94109)
There is currently no path to lower scf.forall to scf.parallel with the goal of targeting the OpenMP dialect. In the SCF->ControlFlow conversion, scf.forall is briefly converted to scf.parallel, but the scf.parallel is lowered directly to a sequential loop. This makes experimenting with scf.forall for CPU execution difficult. This change factors out the rewrite in the SCF->ControlFlow pass into a utility function that can then be used in the SCF->ControlFlow lowering and via a separate -scf-forall-to-parallel pass. --------- Co-authored-by: Spenser Bauman <sabauma@fastmail>
1 parent e775efc commit 0b665c3

File tree

11 files changed

+313
-27
lines changed

11 files changed

+313
-27
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,32 @@ def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
6868
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
6969
}
7070

71+
def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
72+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
73+
DeclareOpInterfaceMethods<TransformOpInterface>]> {
74+
let summary = "Converts scf.forall into a nest of scf.for operations";
75+
let description = [{
76+
Converts the `scf.forall` operation pointed to by the given handle into an
77+
`scf.parallel` operation.
78+
79+
The operand handle must be associated with exactly one payload operation.
80+
81+
Loops with outputs are not supported.
82+
83+
#### Return Modes
84+
85+
Consumes the operand handle. Produces a silenceable failure if the operand
86+
is not associated with a single `scf.forall` payload operation.
87+
Returns a handle to the new `scf.parallel` operation.
88+
Produces a silenceable failure if another number of resulting handles is
89+
requested.
90+
}];
91+
let arguments = (ins TransformHandleTypeInterface:$target);
92+
let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
93+
94+
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
95+
}
96+
7197
def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
7298
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
7399
DeclareOpInterfaceMethods<TransformOpInterface>]> {

mlir/include/mlir/Dialect/SCF/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForLoopRangeFoldingPass();
6262
/// Creates a pass that converts SCF forall loops to SCF for loops.
6363
std::unique_ptr<Pass> createForallToForLoopPass();
6464

65+
/// Creates a pass that converts SCF forall loops to SCF parallel loops.
66+
std::unique_ptr<Pass> createForallToParallelLoopPass();
67+
6568
// Creates a pass which lowers for loops into while loops.
6669
std::unique_ptr<Pass> createForToWhileLoopPass();
6770

mlir/include/mlir/Dialect/SCF/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
125125
let constructor = "mlir::createForallToForLoopPass()";
126126
}
127127

128+
def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
129+
let summary = "Convert SCF forall loops to SCF parallel loops";
130+
let constructor = "mlir::createForallToParallelLoopPass()";
131+
}
132+
128133
def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
129134
let summary = "Convert SCF for loops to SCF while loops";
130135
let constructor = "mlir::createForToWhileLoopPass()";

mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ class WhileOp;
3939
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
4040
SmallVectorImpl<Operation *> *results = nullptr);
4141

42+
/// Try converting scf.forall into an scf.parallel loop.
43+
/// The conversion is only supported for forall operations with no results.
44+
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
45+
ParallelOp *result = nullptr);
46+
4247
/// Fuses all adjacent scf.parallel operations with identical bounds and step
4348
/// into one scf.parallel operations. Uses a naive aliasing and dependency
4449
/// analysis.

mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRSCFToControlFlow
1414
MLIRArithDialect
1515
MLIRControlFlowDialect
1616
MLIRSCFDialect
17+
MLIRSCFTransforms
1718
MLIRTransforms
1819
)

mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1818
#include "mlir/Dialect/SCF/IR/SCF.h"
19+
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
1920
#include "mlir/IR/Builders.h"
2021
#include "mlir/IR/BuiltinOps.h"
2122
#include "mlir/IR/IRMapping.h"
@@ -688,33 +689,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
688689

689690
LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
690691
PatternRewriter &rewriter) const {
691-
Location loc = forallOp.getLoc();
692-
if (!forallOp.getOutputs().empty())
693-
return rewriter.notifyMatchFailure(
694-
forallOp,
695-
"only fully bufferized scf.forall ops can be lowered to scf.parallel");
696-
697-
// Convert mixed bounds and steps to SSA values.
698-
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
699-
rewriter, loc, forallOp.getMixedLowerBound());
700-
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
701-
rewriter, loc, forallOp.getMixedUpperBound());
702-
SmallVector<Value> steps =
703-
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
704-
705-
// Create empty scf.parallel op.
706-
auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
707-
rewriter.eraseBlock(&parallelOp.getRegion().front());
708-
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
709-
parallelOp.getRegion().begin());
710-
// Replace the terminator.
711-
rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
712-
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
713-
parallelOp.getRegion().front().getTerminator());
714-
715-
// Erase the scf.forall op.
716-
rewriter.replaceOp(forallOp, parallelOp);
717-
return success();
692+
return scf::forallToParallelLoop(rewriter, forallOp);
718693
}
719694

720695
void mlir::populateSCFToControlFlowConversionPatterns(

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,50 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
9898
return DiagnosedSilenceableFailure::success();
9999
}
100100

101+
//===----------------------------------------------------------------------===//
102+
// ForallToForOp
103+
//===----------------------------------------------------------------------===//
104+
105+
DiagnosedSilenceableFailure
106+
transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
107+
transform::TransformResults &results,
108+
transform::TransformState &state) {
109+
auto payload = state.getPayloadOps(getTarget());
110+
if (!llvm::hasSingleElement(payload))
111+
return emitSilenceableError() << "expected a single payload op";
112+
113+
auto target = dyn_cast<scf::ForallOp>(*payload.begin());
114+
if (!target) {
115+
DiagnosedSilenceableFailure diag =
116+
emitSilenceableError() << "expected the payload to be scf.forall";
117+
diag.attachNote((*payload.begin())->getLoc()) << "payload op";
118+
return diag;
119+
}
120+
121+
if (!target.getOutputs().empty()) {
122+
return emitSilenceableError()
123+
<< "unsupported shared outputs (didn't bufferize?)";
124+
}
125+
126+
if (getNumResults() != 1) {
127+
DiagnosedSilenceableFailure diag = emitSilenceableError()
128+
<< "op expects one result, given "
129+
<< getNumResults();
130+
diag.attachNote(target.getLoc()) << "payload op";
131+
return diag;
132+
}
133+
134+
scf::ParallelOp opResult;
135+
if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) {
136+
DiagnosedSilenceableFailure diag =
137+
emitSilenceableError() << "failed to convert forall into parallel";
138+
return diag;
139+
}
140+
141+
results.set(cast<OpResult>(getTransformed()[0]), {opResult});
142+
return DiagnosedSilenceableFailure::success();
143+
}
144+
101145
//===----------------------------------------------------------------------===//
102146
// LoopOutlineOp
103147
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
33
BufferizableOpInterfaceImpl.cpp
44
Bufferize.cpp
55
ForallToFor.cpp
6+
ForallToParallel.cpp
67
ForToWhile.cpp
78
LoopCanonicalization.cpp
89
LoopPipelining.cpp
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
//===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Transforms SCF.ForallOp's into SCF.ParallelOps's.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/SCF/IR/SCF.h"
14+
#include "mlir/Dialect/SCF/Transforms/Passes.h"
15+
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
18+
namespace mlir {
19+
#define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP
20+
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
21+
} // namespace mlir
22+
23+
using namespace mlir;
24+
25+
LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
26+
scf::ForallOp forallOp,
27+
scf::ParallelOp *result) {
28+
OpBuilder::InsertionGuard guard(rewriter);
29+
rewriter.setInsertionPoint(forallOp);
30+
31+
Location loc = forallOp.getLoc();
32+
if (!forallOp.getOutputs().empty())
33+
return rewriter.notifyMatchFailure(
34+
forallOp,
35+
"only fully bufferized scf.forall ops can be lowered to scf.parallel");
36+
37+
// Convert mixed bounds and steps to SSA values.
38+
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
39+
rewriter, loc, forallOp.getMixedLowerBound());
40+
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
41+
rewriter, loc, forallOp.getMixedUpperBound());
42+
SmallVector<Value> steps =
43+
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
44+
45+
// Create empty scf.parallel op.
46+
auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps);
47+
rewriter.eraseBlock(&parallelOp.getRegion().front());
48+
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
49+
parallelOp.getRegion().begin());
50+
// Replace the terminator.
51+
rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
52+
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
53+
parallelOp.getRegion().front().getTerminator());
54+
55+
// If the mapping attribute is present, propagate to the new parallelOp.
56+
if (forallOp.getMapping())
57+
parallelOp->setAttr("mapping", *forallOp.getMapping());
58+
59+
// Erase the scf.forall op.
60+
rewriter.replaceOp(forallOp, parallelOp);
61+
62+
if (result)
63+
*result = parallelOp;
64+
65+
return success();
66+
}
67+
68+
namespace {
69+
struct ForallToParallelLoop final
70+
: public impl::SCFForallToParallelLoopBase<ForallToParallelLoop> {
71+
void runOnOperation() override {
72+
Operation *parentOp = getOperation();
73+
IRRewriter rewriter(parentOp->getContext());
74+
75+
parentOp->walk([&](scf::ForallOp forallOp) {
76+
if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
77+
return signalPassFailure();
78+
}
79+
});
80+
}
81+
};
82+
} // namespace
83+
84+
std::unique_ptr<Pass> mlir::createForallToParallelLoopPass() {
85+
return std::make_unique<ForallToParallelLoop>();
86+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-parallel))' -split-input-file | FileCheck %s
2+
3+
func.func private @callee(%i: index, %j: index)
4+
5+
// CHECK-LABEL: @two_iters
6+
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
7+
func.func @two_iters(%ub1: index, %ub2: index) {
8+
scf.forall (%i, %j) in (%ub1, %ub2) {
9+
func.call @callee(%i, %j) : (index, index) -> ()
10+
}
11+
12+
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
13+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
14+
// CHECK: scf.reduce
15+
return
16+
}
17+
18+
// -----
19+
20+
func.func private @callee(%i: index, %j: index)
21+
22+
// CHECK-LABEL: @repeated
23+
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
24+
func.func @repeated(%ub1: index, %ub2: index) {
25+
scf.forall (%i, %j) in (%ub1, %ub2) {
26+
func.call @callee(%i, %j) : (index, index) -> ()
27+
}
28+
29+
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
30+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
31+
// CHECK: scf.reduce
32+
scf.forall (%i, %j) in (%ub1, %ub2) {
33+
func.call @callee(%i, %j) : (index, index) -> ()
34+
}
35+
36+
// CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
37+
// CHECK: func.call @callee(%[[IV3]], %[[IV4]])
38+
// CHECK: scf.reduce
39+
return
40+
}
41+
42+
// -----
43+
44+
func.func private @callee(%i: index, %j: index, %k: index, %l: index)
45+
46+
// CHECK-LABEL: @nested
47+
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
48+
func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
49+
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]]) step (%{{.*}}, %{{.*}}) {
50+
// CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB3]], %[[UB4]]) step (%{{.*}}, %{{.*}}) {
51+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
52+
// CHECK: scf.reduce
53+
// CHECK: }
54+
// CHECK: scf.reduce
55+
// CHECK: }
56+
scf.forall (%i, %j) in (%ub1, %ub2) {
57+
scf.forall (%k, %l) in (%ub3, %ub4) {
58+
func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
59+
}
60+
}
61+
return
62+
}
63+
64+
// -----
65+
66+
// CHECK-LABEL: @mapping_attr
67+
func.func @mapping_attr() -> () {
68+
// CHECK: scf.parallel
69+
// CHECK: scf.reduce
70+
// CHECK: {mapping = [#gpu.thread<x>]}
71+
72+
%num_threads = arith.constant 100 : index
73+
74+
scf.forall (%thread_idx) in (%num_threads) {
75+
scf.forall.in_parallel {
76+
}
77+
} {mapping = [#gpu.thread<x>]}
78+
return
79+
80+
}

0 commit comments

Comments
 (0)