Skip to content

[Flang] Add lowering support for depobj in depend clause #124523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ static mlir::omp::ClauseTaskDependAttr
genDependKindAttr(lower::AbstractConverter &converter,
const omp::clause::DependenceType kind) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();

mlir::omp::ClauseTaskDepend pbKind;
switch (kind) {
Expand All @@ -152,15 +151,15 @@ genDependKindAttr(lower::AbstractConverter &converter,
case omp::clause::DependenceType::Inout:
pbKind = mlir::omp::ClauseTaskDepend::taskdependinout;
break;
case omp::clause::DependenceType::Depobj:
pbKind = mlir::omp::ClauseTaskDepend::taskdependdepobj;
break;
case omp::clause::DependenceType::Mutexinoutset:
pbKind = mlir::omp::ClauseTaskDepend::taskdependmutexinoutset;
break;
case omp::clause::DependenceType::Inoutset:
pbKind = mlir::omp::ClauseTaskDepend::taskdependinoutset;
break;
case omp::clause::DependenceType::Depobj:
TODO(currentLocation, "DEPOBJ dependence-type");
break;
case omp::clause::DependenceType::Sink:
case omp::clause::DependenceType::Source:
llvm_unreachable("unhandled parser task dependence type");
Expand Down
10 changes: 0 additions & 10 deletions flang/test/Lower/OpenMP/Todo/depend-clause-depobj.f90

This file was deleted.

8 changes: 7 additions & 1 deletion flang/test/Lower/OpenMP/task.f90
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,18 @@ subroutine task_depend_multi_task()
x = x - 12
!CHECK: omp.terminator
!$omp end task
!CHECK: omp.task depend(taskdependinoutset -> %{{.+}} : !fir.ref<i32>)
!CHECK: omp.task depend(taskdependinoutset -> %{{.+}} : !fir.ref<i32>)
!$omp task depend(inoutset : x)
!CHECK: arith.subi
x = x - 12
!CHECK: omp.terminator
!$omp end task
!CHECK: omp.task depend(taskdependdepobj -> %{{.+}} : !fir.ref<i32>)
!$omp task depend(depobj: obj)
! CHECK: arith.addi
x = x + 73
! CHECK: omp.terminator
!$omp end task
end subroutine task_depend_multi_task

!===============================================================================
Expand Down
6 changes: 4 additions & 2 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1241,10 +1241,12 @@ class OpenMPIRBuilder {
omp::RTLDependenceKindTy DepKind = omp::RTLDependenceKindTy::DepUnknown;
Type *DepValueType;
Value *DepVal;
bool isTypeDepObj;
explicit DependData() = default;
DependData(omp::RTLDependenceKindTy DepKind, Type *DepValueType,
Value *DepVal)
: DepKind(DepKind), DepValueType(DepValueType), DepVal(DepVal) {}
Value *DepVal, bool isTypeDepObj = false)
: DepKind(DepKind), DepValueType(DepValueType), DepVal(DepVal),
isTypeDepObj(isTypeDepObj) {}
};

/// Generator for `#omp task`
Expand Down
80 changes: 70 additions & 10 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2049,19 +2049,61 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
Builder.CreateStore(Priority, CmplrData);
}

