Skip to content

Commit e2d13cb

Browse files
committed
[mlir][OpenMP] initialize (first)private variables before task exec
This still doesn't fix the memory safety issues because the stack allocations created here for the private variables might go out of scope. I will add a more complete lit test later in this patch series.
1 parent db40592 commit e2d13cb

File tree

2 files changed

+99
-54
lines changed

2 files changed

+99
-54
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 73 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,8 +1339,7 @@ findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
13391339
static llvm::Error initPrivateVar(
13401340
llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
13411341
omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1342-
llvm::SmallVectorImpl<llvm::Value *>::iterator llvmPrivateVarIt,
1343-
llvm::BasicBlock *privInitBlock,
1342+
llvm::Value **llvmPrivateVarIt, llvm::BasicBlock *privInitBlock,
13441343
llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
13451344
Region &initRegion = privDecl.getInitRegion();
13461345
if (initRegion.empty()) {
@@ -1771,31 +1770,82 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
17711770
for (mlir::Value privateVar : taskOp.getPrivateVars())
17721771
mlirPrivateVars.push_back(privateVar);
17731772

1774-
auto bodyCB = [&](InsertPointTy allocaIP,
1775-
InsertPointTy codegenIP) -> llvm::Error {
1776-
// Save the alloca insertion point on ModuleTranslation stack for use in
1777-
// nested regions.
1778-
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
1779-
moduleTranslation, allocaIP);
1773+
// Allocate and copy private variables before creating the task. This avoids
1774+
// accessing invalid memory if (after this scope ends) the private variables
1775+
// are initialized from host variables or if the variables are copied into
1776+
// from host variables (firstprivate). The insertion point is just before
1777+
// where the code for creating and scheduling the task will go. That puts this
1778+
// code outside of the outlined task region, which is what we want because
1779+
// this way the initialization and copy regions are executed immediately while
1780+
// the host variable data are still live.
17801781

1781-
llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
1782-
builder, moduleTranslation, privateBlockArgs, privateDecls,
1783-
mlirPrivateVars, llvmPrivateVars, allocaIP);
1784-
if (handleError(afterAllocas, *taskOp).failed())
1785-
return llvm::make_error<PreviouslyReportedError>();
1782+
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1783+
findAllocaInsertPoint(builder, moduleTranslation);
17861784

1787-
builder.restoreIP(codegenIP);
1788-
if (handleError(initPrivateVars(builder, moduleTranslation,
1789-
privateBlockArgs, privateDecls,
1790-
mlirPrivateVars, llvmPrivateVars),
1791-
*taskOp)
1792-
.failed())
1793-
return llvm::make_error<PreviouslyReportedError>();
1785+
// Not using splitBB() because that requires the current block to have a
1786+
// terminator.
1787+
assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
1788+
llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
1789+
builder.getContext(), "omp.task.start",
1790+
/*Parent=*/builder.GetInsertBlock()->getParent());
1791+
llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
1792+
builder.SetInsertPoint(branchToTaskStartBlock);
17941793

