Skip to content

Commit ba6d59c

Browse files
committed
[mlir][OpenMP] Convert omp.cancel parallel to LLVMIR
Support for other constructs will follow in subsequent PRs.
1 parent 6900e90 commit ba6d59c

File tree

3 files changed

+191
-8
lines changed

3 files changed

+191
-8
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,12 @@ static LogicalResult checkImplementationStatus(Operation &op) {
158158
if (op.getBare())
159159
result = todo("ompx_bare");
160160
};
161+
auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
162+
omp::ClauseCancellationConstructType cancelledDirective =
163+
op.getCancelDirective();
164+
if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel)
165+
result = todo("cancel directive");
166+
};
161167
auto checkDepend = [&todo](auto op, LogicalResult &result) {
162168
if (!op.getDependVars().empty() || op.getDependKinds())
163169
result = todo("depend");
@@ -248,6 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
248254

249255
LogicalResult result = success();
250256
llvm::TypeSwitch<Operation &>(op)
257+
.Case([&](omp::CancelOp op) { checkCancelDirective(op, result); })
251258
.Case([&](omp::DistributeOp op) {
252259
checkAllocate(op, result);
253260
checkDistSchedule(op, result);
@@ -1580,6 +1587,21 @@ cleanupPrivateVars(llvm::IRBuilderBase &builder,
15801587
return success();
15811588
}
15821589

1590+
/// Returns true if the construct contains omp.cancel or omp.cancellation_point
1591+
static bool constructIsCancellable(Operation *op) {
1592+
// omp.cancel must be "closely nested" so it will be visible and not inside of
1593+
// funcion calls. This is enforced by the verifier.
1594+
bool containsCancel = false;
1595+
op->walk([&containsCancel](Operation *child) {
1596+
if (mlir::isa<omp::CancelOp>(child)) {
1597+
containsCancel = true;
1598+
return WalkResult::interrupt();
1599+
}
1600+
return WalkResult::advance();
1601+
});
1602+
return containsCancel;
1603+
}
1604+
15831605
static LogicalResult
15841606
convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
15851607
LLVM::ModuleTranslation &moduleTranslation) {
@@ -2524,8 +2546,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
25242546
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
25252547
if (auto bind = opInst.getProcBindKind())
25262548
pbKind = getProcBindKind(*bind);
2527-
// TODO: Is the Parallel construct cancellable?
2528-
bool isCancellable = false;
2549+
bool isCancellable = constructIsCancellable(opInst);
25292550

25302551
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
25312552
findAllocaInsertPoint(builder, moduleTranslation);
@@ -2991,6 +3012,47 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
29913012
return success();
29923013
}
29933014

3015+
static llvm::omp::Directive convertCancellationConstructType(
3016+
omp::ClauseCancellationConstructType directive) {
3017+
switch (directive) {
3018+
case omp::ClauseCancellationConstructType::Loop:
3019+
return llvm::omp::Directive::OMPD_for;
3020+
case omp::ClauseCancellationConstructType::Parallel:
3021+
return llvm::omp::Directive::OMPD_parallel;
3022+
case omp::ClauseCancellationConstructType::Sections:
3023+
return llvm::omp::Directive::OMPD_sections;
3024+
case omp::ClauseCancellationConstructType::Taskgroup:
3025+
return llvm::omp::Directive::OMPD_taskgroup;
3026+
}
3027+
}
3028+
3029+
static LogicalResult
3030+
convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
3031+
LLVM::ModuleTranslation &moduleTranslation) {
3032+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3033+
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3034+
3035+
if (failed(checkImplementationStatus(*op.getOperation())))
3036+
return failure();
3037+
3038+
llvm::Value *ifCond = nullptr;
3039+
if (Value ifVar = op.getIfExpr())
3040+
ifCond = moduleTranslation.lookupValue(ifVar);
3041+
3042+
llvm::omp::Directive cancelledDirective =
3043+
convertCancellationConstructType(op.getCancelDirective());
3044+
3045+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3046+
ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
3047+
3048+
if (failed(handleError(afterIP, *op.getOperation())))
3049+
return failure();
3050+
3051+
builder.restoreIP(afterIP.get());
3052+
3053+
return success();
3054+
}
3055+
29943056
/// Converts an OpenMP Threadprivate operation into LLVM IR using
29953057
/// OpenMPIRBuilder.
29963058
static LogicalResult
@@ -5421,6 +5483,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
54215483
.Case([&](omp::AtomicCaptureOp op) {
54225484
return convertOmpAtomicCapture(op, builder, moduleTranslation);
54235485
})
5486+
.Case([&](omp::CancelOp op) {
5487+
return convertOmpCancel(op, builder, moduleTranslation);
5488+
})
54245489
.Case([&](omp::SectionsOp) {
54255490
return convertOmpSections(*op, builder, moduleTranslation);
54265491
})
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
2+
3+
llvm.func @cancel_parallel() {
4+
omp.parallel {
5+
omp.cancel cancellation_construct_type(parallel)
6+
omp.terminator
7+
}
8+
llvm.return
9+
}
10+
// CHECK-LABEL: define internal void @cancel_parallel..omp_par
11+
// CHECK: omp.par.entry:
12+
// CHECK: %[[VAL_5:.*]] = alloca i32, align 4
13+
// CHECK: %[[VAL_6:.*]] = load i32, ptr %[[VAL_7:.*]], align 4
14+
// CHECK: store i32 %[[VAL_6]], ptr %[[VAL_5]], align 4
15+
// CHECK: %[[VAL_8:.*]] = load i32, ptr %[[VAL_5]], align 4
16+
// CHECK: br label %[[VAL_9:.*]]
17+
// CHECK: omp.region.after_alloca: ; preds = %[[VAL_10:.*]]
18+
// CHECK: br label %[[VAL_11:.*]]
19+
// CHECK: omp.par.region: ; preds = %[[VAL_9]]
20+
// CHECK: br label %[[VAL_12:.*]]
21+
// CHECK: omp.par.region1: ; preds = %[[VAL_11]]
22+
// CHECK: %[[VAL_13:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
23+
// CHECK: %[[VAL_14:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_13]], i32 1)
24+
// CHECK: %[[VAL_15:.*]] = icmp eq i32 %[[VAL_14]], 0
25+
// CHECK: br i1 %[[VAL_15]], label %[[VAL_16:.*]], label %[[VAL_17:.*]]
26+
// CHECK: omp.par.region1.cncl: ; preds = %[[VAL_12]]
27+
// CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
28+
// CHECK: %[[VAL_19:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_18]])
29+
// CHECK: br label %[[VAL_20:.*]]
30+
// CHECK: omp.par.region1.split: ; preds = %[[VAL_12]]
31+
// CHECK: br label %[[VAL_21:.*]]
32+
// CHECK: omp.region.cont: ; preds = %[[VAL_16]]
33+
// CHECK: br label %[[VAL_22:.*]]
34+
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_21]]
35+
// CHECK: br label %[[VAL_20]]
36+
// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_22]], %[[VAL_17]]
37+
// CHECK: ret void
38+
39+
llvm.func @cancel_parallel_if(%arg0 : i1) {
40+
omp.parallel {
41+
omp.cancel cancellation_construct_type(parallel) if(%arg0)
42+
omp.terminator
43+
}
44+
llvm.return
45+
}
46+
// CHECK-LABEL: define internal void @cancel_parallel_if..omp_par
47+
// CHECK: omp.par.entry:
48+
// CHECK: %[[VAL_9:.*]] = getelementptr { ptr }, ptr %[[VAL_10:.*]], i32 0, i32 0
49+
// CHECK: %[[VAL_11:.*]] = load ptr, ptr %[[VAL_9]], align 8
50+
// CHECK: %[[VAL_12:.*]] = alloca i32, align 4
51+
// CHECK: %[[VAL_13:.*]] = load i32, ptr %[[VAL_14:.*]], align 4
52+
// CHECK: store i32 %[[VAL_13]], ptr %[[VAL_12]], align 4
53+
// CHECK: %[[VAL_15:.*]] = load i32, ptr %[[VAL_12]], align 4
54+
// CHECK: %[[VAL_16:.*]] = load i1, ptr %[[VAL_11]], align 1
55+
// CHECK: br label %[[VAL_17:.*]]
56+
// CHECK: omp.region.after_alloca: ; preds = %[[VAL_18:.*]]
57+
// CHECK: br label %[[VAL_19:.*]]
58+
// CHECK: omp.par.region: ; preds = %[[VAL_17]]
59+
// CHECK: br label %[[VAL_20:.*]]
60+
// CHECK: omp.par.region1: ; preds = %[[VAL_19]]
61+
// CHECK: br i1 %[[VAL_16]], label %[[VAL_21:.*]], label %[[VAL_22:.*]]
62+
// CHECK: 3: ; preds = %[[VAL_20]]
63+
// CHECK: br label %[[VAL_23:.*]]
64+
// CHECK: 4: ; preds = %[[VAL_22]], %[[VAL_24:.*]]
65+
// CHECK: br label %[[VAL_25:.*]]
66+
// CHECK: omp.region.cont: ; preds = %[[VAL_23]]
67+
// CHECK: br label %[[VAL_26:.*]]
68+
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_25]]
69+
// CHECK: br label %[[VAL_27:.*]]
70+
// CHECK: 5: ; preds = %[[VAL_20]]
71+
// CHECK: %[[VAL_28:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
72+
// CHECK: %[[VAL_29:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_28]], i32 1)
73+
// CHECK: %[[VAL_30:.*]] = icmp eq i32 %[[VAL_29]], 0
74+
// CHECK: br i1 %[[VAL_30]], label %[[VAL_24]], label %[[VAL_31:.*]]
75+
// CHECK: .cncl: ; preds = %[[VAL_21]]
76+
// CHECK: %[[VAL_32:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
77+
// CHECK: %[[VAL_33:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_32]])
78+
// CHECK: br label %[[VAL_27]]
79+
// CHECK: .split: ; preds = %[[VAL_21]]
80+
// CHECK: br label %[[VAL_23]]
81+
// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_31]], %[[VAL_26]]
82+
// CHECK: ret void

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,48 @@ llvm.func @atomic_hint(%v : !llvm.ptr, %x : !llvm.ptr, %expr : i32) {
2626

2727
// -----
2828

29-
llvm.func @cancel() {
30-
// expected-error@below {{LLVM Translation failed for operation: omp.parallel}}
31-
omp.parallel {
32-
// expected-error@below {{not yet implemented: omp.cancel}}
33-
// expected-error@below {{LLVM Translation failed for operation: omp.cancel}}
34-
omp.cancel cancellation_construct_type(parallel)
29+
llvm.func @cancel_wsloop(%lb : i32, %ub : i32, %step: i32) {
30+
// expected-error@below {{LLVM Translation failed for operation: omp.wsloop}}
31+
omp.wsloop {
32+
// expected-error@below {{LLVM Translation failed for operation: omp.loop_nest}}
33+
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
34+
// expected-error@below {{not yet implemented: Unhandled clause cancel directive in omp.cancel operation}}
35+
// expected-error@below {{LLVM Translation failed for operation: omp.cancel}}
36+
omp.cancel cancellation_construct_type(loop)
37+
omp.yield
38+
}
39+
}
40+
llvm.return
41+
}
42+
43+
// -----
44+
45+
llvm.func @cancel_sections() {
46+
// expected-error@below {{LLVM Translation failed for operation: omp.sections}}
47+
omp.sections {
48+
omp.section {
49+
// expected-error@below {{not yet implemented: Unhandled clause cancel directive in omp.cancel operation}}
50+
// expected-error@below {{LLVM Translation failed for operation: omp.cancel}}
51+
omp.cancel cancellation_construct_type(sections)
52+
omp.terminator
53+
}
54+
omp.terminator
55+
}
56+
llvm.return
57+
}
58+
59+
// -----
60+
61+
llvm.func @cancel_taskgroup() {
62+
// expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
63+
omp.taskgroup {
64+
// expected-error@below {{LLVM Translation failed for operation: omp.task}}
65+
omp.task {
66+
// expected-error@below {{not yet implemented: Unhandled clause cancel directive in omp.cancel operation}}
67+
// expected-error@below {{LLVM Translation failed for operation: omp.cancel}}
68+
omp.cancel cancellation_construct_type(taskgroup)
69+
omp.terminator
70+
}
3571
omp.terminator
3672
}
3773
llvm.return

0 commit comments

Comments
 (0)