Skip to content

Commit bb1af01

Browse files
committed
[mlir][OpenMP] convert wsloop cancellation to LLVMIR
Taskloop support will follow in a later patch.
1 parent 4d28bd8 commit bb1af01

File tree

3 files changed

+125
-18
lines changed

3 files changed

+125
-18
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

+38-2
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
161161
auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
162162
omp::ClauseCancellationConstructType cancelledDirective =
163163
op.getCancelDirective();
164-
if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel &&
165-
cancelledDirective != omp::ClauseCancellationConstructType::Sections)
164+
if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
166165
result = todo("cancel directive construct type not yet supported");
167166
};
168167
auto checkDepend = [&todo](auto op, LogicalResult &result) {
@@ -2356,6 +2355,30 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
23562355
? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
23572356
: llvm::omp::WorksharingLoopType::ForStaticLoop;
23582357

2358+
SmallVector<llvm::BranchInst *> cancelTerminators;
2359+
// This callback is invoked only if there is cancellation inside of the wsloop
2360+
// body.
2361+
auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2362+
llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder;
2363+
llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2364+
2365+
// ip is currently in the block branched to if cancellation occured.
2366+
// We need to create a branch to terminate that block.
2367+
llvmBuilder.restoreIP(ip);
2368+
2369+
// We must still clean up the wsloop after cancelling it, so we need to
2370+
// branch to the block that finalizes the wsloop.
2371+
// That block has not been created yet so use this block as a dummy for now
2372+
// and fix this after creating the wsloop.
2373+
cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2374+
return llvm::Error::success();
2375+
};
2376+
// We have to add the cleanup to the OpenMPIRBuilder before the body gets
2377+
// created in case the body contains omp.cancel (which will then expect to be
2378+
// able to find this cleanup callback).
2379+
ompBuilder->pushFinalizationCB({finiCB, llvm::omp::Directive::OMPD_for,
2380+
constructIsCancellable(wsloopOp)});
2381+
23592382
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
23602383
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
23612384
wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
@@ -2377,6 +2400,19 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
23772400
if (failed(handleError(wsloopIP, opInst)))
23782401
return failure();
23792402

