Skip to content

Commit 7b70fc7

Browse files
authored
[mlir][OpenMP] Convert omp.cancel sections to LLVMIR (#137193)
This is quite ugly but it is the best I could think of. The old FiniCBWrapper was way too brittle depending upon the exact block structure inside of the section, and could be confused by any control flow in the section (e.g. an if clause on cancel). The wording in the comment and variable names didn't seem to match where it was actually branching too as well. Clang's (non-OpenMPIRBuilder) lowering for cancel inside of sections branches to a block containing __kmpc_for_static_fini. This was hard to achieve here because sometimes the FiniCBWrapper has to run before the worksharing loop finalization has been crated. To get around this ordering issue I created a dummy branch to a dummy block, which is then fixed later once all of the information is available.
1 parent d20796d commit 7b70fc7

File tree

5 files changed

+98
-30
lines changed

5 files changed

+98
-30
lines changed

clang/lib/CodeGen/CGStmtOpenMP.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -4345,8 +4345,9 @@ void CodeGenFunction::EmitOMPSectionsDirective(const OMPSectionsDirective &S) {
43454345
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
43464346
using BodyGenCallbackTy = llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
43474347

4348-
auto FiniCB = [this](InsertPointTy IP) {
4349-
OMPBuilderCBHelpers::FinalizeOMPRegion(*this, IP);
4348+
auto FiniCB = [](InsertPointTy IP) {
4349+
// Don't FinalizeOMPRegion because this is done inside of OMPIRBuilder for
4350+
// sections.
43504351
return llvm::Error::success();
43514352
};
43524353

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

+15-10
Original file line numberDiff line numberDiff line change
@@ -2172,23 +2172,19 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSections(
21722172
if (!updateToLocation(Loc))
21732173
return Loc.IP;
21742174

2175+
// FiniCBWrapper needs to create a branch to the loop finalization block, but
2176+
// this has not been created yet at some times when this callback runs.
2177+
SmallVector<BranchInst *> CancellationBranches;
21752178
auto FiniCBWrapper = [&](InsertPointTy IP) {
21762179
if (IP.getBlock()->end() != IP.getPoint())
21772180
return FiniCB(IP);
21782181
// This must be done otherwise any nested constructs using FinalizeOMPRegion
21792182
// will fail because that function requires the Finalization Basic Block to
21802183
// have a terminator, which is already removed by EmitOMPRegionBody.
21812184
// IP is currently at cancelation block.
2182-
// We need to backtrack to the condition block to fetch
2183-
// the exit block and create a branch from cancelation
2184-
// to exit block.
2185-
IRBuilder<>::InsertPointGuard IPG(Builder);
2186-
Builder.restoreIP(IP);
2187-
auto *CaseBB = IP.getBlock()->getSinglePredecessor();
2188-
auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2189-
auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
2190-
Instruction *I = Builder.CreateBr(ExitBB);
2191-
IP = InsertPointTy(I->getParent(), I->getIterator());
2185+
BranchInst *DummyBranch = Builder.CreateBr(IP.getBlock());
2186+
IP = InsertPointTy(DummyBranch->getParent(), DummyBranch->getIterator());
2187+
CancellationBranches.push_back(DummyBranch);
21922188
return FiniCB(IP);
21932189
};
21942190

@@ -2251,6 +2247,9 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSections(
22512247
return WsloopIP.takeError();
22522248
InsertPointTy AfterIP = *WsloopIP;
22532249

2250+
BasicBlock *LoopFini = AfterIP.getBlock()->getSinglePredecessor();
2251+
assert(LoopFini && "Bad structure of static workshare loop finalization");
2252+
22542253
// Apply the finalization callback in LoopAfterBB
22552254
auto FiniInfo = FinalizationStack.pop_back_val();
22562255
assert(FiniInfo.DK == OMPD_sections &&
@@ -2264,6 +2263,12 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSections(
22642263
AfterIP = {FiniBB, FiniBB->begin()};
22652264
}
22662265

2266+
// Now we can fix the dummy branch to point to the right place
2267+
for (BranchInst *DummyBranch : CancellationBranches) {
2268+
assert(DummyBranch->getNumSuccessors() == 1);
2269+
DummyBranch->setSuccessor(0, LoopFini);
2270+
}
2271+
22672272
return AfterIP;
22682273
}
22692274

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
161161
auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
162162
omp::ClauseCancellationConstructType cancelledDirective =
163163
op.getCancelDirective();
164-
if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel)
164+
if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel &&
165+
cancelledDirective != omp::ClauseCancellationConstructType::Sections)
165166
result = todo("cancel directive construct type not yet supported");
166167
};
167168
auto checkDepend = [&todo](auto op, LogicalResult &result) {
@@ -1688,10 +1689,11 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
16881689
auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
16891690

16901691
allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1692+
bool isCancellable = constructIsCancellable(sectionsOp);
16911693
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
16921694
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
16931695
moduleTranslation.getOpenMPBuilder()->createSections(
1694-
ompLoc, allocaIP, sectionCBs, privCB, finiCB, false,
1696+
ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
16951697
sectionsOp.getNowait());
16961698

16971699
if (failed(handleError(afterIP, opInst)))

mlir/test/Target/LLVMIR/openmp-cancel.mlir

+76
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,79 @@ llvm.func @cancel_parallel_if(%arg0 : i1) {
8080
// CHECK: br label %[[VAL_23]]
8181
// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_31]], %[[VAL_26]]
8282
// CHECK: ret void
83+
84+
llvm.func @cancel_sections_if(%cond : i1) {
85+
omp.sections {
86+
omp.section {
87+
omp.cancel cancellation_construct_type(sections) if(%cond)
88+
omp.terminator
89+
}
90+
omp.terminator
91+
}
92+
llvm.return
93+
}
94+
// CHECK-LABEL: define void @cancel_sections_if
95+
// CHECK: %[[VAL_0:.*]] = alloca i32, align 4
96+
// CHECK: %[[VAL_1:.*]] = alloca i32, align 4
97+
// CHECK: %[[VAL_2:.*]] = alloca i32, align 4
98+
// CHECK: %[[VAL_3:.*]] = alloca i32, align 4
99+
// CHECK: br label %[[VAL_4:.*]]
100+
// CHECK: entry: ; preds = %[[VAL_5:.*]]
101+
// CHECK: br label %[[VAL_6:.*]]
102+
// CHECK: omp_section_loop.preheader: ; preds = %[[VAL_4]]
103+
// CHECK: store i32 0, ptr %[[VAL_1]], align 4
104+
// CHECK: store i32 0, ptr %[[VAL_2]], align 4
105+
// CHECK: store i32 1, ptr %[[VAL_3]], align 4
106+
// CHECK: %[[VAL_7:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
107+
// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %[[VAL_7]], i32 34, ptr %[[VAL_0]], ptr %[[VAL_1]], ptr %[[VAL_2]], ptr %[[VAL_3]], i32 1, i32 0)
108+
// CHECK: %[[VAL_8:.*]] = load i32, ptr %[[VAL_1]], align 4
109+
// CHECK: %[[VAL_9:.*]] = load i32, ptr %[[VAL_2]], align 4
110+
// CHECK: %[[VAL_10:.*]] = sub i32 %[[VAL_9]], %[[VAL_8]]
111+
// CHECK: %[[VAL_11:.*]] = add i32 %[[VAL_10]], 1
112+
// CHECK: br label %[[VAL_12:.*]]
113+
// CHECK: omp_section_loop.header: ; preds = %[[VAL_13:.*]], %[[VAL_6]]
114+
// CHECK: %[[VAL_14:.*]] = phi i32 [ 0, %[[VAL_6]] ], [ %[[VAL_15:.*]], %[[VAL_13]] ]
115+
// CHECK: br label %[[VAL_16:.*]]
116+
// CHECK: omp_section_loop.cond: ; preds = %[[VAL_12]]
117+
// CHECK: %[[VAL_17:.*]] = icmp ult i32 %[[VAL_14]], %[[VAL_11]]
118+
// CHECK: br i1 %[[VAL_17]], label %[[VAL_18:.*]], label %[[VAL_19:.*]]
119+
// CHECK: omp_section_loop.body: ; preds = %[[VAL_16]]
120+
// CHECK: %[[VAL_20:.*]] = add i32 %[[VAL_14]], %[[VAL_8]]
121+
// CHECK: %[[VAL_21:.*]] = mul i32 %[[VAL_20]], 1
122+
// CHECK: %[[VAL_22:.*]] = add i32 %[[VAL_21]], 0
123+
// CHECK: switch i32 %[[VAL_22]], label %[[VAL_23:.*]] [
124+
// CHECK: i32 0, label %[[VAL_24:.*]]
125+
// CHECK: ]
126+
// CHECK: omp_section_loop.body.case: ; preds = %[[VAL_18]]
127+
// CHECK: br label %[[VAL_25:.*]]
128+
// CHECK: omp.section.region: ; preds = %[[VAL_24]]
129+
// CHECK: br i1 %[[VAL_26:.*]], label %[[VAL_27:.*]], label %[[VAL_28:.*]]
130+
// CHECK: 9: ; preds = %[[VAL_25]]
131+
// CHECK: %[[VAL_29:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
132+
// CHECK: %[[VAL_30:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_29]], i32 3)
133+
// CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0
134+
// CHECK: br i1 %[[VAL_31]], label %[[VAL_32:.*]], label %[[VAL_33:.*]]
135+
// CHECK: .split: ; preds = %[[VAL_27]]
136+
// CHECK: br label %[[VAL_34:.*]]
137+
// CHECK: 12: ; preds = %[[VAL_25]]
138+
// CHECK: br label %[[VAL_34]]
139+
// CHECK: 13: ; preds = %[[VAL_28]], %[[VAL_32]]
140+
// CHECK: br label %[[VAL_35:.*]]
141+
// CHECK: omp.region.cont: ; preds = %[[VAL_34]]
142+
// CHECK: br label %[[VAL_23]]
143+
// CHECK: omp_section_loop.body.sections.after: ; preds = %[[VAL_35]], %[[VAL_18]]
144+
// CHECK: br label %[[VAL_13]]
145+
// CHECK: omp_section_loop.inc: ; preds = %[[VAL_23]]
146+
// CHECK: %[[VAL_15]] = add nuw i32 %[[VAL_14]], 1
147+
// CHECK: br label %[[VAL_12]]
148+
// CHECK: omp_section_loop.exit: ; preds = %[[VAL_33]], %[[VAL_16]]
149+
// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_7]])
150+
// CHECK: %[[VAL_36:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
151+
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_36]])
152+
// CHECK: br label %[[VAL_37:.*]]
153+
// CHECK: omp_section_loop.after: ; preds = %[[VAL_19]]
154+
// CHECK: br label %[[VAL_38:.*]]
155+
// CHECK: omp_section_loop.aftersections.fini: ; preds = %[[VAL_37]]
156+
// CHECK: ret void
157+
// CHECK: .cncl: ; preds = %[[VAL_27]]
158+
// CHECK: br label %[[VAL_19]]

mlir/test/Target/LLVMIR/openmp-todo.mlir

-16
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,6 @@ llvm.func @cancel_wsloop(%lb : i32, %ub : i32, %step: i32) {
4242

4343
// -----
4444

45-
llvm.func @cancel_sections() {
46-
// expected-error@below {{LLVM Translation failed for operation: omp.sections}}
47-
omp.sections {
48-
omp.section {
49-
// expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancel operation}}
50-
// expected-error@below {{LLVM Translation failed for operation: omp.cancel}}
51-
omp.cancel cancellation_construct_type(sections)
52-
omp.terminator
53-
}
54-
omp.terminator
55-
}
56-
llvm.return
57-
}
58-
59-
// -----
60-
6145
llvm.func @cancel_taskgroup() {
6246
// expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
6347
omp.taskgroup {

0 commit comments

Comments
 (0)