Skip to content

Commit c5f1005

Browse files
committed
[mlir][OpenMP] initialize (first)private variables before task exec
This still doesn't fix the memory safety issues because the stack allocations created here for the private variables might go out of scope. I will add a more complete lit test later in this patch series.
1 parent d3519f1 commit c5f1005

File tree

2 files changed

+96
-41
lines changed

2 files changed

+96
-41
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,25 +1750,80 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
17501750
for (mlir::Value privateVar : taskOp.getPrivateVars())
17511751
mlirPrivateVars.push_back(privateVar);
17521752

1753-
auto bodyCB = [&](InsertPointTy allocaIP,
1754-
InsertPointTy codegenIP) -> llvm::Error {
1755-
// Save the alloca insertion point on ModuleTranslation stack for use in
1756-
// nested regions.
1757-
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
1758-
moduleTranslation, allocaIP);
1753+
// Allocate and copy private variables before creating the task. This avoids
1754+
// accessing invalid memory if (after this scope ends) the private variables
1755+
// are initialized from host variables or if the variables are copied into
1756+
// from host variables (firstprivate). The insertion point is just before
1757+
// where the code for creating and scheduling the task will go. That puts this
1758+
// code outside of the outlined task region, which is what we want because
1759+
// this way the initialization and copy regions are executed immediately while
1760+
// the host variable data are still live.
17591761

1760-
llvm::Expected<llvm::BasicBlock *> afterAllocas =
1761-
allocateAndInitPrivateVars(builder, moduleTranslation, privateBlockArgs,
1762-
privateDecls, mlirPrivateVars,
1763-
llvmPrivateVars, allocaIP);
1764-
if (handleError(afterAllocas, *taskOp).failed())
1765-
return llvm::make_error<PreviouslyReportedError>();
1762+
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1763+
findAllocaInsertPoint(builder, moduleTranslation);
17661764

