@@ -158,6 +158,12 @@ static LogicalResult checkImplementationStatus(Operation &op) {
158
158
if (op.getBare ())
159
159
result = todo (" ompx_bare" );
160
160
};
161
+ auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
162
+ omp::ClauseCancellationConstructType cancelledDirective =
163
+ op.getCancelDirective ();
164
+ if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel)
165
+ result = todo (" cancel directive construct type not yet supported" );
166
+ };
161
167
auto checkDepend = [&todo](auto op, LogicalResult &result) {
162
168
if (!op.getDependVars ().empty () || op.getDependKinds ())
163
169
result = todo (" depend" );
@@ -248,6 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
248
254
249
255
LogicalResult result = success ();
250
256
llvm::TypeSwitch<Operation &>(op)
257
+ .Case ([&](omp::CancelOp op) { checkCancelDirective (op, result); })
251
258
.Case ([&](omp::DistributeOp op) {
252
259
checkAllocate (op, result);
253
260
checkDistSchedule (op, result);
@@ -1580,6 +1587,19 @@ cleanupPrivateVars(llvm::IRBuilderBase &builder,
1580
1587
return success ();
1581
1588
}
1582
1589
1590
+ // / Returns true if the construct contains omp.cancel or omp.cancellation_point
1591
+ static bool constructIsCancellable (Operation *op) {
1592
+ // omp.cancel must be "closely nested" so it will be visible and not inside of
1593
+ // funcion calls. This is enforced by the verifier.
1594
+ return op
1595
+ ->walk ([](Operation *child) {
1596
+ if (mlir::isa<omp::CancelOp>(child))
1597
+ return WalkResult::interrupt ();
1598
+ return WalkResult::advance ();
1599
+ })
1600
+ .wasInterrupted ();
1601
+ }
1602
+
1583
1603
static LogicalResult
1584
1604
convertOmpSections (Operation &opInst, llvm::IRBuilderBase &builder,
1585
1605
LLVM::ModuleTranslation &moduleTranslation) {
@@ -2524,8 +2544,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
2524
2544
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2525
2545
if (auto bind = opInst.getProcBindKind ())
2526
2546
pbKind = getProcBindKind (*bind);
2527
- // TODO: Is the Parallel construct cancellable?
2528
- bool isCancellable = false ;
2547
+ bool isCancellable = constructIsCancellable (opInst);
2529
2548
2530
2549
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2531
2550
findAllocaInsertPoint (builder, moduleTranslation);
@@ -2991,6 +3010,47 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
2991
3010
return success ();
2992
3011
}
2993
3012
3013
+ static llvm::omp::Directive convertCancellationConstructType (
3014
+ omp::ClauseCancellationConstructType directive) {
3015
+ switch (directive) {
3016
+ case omp::ClauseCancellationConstructType::Loop:
3017
+ return llvm::omp::Directive::OMPD_for;
3018
+ case omp::ClauseCancellationConstructType::Parallel:
3019
+ return llvm::omp::Directive::OMPD_parallel;
3020
+ case omp::ClauseCancellationConstructType::Sections:
3021
+ return llvm::omp::Directive::OMPD_sections;
3022
+ case omp::ClauseCancellationConstructType::Taskgroup:
3023
+ return llvm::omp::Directive::OMPD_taskgroup;
3024
+ }
3025
+ }
3026
+
3027
+ static LogicalResult
3028
+ convertOmpCancel (omp::CancelOp op, llvm::IRBuilderBase &builder,
3029
+ LLVM::ModuleTranslation &moduleTranslation) {
3030
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3031
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3032
+
3033
+ if (failed (checkImplementationStatus (*op.getOperation ())))
3034
+ return failure ();
3035
+
3036
+ llvm::Value *ifCond = nullptr ;
3037
+ if (Value ifVar = op.getIfExpr ())
3038
+ ifCond = moduleTranslation.lookupValue (ifVar);
3039
+
3040
+ llvm::omp::Directive cancelledDirective =
3041
+ convertCancellationConstructType (op.getCancelDirective ());
3042
+
3043
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3044
+ ompBuilder->createCancel (ompLoc, ifCond, cancelledDirective);
3045
+
3046
+ if (failed (handleError (afterIP, *op.getOperation ())))
3047
+ return failure ();
3048
+
3049
+ builder.restoreIP (afterIP.get ());
3050
+
3051
+ return success ();
3052
+ }
3053
+
2994
3054
// / Converts an OpenMP Threadprivate operation into LLVM IR using
2995
3055
// / OpenMPIRBuilder.
2996
3056
static LogicalResult
@@ -5421,6 +5481,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
5421
5481
.Case ([&](omp::AtomicCaptureOp op) {
5422
5482
return convertOmpAtomicCapture (op, builder, moduleTranslation);
5423
5483
})
5484
+ .Case ([&](omp::CancelOp op) {
5485
+ return convertOmpCancel (op, builder, moduleTranslation);
5486
+ })
5424
5487
.Case ([&](omp::SectionsOp) {
5425
5488
return convertOmpSections (*op, builder, moduleTranslation);
5426
5489
})
0 commit comments