Skip to content

Commit 973cb2c

Browse files
committed
[MLIR][OMP] Ensure nested scf.parallel execute all iterations
Presently, the lowering of nested scf.parallel loops to OpenMP creates one omp.parallel region, with two (nested) OpenMP worksharing loops on the inside. When lowered to LLVM and executed, this results in incorrect results. The reason for this is as follows: An OpenMP parallel region results in the code being run with whatever number of threads available to OpenMP. Within a parallel region a worksharing loop divides up the total number of requested iterations by the available number of threads, and distributes accordingly. For a single ws loop in a parallel region, this works as intended. Now consider nested ws loops as follows: omp.parallel { A: omp.ws %i = 0...10 { B: omp.ws %j = 0...10 { code(%i, %j) } } } Suppose we ran this on two threads. The first workshare loop would decide to execute iterations 0, 1, 2, 3, 4 on thread 0, and iterations 5, 6, 7, 8, 9 on thread 1. The second workshare loop would decide the same for its iteration. This means thread 0 would execute i \in [0, 5) and j \in [0, 5). Thread 1 would execute i \in [5, 10) and j \in [5, 10). This means that iterations i in [5, 10), j in [0, 5) and i in [0, 5), j in [5, 10) never get executed, which is clearly wrong. This permits two options for a remedy: 1) Change the semantics of the omp.wsloop to be distinct from that of the OpenMP runtime call or equivalently #pragma omp for. This could then allow some lowering transformation to remedy the aforementioned issue. I don't think this is desirable for an abstraction standpoint. 2) When lowering an scf.parallel always surround the wsloop with a new parallel region (thereby causing the innermost wsloop to use the number of threads available only to it). This PR implements the latter change. Reviewed By: jdoerfert Differential Revision: https://reviews.llvm.org/D108426
1 parent 40aab04 commit 973cb2c

File tree

2 files changed

+6
-30
lines changed

2 files changed

+6
-30
lines changed

mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

+5-29
Original file line numberDiff line numberDiff line change
@@ -44,44 +44,21 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
4444
}
4545

4646
// Replace the loop.
47+
auto omp = rewriter.create<omp::ParallelOp>(parallelOp.getLoc());
48+
Block *block = rewriter.createBlock(&omp.getRegion());
49+
rewriter.setInsertionPointToStart(block);
4750
auto loop = rewriter.create<omp::WsLoopOp>(
4851
parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(),
4952
parallelOp.step());
5053
rewriter.inlineRegionBefore(parallelOp.region(), loop.region(),
5154
loop.region().begin());
55+
rewriter.create<omp::TerminatorOp>(parallelOp.getLoc());
56+
5257
rewriter.eraseOp(parallelOp);
5358
return success();
5459
}
5560
};
5661

57-
/// Inserts OpenMP "parallel" operations around top-level SCF "parallel"
58-
/// operations in the given function. This is implemented as a direct IR
59-
/// modification rather than as a conversion pattern because it does not
60-
/// modify the top-level operation it matches, which is a requirement for
61-
/// rewrite patterns.
62-
//
63-
// TODO: consider creating nested parallel operations when necessary.
64-
static void insertOpenMPParallel(FuncOp func) {
65-
// Collect top-level SCF "parallel" ops.
66-
SmallVector<scf::ParallelOp, 4> topLevelParallelOps;
67-
func.walk([&topLevelParallelOps](scf::ParallelOp parallelOp) {
68-
// Ignore ops that are already within OpenMP parallel construct.
69-
if (!parallelOp->getParentOfType<scf::ParallelOp>())
70-
topLevelParallelOps.push_back(parallelOp);
71-
});
72-
73-
// Wrap SCF ops into OpenMP "parallel" ops.
74-
for (scf::ParallelOp parallelOp : topLevelParallelOps) {
75-
OpBuilder builder(parallelOp);
76-
auto omp = builder.create<omp::ParallelOp>(parallelOp.getLoc());
77-
Block *block = builder.createBlock(&omp.getRegion());
78-
builder.create<omp::TerminatorOp>(parallelOp.getLoc());
79-
block->getOperations().splice(block->begin(),
80-
parallelOp->getBlock()->getOperations(),
81-
parallelOp.getOperation());
82-
}
83-
}
84-
8562
/// Applies the conversion patterns in the given function.
8663
static LogicalResult applyPatterns(FuncOp func) {
8764
ConversionTarget target(*func.getContext());
@@ -100,7 +77,6 @@ static LogicalResult applyPatterns(FuncOp func) {
10077
struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase<SCFToOpenMPPass> {
10178
/// Pass entry point.
10279
void runOnFunction() override {
103-
insertOpenMPParallel(getFunction());
10480
if (failed(applyPatterns(getFunction())))
10581
signalPassFailure();
10682
}

mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
2121
%arg3: index, %arg4: index, %arg5: index) {
2222
// CHECK: omp.parallel {
2323
// CHECK: omp.wsloop (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
24-
// CHECK-NOT: omp.parallel
2524
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
25+
// CHECK: omp.parallel
2626
// CHECK: omp.wsloop (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
2727
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
2828
// CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> ()

0 commit comments

Comments
 (0)