1767-
if (failed(initFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
1768-
llvmPrivateVars, privateDecls,
1769-
afterAllocas.get())))
1770-
return llvm::make_error<PreviouslyReportedError>();
1765+
// Not using splitBB() because that requires the current block to have a
1766+
// terminator.
1767+
assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
1768+
llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
1769+
builder.getContext(), "omp.task.start",
1770+
/*Parent=*/builder.GetInsertBlock()->getParent());
1771+
llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
1772+
builder.SetInsertPoint(branchToTaskStartBlock);
1773+
1774+
// Now do this again to make the initialization and copy blocks
1775+
llvm::BasicBlock *copyBlock =
1776+
splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
1777+
llvm::BasicBlock *initBlock =
1778+
splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
1779+
1780+
// Now the control flow graph should look like
1781+
// starter_block:
1782+
// <---- where we started when convertOmpTaskOp was called
1783+
// br %omp.private.init
1784+
// omp.private.init:
1785+
// br %omp.private.copy
1786+
// omp.private.copy:
1787+
// br %omp.task.start
1788+
// omp.task.start:
1789+
// <---- where we want the insertion point to be when we call createTask()
1790+
1791+
// Save the alloca insertion point on ModuleTranslation stack for use in
1792+
// nested regions.
1793+
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
1794+
moduleTranslation, allocaIP);
1795+
1796+
// Allocate and initialize private variables
1797+
// TODO: package private variables up in a structure
1798+
builder.SetInsertPoint(initBlock->getTerminator());
1799+
for (auto [privDecl, mlirPrivVar, blockArg] :
1800+
llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs)) {
1801+
llvm::Type *llvmAllocType =
1802+
moduleTranslation.convertType(privDecl.getType());
17711803

1804+
// Allocations:
1805+
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1806+
llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1807+
llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1808+
1809+
// builder.SetInsertPoint(initBlock->getTerminator());
1810+
auto err =
1811+
initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
1812+
blockArg, llvmPrivateVar, llvmPrivateVars, initBlock);
1813+
if (err)
1814+
return handleError(std::move(err), *taskOp.getOperation());
1815+
}
1816+
1817+
// firstprivate copy region
1818+
if (failed(initFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
1819+
llvmPrivateVars, privateDecls, copyBlock)))
1820+
return llvm::failure();
1821+
1822+
// Set up for call to createTask()
1823+
builder.SetInsertPoint(taskStartBlock);
1824+
1825+
auto bodyCB = [&](InsertPointTy allocaIP,
1826+
InsertPointTy codegenIP) -> llvm::Error {
17721827
// translate the body of the task:
17731828
builder.restoreIP(codegenIP);
17741829
auto continuationBlockOrError = convertOmpOpRegions(
@@ -1789,8 +1844,6 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
17891844
buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
17901845
moduleTranslation, dds);
17911846

1792-
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1793-
findAllocaInsertPoint(builder, moduleTranslation);
17941847
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
17951848
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
17961849
moduleTranslation.getOpenMPBuilder()->createTask(

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2790,11 +2790,14 @@ llvm.func @par_task_(%arg0: !llvm.ptr {fir.bindc_name = "a"}) {
27902790
}
27912791

27922792
// CHECK-LABEL: @par_task_
2793+
// CHECK: %[[ARG_ALLOC:.*]] = alloca { ptr }, align 8
27932794
// CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc({{.*}}ptr @[[task_outlined_fn:.+]])
27942795
// CHECK: call i32 @__kmpc_omp_task({{.*}}, ptr %[[TASK_ALLOC]])
2795-
// CHECK: define internal void @[[task_outlined_fn]]
2796-
// CHECK: %[[ARG_ALLOC:.*]] = alloca { ptr }, align 8
2797-
// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @[[parallel_outlined_fn:.+]], ptr %[[ARG_ALLOC]])
2796+
// CHECK: define internal void @[[task_outlined_fn]](i32 %[[GLOBAL_TID_VAL:.*]], ptr %[[STRUCT_ARG:.*]])
2797+
// CHECK: %[[LOADED_STRUCT_PTR:.*]] = load ptr, ptr %[[STRUCT_ARG]], align 8
2798+
// CHECK: %[[GEP_STRUCTARG:.*]] = getelementptr { ptr, ptr }, ptr %[[LOADED_STRUCT_PTR]], i32 0, i32 0
2799+
// CHECK: %[[LOADGEP_STRUCTARG:.*]] = load ptr, ptr %[[GEP_STRUCTARG]], align 8
2800+
// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @[[parallel_outlined_fn:.+]], ptr %[[LOADGEP_STRUCTARG]])
27982801
// CHECK: define internal void @[[parallel_outlined_fn]]
27992802
// -----
28002803

@@ -2820,27 +2823,18 @@ llvm.func @task(%arg0 : !llvm.ptr) {
28202823
llvm.return
28212824
}
28222825
// CHECK-LABEL: @task..omp_par
2823-
// CHECK: task.alloca:
2824-
// CHECK: %[[VAL_11:.*]] = load ptr, ptr %[[VAL_12:.*]], align 8
2825-
// CHECK: %[[VAL_13:.*]] = getelementptr { ptr }, ptr %[[VAL_11]], i32 0, i32 0
2826+
// CHECK: task.alloca:
2827+
// CHECK: %[[VAL_12:.*]] = load ptr, ptr %[[STRUCT_ARG:.*]], align 8
2828+
// CHECK: %[[VAL_13:.*]] = getelementptr { ptr }, ptr %[[VAL_12]], i32 0, i32 0
28262829
// CHECK: %[[VAL_14:.*]] = load ptr, ptr %[[VAL_13]], align 8
2827-
// CHECK: %[[VAL_15:.*]] = alloca i32, align 4
2828-
// CHECK: br label %omp.private.init
2829-
// CHECK: omp.private.init: ; preds = %task.alloca
2830-
// CHECK: br label %omp.private.copy
2831-
// CHECK: omp.private.copy: ; preds = %omp.private.init
2832-
// CHECK: %[[VAL_19:.*]] = load i32, ptr %[[VAL_14]], align 4
2833-
// CHECK: store i32 %[[VAL_19]], ptr %[[VAL_15]], align 4
2834-
// CHECK: br label %[[VAL_20:.*]]
2835-
// CHECK: [[VAL_20]]:
28362830
// CHECK: br label %task.body
2837-
// CHECK: task.body: ; preds = %[[VAL_20]]
2831+
// CHECK: task.body: ; preds = %task.alloca
28382832
// CHECK: br label %omp.task.region
28392833
// CHECK: omp.task.region: ; preds = %task.body
2840-
// CHECK: call void @foo(ptr %[[VAL_15]])
2834+
// CHECK: call void @foo(ptr %[[VAL_14]])
28412835
// CHECK: br label %omp.region.cont
28422836
// CHECK: omp.region.cont: ; preds = %omp.task.region
2843-
// CHECK: call void @destroy(ptr %[[VAL_15]])
2837+
// CHECK: call void @destroy(ptr %[[VAL_14]])
28442838
// CHECK: br label %task.exit.exitStub
28452839
// CHECK: task.exit.exitStub: ; preds = %omp.region.cont
28462840
// CHECK: ret void
@@ -2910,6 +2904,19 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr) {
29102904
// CHECK: br label %[[omp_region_cont:[^,]+]]
29112905
// CHECK: [[omp_taskgroup_region]]:
29122906
// CHECK: %{{.+}} = alloca i8, align 1
2907+
// CHECK: br label %[[omp_private_init:[^,]+]]
2908+
// CHECK: [[omp_private_init]]:
2909+
// CHECK: br label %[[omp_private_copy:[^,]+]]
2910+
// CHECK: [[omp_private_copy]]:
2911+
// CHECK: br label %[[omp_task_start:[^,]+]]
2912+
2913+
// CHECK: [[omp_region_cont:[^,]+]]:
2914+
// CHECK: br label %[[taskgroup_exit:[^,]+]]
2915+
// CHECK: [[taskgroup_exit]]:
2916+
// CHECK: call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[omp_global_thread_num]])
2917+
// CHECK: ret void
2918+
2919+
// CHECK: [[omp_task_start]]:
29132920
// CHECK: br label %[[codeRepl:[^,]+]]
29142921
// CHECK: [[codeRepl]]:
29152922
// CHECK: %[[omp_global_thread_num_t1:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
@@ -2933,11 +2940,6 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr) {
29332940
// CHECK: br label %[[task_exit3:[^,]+]]
29342941
// CHECK: [[task_exit3]]:
29352942
// CHECK: br label %[[omp_taskgroup_region1]]
2936-
// CHECK: [[omp_region_cont]]:
2937-
// CHECK: br label %[[taskgroup_exit:[^,]+]]
2938-
// CHECK: [[taskgroup_exit]]:
2939-
// CHECK: call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[omp_global_thread_num]])
2940-
// CHECK: ret void
29412943
// CHECK: }
29422944

29432945
// -----

0 commit comments

Comments
 (0)