Skip to content

[Flang] Add lowering support for depobj in depend clause #124523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Thirumalai-Shaktivel
Copy link
Member

@Thirumalai-Shaktivel Thirumalai-Shaktivel commented Jan 27, 2025

From Documentation:
depobj: The task dependences are derived from the depend clause specified in the depobj constructs that initialized dependencies represented by the depend objects specified in the depend clause as if the depend clauses of the depobj constructs were specified in the current construct.

Implementation details:

  • The variable is of type omp_depend_kind and is used as a locator_list.
  • Access the base address of obj and compute the clause size, based on the obj value and other clauses count.
  • Allocate struct.kmp_dep_info with the size computed before.
  • Now, populate all the depend clauses information into the alloca. the other clauses info is added first in the index, 0, 1, 2, ... and then all the depobj clauses info.
  • Then, the alloca and size is passed as argument to __kmpc_omp_task_with_deps runtime.
  • Stacksave and Stackrestore are used to restore the stack pointer to the state before the depobj operations. Basically, Cleaning up the block.

TODO:
Requires depobj construct support for checking runtime results. Also, test debobj modify and destroy clauses

From Documentation:
depobj: The task dependences are derived from the depend clause
specified in the depobj constructs that initialized dependences
represented by the depend objects specified in the depend clause
as if the depend clauses of the depobj constructs were specified in
the current construct.

Implementation details:
- The variable is of type omp_depend_kind and is used as a locator_list.
- Access the base address of obj and compute the clause size, based on
  the obj value and other clauses count.
- Allocate struct.kmp_dep_info with the size computed before.
- Now, populate all the depend clauses information into the alloca.
  the other clauses info is added first in the index, 0, 1, 2, ... and
  then all the depobj clauses info.
- Then, the alloca and size is passed as argument to
  __kmpc_omp_task_with_deps runtime.
- `Stacksave` and `Stackrestore` is used to restore the stack pointer to
the state before the depobj operations. Basically removing all the
alloca's used.

TODO:
Requires depobj construct support for checking runtime results.
Also, test debobj modify and destroy clauses
@llvmbot
Copy link
Member

llvmbot commented Jan 27, 2025

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-flang-fir-hlfir

Author: Thirumalai Shaktivel (Thirumalai-Shaktivel)

Changes

From Documentation:
depobj: The task dependences are derived from the depend clause specified in the depobj constructs that initialized dependences represented by the depend objects specified in the depend clause as if the depend clauses of the depobj constructs were specified in the current construct.

Implementation details:

  • The variable is of type omp_depend_kind and is used as a locator_list.
  • Access the base address of obj and compute the clause size, based on the obj value and other clauses count.
  • Allocate struct.kmp_dep_info with the size computed before.
  • Now, populate all the depend clauses information into the alloca. the other clauses info is added first in the index, 0, 1, 2, ... and then all the depobj clauses info.
  • Then, the alloca and size is passed as argument to __kmpc_omp_task_with_deps runtime.
  • Stacksave and Stackrestore is used to restore the stack pointer to the state before the depobj operations. Basically removing all the alloca's used.

TODO:
Requires depobj construct support for checking runtime results. Also, test debobj modify and destroy clauses


Full diff: https://github.com/llvm/llvm-project/pull/124523.diff

8 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+3-4)
  • (removed) flang/test/Lower/OpenMP/Todo/depend-clause-depobj.f90 (-10)
  • (modified) flang/test/Lower/OpenMP/task.f90 (+7-1)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+4-2)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+70-10)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td (+6-4)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+6-1)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+62-5)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 299d9d438f1156..3378ea2fc2b414 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -139,7 +139,6 @@ static mlir::omp::ClauseTaskDependAttr
 genDependKindAttr(lower::AbstractConverter &converter,
                   const omp::clause::DependenceType kind) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  mlir::Location currentLocation = converter.getCurrentLocation();
 
   mlir::omp::ClauseTaskDepend pbKind;
   switch (kind) {
@@ -152,15 +151,15 @@ genDependKindAttr(lower::AbstractConverter &converter,
   case omp::clause::DependenceType::Inout:
     pbKind = mlir::omp::ClauseTaskDepend::taskdependinout;
     break;
+  case omp::clause::DependenceType::Depobj:
+    pbKind = mlir::omp::ClauseTaskDepend::taskdependdepobj;
+    break;
   case omp::clause::DependenceType::Mutexinoutset:
     pbKind = mlir::omp::ClauseTaskDepend::taskdependmutexinoutset;
     break;
   case omp::clause::DependenceType::Inoutset:
     pbKind = mlir::omp::ClauseTaskDepend::taskdependinoutset;
     break;
-  case omp::clause::DependenceType::Depobj:
-    TODO(currentLocation, "DEPOBJ dependence-type");
-    break;
   case omp::clause::DependenceType::Sink:
   case omp::clause::DependenceType::Source:
     llvm_unreachable("unhandled parser task dependence type");
diff --git a/flang/test/Lower/OpenMP/Todo/depend-clause-depobj.f90 b/flang/test/Lower/OpenMP/Todo/depend-clause-depobj.f90
deleted file mode 100644
index 4e98d77d0bb3e3..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/depend-clause-depobj.f90
+++ /dev/null
@@ -1,10 +0,0 @@
-!RUN: %not_todo_cmd bbc -emit-hlfir -fopenmp -fopenmp-version=52 -o - %s 2>&1 | FileCheck %s
-!RUN: %not_todo_cmd %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 -o - %s 2>&1 | FileCheck %s
-
-!CHECK: not yet implemented: DEPOBJ dependence-type
-
-subroutine f00(x)
-  integer :: x
-  !$omp task depend(depobj: x)
-  !$omp end task
-end
diff --git a/flang/test/Lower/OpenMP/task.f90 b/flang/test/Lower/OpenMP/task.f90
index 13ebf2acd91012..28d1b36a162a7e 100644
--- a/flang/test/Lower/OpenMP/task.f90
+++ b/flang/test/Lower/OpenMP/task.f90
@@ -150,12 +150,18 @@ subroutine task_depend_multi_task()
   x = x - 12
   !CHECK: omp.terminator
   !$omp end task
-    !CHECK: omp.task depend(taskdependinoutset -> %{{.+}} : !fir.ref<i32>)
+  !CHECK: omp.task depend(taskdependinoutset -> %{{.+}} : !fir.ref<i32>)
   !$omp task depend(inoutset : x)
   !CHECK: arith.subi
   x = x - 12
   !CHECK: omp.terminator
   !$omp end task
+  !CHECK: omp.task depend(taskdependdepobj -> %{{.+}} : !fir.ref<i32>)
+  !$omp task depend(depobj: obj)
+  ! CHECK: arith.addi
+      x = x + 73
+  ! CHECK:   omp.terminator
+  !$omp end task
 end subroutine task_depend_multi_task
 
 !===============================================================================
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 9802cbe8b7b943..2d996e5fe3554a 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1241,10 +1241,12 @@ class OpenMPIRBuilder {
     omp::RTLDependenceKindTy DepKind = omp::RTLDependenceKindTy::DepUnknown;
     Type *DepValueType;
     Value *DepVal;
+    bool isTypeDepObj;
     explicit DependData() = default;
     DependData(omp::RTLDependenceKindTy DepKind, Type *DepValueType,
-               Value *DepVal)
-        : DepKind(DepKind), DepValueType(DepValueType), DepVal(DepVal) {}
+               Value *DepVal, bool isTypeDepObj = false)
+        : DepKind(DepKind), DepValueType(DepValueType), DepVal(DepVal),
+          isTypeDepObj(isTypeDepObj) {}
   };
 
   /// Generator for `#omp task`
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 8cc3a99d92023d..476c0c80b985a9 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -2049,19 +2049,61 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
       Builder.CreateStore(Priority, CmplrData);
     }
 
