@@ -158,6 +158,12 @@ static LogicalResult checkImplementationStatus(Operation &op) {
158
158
if (op.getBare ())
159
159
result = todo (" ompx_bare" );
160
160
};
161
+ auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
162
+ omp::ClauseCancellationConstructType cancelledDirective =
163
+ op.getCancelDirective ();
164
+ if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel)
165
+ result = todo (" cancel directive" );
166
+ };
161
167
auto checkDepend = [&todo](auto op, LogicalResult &result) {
162
168
if (!op.getDependVars ().empty () || op.getDependKinds ())
163
169
result = todo (" depend" );
@@ -248,6 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
248
254
249
255
LogicalResult result = success ();
250
256
llvm::TypeSwitch<Operation &>(op)
257
+ .Case ([&](omp::CancelOp op) { checkCancelDirective (op, result); })
251
258
.Case ([&](omp::DistributeOp op) {
252
259
checkAllocate (op, result);
253
260
checkDistSchedule (op, result);
@@ -1580,6 +1587,21 @@ cleanupPrivateVars(llvm::IRBuilderBase &builder,
1580
1587
return success ();
1581
1588
}
1582
1589
1590
+ // / Returns true if the construct contains omp.cancel or omp.cancellation_point
1591
+ static bool constructIsCancellable (Operation *op) {
1592
+ // omp.cancel must be "closely nested" so it will be visible and not inside of
1593
+ // funcion calls. This is enforced by the verifier.
1594
+ bool containsCancel = false ;
1595
+ op->walk ([&containsCancel](Operation *child) {
1596
+ if (mlir::isa<omp::CancelOp>(child)) {
1597
+ containsCancel = true ;
1598
+ return WalkResult::interrupt ();
1599
+ }
1600
+ return WalkResult::advance ();
1601
+ });
1602
+ return containsCancel;
1603
+ }
1604
+
1583
1605
static LogicalResult
1584
1606
convertOmpSections (Operation &opInst, llvm::IRBuilderBase &builder,
1585
1607
LLVM::ModuleTranslation &moduleTranslation) {
@@ -2524,8 +2546,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
2524
2546
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2525
2547
if (auto bind = opInst.getProcBindKind ())
2526
2548
pbKind = getProcBindKind (*bind);
2527
- // TODO: Is the Parallel construct cancellable?
2528
- bool isCancellable = false ;
2549
+ bool isCancellable = constructIsCancellable (opInst);
2529
2550
2530
2551
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2531
2552
findAllocaInsertPoint (builder, moduleTranslation);
@@ -2991,6 +3012,47 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
2991
3012
return success ();
2992
3013
}
2993
3014
3015
+ static llvm::omp::Directive convertCancellationConstructType (
3016
+ omp::ClauseCancellationConstructType directive) {
3017
+ switch (directive) {
3018
+ case omp::ClauseCancellationConstructType::Loop:
3019
+ return llvm::omp::Directive::OMPD_for;
3020
+ case omp::ClauseCancellationConstructType::Parallel:
3021
+ return llvm::omp::Directive::OMPD_parallel;
3022
+ case omp::ClauseCancellationConstructType::Sections:
3023
+ return llvm::omp::Directive::OMPD_sections;
3024
+ case omp::ClauseCancellationConstructType::Taskgroup:
3025
+ return llvm::omp::Directive::OMPD_taskgroup;
3026
+ }
3027
+ }
3028
+
3029
+ static LogicalResult
3030
+ convertOmpCancel (omp::CancelOp op, llvm::IRBuilderBase &builder,
3031
+ LLVM::ModuleTranslation &moduleTranslation) {
3032
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3033
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3034
+
3035
+ if (failed (checkImplementationStatus (*op.getOperation ())))
3036
+ return failure ();
3037
+
3038
+ llvm::Value *ifCond = nullptr ;
3039
+ if (Value ifVar = op.getIfExpr ())
3040
+ ifCond = moduleTranslation.lookupValue (ifVar);
3041
+
3042
+ llvm::omp::Directive cancelledDirective =
3043
+ convertCancellationConstructType (op.getCancelDirective ());
3044
+
3045
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3046
+ ompBuilder->createCancel (ompLoc, ifCond, cancelledDirective);
3047
+
3048
+ if (failed (handleError (afterIP, *op.getOperation ())))
3049
+ return failure ();
3050
+
3051
+ builder.restoreIP (afterIP.get ());
3052
+
3053
+ return success ();
3054
+ }
3055
+
2994
3056
// / Converts an OpenMP Threadprivate operation into LLVM IR using
2995
3057
// / OpenMPIRBuilder.
2996
3058
static LogicalResult
@@ -5421,6 +5483,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
5421
5483
.Case ([&](omp::AtomicCaptureOp op) {
5422
5484
return convertOmpAtomicCapture (op, builder, moduleTranslation);
5423
5485
})
5486
+ .Case ([&](omp::CancelOp op) {
5487
+ return convertOmpCancel (op, builder, moduleTranslation);
5488
+ })
5424
5489
.Case ([&](omp::SectionsOp) {
5425
5490
return convertOmpSections (*op, builder, moduleTranslation);
5426
5491
})
0 commit comments