2403+
ompBuilder->popFinalizationCB();
2404+
if (!cancelTerminators.empty()) {
2405+
// If we cancelled the loop, we should branch to the finalization block of
2406+
// the wsloop (which is always immediately before the loop continuation
2407+
// block). Now the finalization has been created, we can fix the branch.
2408+
llvm::BasicBlock *wsloopFini = wsloopIP->getBlock()->getSinglePredecessor();
2409+
for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2410+
assert(cancelBranch->getNumSuccessors() == 1 &&
2411+
"cancel branch should have one target");
2412+
cancelBranch->setSuccessor(0, wsloopFini);
2413+
}
2414+
}
2415+
23802416
// Process the reductions if required.
23812417
if (failed(createReductionsAndCleanup(
23822418
wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,

mlir/test/Target/LLVMIR/openmp-cancel.mlir

+87
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,90 @@ llvm.func @cancel_sections_if(%cond : i1) {
156156
// CHECK: ret void
157157
// CHECK: .cncl: ; preds = %[[VAL_27]]
158158
// CHECK: br label %[[VAL_19]]
159+
160+
llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
161+
omp.wsloop {
162+
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
163+
omp.cancel cancellation_construct_type(loop) if(%cond)
164+
omp.yield
165+
}
166+
}
167+
llvm.return
168+
}
169+
// CHECK-LABEL: define void @cancel_wsloop_if
170+
// CHECK: %[[VAL_0:.*]] = alloca i32, align 4
171+
// CHECK: %[[VAL_1:.*]] = alloca i32, align 4
172+
// CHECK: %[[VAL_2:.*]] = alloca i32, align 4
173+
// CHECK: %[[VAL_3:.*]] = alloca i32, align 4
174+
// CHECK: br label %[[VAL_4:.*]]
175+
// CHECK: omp.region.after_alloca: ; preds = %[[VAL_5:.*]]
176+
// CHECK: br label %[[VAL_6:.*]]
177+
// CHECK: entry: ; preds = %[[VAL_4]]
178+
// CHECK: br label %[[VAL_7:.*]]
179+
// CHECK: omp.wsloop.region: ; preds = %[[VAL_6]]
180+
// CHECK: %[[VAL_8:.*]] = icmp slt i32 %[[VAL_9:.*]], 0
181+
// CHECK: %[[VAL_10:.*]] = sub i32 0, %[[VAL_9]]
182+
// CHECK: %[[VAL_11:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_10]], i32 %[[VAL_9]]
183+
// CHECK: %[[VAL_12:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_13:.*]], i32 %[[VAL_14:.*]]
184+
// CHECK: %[[VAL_15:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_14]], i32 %[[VAL_13]]
185+
// CHECK: %[[VAL_16:.*]] = sub nsw i32 %[[VAL_15]], %[[VAL_12]]
186+
// CHECK: %[[VAL_17:.*]] = icmp sle i32 %[[VAL_15]], %[[VAL_12]]
187+
// CHECK: %[[VAL_18:.*]] = sub i32 %[[VAL_16]], 1
188+
// CHECK: %[[VAL_19:.*]] = udiv i32 %[[VAL_18]], %[[VAL_11]]
189+
// CHECK: %[[VAL_20:.*]] = add i32 %[[VAL_19]], 1
190+
// CHECK: %[[VAL_21:.*]] = icmp ule i32 %[[VAL_16]], %[[VAL_11]]
191+
// CHECK: %[[VAL_22:.*]] = select i1 %[[VAL_21]], i32 1, i32 %[[VAL_20]]
192+
// CHECK: %[[VAL_23:.*]] = select i1 %[[VAL_17]], i32 0, i32 %[[VAL_22]]
193+
// CHECK: br label %[[VAL_24:.*]]
194+
// CHECK: omp_loop.preheader: ; preds = %[[VAL_7]]
195+
// CHECK: store i32 0, ptr %[[VAL_1]], align 4
196+
// CHECK: %[[VAL_25:.*]] = sub i32 %[[VAL_23]], 1
197+
// CHECK: store i32 %[[VAL_25]], ptr %[[VAL_2]], align 4
198+
// CHECK: store i32 1, ptr %[[VAL_3]], align 4
199+
// CHECK: %[[VAL_26:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
200+
// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %[[VAL_26]], i32 34, ptr %[[VAL_0]], ptr %[[VAL_1]], ptr %[[VAL_2]], ptr %[[VAL_3]], i32 1, i32 0)
201+
// CHECK: %[[VAL_27:.*]] = load i32, ptr %[[VAL_1]], align 4
202+
// CHECK: %[[VAL_28:.*]] = load i32, ptr %[[VAL_2]], align 4
203+
// CHECK: %[[VAL_29:.*]] = sub i32 %[[VAL_28]], %[[VAL_27]]
204+
// CHECK: %[[VAL_30:.*]] = add i32 %[[VAL_29]], 1
205+
// CHECK: br label %[[VAL_31:.*]]
206+
// CHECK: omp_loop.header: ; preds = %[[VAL_32:.*]], %[[VAL_24]]
207+
// CHECK: %[[VAL_33:.*]] = phi i32 [ 0, %[[VAL_24]] ], [ %[[VAL_34:.*]], %[[VAL_32]] ]
208+
// CHECK: br label %[[VAL_35:.*]]
209+
// CHECK: omp_loop.cond: ; preds = %[[VAL_31]]
210+
// CHECK: %[[VAL_36:.*]] = icmp ult i32 %[[VAL_33]], %[[VAL_30]]
211+
// CHECK: br i1 %[[VAL_36]], label %[[VAL_37:.*]], label %[[VAL_38:.*]]
212+
// CHECK: omp_loop.body: ; preds = %[[VAL_35]]
213+
// CHECK: %[[VAL_39:.*]] = add i32 %[[VAL_33]], %[[VAL_27]]
214+
// CHECK: %[[VAL_40:.*]] = mul i32 %[[VAL_39]], %[[VAL_9]]
215+
// CHECK: %[[VAL_41:.*]] = add i32 %[[VAL_40]], %[[VAL_14]]
216+
// CHECK: br label %[[VAL_42:.*]]
217+
// CHECK: omp.loop_nest.region: ; preds = %[[VAL_37]]
218+
// CHECK: br i1 %[[VAL_43:.*]], label %[[VAL_44:.*]], label %[[VAL_45:.*]]
219+
// CHECK: 25: ; preds = %[[VAL_42]]
220+
// CHECK: %[[VAL_46:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
221+
// CHECK: %[[VAL_47:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_46]], i32 2)
222+
// CHECK: %[[VAL_48:.*]] = icmp eq i32 %[[VAL_47]], 0
223+
// CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_50:.*]]
224+
// CHECK: .split: ; preds = %[[VAL_44]]
225+
// CHECK: br label %[[VAL_51:.*]]
226+
// CHECK: 28: ; preds = %[[VAL_42]]
227+
// CHECK: br label %[[VAL_51]]
228+
// CHECK: 29: ; preds = %[[VAL_45]], %[[VAL_49]]
229+
// CHECK: br label %[[VAL_52:.*]]
230+
// CHECK: omp.region.cont1: ; preds = %[[VAL_51]]
231+
// CHECK: br label %[[VAL_32]]
232+
// CHECK: omp_loop.inc: ; preds = %[[VAL_52]]
233+
// CHECK: %[[VAL_34]] = add nuw i32 %[[VAL_33]], 1
234+
// CHECK: br label %[[VAL_31]]
235+
// CHECK: omp_loop.exit: ; preds = %[[VAL_50]], %[[VAL_35]]
236+
// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_26]])
237+
// CHECK: %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
238+
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_53]])
239+
// CHECK: br label %[[VAL_54:.*]]
240+
// CHECK: omp_loop.after: ; preds = %[[VAL_38]]
241+
// CHECK: br label %[[VAL_55:.*]]
242+
// CHECK: omp.region.cont: ; preds = %[[VAL_54]]
243+
// CHECK: ret void
244+
// CHECK: .cncl: ; preds = %[[VAL_44]]
245+
// CHECK: br label %[[VAL_38]]

mlir/test/Target/LLVMIR/openmp-todo.mlir

-16
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,6 @@ llvm.func @atomic_hint(%v : !llvm.ptr, %x : !llvm.ptr, %expr : i32) {
2626

2727
// -----
2828

29-
llvm.func @cancel_wsloop(%lb : i32, %ub : i32, %step: i32) {
30-
// expected-error@below {{LLVM Translation failed for operation: omp.wsloop}}
31-
omp.wsloop {
32-
// expected-error@below {{LLVM Translation failed for operation: omp.loop_nest}}
33-
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
34-
// expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancel operation}}
35-
// expected-error@below {{LLVM Translation failed for operation: omp.cancel}}
36-
omp.cancel cancellation_construct_type(loop)
37-
omp.yield
38-
}
39-
}
40-
llvm.return
41-
}
42-
43-
// -----
44-
4529
llvm.func @cancel_taskgroup() {
4630
// expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
4731
omp.taskgroup {

0 commit comments

Comments
 (0)