-    Value *DepArray = nullptr;
+    Value *DepAlloca = nullptr;
+    Value *stackSave = nullptr;
+    Value *depSize = Builder.getInt32(Dependencies.size());
     if (Dependencies.size()) {
       InsertPointTy OldIP = Builder.saveIP();
       Builder.SetInsertPoint(
           &OldIP.getBlock()->getParent()->getEntryBlock().back());
 
       Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
-      DepArray = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
+
+      // Used to keep a count of other dependence type apart from DEPOBJ
+      size_t otherDepTypeCount = 0;
+      SmallVector<Value *> objsVal;
+      // Load all the value of DEPOBJ object from omp_depend_t object
+      for (const DependData &dep : Dependencies) {
+        if (dep.isTypeDepObj) {
+          Value *loadDepVal = Builder.CreateLoad(VoidPtr, dep.DepVal);
+          Value *depValGEP =
+              Builder.CreateGEP(DependInfo, loadDepVal, Builder.getInt64(-1));
+          Value *obj =
+              Builder.CreateConstInBoundsGEP2_64(DependInfo, depValGEP, 0, 0);
+          Value *objVal = Builder.CreateLoad(Builder.getInt64Ty(), obj);
+          objsVal.push_back(objVal);
+        } else {
+          otherDepTypeCount++;
+        }
+      }
+
+      // Add all the values and use it as the size for DependInfo alloca
+      if (objsVal.size() > 0) {
+        depSize = objsVal[0];
+        for (size_t i = 1; i < objsVal.size(); i++)
+          depSize = Builder.CreateAdd(depSize, objsVal[i]);
+        if (otherDepTypeCount > 0)
+          depSize =
+              Builder.CreateAdd(depSize, Builder.getInt64(otherDepTypeCount));
+      }
+
+      if (!isa<ConstantInt>(depSize)) {
+        // stackSave to save the stack pointer
+        if (!stackSave)
+          stackSave = Builder.CreateStackSave();
+        DepAlloca = Builder.CreateAlloca(DependInfo, depSize, "dep.addr");
+        ((AllocaInst *)DepAlloca)->setAlignment(Align(16));
+        depSize = Builder.CreateTrunc(depSize, Builder.getInt32Ty());
+      } else {
+        DepAlloca = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
+      }
 
       unsigned P = 0;
       for (const DependData &Dep : Dependencies) {
+        if (Dep.isTypeDepObj)
+          continue;
         Value *Base =
-            Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P);
+            Builder.CreateGEP(DependInfo, DepAlloca, Builder.getInt64(P));
         // Store the pointer to the variable
         Value *Addr = Builder.CreateStructGEP(
             DependInfo, Base,
@@ -2087,6 +2129,23 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
         ++P;
       }
 
+      P = 0;
+      Value *depAllocaIdx = Builder.getInt64(otherDepTypeCount);
+      for (const DependData &dep : Dependencies) {
+        if (dep.isTypeDepObj) {
+          Value *depAllocaPtr =
+              Builder.CreateGEP(DependInfo, DepAlloca, depAllocaIdx);
+          Align alignment = Align(8);
+          Value *loadDepVal = Builder.CreateLoad(VoidPtr, dep.DepVal);
+          Value *memCpySize =
+              Builder.CreateMul(Builder.getInt64(24), objsVal[P]);
+          Builder.CreateMemCpy(depAllocaPtr, alignment, loadDepVal, alignment,
+                               memCpySize);
+          depAllocaIdx = Builder.CreateAdd(depAllocaIdx, objsVal[P]);
+          ++P;
+        }
+      }
+
       Builder.restoreIP(OldIP);
     }
 