Value *DepArray = nullptr;
Value *DepAlloca = nullptr;
Value *stackSave = nullptr;
Value *depSize = Builder.getInt32(Dependencies.size());
if (Dependencies.size()) {
InsertPointTy OldIP = Builder.saveIP();
Builder.SetInsertPoint(
&OldIP.getBlock()->getParent()->getEntryBlock().back());

Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
DepArray = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");

// Used to keep a count of other dependence type apart from DEPOBJ
size_t otherDepTypeCount = 0;
SmallVector<Value *> objsVal;
// Load all the value of DEPOBJ object from omp_depend_t object
for (const DependData &dep : Dependencies) {
if (dep.isTypeDepObj) {
Value *loadDepVal = Builder.CreateLoad(VoidPtr, dep.DepVal);
Value *depValGEP =
Builder.CreateGEP(DependInfo, loadDepVal, Builder.getInt64(-1));
Value *obj =
Builder.CreateConstInBoundsGEP2_64(DependInfo, depValGEP, 0, 0);
Value *objVal = Builder.CreateLoad(Builder.getInt64Ty(), obj);
objsVal.push_back(objVal);
} else {
otherDepTypeCount++;
}
}

// Add all the values and use it as the size for DependInfo alloca
if (objsVal.size() > 0) {
depSize = objsVal[0];
for (size_t i = 1; i < objsVal.size(); i++)
depSize = Builder.CreateAdd(depSize, objsVal[i]);
if (otherDepTypeCount > 0)
depSize =
Builder.CreateAdd(depSize, Builder.getInt64(otherDepTypeCount));
}

if (!isa<ConstantInt>(depSize)) {
// stackSave to save the stack pointer
if (!stackSave)
stackSave = Builder.CreateStackSave();
DepAlloca = Builder.CreateAlloca(DependInfo, depSize, "dep.addr");
((AllocaInst *)DepAlloca)->setAlignment(Align(16));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 16?

Copy link
Member Author

@Thirumalai-Shaktivel Thirumalai-Shaktivel Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I referred to the omp_depend_kind value as 16 from gfortran (I thought it was according to standards).
During the implementation, I didn't check Flang, my bad. It seems the kind value is compiler-specific as Flang uses the kind value as 8.

Maybe we can keep it as align 8, what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So your intention is for the alignment to match the size of the type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, integer(omp_depend_kind).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I think it would be best to programatically get the type size in bytes and use that. Something like this (I haven't tried building it)

Suggested change
((AllocaInst *)DepAlloca)->setAlignment(Align(16));
llvm::Constant *size = llvm::ConstantExpr::getSizeOf(DependInfo);
((AllocaInst *)DepAlloca)->setAlignment(Align(size));

Also please use C++ style casts: https://llvm.org/docs/CodingStandards.html#prefer-c-style-casts

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, I will push the required changes soon.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have strong reason keeping the alignment 16, so I thought to remove it.

KmpDependInfoArrayTy =
C.getVariableArrayType(KmpDependInfoTy, OVE, ArraySizeModifier::Normal,
/*IndexTypeQuals=*/0, SourceRange(Loc, Loc));
// CGF.EmitVariablyModifiedType(KmpDependInfoArrayTy);
// Properly emit variable-sized array.
auto *PD = ImplicitParamDecl::Create(C, KmpDependInfoArrayTy,
ImplicitParamKind::Other);
CGF.EmitVarDecl(*PD);

In clang, the alignment value is taking from PD, which is created using KmpDependInfoArrayTy

depSize = Builder.CreateTrunc(depSize, Builder.getInt32Ty());
} else {
DepAlloca = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
}

unsigned P = 0;
for (const DependData &Dep : Dependencies) {
if (Dep.isTypeDepObj)
continue;
Value *Base =
Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P);
Builder.CreateGEP(DependInfo, DepAlloca, Builder.getInt64(P));
// Store the pointer to the variable
Value *Addr = Builder.CreateStructGEP(
DependInfo, Base,
Expand All @@ -2087,6 +2129,23 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
++P;
}

P = 0;
Value *depAllocaIdx = Builder.getInt64(otherDepTypeCount);
for (const DependData &dep : Dependencies) {
if (dep.isTypeDepObj) {
Value *depAllocaPtr =
Builder.CreateGEP(DependInfo, DepAlloca, depAllocaIdx);
Align alignment = Align(8);
Value *loadDepVal = Builder.CreateLoad(VoidPtr, dep.DepVal);
Value *memCpySize =
Builder.CreateMul(Builder.getInt64(24), objsVal[P]);
Builder.CreateMemCpy(depAllocaPtr, alignment, loadDepVal, alignment,
memCpySize);
depAllocaIdx = Builder.CreateAdd(depAllocaIdx, objsVal[P]);
++P;
}
}

Builder.restoreIP(OldIP);
}

Expand Down Expand Up @@ -2124,7 +2183,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
Builder.CreateCall(
TaskWaitFn,
{Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepArray,
{Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepAlloca,
ConstantInt::get(Builder.getInt32Ty(), 0),
ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
}
Expand All @@ -2146,12 +2205,13 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
if (Dependencies.size()) {
Function *TaskFn =
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
Builder.CreateCall(
TaskFn,
{Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});

Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData, depSize, DepAlloca,
ConstantInt::get(Builder.getInt32Ty(), 0),
ConstantPointerNull::get(
PointerType::getUnqual(M.getContext()))});
// stackSave is used by depend(depobj: x) clause to save the stack pointer
if (stackSave)
Builder.CreateStackRestore(stackSave);
} else {
// Emit the @__kmpc_omp_task runtime call to spawn the task
Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
Expand Down
10 changes: 6 additions & 4 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,14 @@ def ClauseTaskDependInOut : I32EnumAttrCase<"taskdependinout", 2>;
def ClauseTaskDependMutexInOutSet
: I32EnumAttrCase<"taskdependmutexinoutset", 3>;
def ClauseTaskDependInOutSet : I32EnumAttrCase<"taskdependinoutset", 4>;
def ClauseTaskDependDepObj : I32EnumAttrCase<"taskdependdepobj", 5>;

def ClauseTaskDepend
: OpenMP_I32EnumAttr<
"ClauseTaskDepend", "depend clause in a target or task construct",
[ClauseTaskDependIn, ClauseTaskDependOut, ClauseTaskDependInOut,
ClauseTaskDependMutexInOutSet, ClauseTaskDependInOutSet]>;
: OpenMP_I32EnumAttr<"ClauseTaskDepend",
"depend clause in a target or task construct",
[ClauseTaskDependIn, ClauseTaskDependOut,
ClauseTaskDependInOut, ClauseTaskDependMutexInOutSet,
ClauseTaskDependInOutSet, ClauseTaskDependDepObj]>;

def ClauseTaskDependAttr : OpenMP_EnumAttr<ClauseTaskDepend,
"clause_task_depend"> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1715,6 +1715,7 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
return;
for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
llvm::omp::RTLDependenceKindTy type;
bool isTypeDepObj = false;
switch (
cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
case mlir::omp::ClauseTaskDepend::taskdependin:
Expand All @@ -1733,9 +1734,13 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
case mlir::omp::ClauseTaskDepend::taskdependinoutset:
type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
break;
case mlir::omp::ClauseTaskDepend::taskdependdepobj:
isTypeDepObj = true;
break;
};
llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal,
isTypeDepObj);
dds.emplace_back(dd);
}
}
Expand Down
67 changes: 62 additions & 5 deletions mlir/test/Target/LLVMIR/openmp-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2653,7 +2653,7 @@ llvm.func @omp_task_attrs() -> () attributes {
// CHECK-LABEL: define void @omp_task_with_deps
// CHECK-SAME: (ptr %[[zaddr:.+]])
// CHECK: %[[dep_arr_addr:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
// CHECK: %[[dep_arr_addr_0:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[dep_arr_addr]], i64 0, i64 0
// CHECK: %[[dep_arr_addr_0:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_arr_addr]], i64 0
// CHECK: %[[dep_arr_addr_0_val:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[dep_arr_addr_0]], i32 0, i32 0
// CHECK: %[[dep_arr_addr_0_val_int:.+]] = ptrtoint ptr %0 to i64
// CHECK: store i64 %[[dep_arr_addr_0_val_int]], ptr %[[dep_arr_addr_0_val]], align 4
Expand All @@ -2664,28 +2664,28 @@ llvm.func @omp_task_attrs() -> () attributes {
// -----
// dependence_type: Out
// CHECK: %[[DEP_ARR_ADDR1:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
// CHECK: %[[DEP_ARR_ADDR_1:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR1]], i64 0, i64 0
// CHECK: %[[DEP_ARR_ADDR_1:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR1]], i64 0
// [...]
// CHECK: %[[DEP_TYPE_1:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_1]], i32 0, i32 2
// CHECK: store i8 3, ptr %[[DEP_TYPE_1]], align 1
// -----
// dependence_type: Inout
// CHECK: %[[DEP_ARR_ADDR2:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
// CHECK: %[[DEP_ARR_ADDR_2:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR2]], i64 0, i64 0
// CHECK: %[[DEP_ARR_ADDR_2:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR2]], i64 0
// [...]
// CHECK: %[[DEP_TYPE_2:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_2]], i32 0, i32 2
// CHECK: store i8 3, ptr %[[DEP_TYPE_2]], align 1
// -----
// dependence_type: Mutexinoutset
// CHECK: %[[DEP_ARR_ADDR3:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
// CHECK: %[[DEP_ARR_ADDR_3:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR3]], i64 0, i64 0
// CHECK: %[[DEP_ARR_ADDR_3:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR3]], i64 0
// [...]
// CHECK: %[[DEP_TYPE_3:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_3]], i32 0, i32 2
// CHECK: store i8 4, ptr %[[DEP_TYPE_3]], align 1
// -----
// dependence_type: Inoutset
// CHECK: %[[DEP_ARR_ADDR4:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
// CHECK: %[[DEP_ARR_ADDR_4:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR4]], i64 0, i64 0
// CHECK: %[[DEP_ARR_ADDR_4:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR4]], i64 0
// [...]
// CHECK: %[[DEP_TYPE_4:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_4]], i32 0, i32 2
// CHECK: store i8 8, ptr %[[DEP_TYPE_4]], align 1
Expand Down Expand Up @@ -2734,6 +2734,63 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr) {

// -----

// CHECK-LABEL: define void @omp_task_with_deps_02(ptr %0, ptr %1) {

// CHECK: %[[obj:.+]] = alloca i64, i64 1, align 8
// CHECK: %[[obj_load_01:.+]] = load ptr, ptr %[[obj]], align 8
// CHECK: %[[gep_01:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[obj_load_01]], i64 -1
// CHECK: %[[gep_02:.+]] = getelementptr inbounds %struct.kmp_dep_info, ptr %[[gep_01]], i64 0, i64 0
// CHECK: %[[obj_addr:.+]] = load i64, ptr %[[gep_02]], align 4

// CHECK: %[[size:.+]] = add i64 %[[obj_addr]], 2

// CHECK: %[[stack_ptr:.+]] = call ptr @llvm.stacksave.p0()
// CHECK: %[[dep_addr:.+]] = alloca %struct.kmp_dep_info, i64 %[[size]], align 16
// CHECK: %[[dep_size:.+]] = trunc i64 %[[size]] to i32

// CHECK: %[[gep_03:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_addr]], i64 0
// CHECK: %[[gep_04:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_03]], i32 0, i32 0
// CHECK: %[[arg_01_int:.+]] = ptrtoint ptr %0 to i64
// CHECK: store i64 %[[arg_01_int]], ptr %[[gep_04]], align 4
// CHECK: %[[gep_05:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_03]], i32 0, i32 1
// CHECK: store i64 8, ptr %[[gep_05]], align 4
// CHECK: %[[gep_06:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_03]], i32 0, i32 2
// CHECK: store i8 1, ptr %[[gep_06]], align 1

// CHECK: %[[gep_07:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_addr]], i64 1
// CHECK: %[[gep_08:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_07]], i32 0, i32 0
// CHECK: %[[arg_02_int:.+]] = ptrtoint ptr %1 to i64
// CHECK: store i64 %[[arg_02_int]], ptr %[[gep_08]], align 4
// CHECK: %[[gep_09:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_07]], i32 0, i32 1
// CHECK: store i64 8, ptr %[[gep_09]], align 4
// CHECK: %[[gep_10:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_07]], i32 0, i32 2
// CHECK: store i8 3, ptr %[[gep_10]], align 1

// CHECK: %[[gep_11:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_addr]], i64 2
// CHECK: %[[obj_load_02:.+]] = load ptr, ptr %[[obj]], align 8
// CHECK: %[[obj_size:.+]] = mul i64 24, %[[obj_addr]]
// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 8 %[[gep_11]], ptr align 8 %[[obj_load_02]], i64 %[[obj_size]], i1 false)
// CHECK: %[[dep_size_idx:.+]] = add i64 2, %[[obj_addr]]

// CHECK: %[[task:.+]] = call i32 @__kmpc_omp_task_with_deps({{.*}}, i32 %[[dep_size]], ptr %[[dep_addr]], i32 0, ptr null)
// CHECK: call void @llvm.stackrestore.p0(ptr %[[stack_ptr]])
// CHECK: }


llvm.func @omp_task_with_deps_02(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
%c_1 = llvm.mlir.constant(1 : i64) : i64
%1 = llvm.alloca %c_1 x i64 : (i64) -> !llvm.ptr
omp.task depend(taskdependin -> %arg0 : !llvm.ptr, taskdependdepobj -> %1 : !llvm.ptr, taskdependout -> %arg1 : !llvm.ptr) {
%4 = llvm.load %arg0 : !llvm.ptr -> i64
%5 = llvm.add %4, %c_1 : i64
llvm.store %5, %arg1 : i64, !llvm.ptr
omp.terminator
}
llvm.return
}

// -----

// CHECK-LABEL: define void @omp_task
// CHECK-SAME: (i32 %[[x:.+]], i32 %[[y:.+]], ptr %[[zaddr:.+]])
module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
Expand Down
Loading