-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][OpenMP] Convert omp.cancel parallel to LLVMIR #137192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -158,6 +158,12 @@ static LogicalResult checkImplementationStatus(Operation &op) { | |||
if (op.getBare()) | ||||
result = todo("ompx_bare"); | ||||
}; | ||||
auto checkCancelDirective = [&todo](auto op, LogicalResult &result) { | ||||
omp::ClauseCancellationConstructType cancelledDirective = | ||||
op.getCancelDirective(); | ||||
if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel) | ||||
result = todo("cancel directive"); | ||||
}; | ||||
auto checkDepend = [&todo](auto op, LogicalResult &result) { | ||||
if (!op.getDependVars().empty() || op.getDependKinds()) | ||||
result = todo("depend"); | ||||
|
@@ -248,6 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { | |||
|
||||
LogicalResult result = success(); | ||||
llvm::TypeSwitch<Operation &>(op) | ||||
.Case([&](omp::CancelOp op) { checkCancelDirective(op, result); }) | ||||
.Case([&](omp::DistributeOp op) { | ||||
checkAllocate(op, result); | ||||
checkDistSchedule(op, result); | ||||
|
@@ -1580,6 +1587,21 @@ cleanupPrivateVars(llvm::IRBuilderBase &builder, | |||
return success(); | ||||
} | ||||
|
||||
/// Returns true if the construct contains omp.cancel or omp.cancellation_point | ||||
static bool constructIsCancellable(Operation *op) { | ||||
// omp.cancel must be "closely nested" so it will be visible and not inside of | ||||
// funcion calls. This is enforced by the verifier. | ||||
bool containsCancel = false; | ||||
op->walk([&containsCancel](Operation *child) { | ||||
if (mlir::isa<omp::CancelOp>(child)) { | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we pass an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The verifier checks that the clause cancellation construct type matches so I am taking that for granted here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I'm thinking is that if something like this is legal, we'd probably need to update verifiers and this function would need to know what type of !$omp parallel
! ...
!$omp do
do i = 1, N
! ...
!$omp cancel parallel
! ...
end do
! ...
!$omp end parallel I'm not familiar with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I understand it, I do not think this is legal. There was some discussion of how to interpret the standard in my PR correcting the verifiers: #134084 |
||||
containsCancel = true; | ||||
return WalkResult::interrupt(); | ||||
} | ||||
return WalkResult::advance(); | ||||
}); | ||||
return containsCancel; | ||||
tblah marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
} | ||||
|
||||
static LogicalResult | ||||
convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, | ||||
LLVM::ModuleTranslation &moduleTranslation) { | ||||
|
@@ -2524,8 +2546,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, | |||
auto pbKind = llvm::omp::OMP_PROC_BIND_default; | ||||
if (auto bind = opInst.getProcBindKind()) | ||||
pbKind = getProcBindKind(*bind); | ||||
// TODO: Is the Parallel construct cancellable? | ||||
bool isCancellable = false; | ||||
bool isCancellable = constructIsCancellable(opInst); | ||||
|
||||
llvm::OpenMPIRBuilder::InsertPointTy allocaIP = | ||||
findAllocaInsertPoint(builder, moduleTranslation); | ||||
|
@@ -2991,6 +3012,47 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, | |||
return success(); | ||||
} | ||||
|
||||
static llvm::omp::Directive convertCancellationConstructType( | ||||
omp::ClauseCancellationConstructType directive) { | ||||
switch (directive) { | ||||
case omp::ClauseCancellationConstructType::Loop: | ||||
return llvm::omp::Directive::OMPD_for; | ||||
case omp::ClauseCancellationConstructType::Parallel: | ||||
return llvm::omp::Directive::OMPD_parallel; | ||||
case omp::ClauseCancellationConstructType::Sections: | ||||
return llvm::omp::Directive::OMPD_sections; | ||||
case omp::ClauseCancellationConstructType::Taskgroup: | ||||
return llvm::omp::Directive::OMPD_taskgroup; | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be a TODO? I noticed that cancel on taskgroup is a TODO, rest all are PR stack. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The TODO is generated by |
||||
} | ||||
} | ||||
|
||||
static LogicalResult | ||||
convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, | ||||
LLVM::ModuleTranslation &moduleTranslation) { | ||||
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); | ||||
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); | ||||
|
||||
if (failed(checkImplementationStatus(*op.getOperation()))) | ||||
return failure(); | ||||
|
||||
llvm::Value *ifCond = nullptr; | ||||
if (Value ifVar = op.getIfExpr()) | ||||
ifCond = moduleTranslation.lookupValue(ifVar); | ||||
|
||||
llvm::omp::Directive cancelledDirective = | ||||
convertCancellationConstructType(op.getCancelDirective()); | ||||
|
||||
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||
ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective); | ||||
|
||||
if (failed(handleError(afterIP, *op.getOperation()))) | ||||
return failure(); | ||||
|
||||
builder.restoreIP(afterIP.get()); | ||||
|
||||
return success(); | ||||
} | ||||
|
||||
/// Converts an OpenMP Threadprivate operation into LLVM IR using | ||||
/// OpenMPIRBuilder. | ||||
static LogicalResult | ||||
|
@@ -5421,6 +5483,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, | |||
.Case([&](omp::AtomicCaptureOp op) { | ||||
return convertOmpAtomicCapture(op, builder, moduleTranslation); | ||||
}) | ||||
.Case([&](omp::CancelOp op) { | ||||
return convertOmpCancel(op, builder, moduleTranslation); | ||||
}) | ||||
.Case([&](omp::SectionsOp) { | ||||
return convertOmpSections(*op, builder, moduleTranslation); | ||||
}) | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s | ||
|
||
llvm.func @cancel_parallel() { | ||
omp.parallel { | ||
omp.cancel cancellation_construct_type(parallel) | ||
omp.terminator | ||
} | ||
llvm.return | ||
} | ||
// CHECK-LABEL: define internal void @cancel_parallel..omp_par | ||
// CHECK: omp.par.entry: | ||
// CHECK: %[[VAL_5:.*]] = alloca i32, align 4 | ||
// CHECK: %[[VAL_6:.*]] = load i32, ptr %[[VAL_7:.*]], align 4 | ||
// CHECK: store i32 %[[VAL_6]], ptr %[[VAL_5]], align 4 | ||
// CHECK: %[[VAL_8:.*]] = load i32, ptr %[[VAL_5]], align 4 | ||
// CHECK: br label %[[VAL_9:.*]] | ||
// CHECK: omp.region.after_alloca: ; preds = %[[VAL_10:.*]] | ||
// CHECK: br label %[[VAL_11:.*]] | ||
// CHECK: omp.par.region: ; preds = %[[VAL_9]] | ||
// CHECK: br label %[[VAL_12:.*]] | ||
// CHECK: omp.par.region1: ; preds = %[[VAL_11]] | ||
// CHECK: %[[VAL_13:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) | ||
// CHECK: %[[VAL_14:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_13]], i32 1) | ||
// CHECK: %[[VAL_15:.*]] = icmp eq i32 %[[VAL_14]], 0 | ||
// CHECK: br i1 %[[VAL_15]], label %[[VAL_16:.*]], label %[[VAL_17:.*]] | ||
// CHECK: omp.par.region1.cncl: ; preds = %[[VAL_12]] | ||
// CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) | ||
// CHECK: %[[VAL_19:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_18]]) | ||
// CHECK: br label %[[VAL_20:.*]] | ||
// CHECK: omp.par.region1.split: ; preds = %[[VAL_12]] | ||
// CHECK: br label %[[VAL_21:.*]] | ||
// CHECK: omp.region.cont: ; preds = %[[VAL_16]] | ||
// CHECK: br label %[[VAL_22:.*]] | ||
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_21]] | ||
// CHECK: br label %[[VAL_20]] | ||
// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_22]], %[[VAL_17]] | ||
// CHECK: ret void | ||
|
||
llvm.func @cancel_parallel_if(%arg0 : i1) { | ||
omp.parallel { | ||
omp.cancel cancellation_construct_type(parallel) if(%arg0) | ||
omp.terminator | ||
} | ||
llvm.return | ||
} | ||
// CHECK-LABEL: define internal void @cancel_parallel_if..omp_par | ||
// CHECK: omp.par.entry: | ||
// CHECK: %[[VAL_9:.*]] = getelementptr { ptr }, ptr %[[VAL_10:.*]], i32 0, i32 0 | ||
// CHECK: %[[VAL_11:.*]] = load ptr, ptr %[[VAL_9]], align 8 | ||
// CHECK: %[[VAL_12:.*]] = alloca i32, align 4 | ||
// CHECK: %[[VAL_13:.*]] = load i32, ptr %[[VAL_14:.*]], align 4 | ||
// CHECK: store i32 %[[VAL_13]], ptr %[[VAL_12]], align 4 | ||
// CHECK: %[[VAL_15:.*]] = load i32, ptr %[[VAL_12]], align 4 | ||
// CHECK: %[[VAL_16:.*]] = load i1, ptr %[[VAL_11]], align 1 | ||
// CHECK: br label %[[VAL_17:.*]] | ||
// CHECK: omp.region.after_alloca: ; preds = %[[VAL_18:.*]] | ||
// CHECK: br label %[[VAL_19:.*]] | ||
// CHECK: omp.par.region: ; preds = %[[VAL_17]] | ||
// CHECK: br label %[[VAL_20:.*]] | ||
// CHECK: omp.par.region1: ; preds = %[[VAL_19]] | ||
// CHECK: br i1 %[[VAL_16]], label %[[VAL_21:.*]], label %[[VAL_22:.*]] | ||
// CHECK: 3: ; preds = %[[VAL_20]] | ||
// CHECK: br label %[[VAL_23:.*]] | ||
// CHECK: 4: ; preds = %[[VAL_22]], %[[VAL_24:.*]] | ||
// CHECK: br label %[[VAL_25:.*]] | ||
// CHECK: omp.region.cont: ; preds = %[[VAL_23]] | ||
// CHECK: br label %[[VAL_26:.*]] | ||
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_25]] | ||
// CHECK: br label %[[VAL_27:.*]] | ||
// CHECK: 5: ; preds = %[[VAL_20]] | ||
// CHECK: %[[VAL_28:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) | ||
// CHECK: %[[VAL_29:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_28]], i32 1) | ||
// CHECK: %[[VAL_30:.*]] = icmp eq i32 %[[VAL_29]], 0 | ||
// CHECK: br i1 %[[VAL_30]], label %[[VAL_24]], label %[[VAL_31:.*]] | ||
// CHECK: .cncl: ; preds = %[[VAL_21]] | ||
// CHECK: %[[VAL_32:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) | ||
// CHECK: %[[VAL_33:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_32]]) | ||
// CHECK: br label %[[VAL_27]] | ||
// CHECK: .split: ; preds = %[[VAL_21]] | ||
// CHECK: br label %[[VAL_23]] | ||
// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_31]], %[[VAL_26]] | ||
// CHECK: ret void |
Uh oh!
There was an error while loading. Please reload this page.