@@ -2124,7 +2183,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
             getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
         Builder.CreateCall(
             TaskWaitFn,
-            {Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepArray,
+            {Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepAlloca,
              ConstantInt::get(Builder.getInt32Ty(), 0),
              ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
       }
@@ -2146,12 +2205,13 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
     if (Dependencies.size()) {
       Function *TaskFn =
           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
-      Builder.CreateCall(
-          TaskFn,
-          {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
-           DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
-           ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
-
+      Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData, depSize, DepAlloca,
+                                  ConstantInt::get(Builder.getInt32Ty(), 0),
+                                  ConstantPointerNull::get(
+                                      PointerType::getUnqual(M.getContext()))});
+      // stackSave is used by depend(depobj: x) clause to save the stack pointer
+      if (stackSave)
+        Builder.CreateStackRestore(stackSave);
     } else {
       // Emit the @__kmpc_omp_task runtime call to spawn the task
       Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 690e3df1f685e3..bbe1174775184d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -111,12 +111,14 @@ def ClauseTaskDependInOut : I32EnumAttrCase<"taskdependinout", 2>;
 def ClauseTaskDependMutexInOutSet
     : I32EnumAttrCase<"taskdependmutexinoutset", 3>;
 def ClauseTaskDependInOutSet : I32EnumAttrCase<"taskdependinoutset", 4>;
+def ClauseTaskDependDepObj : I32EnumAttrCase<"taskdependdepobj", 5>;
 
 def ClauseTaskDepend
-    : OpenMP_I32EnumAttr<
-          "ClauseTaskDepend", "depend clause in a target or task construct",
-          [ClauseTaskDependIn, ClauseTaskDependOut, ClauseTaskDependInOut,
-           ClauseTaskDependMutexInOutSet, ClauseTaskDependInOutSet]>;
+    : OpenMP_I32EnumAttr<"ClauseTaskDepend",
+                         "depend clause in a target or task construct",
+                         [ClauseTaskDependIn, ClauseTaskDependOut,
+                          ClauseTaskDependInOut, ClauseTaskDependMutexInOutSet,
+                          ClauseTaskDependInOutSet, ClauseTaskDependDepObj]>;
 
 def ClauseTaskDependAttr : OpenMP_EnumAttr<ClauseTaskDepend,
                                            "clause_task_depend"> {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3fcdefa8a2f673..de4bd108fff675 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1715,6 +1715,7 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
     return;
   for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
     llvm::omp::RTLDependenceKindTy type;
+    bool isTypeDepObj = false;
     switch (
         cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
     case mlir::omp::ClauseTaskDepend::taskdependin:
@@ -1733,9 +1734,13 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
     case mlir::omp::ClauseTaskDepend::taskdependinoutset:
       type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
       break;
+    case mlir::omp::ClauseTaskDepend::taskdependdepobj:
+      isTypeDepObj = true;
+      break;
     };
     llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
-    llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
+    llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal,
+                                         isTypeDepObj);
     dds.emplace_back(dd);
   }
 }
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 9868ef227d49e0..b8ae3d0bec2c88 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2653,7 +2653,7 @@ llvm.func @omp_task_attrs() -> () attributes {
 // CHECK-LABEL: define void @omp_task_with_deps
 // CHECK-SAME: (ptr %[[zaddr:.+]])
 // CHECK:  %[[dep_arr_addr:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
-// CHECK:  %[[dep_arr_addr_0:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[dep_arr_addr]], i64 0, i64 0
+// CHECK:  %[[dep_arr_addr_0:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_arr_addr]], i64 0
 // CHECK:  %[[dep_arr_addr_0_val:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[dep_arr_addr_0]], i32 0, i32 0
 // CHECK:  %[[dep_arr_addr_0_val_int:.+]] = ptrtoint ptr %0 to i64
 // CHECK:  store i64 %[[dep_arr_addr_0_val_int]], ptr %[[dep_arr_addr_0_val]], align 4
@@ -2664,28 +2664,28 @@ llvm.func @omp_task_attrs() -> () attributes {
 // -----
 // dependence_type: Out
 // CHECK:  %[[DEP_ARR_ADDR1:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
-// CHECK:  %[[DEP_ARR_ADDR_1:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR1]], i64 0, i64 0
+// CHECK:  %[[DEP_ARR_ADDR_1:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR1]], i64 0
 //         [...]
 // CHECK:  %[[DEP_TYPE_1:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_1]], i32 0, i32 2
 // CHECK:  store i8 3, ptr %[[DEP_TYPE_1]], align 1
 // -----
 // dependence_type: Inout
 // CHECK:  %[[DEP_ARR_ADDR2:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
-// CHECK:  %[[DEP_ARR_ADDR_2:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR2]], i64 0, i64 0
+// CHECK:  %[[DEP_ARR_ADDR_2:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR2]], i64 0
 //         [...]
 // CHECK:  %[[DEP_TYPE_2:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_2]], i32 0, i32 2
 // CHECK:  store i8 3, ptr %[[DEP_TYPE_2]], align 1
 // -----
 // dependence_type: Mutexinoutset
 // CHECK:  %[[DEP_ARR_ADDR3:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
-// CHECK:  %[[DEP_ARR_ADDR_3:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR3]], i64 0, i64 0
+// CHECK:  %[[DEP_ARR_ADDR_3:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR3]], i64 0
 //         [...]
 // CHECK:  %[[DEP_TYPE_3:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_3]], i32 0, i32 2
 // CHECK:  store i8 4, ptr %[[DEP_TYPE_3]], align 1
 // -----
 // dependence_type: Inoutset
 // CHECK:  %[[DEP_ARR_ADDR4:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
-// CHECK:  %[[DEP_ARR_ADDR_4:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR4]], i64 0, i64 0
+// CHECK:  %[[DEP_ARR_ADDR_4:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR4]], i64 0
 //         [...]
 // CHECK:  %[[DEP_TYPE_4:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_4]], i32 0, i32 2
 // CHECK:  store i8 8, ptr %[[DEP_TYPE_4]], align 1
@@ -2734,6 +2734,63 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr) {
 
 // -----
 
+// CHECK-LABEL: define void @omp_task_with_deps_02(ptr %0, ptr %1) {
+
+// CHECK:   %[[obj:.+]] = alloca i64, i64 1, align 8
+// CHECK:   %[[obj_load_01:.+]] = load ptr, ptr %[[obj]], align 8
+// CHECK:   %[[gep_01:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[obj_load_01]], i64 -1
+// CHECK:   %[[gep_02:.+]] = getelementptr inbounds %struct.kmp_dep_info, ptr %[[gep_01]], i64 0, i64 0
+// CHECK:   %[[obj_addr:.+]] = load i64, ptr %[[gep_02]], align 4
+
+// CHECK:   %[[size:.+]] = add i64 %[[obj_addr]], 2
+
+// CHECK:   %[[stack_ptr:.+]] = call ptr @llvm.stacksave.p0()
+// CHECK:   %[[dep_addr:.+]] = alloca %struct.kmp_dep_info, i64 %[[size]], align 16
+// CHECK:   %[[dep_size:.+]] = trunc i64 %[[size]] to i32
+
+// CHECK:   %[[gep_03:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_addr]], i64 0
+// CHECK:   %[[gep_04:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_03]], i32 0, i32 0
+// CHECK:   %[[arg_01_int:.+]] = ptrtoint ptr %0 to i64
+// CHECK:   store i64 %[[arg_01_int]], ptr %[[gep_04]], align 4
+// CHECK:   %[[gep_05:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_03]], i32 0, i32 1
+// CHECK:   store i64 8, ptr %[[gep_05]], align 4
+// CHECK:   %[[gep_06:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_03]], i32 0, i32 2
+// CHECK:   store i8 1, ptr %[[gep_06]], align 1
+
+// CHECK:   %[[gep_07:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_addr]], i64 1
+// CHECK:   %[[gep_08:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_07]], i32 0, i32 0
+// CHECK:   %[[arg_02_int:.+]] = ptrtoint ptr %1 to i64
+// CHECK:   store i64 %[[arg_02_int]], ptr %[[gep_08]], align 4
+// CHECK:   %[[gep_09:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_07]], i32 0, i32 1
+// CHECK:   store i64 8, ptr %[[gep_09]], align 4
+// CHECK:   %[[gep_10:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_07]], i32 0, i32 2
+// CHECK:   store i8 3, ptr %[[gep_10]], align 1
+
+// CHECK:   %[[gep_11:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_addr]], i64 2
+// CHECK:   %[[obj_load_02:.+]] = load ptr, ptr %[[obj]], align 8
+// CHECK:   %[[obj_size:.+]] = mul i64 24, %[[obj_addr]]
+// CHECK:   call void @llvm.memcpy.p0.p0.i64(ptr align 8 %[[gep_11]], ptr align 8 %[[obj_load_02]], i64 %[[obj_size]], i1 false)
+// CHECK:   %[[dep_size_idx:.+]] = add i64 2, %[[obj_addr]]
+
+// CHECK:   %[[task:.+]] = call i32 @__kmpc_omp_task_with_deps({{.*}}, i32 %[[dep_size]], ptr %[[dep_addr]], i32 0, ptr null)
+// CHECK:   call void @llvm.stackrestore.p0(ptr %[[stack_ptr]])
+// CHECK: }
+
+
+llvm.func @omp_task_with_deps_02(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+  %c_1 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %c_1 x i64 : (i64) -> !llvm.ptr
+  omp.task depend(taskdependin -> %arg0 : !llvm.ptr, taskdependdepobj -> %1 : !llvm.ptr, taskdependout -> %arg1 : !llvm.ptr) {
+    %4 = llvm.load %arg0 : !llvm.ptr -> i64
+    %5 = llvm.add %4, %c_1 : i64
+    llvm.store %5, %arg1 : i64, !llvm.ptr
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
 // CHECK-LABEL: define void @omp_task
 // CHECK-SAME: (i32 %[[x:.+]], i32 %[[y:.+]], ptr %[[zaddr:.+]])
 module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {

@llvmbot
Copy link
Member

llvmbot commented Jan 27, 2025

@llvm/pr-subscribers-flang-openmp

Author: Thirumalai Shaktivel (Thirumalai-Shaktivel)

Changes

From Documentation:
depobj: The task dependences are derived from the depend clause specified in the depobj constructs that initialized dependences represented by the depend objects specified in the depend clause as if the depend clauses of the depobj constructs were specified in the current construct.

Implementation details:

  • The variable is of type omp_depend_kind and is used as a locator_list.
  • Access the base address of obj and compute the clause size, based on the obj value and other clauses count.
  • Allocate struct.kmp_dep_info with the size computed before.
  • Now, populate all the depend clauses information into the alloca. the other clauses info is added first in the index, 0, 1, 2, ... and then all the depobj clauses info.
  • Then, the alloca and size is passed as argument to __kmpc_omp_task_with_deps runtime.
  • Stacksave and Stackrestore is used to restore the stack pointer to the state before the depobj operations. Basically removing all the alloca's used.

TODO:
Requires depobj construct support for checking runtime results. Also, test debobj modify and destroy clauses


Full diff: https://github.com/llvm/llvm-project/pull/124523.diff

8 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+3-4)
  • (removed) flang/test/Lower/OpenMP/Todo/depend-clause-depobj.f90 (-10)
  • (modified) flang/test/Lower/OpenMP/task.f90 (+7-1)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+4-2)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+70-10)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td (+6-4)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+6-1)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+62-5)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 299d9d438f1156..3378ea2fc2b414 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -139,7 +139,6 @@ static mlir::omp::ClauseTaskDependAttr
 genDependKindAttr(lower::AbstractConverter &converter,
                   const omp::clause::DependenceType kind) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  mlir::Location currentLocation = converter.getCurrentLocation();
 
   mlir::omp::ClauseTaskDepend pbKind;
   switch (kind) {
@@ -152,15 +151,15 @@ genDependKindAttr(lower::AbstractConverter &converter,
   case omp::clause::DependenceType::Inout:
     pbKind = mlir::omp::ClauseTaskDepend::taskdependinout;
     break;
+  case omp::clause::DependenceType::Depobj:
+    pbKind = mlir::omp::ClauseTaskDepend::taskdependdepobj;
+    break;
   case omp::clause::DependenceType::Mutexinoutset:
     pbKind = mlir::omp::ClauseTaskDepend::taskdependmutexinoutset;
     break;
   case omp::clause::DependenceType::Inoutset:
     pbKind = mlir::omp::ClauseTaskDepend::taskdependinoutset;
     break;
-  case omp::clause::DependenceType::Depobj:
-    TODO(currentLocation, "DEPOBJ dependence-type");
-    break;
   case omp::clause::DependenceType::Sink:
   case omp::clause::DependenceType::Source:
     llvm_unreachable("unhandled parser task dependence type");
diff --git a/flang/test/Lower/OpenMP/Todo/depend-clause-depobj.f90 b/flang/test/Lower/OpenMP/Todo/depend-clause-depobj.f90
deleted file mode 100644
index 4e98d77d0bb3e3..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/depend-clause-depobj.f90
+++ /dev/null
@@ -1,10 +0,0 @@
-!RUN: %not_todo_cmd bbc -emit-hlfir -fopenmp -fopenmp-version=52 -o - %s 2>&1 | FileCheck %s
-!RUN: %not_todo_cmd %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 -o - %s 2>&1 | FileCheck %s
-
-!CHECK: not yet implemented: DEPOBJ dependence-type
-
-subroutine f00(x)
-  integer :: x
-  !$omp task depend(depobj: x)
-  !$omp end task
-end
diff --git a/flang/test/Lower/OpenMP/task.f90 b/flang/test/Lower/OpenMP/task.f90
index 13ebf2acd91012..28d1b36a162a7e 100644
--- a/flang/test/Lower/OpenMP/task.f90
+++ b/flang/test/Lower/OpenMP/task.f90
@@ -150,12 +150,18 @@ subroutine task_depend_multi_task()
   x = x - 12
   !CHECK: omp.terminator
   !$omp end task
-    !CHECK: omp.task depend(taskdependinoutset -> %{{.+}} : !fir.ref<i32>)
+  !CHECK: omp.task depend(taskdependinoutset -> %{{.+}} : !fir.ref<i32>)
   !$omp task depend(inoutset : x)
   !CHECK: arith.subi
   x = x - 12
   !CHECK: omp.terminator
   !$omp end task
+  !CHECK: omp.task depend(taskdependdepobj -> %{{.+}} : !fir.ref<i32>)
+  !$omp task depend(depobj: obj)
+  ! CHECK: arith.addi
+      x = x + 73
+  ! CHECK:   omp.terminator
+  !$omp end task
 end subroutine task_depend_multi_task
 
 !===============================================================================
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 9802cbe8b7b943..2d996e5fe3554a 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1241,10 +1241,12 @@ class OpenMPIRBuilder {
     omp::RTLDependenceKindTy DepKind = omp::RTLDependenceKindTy::DepUnknown;
     Type *DepValueType;
     Value *DepVal;
+    bool isTypeDepObj;
     explicit DependData() = default;
     DependData(omp::RTLDependenceKindTy DepKind, Type *DepValueType,
-               Value *DepVal)
-        : DepKind(DepKind), DepValueType(DepValueType), DepVal(DepVal) {}
+               Value *DepVal, bool isTypeDepObj = false)
+        : DepKind(DepKind), DepValueType(DepValueType), DepVal(DepVal),
+          isTypeDepObj(isTypeDepObj) {}
   };
 
   /// Generator for `#omp task`
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 8cc3a99d92023d..476c0c80b985a9 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -2049,19 +2049,61 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
       Builder.CreateStore(Priority, CmplrData);
     }
 
-    Value *DepArray = nullptr;
+    Value *DepAlloca = nullptr;
+    Value *stackSave = nullptr;
+    Value *depSize = Builder.getInt32(Dependencies.size());
     if (Dependencies.size()) {
       InsertPointTy OldIP = Builder.saveIP();
       Builder.SetInsertPoint(
           &OldIP.getBlock()->getParent()->getEntryBlock().back());
 
       Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
-      DepArray = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
+
+      // Used to keep a count of other dependence type apart from DEPOBJ
+      size_t otherDepTypeCount = 0;
+      SmallVector<Value *> objsVal;
+      // Load all the value of DEPOBJ object from omp_depend_t object
+      for (const DependData &dep : Dependencies) {
+        if (dep.isTypeDepObj) {
+          Value *loadDepVal = Builder.CreateLoad(VoidPtr, dep.DepVal);
+          Value *depValGEP =
+              Builder.CreateGEP(DependInfo, loadDepVal, Builder.getInt64(-1));
+          Value *obj =
+              Builder.CreateConstInBoundsGEP2_64(DependInfo, depValGEP, 0, 0);
+          Value *objVal = Builder.CreateLoad(Builder.getInt64Ty(), obj);
+          objsVal.push_back(objVal);
+        } else {
+          otherDepTypeCount++;
+        }
+      }
+
+      // Add all the values and use it as the size for DependInfo alloca
+      if (objsVal.size() > 0) {
+        depSize = objsVal[0];
+        for (size_t i = 1; i < objsVal.size(); i++)
+          depSize = Builder.CreateAdd(depSize, objsVal[i]);
+        if (otherDepTypeCount > 0)
+          depSize =
+              Builder.CreateAdd(depSize, Builder.getInt64(otherDepTypeCount));
+      }
+
+      if (!isa<ConstantInt>(depSize)) {
+        // stackSave to save the stack pointer
+        if (!stackSave)
+          stackSave = Builder.CreateStackSave();
+        DepAlloca = Builder.CreateAlloca(DependInfo, depSize, "dep.addr");
+        ((AllocaInst *)DepAlloca)->setAlignment(Align(16));
+        depSize = Builder.CreateTrunc(depSize, Builder.getInt32Ty());
+      } else {
+        DepAlloca = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
+      }
 
       unsigned P = 0;
       for (const DependData &Dep : Dependencies) {
+        if (Dep.isTypeDepObj)
+          continue;
         Value *Base =
-            Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P);
+            Builder.CreateGEP(DependInfo, DepAlloca, Builder.getInt64(P));
         // Store the pointer to the variable
         Value *Addr = Builder.CreateStructGEP(
             DependInfo, Base,
@@ -2087,6 +2129,23 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
         ++P;
       }
 
+      P = 0;
+      Value *depAllocaIdx = Builder.getInt64(otherDepTypeCount);
+      for (const DependData &dep : Dependencies) {
+        if (dep.isTypeDepObj) {
+          Value *depAllocaPtr =
+              Builder.CreateGEP(DependInfo, DepAlloca, depAllocaIdx);
+          Align alignment = Align(8);
+          Value *loadDepVal = Builder.CreateLoad(VoidPtr, dep.DepVal);
+          Value *memCpySize =
+              Builder.CreateMul(Builder.getInt64(24), objsVal[P]);
+          Builder.CreateMemCpy(depAllocaPtr, alignment, loadDepVal, alignment,
+                               memCpySize);
+          depAllocaIdx = Builder.CreateAdd(depAllocaIdx, objsVal[P]);
+          ++P;
+        }
+      }
+
       Builder.restoreIP(OldIP);
     }
 
@@ -2124,7 +2183,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
             getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
         Builder.CreateCall(
             TaskWaitFn,
-            {Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepArray,
+            {Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepAlloca,
              ConstantInt::get(Builder.getInt32Ty(), 0),
              ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
       }
@@ -2146,12 +2205,13 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
     if (Dependencies.size()) {
       Function *TaskFn =
           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
-      Builder.CreateCall(
-          TaskFn,
-          {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
-           DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
-           ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
-
+      Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData, depSize, DepAlloca,
+                                  ConstantInt::get(Builder.getInt32Ty(), 0),
+                                  ConstantPointerNull::get(
+                                      PointerType::getUnqual(M.getContext()))});
+      // stackSave is used by depend(depobj: x) clause to save the stack pointer
+      if (stackSave)
+        Builder.CreateStackRestore(stackSave);
     } else {
       // Emit the @__kmpc_omp_task runtime call to spawn the task
       Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 690e3df1f685e3..bbe1174775184d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -111,12 +111,14 @@ def ClauseTaskDependInOut : I32EnumAttrCase<"taskdependinout", 2>;
 def ClauseTaskDependMutexInOutSet
     : I32EnumAttrCase<"taskdependmutexinoutset", 3>;
 def ClauseTaskDependInOutSet : I32EnumAttrCase<"taskdependinoutset", 4>;
+def ClauseTaskDependDepObj : I32EnumAttrCase<"taskdependdepobj", 5>;
 
 def ClauseTaskDepend
-    : OpenMP_I32EnumAttr<
-          "ClauseTaskDepend", "depend clause in a target or task construct",
-          [ClauseTaskDependIn, ClauseTaskDependOut, ClauseTaskDependInOut,
-           ClauseTaskDependMutexInOutSet, ClauseTaskDependInOutSet]>;
+    : OpenMP_I32EnumAttr<"ClauseTaskDepend",
+                         "depend clause in a target or task construct",
+                         [ClauseTaskDependIn, ClauseTaskDependOut,
+                          ClauseTaskDependInOut, ClauseTaskDependMutexInOutSet,
+                          ClauseTaskDependInOutSet, ClauseTaskDependDepObj]>;
 
 def ClauseTaskDependAttr : OpenMP_EnumAttr<ClauseTaskDepend,
                                            "clause_task_depend"> {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3fcdefa8a2f673..de4bd108fff675 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1715,6 +1715,7 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
     return;
   for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
     llvm::omp::RTLDependenceKindTy type;
+    bool isTypeDepObj = false;
     switch (
         cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
     case mlir::omp::ClauseTaskDepend::taskdependin:
@@ -1733,9 +1734,13 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
     case mlir::omp::ClauseTaskDepend::taskdependinoutset:
       type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
       break;
+    case mlir::omp::ClauseTaskDepend::taskdependdepobj:
+      isTypeDepObj = true;
+      break;
     };
     llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
-    llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
+    llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal,
+                                         isTypeDepObj);
     dds.emplace_back(dd);
   }
 }
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 9868ef227d49e0..b8ae3d0bec2c88 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2653,7 +2653,7 @@ llvm.func @omp_task_attrs() -> () attributes {
 // CHECK-LABEL: define void @omp_task_with_deps
 // CHECK-SAME: (ptr %[[zaddr:.+]])
 // CHECK:  %[[dep_arr_addr:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
-// CHECK:  %[[dep_arr_addr_0:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[dep_arr_addr]], i64 0, i64 0
+// CHECK:  %[[dep_arr_addr_0:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_arr_addr]], i64 0
 // CHECK:  %[[dep_arr_addr_0_val:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[dep_arr_addr_0]], i32 0, i32 0
 // CHECK:  %[[dep_arr_addr_0_val_int:.+]] = ptrtoint ptr %0 to i64
 // CHECK:  store i64 %[[dep_arr_addr_0_val_int]], ptr %[[dep_arr_addr_0_val]], align 4
@@ -2664,28 +2664,28 @@ llvm.func @omp_task_attrs() -> () attributes {
 // -----
 // dependence_type: Out
 // CHECK:  %[[DEP_ARR_ADDR1:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
-// CHECK:  %[[DEP_ARR_ADDR_1:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR1]], i64 0, i64 0
+// CHECK:  %[[DEP_ARR_ADDR_1:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR1]], i64 0
 //         [...]
 // CHECK:  %[[DEP_TYPE_1:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_1]], i32 0, i32 2
 // CHECK:  store i8 3, ptr %[[DEP_TYPE_1]], align 1
 // -----
 // dependence_type: Inout
 // CHECK:  %[[DEP_ARR_ADDR2:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
-// CHECK:  %[[DEP_ARR_ADDR_2:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR2]], i64 0, i64 0
+// CHECK:  %[[DEP_ARR_ADDR_2:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR2]], i64 0
 //         [...]
 // CHECK:  %[[DEP_TYPE_2:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_2]], i32 0, i32 2
 // CHECK:  store i8 3, ptr %[[DEP_TYPE_2]], align 1
 // -----
 // dependence_type: Mutexinoutset
 // CHECK:  %[[DEP_ARR_ADDR3:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
-// CHECK:  %[[DEP_ARR_ADDR_3:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR3]], i64 0, i64 0
+// CHECK:  %[[DEP_ARR_ADDR_3:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR3]], i64 0
 //         [...]
 // CHECK:  %[[DEP_TYPE_3:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_3]], i32 0, i32 2
 // CHECK:  store i8 4, ptr %[[DEP_TYPE_3]], align 1
 // -----
 // dependence_type: Inoutset
 // CHECK:  %[[DEP_ARR_ADDR4:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
-// CHECK:  %[[DEP_ARR_ADDR_4:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR4]], i64 0, i64 0
+// CHECK:  %[[DEP_ARR_ADDR_4:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR4]], i64 0
 //         [...]
 // CHECK:  %[[DEP_TYPE_4:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_4]], i32 0, i32 2
 // CHECK:  store i8 8, ptr %[[DEP_TYPE_4]], align 1
@@ -2734,6 +2734,63 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr) {
 
 // -----
 
+// CHECK-LABEL: define void @omp_task_with_deps_02(ptr %0, ptr %1) {
+
+// CHECK:   %[[obj:.+]] = alloca i64, i64 1, align 8
+// CHECK:   %[[obj_load_01:.+]] = load ptr, ptr %[[obj]], align 8
+// CHECK:   %[[gep_01:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[obj_load_01]], i64 -1
+// CHECK:   %[[gep_02:.+]] = getelementptr inbounds %struct.kmp_dep_info, ptr %[[gep_01]], i64 0, i64 0
+// CHECK:   %[[obj_addr:.+]] = load i64, ptr %[[gep_02]], align 4
+
+// CHECK:   %[[size:.+]] = add i64 %[[obj_addr]], 2
+
+// CHECK:   %[[stack_ptr:.+]] = call ptr @llvm.stacksave.p0()
+// CHECK:   %[[dep_addr:.+]] = alloca %struct.kmp_dep_info, i64 %[[size]], align 16
+// CHECK:   %[[dep_size:.+]] = trunc i64 %[[size]] to i32
+
+// CHECK:   %[[gep_03:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_addr]], i64 0
+// CHECK:   %[[gep_04:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_03]], i32 0, i32 0
+// CHECK:   %[[arg_01_int:.+]] = ptrtoint ptr %0 to i64
+// CHECK:   store i64 %[[arg_01_int]], ptr %[[gep_04]], align 4
+// CHECK:   %[[gep_05:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_03]], i32 0, i32 1
+// CHECK:   store i64 8, ptr %[[gep_05]], align 4
+// CHECK:   %[[gep_06:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_03]], i32 0, i32 2
+// CHECK:   store i8 1, ptr %[[gep_06]], align 1
+
+// CHECK:   %[[gep_07:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_addr]], i64 1
+// CHECK:   %[[gep_08:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_07]], i32 0, i32 0
+// CHECK:   %[[arg_02_int:.+]] = ptrtoint ptr %1 to i64
+// CHECK:   store i64 %[[arg_02_int]], ptr %[[gep_08]], align 4
+// CHECK:   %[[gep_09:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_07]], i32 0, i32 1
+// CHECK:   store i64 8, ptr %[[gep_09]], align 4
+// CHECK:   %[[gep_10:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_07]], i32 0, i32 2
+// CHECK:   store i8 3, ptr %[[gep_10]], align 1
+
+// CHECK:   %[[gep_11:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_addr]], i64 2
+// CHECK:   %[[obj_load_02:.+]] = load ptr, ptr %[[obj]], align 8
+// CHECK:   %[[obj_size:.+]] = mul i64 24, %[[obj_addr]]
+// CHECK:   call void @llvm.memcpy.p0.p0.i64(ptr align 8 %[[gep_11]], ptr align 8 %[[obj_load_02]], i64 %[[obj_size]], i1 false)
+// CHECK:   %[[dep_size_idx:.+]] = add i64 2, %[[obj_addr]]
+
+// CHECK:   %[[task:.+]] = call i32 @__kmpc_omp_task_with_deps({{.*}}, i32 %[[dep_size]], ptr %[[dep_addr]], i32 0, ptr null)
+// CHECK:   call void @llvm.stackrestore.p0(ptr %[[stack_ptr]])
+// CHECK: }
+
+
+llvm.func @omp_task_with_deps_02(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+  %c_1 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %c_1 x i64 : (i64) -> !llvm.ptr
+  omp.task depend(taskdependin -> %arg0 : !llvm.ptr, taskdependdepobj -> %1 : !llvm.ptr, taskdependout -> %arg1 : !llvm.ptr) {
+    %4 = llvm.load %arg0 : !llvm.ptr -> i64
+    %5 = llvm.add %4, %c_1 : i64
+    llvm.store %5, %arg1 : i64, !llvm.ptr
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
 // CHECK-LABEL: define void @omp_task
 // CHECK-SAME: (i32 %[[x:.+]], i32 %[[y:.+]], ptr %[[zaddr:.+]])
 module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {

@llvmbot
Copy link
Member

llvmbot commented Jan 27, 2025

@llvm/pr-subscribers-mlir-openmp

Author: Thirumalai Shaktivel (Thirumalai-Shaktivel)

Changes

From Documentation:
depobj: The task dependences are derived from the depend clause specified in the depobj constructs that initialized dependences represented by the depend objects specified in the depend clause as if the depend clauses of the depobj constructs were specified in the current construct.

Implementation details:

  • The variable is of type omp_depend_kind and is used as a locator_list.
  • Access the base address of obj and compute the clause size, based on the obj value and other clauses count.
  • Allocate struct.kmp_dep_info with the size computed before.
  • Now, populate all the depend clauses information into the alloca. the other clauses info is added first in the index, 0, 1, 2, ... and then all the depobj clauses info.
  • Then, the alloca and size is passed as argument to __kmpc_omp_task_with_deps runtime.
  • Stacksave and Stackrestore is used to restore the stack pointer to the state before the depobj operations. Basically removing all the alloca's used.

TODO:
Requires depobj construct support for checking runtime results. Also, test debobj modify and destroy clauses


Full diff: https://github.com/llvm/llvm-project/pull/124523.diff

8 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+3-4)
  • (removed) flang/test/Lower/OpenMP/Todo/depend-clause-depobj.f90 (-10)
  • (modified) flang/test/Lower/OpenMP/task.f90 (+7-1)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+4-2)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+70-10)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td (+6-4)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+6-1)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+62-5)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 299d9d438f1156..3378ea2fc2b414 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -139,7 +139,6 @@ static mlir::omp::ClauseTaskDependAttr
 genDependKindAttr(lower::AbstractConverter &converter,
                   const omp::clause::DependenceType kind) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  mlir::Location currentLocation = converter.getCurrentLocation();
 
   mlir::omp::ClauseTaskDepend pbKind;
   switch (kind) {
@@ -152,15 +151,15 @@ genDependKindAttr(lower::AbstractConverter &converter,
   case omp::clause::DependenceType::Inout:
     pbKind = mlir::omp::ClauseTaskDepend::taskdependinout;
     break;
+  case omp::clause::DependenceType::Depobj:
+    pbKind = mlir::omp::ClauseTaskDepend::taskdependdepobj;
+    break;
   case omp::clause::DependenceType::Mutexinoutset:
     pbKind = mlir::omp::ClauseTaskDepend::taskdependmutexinoutset;
     break;
   case omp::clause::DependenceType::Inoutset:
     pbKind = mlir::omp::ClauseTaskDepend::taskdependinoutset;
     break;
-  case omp::clause::DependenceType::Depobj:
-    TODO(currentLocation, "DEPOBJ dependence-type");
-    break;
   case omp::clause::DependenceType::Sink:
   case omp::clause::DependenceType::Source:
     llvm_unreachable("unhandled parser task dependence type");
diff --git a/flang/test/Lower/OpenMP/Todo/depend-clause-depobj.f90 b/flang/test/Lower/OpenMP/Todo/depend-clause-depobj.f90
deleted file mode 100644
index 4e98d77d0bb3e3..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/depend-clause-depobj.f90
+++ /dev/null
@@ -1,10 +0,0 @@
-!RUN: %not_todo_cmd bbc -emit-hlfir -fopenmp -fopenmp-version=52 -o - %s 2>&1 | FileCheck %s
-!RUN: %not_todo_cmd %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 -o - %s 2>&1 | FileCheck %s
-
-!CHECK: not yet implemented: DEPOBJ dependence-type
-
-subroutine f00(x)
-  integer :: x
-  !$omp task depend(depobj: x)
-  !$omp end task
-end
diff --git a/flang/test/Lower/OpenMP/task.f90 b/flang/test/Lower/OpenMP/task.f90
index 13ebf2acd91012..28d1b36a162a7e 100644
--- a/flang/test/Lower/OpenMP/task.f90
+++ b/flang/test/Lower/OpenMP/task.f90
@@ -150,12 +150,18 @@ subroutine task_depend_multi_task()
   x = x - 12
   !CHECK: omp.terminator
   !$omp end task
-    !CHECK: omp.task depend(taskdependinoutset -> %{{.+}} : !fir.ref<i32>)
+  !CHECK: omp.task depend(taskdependinoutset -> %{{.+}} : !fir.ref<i32>)
   !$omp task depend(inoutset : x)
   !CHECK: arith.subi
   x = x - 12
   !CHECK: omp.terminator
   !$omp end task
+  !CHECK: omp.task depend(taskdependdepobj -> %{{.+}} : !fir.ref<i32>)
+  !$omp task depend(depobj: obj)
+  ! CHECK: arith.addi
+      x = x + 73
+  ! CHECK:   omp.terminator
+  !$omp end task
 end subroutine task_depend_multi_task
 
 !===============================================================================
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 9802cbe8b7b943..2d996e5fe3554a 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1241,10 +1241,12 @@ class OpenMPIRBuilder {
     omp::RTLDependenceKindTy DepKind = omp::RTLDependenceKindTy::DepUnknown;
     Type *DepValueType;
     Value *DepVal;
+    bool isTypeDepObj;
     explicit DependData() = default;
     DependData(omp::RTLDependenceKindTy DepKind, Type *DepValueType,
-               Value *DepVal)
-        : DepKind(DepKind), DepValueType(DepValueType), DepVal(DepVal) {}
+               Value *DepVal, bool isTypeDepObj = false)
+        : DepKind(DepKind), DepValueType(DepValueType), DepVal(DepVal),
+          isTypeDepObj(isTypeDepObj) {}
   };
 
   /// Generator for `#omp task`
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 8cc3a99d92023d..476c0c80b985a9 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -2049,19 +2049,61 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
       Builder.CreateStore(Priority, CmplrData);
     }
 
-    Value *DepArray = nullptr;
+    Value *DepAlloca = nullptr;
+    Value *stackSave = nullptr;
+    Value *depSize = Builder.getInt32(Dependencies.size());
     if (Dependencies.size()) {
       InsertPointTy OldIP = Builder.saveIP();
       Builder.SetInsertPoint(
           &OldIP.getBlock()->getParent()->getEntryBlock().back());
 
       Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
-      DepArray = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
+
+      // Used to keep a count of other dependence type apart from DEPOBJ
+      size_t otherDepTypeCount = 0;
+      SmallVector<Value *> objsVal;
+      // Load all the value of DEPOBJ object from omp_depend_t object
+      for (const DependData &dep : Dependencies) {
+        if (dep.isTypeDepObj) {
+          Value *loadDepVal = Builder.CreateLoad(VoidPtr, dep.DepVal);
+          Value *depValGEP =
+              Builder.CreateGEP(DependInfo, loadDepVal, Builder.getInt64(-1));
+          Value *obj =
+              Builder.CreateConstInBoundsGEP2_64(DependInfo, depValGEP, 0, 0);
+          Value *objVal = Builder.CreateLoad(Builder.getInt64Ty(), obj);
+          objsVal.push_back(objVal);
+        } else {
+          otherDepTypeCount++;
+        }
+      }
+
+      // Add all the values and use it as the size for DependInfo alloca
+      if (objsVal.size() > 0) {
+        depSize = objsVal[0];
+        for (size_t i = 1; i < objsVal.size(); i++)
+          depSize = Builder.CreateAdd(depSize, objsVal[i]);
+        if (otherDepTypeCount > 0)
+          depSize =
+              Builder.CreateAdd(depSize, Builder.getInt64(otherDepTypeCount));
+      }
+
+      if (!isa<ConstantInt>(depSize)) {
+        // stackSave to save the stack pointer
+        if (!stackSave)
+          stackSave = Builder.CreateStackSave();
+        DepAlloca = Builder.CreateAlloca(DependInfo, depSize, "dep.addr");
+        ((AllocaInst *)DepAlloca)->setAlignment(Align(16));
+        depSize = Builder.CreateTrunc(depSize, Builder.getInt32Ty());
+      } else {
+        DepAlloca = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
+      }
 
       unsigned P = 0;
       for (const DependData &Dep : Dependencies) {
+        if (Dep.isTypeDepObj)
+          continue;
         Value *Base =
-            Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P);
+            Builder.CreateGEP(DependInfo, DepAlloca, Builder.getInt64(P));
         // Store the pointer to the variable
         Value *Addr = Builder.CreateStructGEP(
             DependInfo, Base,
@@ -2087,6 +2129,23 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
         ++P;
       }
 
+      P = 0;
+      Value *depAllocaIdx = Builder.getInt64(otherDepTypeCount);
+      for (const DependData &dep : Dependencies) {
+        if (dep.isTypeDepObj) {
+          Value *depAllocaPtr =
+              Builder.CreateGEP(DependInfo, DepAlloca, depAllocaIdx);
+          Align alignment = Align(8);
+          Value *loadDepVal = Builder.CreateLoad(VoidPtr, dep.DepVal);
+          Value *memCpySize =
+              Builder.CreateMul(Builder.getInt64(24), objsVal[P]);
+          Builder.CreateMemCpy(depAllocaPtr, alignment, loadDepVal, alignment,
+                               memCpySize);
+          depAllocaIdx = Builder.CreateAdd(depAllocaIdx, objsVal[P]);
+          ++P;
+        }
+      }
+
       Builder.restoreIP(OldIP);
     }
 
@@ -2124,7 +2183,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
             getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
         Builder.CreateCall(
             TaskWaitFn,
-            {Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepArray,
+            {Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepAlloca,
              ConstantInt::get(Builder.getInt32Ty(), 0),
              ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
       }
@@ -2146,12 +2205,13 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
     if (Dependencies.size()) {
       Function *TaskFn =
           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
-      Builder.CreateCall(
-          TaskFn,
-          {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
-           DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
-           ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
-
+      Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData, depSize, DepAlloca,
+                                  ConstantInt::get(Builder.getInt32Ty(), 0),
+                                  ConstantPointerNull::get(
+                                      PointerType::getUnqual(M.getContext()))});
+      // stackSave is used by depend(depobj: x) clause to save the stack pointer
+      if (stackSave)
+        Builder.CreateStackRestore(stackSave);
     } else {
       // Emit the @__kmpc_omp_task runtime call to spawn the task
       Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 690e3df1f685e3..bbe1174775184d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -111,12 +111,14 @@ def ClauseTaskDependInOut : I32EnumAttrCase<"taskdependinout", 2>;
 def ClauseTaskDependMutexInOutSet
     : I32EnumAttrCase<"taskdependmutexinoutset", 3>;
 def ClauseTaskDependInOutSet : I32EnumAttrCase<"taskdependinoutset", 4>;
+def ClauseTaskDependDepObj : I32EnumAttrCase<"taskdependdepobj", 5>;
 
 def ClauseTaskDepend
-    : OpenMP_I32EnumAttr<
-          "ClauseTaskDepend", "depend clause in a target or task construct",
-          [ClauseTaskDependIn, ClauseTaskDependOut, ClauseTaskDependInOut,
-           ClauseTaskDependMutexInOutSet, ClauseTaskDependInOutSet]>;
+    : OpenMP_I32EnumAttr<"ClauseTaskDepend",
+                         "depend clause in a target or task construct",
+                         [ClauseTaskDependIn, ClauseTaskDependOut,
+                          ClauseTaskDependInOut, ClauseTaskDependMutexInOutSet,
+                          ClauseTaskDependInOutSet, ClauseTaskDependDepObj]>;
 
 def ClauseTaskDependAttr : OpenMP_EnumAttr<ClauseTaskDepend,
                                            "clause_task_depend"> {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3fcdefa8a2f673..de4bd108fff675 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1715,6 +1715,7 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
     return;
   for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
     llvm::omp::RTLDependenceKindTy type;
+    bool isTypeDepObj = false;
     switch (
         cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
     case mlir::omp::ClauseTaskDepend::taskdependin:
@@ -1733,9 +1734,13 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
     case mlir::omp::ClauseTaskDepend::taskdependinoutset:
       type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
       break;
+    case mlir::omp::ClauseTaskDepend::taskdependdepobj:
+      isTypeDepObj = true;
+      break;
     };
     llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
-    llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
+    llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal,
+                                         isTypeDepObj);
     dds.emplace_back(dd);
   }
 }
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 9868ef227d49e0..b8ae3d0bec2c88 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2653,7 +2653,7 @@ llvm.func @omp_task_attrs() -> () attributes {
 // CHECK-LABEL: define void @omp_task_with_deps
 // CHECK-SAME: (ptr %[[zaddr:.+]])
 // CHECK:  %[[dep_arr_addr:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
-// CHECK:  %[[dep_arr_addr_0:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[dep_arr_addr]], i64 0, i64 0
+// CHECK:  %[[dep_arr_addr_0:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_arr_addr]], i64 0
 // CHECK:  %[[dep_arr_addr_0_val:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[dep_arr_addr_0]], i32 0, i32 0
 // CHECK:  %[[dep_arr_addr_0_val_int:.+]] = ptrtoint ptr %0 to i64
 // CHECK:  store i64 %[[dep_arr_addr_0_val_int]], ptr %[[dep_arr_addr_0_val]], align 4
@@ -2664,28 +2664,28 @@ llvm.func @omp_task_attrs() -> () attributes {
 // -----
 // dependence_type: Out
 // CHECK:  %[[DEP_ARR_ADDR1:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
-// CHECK:  %[[DEP_ARR_ADDR_1:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR1]], i64 0, i64 0
+// CHECK:  %[[DEP_ARR_ADDR_1:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR1]], i64 0
 //         [...]
 // CHECK:  %[[DEP_TYPE_1:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_1]], i32 0, i32 2
 // CHECK:  store i8 3, ptr %[[DEP_TYPE_1]], align 1
 // -----
 // dependence_type: Inout
 // CHECK:  %[[DEP_ARR_ADDR2:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
-// CHECK:  %[[DEP_ARR_ADDR_2:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR2]], i64 0, i64 0
+// CHECK:  %[[DEP_ARR_ADDR_2:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR2]], i64 0
 //         [...]
 // CHECK:  %[[DEP_TYPE_2:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_2]], i32 0, i32 2
 // CHECK:  store i8 3, ptr %[[DEP_TYPE_2]], align 1
 // -----
 // dependence_type: Mutexinoutset
 // CHECK:  %[[DEP_ARR_ADDR3:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
-// CHECK:  %[[DEP_ARR_ADDR_3:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR3]], i64 0, i64 0
+// CHECK:  %[[DEP_ARR_ADDR_3:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR3]], i64 0
 //         [...]
 // CHECK:  %[[DEP_TYPE_3:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_3]], i32 0, i32 2
 // CHECK:  store i8 4, ptr %[[DEP_TYPE_3]], align 1
 // -----
 // dependence_type: Inoutset
 // CHECK:  %[[DEP_ARR_ADDR4:.+]] = alloca [1 x %struct.kmp_dep_info], align 8
-// CHECK:  %[[DEP_ARR_ADDR_4:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[DEP_ARR_ADDR4]], i64 0, i64 0
+// CHECK:  %[[DEP_ARR_ADDR_4:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR4]], i64 0
 //         [...]
 // CHECK:  %[[DEP_TYPE_4:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[DEP_ARR_ADDR_4]], i32 0, i32 2
 // CHECK:  store i8 8, ptr %[[DEP_TYPE_4]], align 1
@@ -2734,6 +2734,63 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr) {
 
 // -----
 
+// CHECK-LABEL: define void @omp_task_with_deps_02(ptr %0, ptr %1) {
+
+// CHECK:   %[[obj:.+]] = alloca i64, i64 1, align 8
+// CHECK:   %[[obj_load_01:.+]] = load ptr, ptr %[[obj]], align 8
+// CHECK:   %[[gep_01:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[obj_load_01]], i64 -1
+// CHECK:   %[[gep_02:.+]] = getelementptr inbounds %struct.kmp_dep_info, ptr %[[gep_01]], i64 0, i64 0
+// CHECK:   %[[obj_addr:.+]] = load i64, ptr %[[gep_02]], align 4
+
+// CHECK:   %[[size:.+]] = add i64 %[[obj_addr]], 2
+
+// CHECK:   %[[stack_ptr:.+]] = call ptr @llvm.stacksave.p0()
+// CHECK:   %[[dep_addr:.+]] = alloca %struct.kmp_dep_info, i64 %[[size]], align 16
+// CHECK:   %[[dep_size:.+]] = trunc i64 %[[size]] to i32
+
+// CHECK:   %[[gep_03:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_addr]], i64 0
+// CHECK:   %[[gep_04:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_03]], i32 0, i32 0
+// CHECK:   %[[arg_01_int:.+]] = ptrtoint ptr %0 to i64
+// CHECK:   store i64 %[[arg_01_int]], ptr %[[gep_04]], align 4
+// CHECK:   %[[gep_05:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_03]], i32 0, i32 1
+// CHECK:   store i64 8, ptr %[[gep_05]], align 4
+// CHECK:   %[[gep_06:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_03]], i32 0, i32 2
+// CHECK:   store i8 1, ptr %[[gep_06]], align 1
+
+// CHECK:   %[[gep_07:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_addr]], i64 1
+// CHECK:   %[[gep_08:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_07]], i32 0, i32 0
+// CHECK:   %[[arg_02_int:.+]] = ptrtoint ptr %1 to i64
+// CHECK:   store i64 %[[arg_02_int]], ptr %[[gep_08]], align 4
+// CHECK:   %[[gep_09:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_07]], i32 0, i32 1
+// CHECK:   store i64 8, ptr %[[gep_09]], align 4
+// CHECK:   %[[gep_10:.+]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[gep_07]], i32 0, i32 2
+// CHECK:   store i8 3, ptr %[[gep_10]], align 1
+
+// CHECK:   %[[gep_11:.+]] = getelementptr %struct.kmp_dep_info, ptr %[[dep_addr]], i64 2
+// CHECK:   %[[obj_load_02:.+]] = load ptr, ptr %[[obj]], align 8
+// CHECK:   %[[obj_size:.+]] = mul i64 24, %[[obj_addr]]
+// CHECK:   call void @llvm.memcpy.p0.p0.i64(ptr align 8 %[[gep_11]], ptr align 8 %[[obj_load_02]], i64 %[[obj_size]], i1 false)
+// CHECK:   %[[dep_size_idx:.+]] = add i64 2, %[[obj_addr]]
+
+// CHECK:   %[[task:.+]] = call i32 @__kmpc_omp_task_with_deps({{.*}}, i32 %[[dep_size]], ptr %[[dep_addr]], i32 0, ptr null)
+// CHECK:   call void @llvm.stackrestore.p0(ptr %[[stack_ptr]])
+// CHECK: }
+
+
+llvm.func @omp_task_with_deps_02(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+  %c_1 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %c_1 x i64 : (i64) -> !llvm.ptr
+  omp.task depend(taskdependin -> %arg0 : !llvm.ptr, taskdependdepobj -> %1 : !llvm.ptr, taskdependout -> %arg1 : !llvm.ptr) {
+    %4 = llvm.load %arg0 : !llvm.ptr -> i64
+    %5 = llvm.add %4, %c_1 : i64
+    llvm.store %5, %arg1 : i64, !llvm.ptr
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
 // CHECK-LABEL: define void @omp_task
 // CHECK-SAME: (i32 %[[x:.+]], i32 %[[y:.+]], ptr %[[zaddr:.+]])
 module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the stack allocated dependency structure used only during task creation (which is synchronous) or might the data also be referred to (at least by the openmp runtime) during task execution (which is asynchronous)?

Tasks may not be executed immediately and so the task being passed DepAlloca might not execute until after the stack restore. Even without that stack restore, it might also outlive the stack frame which created the task.

@Thirumalai-Shaktivel
Copy link
Member Author

Thanks for the review, @tblah.

The DepAlloca data will required for the task creation, so I think it should be available during the runtime (Mostly, Asynchronous).
Regarding the StackSave and Restore usage, I think it must be a mistake and might not required. But, I have a doubt, how does the memory deallocation happen here?

I took reference from Clang's IR and it seems Clang uses stack save and restore to clean up the block.
They handle it here:

if (!VarAllocated) {
if (!DidCallStackSave) {
// Save the stack.
Address Stack =
CreateDefaultAlignTempAlloca(AllocaInt8PtrTy, "saved_stack");
llvm::Value *V = Builder.CreateStackSave();
assert(V->getType() == AllocaInt8PtrTy);
Builder.CreateStore(V, Stack);
DidCallStackSave = true;
// Push a cleanup block and restore the stack there.
// FIXME: in general circumstances, this should be an EH cleanup.
pushStackRestore(NormalCleanup, Stack);
}
auto VlaSize = getVLASize(Ty);
llvm::Type *llvmTy = ConvertTypeForMem(VlaSize.Type);
// Allocate memory for the array.
address = CreateTempAlloca(llvmTy, alignment, "vla", VlaSize.NumElts,
&AllocaAddr);
}
and the array is created here:
auto *PD = ImplicitParamDecl::Create(C, KmpTaskAffinityInfoArrayTy,
ImplicitParamKind::Other);
CGF.EmitVarDecl(*PD);

@tblah
Copy link
Contributor

tblah commented Jan 31, 2025

The DepAlloca data will required for the task creation, so I think it should be available during the runtime (Mostly, Asynchronous). Regarding the StackSave and Restore usage, I think it must be a mistake and might not required. But, I have a doubt, how does the memory deallocation happen here?

When a function is called, data are allocated on the stack (by adjusting the stack pointer), then as the function executes there might be later allocas which further adjust the stack pointer (allocating more memory). Your stack restore un-does one of these later adjustments. Otherwise, when the function returns the stack pointer is restored to its value before the function was called. So after the function return, all stack data are unallocated, and accessing them in any way is undefined behavior.

Task creation occurs synchronously. The function creating the task will not return until after the task has been created and so if data are only used during task creation, it is safe for those to be stack-allocated.

Task execution occurs asynchronously. It could happen any time between the creation of the task and the next synchronization event (e.g. taskwait). Therefore it is not safe for a task to refer to stack allocated data because the function containing that stack frame might have returned by the time the task uses the data.

@tblah
Copy link
Contributor

tblah commented Jan 31, 2025

Looking at the implementation of the runtime I think this array is only used in task creation:

kmp_int32 __kmpc_omp_task_with_deps(ident_t *loc_ref, kmp_int32 gtid,

if (!stackSave)
stackSave = Builder.CreateStackSave();
DepAlloca = Builder.CreateAlloca(DependInfo, depSize, "dep.addr");
((AllocaInst *)DepAlloca)->setAlignment(Align(16));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 16?

Copy link
Member Author

@Thirumalai-Shaktivel Thirumalai-Shaktivel Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I referred to the omp_depend_kind value as 16 from gfortran (I thought it was according to standards).
During the implementation, I didn't check Flang, my bad. It seems the kind value is compiler-specific as Flang uses the kind value as 8.

Maybe we can keep it as align 8, what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So your intention is for the alignment to match the size of the type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, integer(omp_depend_kind).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I think it would be best to programatically get the type size in bytes and use that. Something like this (I haven't tried building it)

Suggested change
((AllocaInst *)DepAlloca)->setAlignment(Align(16));
llvm::Constant *size = llvm::ConstantExpr::getSizeOf(DependInfo);
((AllocaInst *)DepAlloca)->setAlignment(Align(size));

Also please use C++ style casts: https://llvm.org/docs/CodingStandards.html#prefer-c-style-casts

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, I will push the required changes soon.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have strong reason keeping the alignment 16, so I thought to remove it.

KmpDependInfoArrayTy =
C.getVariableArrayType(KmpDependInfoTy, OVE, ArraySizeModifier::Normal,
/*IndexTypeQuals=*/0, SourceRange(Loc, Loc));
// CGF.EmitVariablyModifiedType(KmpDependInfoArrayTy);
// Properly emit variable-sized array.
auto *PD = ImplicitParamDecl::Create(C, KmpDependInfoArrayTy,
ImplicitParamKind::Other);
CGF.EmitVarDecl(*PD);

In clang, the alignment value is taking from PD, which is created using KmpDependInfoArrayTy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:openmp OpenMP related changes to Clang flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category mlir:llvm mlir:openmp mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants