Skip to content

[mlir][OpenMP] initialize (first)private variables before task exec #125304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 73 additions & 25 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1339,8 +1339,7 @@ findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
static llvm::Error initPrivateVar(
llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
llvm::SmallVectorImpl<llvm::Value *>::iterator llvmPrivateVarIt,
llvm::BasicBlock *privInitBlock,
llvm::Value **llvmPrivateVarIt, llvm::BasicBlock *privInitBlock,
llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
Region &initRegion = privDecl.getInitRegion();
if (initRegion.empty()) {
Expand Down Expand Up @@ -1771,31 +1770,82 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
for (mlir::Value privateVar : taskOp.getPrivateVars())
mlirPrivateVars.push_back(privateVar);

auto bodyCB = [&](InsertPointTy allocaIP,
InsertPointTy codegenIP) -> llvm::Error {
// Save the alloca insertion point on ModuleTranslation stack for use in
// nested regions.
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
moduleTranslation, allocaIP);
// Allocate and copy private variables before creating the task. This avoids
// accessing invalid memory if (after this scope ends) the private variables
// are initialized from host variables or if the variables are copied into
// from host variables (firstprivate). The insertion point is just before
// where the code for creating and scheduling the task will go. That puts this
// code outside of the outlined task region, which is what we want because
// this way the initialization and copy regions are executed immediately while
// the host variable data are still live.

llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
builder, moduleTranslation, privateBlockArgs, privateDecls,
mlirPrivateVars, llvmPrivateVars, allocaIP);
if (handleError(afterAllocas, *taskOp).failed())
return llvm::make_error<PreviouslyReportedError>();
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);

builder.restoreIP(codegenIP);
if (handleError(initPrivateVars(builder, moduleTranslation,
privateBlockArgs, privateDecls,
mlirPrivateVars, llvmPrivateVars),
*taskOp)
.failed())
return llvm::make_error<PreviouslyReportedError>();
// Not using splitBB() because that requires the current block to have a
// terminator.
assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
builder.getContext(), "omp.task.start",
/*Parent=*/builder.GetInsertBlock()->getParent());
llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
builder.SetInsertPoint(branchToTaskStartBlock);

if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
llvmPrivateVars, privateDecls)))
return llvm::make_error<PreviouslyReportedError>();
// Now do this again to make the initialization and copy blocks
llvm::BasicBlock *copyBlock =
splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
llvm::BasicBlock *initBlock =
splitBB(builder, /*CreateBranch=*/true, "omp.private.init");

// Now the control flow graph should look like
// starter_block:
// <---- where we started when convertOmpTaskOp was called
// br %omp.private.init
// omp.private.init:
// br %omp.private.copy
// omp.private.copy:
// br %omp.task.start
// omp.task.start:
// <---- where we want the insertion point to be when we call createTask()

// Save the alloca insertion point on ModuleTranslation stack for use in
// nested regions.
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
moduleTranslation, allocaIP);

// Allocate and initialize private variables
// TODO: package private variables up in a structure
for (auto [privDecl, mlirPrivVar, blockArg] :
llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs)) {
llvm::Type *llvmAllocType =
moduleTranslation.convertType(privDecl.getType());

// Allocations:
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
llvm::Value *llvmPrivateVar = builder.CreateAlloca(
llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");

builder.SetInsertPoint(initBlock->getTerminator());
auto err = initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
blockArg, &llvmPrivateVar, initBlock);
if (err)
return handleError(std::move(err), *taskOp.getOperation());

llvmPrivateVars.push_back(llvmPrivateVar);
}

// firstprivate copy region
builder.SetInsertPoint(copyBlock->getTerminator());
if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
llvmPrivateVars, privateDecls)))
return llvm::failure();

// Set up for call to createTask()
builder.SetInsertPoint(taskStartBlock);

auto bodyCB = [&](InsertPointTy allocaIP,
InsertPointTy codegenIP) -> llvm::Error {
builder.restoreIP(codegenIP);
// translate the body of the task:
auto continuationBlockOrError = convertOmpOpRegions(
taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
Expand All @@ -1815,8 +1865,6 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
moduleTranslation, dds);

llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTask(
Expand Down
72 changes: 43 additions & 29 deletions mlir/test/Target/LLVMIR/openmp-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2790,11 +2790,14 @@ llvm.func @par_task_(%arg0: !llvm.ptr {fir.bindc_name = "a"}) {
}

// CHECK-LABEL: @par_task_
// CHECK: %[[ARG_ALLOC:.*]] = alloca { ptr }, align 8
// CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc({{.*}}ptr @[[task_outlined_fn:.+]])
// CHECK: call i32 @__kmpc_omp_task({{.*}}, ptr %[[TASK_ALLOC]])
// CHECK: define internal void @[[task_outlined_fn]]
// CHECK: %[[ARG_ALLOC:.*]] = alloca { ptr }, align 8
// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @[[parallel_outlined_fn:.+]], ptr %[[ARG_ALLOC]])
// CHECK: define internal void @[[task_outlined_fn]](i32 %[[GLOBAL_TID_VAL:.*]], ptr %[[STRUCT_ARG:.*]])
// CHECK: %[[LOADED_STRUCT_PTR:.*]] = load ptr, ptr %[[STRUCT_ARG]], align 8
// CHECK: %[[GEP_STRUCTARG:.*]] = getelementptr { ptr, ptr }, ptr %[[LOADED_STRUCT_PTR]], i32 0, i32 0
// CHECK: %[[LOADGEP_STRUCTARG:.*]] = load ptr, ptr %[[GEP_STRUCTARG]], align 8
// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @[[parallel_outlined_fn:.+]], ptr %[[LOADGEP_STRUCTARG]])
// CHECK: define internal void @[[parallel_outlined_fn]]
// -----

Expand All @@ -2819,33 +2822,36 @@ llvm.func @task(%arg0 : !llvm.ptr) {
}
llvm.return
}
// CHECK-LABEL: @task
// CHECK-SAME: (ptr %[[ARG:.*]])
// CHECK: %[[STRUCT_ARG:.*]] = alloca { ptr }, align 8
// CHECK: %[[OMP_PRIVATE_ALLOC:.*]] = alloca i32, align 4
// ...
// CHECK: br label %omp.private.init
// CHECK: omp.private.init:
// CHECK: br label %omp.private.copy1
// CHECK: omp.private.copy1:
// CHECK: %[[LOADED:.*]] = load i32, ptr %[[ARG]], align 4
// CHECK: store i32 %[[LOADED]], ptr %[[OMP_PRIVATE_ALLOC]], align 4
// ...
// CHECK: br label %omp.task.start
// CHECK: omp.task.start:
// CHECK: br label %[[CODEREPL:.*]]
// CHECK: [[CODEREPL]]:

// CHECK-LABEL: @task..omp_par
// CHECK: task.alloca:
// CHECK: %[[VAL_11:.*]] = load ptr, ptr %[[VAL_12:.*]], align 8
// CHECK: %[[VAL_13:.*]] = getelementptr { ptr }, ptr %[[VAL_11]], i32 0, i32 0
// CHECK: task.alloca:
// CHECK: %[[VAL_12:.*]] = load ptr, ptr %[[STRUCT_ARG:.*]], align 8
// CHECK: %[[VAL_13:.*]] = getelementptr { ptr }, ptr %[[VAL_12]], i32 0, i32 0
// CHECK: %[[VAL_14:.*]] = load ptr, ptr %[[VAL_13]], align 8
// CHECK: %[[VAL_15:.*]] = alloca i32, align 4
// CHECK: br label %omp.region.after_alloca

// CHECK: omp.region.after_alloca:
// CHECK: br label %task.body

// CHECK: task.body: ; preds = %omp.region.after_alloca
// CHECK: br label %omp.private.init

// CHECK: omp.private.init: ; preds = %task.body
// CHECK: br label %omp.private.copy

// CHECK: omp.private.copy: ; preds = %omp.private.init
// CHECK: %[[VAL_19:.*]] = load i32, ptr %[[VAL_14]], align 4
// CHECK: store i32 %[[VAL_19]], ptr %[[VAL_15]], align 4
// CHECK: task.body: ; preds = %task.alloca
// CHECK: br label %omp.task.region

// CHECK: omp.task.region: ; preds = %omp.private.copy
// CHECK: call void @foo(ptr %[[VAL_15]])
// CHECK: omp.task.region: ; preds = %task.body
// CHECK: call void @foo(ptr %[[VAL_14]])
// CHECK: br label %omp.region.cont
// CHECK: omp.region.cont: ; preds = %omp.task.region
// CHECK: call void @destroy(ptr %[[VAL_15]])
// CHECK: call void @destroy(ptr %[[VAL_14]])
// CHECK: br label %task.exit.exitStub
// CHECK: task.exit.exitStub: ; preds = %omp.region.cont
// CHECK: ret void
Expand Down Expand Up @@ -2915,6 +2921,19 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr) {
// CHECK: br label %[[omp_region_cont:[^,]+]]
// CHECK: [[omp_taskgroup_region]]:
// CHECK: %{{.+}} = alloca i8, align 1
// CHECK: br label %[[omp_private_init:[^,]+]]
// CHECK: [[omp_private_init]]:
// CHECK: br label %[[omp_private_copy:[^,]+]]
// CHECK: [[omp_private_copy]]:
// CHECK: br label %[[omp_task_start:[^,]+]]

// CHECK: [[omp_region_cont:[^,]+]]:
// CHECK: br label %[[taskgroup_exit:[^,]+]]
// CHECK: [[taskgroup_exit]]:
// CHECK: call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[omp_global_thread_num]])
// CHECK: ret void

// CHECK: [[omp_task_start]]:
// CHECK: br label %[[codeRepl:[^,]+]]
// CHECK: [[codeRepl]]:
// CHECK: %[[omp_global_thread_num_t1:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
Expand All @@ -2938,11 +2957,6 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr) {
// CHECK: br label %[[task_exit3:[^,]+]]
// CHECK: [[task_exit3]]:
// CHECK: br label %[[omp_taskgroup_region1]]
// CHECK: [[omp_region_cont]]:
// CHECK: br label %[[taskgroup_exit:[^,]+]]
// CHECK: [[taskgroup_exit]]:
// CHECK: call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[omp_global_thread_num]])
// CHECK: ret void
// CHECK: }

// -----
Expand Down