1795-
if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
1796-
llvmPrivateVars, privateDecls)))
1797-
return llvm::make_error<PreviouslyReportedError>();
1794+
// Now do this again to make the initialization and copy blocks
1795+
llvm::BasicBlock *copyBlock =
1796+
splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
1797+
llvm::BasicBlock *initBlock =
1798+
splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
1799+
1800+
// Now the control flow graph should look like
1801+
// starter_block:
1802+
// <---- where we started when convertOmpTaskOp was called
1803+
// br %omp.private.init
1804+
// omp.private.init:
1805+
// br %omp.private.copy
1806+
// omp.private.copy:
1807+
// br %omp.task.start
1808+
// omp.task.start:
1809+
// <---- where we want the insertion point to be when we call createTask()
1810+
1811+
// Save the alloca insertion point on ModuleTranslation stack for use in
1812+
// nested regions.
1813+
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
1814+
moduleTranslation, allocaIP);
1815+
1816+
// Allocate and initialize private variables
1817+
// TODO: package private variables up in a structure
1818+
for (auto [privDecl, mlirPrivVar, blockArg] :
1819+
llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs)) {
1820+
llvm::Type *llvmAllocType =
1821+
moduleTranslation.convertType(privDecl.getType());
1822+
1823+
// Allocations:
1824+
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1825+
llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1826+
llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1827+
1828+
builder.SetInsertPoint(initBlock->getTerminator());
1829+
auto err = initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
1830+
blockArg, &llvmPrivateVar, initBlock);
1831+
if (err)
1832+
return handleError(std::move(err), *taskOp.getOperation());
1833+
1834+
llvmPrivateVars.push_back(llvmPrivateVar);
1835+
}
1836+
1837+
// firstprivate copy region
1838+
builder.SetInsertPoint(copyBlock->getTerminator());
1839+
if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
1840+
llvmPrivateVars, privateDecls)))
1841+
return llvm::failure();
1842+
1843+
// Set up for call to createTask()
1844+
builder.SetInsertPoint(taskStartBlock);
17981845

1846+
auto bodyCB = [&](InsertPointTy allocaIP,
1847+
InsertPointTy codegenIP) -> llvm::Error {
1848+
builder.restoreIP(codegenIP);
17991849
// translate the body of the task:
18001850
auto continuationBlockOrError = convertOmpOpRegions(
18011851
taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
@@ -1815,8 +1865,6 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
18151865
buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
18161866
moduleTranslation, dds);
18171867

1818-
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1819-
findAllocaInsertPoint(builder, moduleTranslation);
18201868
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
18211869
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
18221870
moduleTranslation.getOpenMPBuilder()->createTask(

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2790,11 +2790,14 @@ llvm.func @par_task_(%arg0: !llvm.ptr {fir.bindc_name = "a"}) {
27902790
}
27912791

27922792
// CHECK-LABEL: @par_task_
2793+
// CHECK: %[[ARG_ALLOC:.*]] = alloca { ptr }, align 8
27932794
// CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc({{.*}}ptr @[[task_outlined_fn:.+]])
27942795
// CHECK: call i32 @__kmpc_omp_task({{.*}}, ptr %[[TASK_ALLOC]])
2795-
// CHECK: define internal void @[[task_outlined_fn]]
2796-
// CHECK: %[[ARG_ALLOC:.*]] = alloca { ptr }, align 8
2797-
// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @[[parallel_outlined_fn:.+]], ptr %[[ARG_ALLOC]])
2796+
// CHECK: define internal void @[[task_outlined_fn]](i32 %[[GLOBAL_TID_VAL:.*]], ptr %[[STRUCT_ARG:.*]])
2797+
// CHECK: %[[LOADED_STRUCT_PTR:.*]] = load ptr, ptr %[[STRUCT_ARG]], align 8
2798+
// CHECK: %[[GEP_STRUCTARG:.*]] = getelementptr { ptr, ptr }, ptr %[[LOADED_STRUCT_PTR]], i32 0, i32 0
2799+
// CHECK: %[[LOADGEP_STRUCTARG:.*]] = load ptr, ptr %[[GEP_STRUCTARG]], align 8
2800+
// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @[[parallel_outlined_fn:.+]], ptr %[[LOADGEP_STRUCTARG]])
27982801
// CHECK: define internal void @[[parallel_outlined_fn]]
27992802
// -----
28002803

@@ -2820,32 +2823,18 @@ llvm.func @task(%arg0 : !llvm.ptr) {
28202823
llvm.return
28212824
}
28222825
// CHECK-LABEL: @task..omp_par
2823-
// CHECK: task.alloca:
2824-
// CHECK: %[[VAL_11:.*]] = load ptr, ptr %[[VAL_12:.*]], align 8
2825-
// CHECK: %[[VAL_13:.*]] = getelementptr { ptr }, ptr %[[VAL_11]], i32 0, i32 0
2826+
// CHECK: task.alloca:
2827+
// CHECK: %[[VAL_12:.*]] = load ptr, ptr %[[STRUCT_ARG:.*]], align 8
2828+
// CHECK: %[[VAL_13:.*]] = getelementptr { ptr }, ptr %[[VAL_12]], i32 0, i32 0
28262829
// CHECK: %[[VAL_14:.*]] = load ptr, ptr %[[VAL_13]], align 8
2827-
// CHECK: %[[VAL_15:.*]] = alloca i32, align 4
2828-
// CHECK: br label %omp.region.after_alloca
2829-
2830-
// CHECK: omp.region.after_alloca:
28312830
// CHECK: br label %task.body
2832-
2833-
// CHECK: task.body: ; preds = %omp.region.after_alloca
2834-
// CHECK: br label %omp.private.init
2835-
2836-
// CHECK: omp.private.init: ; preds = %task.body
2837-
// CHECK: br label %omp.private.copy
2838-
2839-
// CHECK: omp.private.copy: ; preds = %omp.private.init
2840-
// CHECK: %[[VAL_19:.*]] = load i32, ptr %[[VAL_14]], align 4
2841-
// CHECK: store i32 %[[VAL_19]], ptr %[[VAL_15]], align 4
2831+
// CHECK: task.body: ; preds = %task.alloca
28422832
// CHECK: br label %omp.task.region
2843-
2844-
// CHECK: omp.task.region: ; preds = %omp.private.copy
2845-
// CHECK: call void @foo(ptr %[[VAL_15]])
2833+
// CHECK: omp.task.region: ; preds = %task.body
2834+
// CHECK: call void @foo(ptr %[[VAL_14]])
28462835
// CHECK: br label %omp.region.cont
28472836
// CHECK: omp.region.cont: ; preds = %omp.task.region
2848-
// CHECK: call void @destroy(ptr %[[VAL_15]])
2837+
// CHECK: call void @destroy(ptr %[[VAL_14]])
28492838
// CHECK: br label %task.exit.exitStub
28502839
// CHECK: task.exit.exitStub: ; preds = %omp.region.cont
28512840
// CHECK: ret void
@@ -2915,6 +2904,19 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr) {
29152904
// CHECK: br label %[[omp_region_cont:[^,]+]]
29162905
// CHECK: [[omp_taskgroup_region]]:
29172906
// CHECK: %{{.+}} = alloca i8, align 1
2907+
// CHECK: br label %[[omp_private_init:[^,]+]]
2908+
// CHECK: [[omp_private_init]]:
2909+
// CHECK: br label %[[omp_private_copy:[^,]+]]
2910+
// CHECK: [[omp_private_copy]]:
2911+
// CHECK: br label %[[omp_task_start:[^,]+]]
2912+
2913+
// CHECK: [[omp_region_cont:[^,]+]]:
2914+
// CHECK: br label %[[taskgroup_exit:[^,]+]]
2915+
// CHECK: [[taskgroup_exit]]:
2916+
// CHECK: call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[omp_global_thread_num]])
2917+
// CHECK: ret void
2918+
2919+
// CHECK: [[omp_task_start]]:
29182920
// CHECK: br label %[[codeRepl:[^,]+]]
29192921
// CHECK: [[codeRepl]]:
29202922
// CHECK: %[[omp_global_thread_num_t1:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
@@ -2938,11 +2940,6 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr) {
29382940
// CHECK: br label %[[task_exit3:[^,]+]]
29392941
// CHECK: [[task_exit3]]:
29402942
// CHECK: br label %[[omp_taskgroup_region1]]
2941-
// CHECK: [[omp_region_cont]]:
2942-
// CHECK: br label %[[taskgroup_exit:[^,]+]]
2943-
// CHECK: [[taskgroup_exit]]:
2944-
// CHECK: call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[omp_global_thread_num]])
2945-
// CHECK: ret void
29462943
// CHECK: }
29472944

29482945
// -----

0 commit comments

Comments
 (0)