Skip to content

[OpenMP][MLIR] Lowering task_reduction clause to LLVMIR #111788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

harishch4
Copy link
Contributor

This patch lowers the task_reduction clause operation to LLVM IR, similar to how Clang handles it. It generates task reduction functions (red_init, red_comb, red_fini, etc.) and maps them to an internal structure, kmp_taskred_input_t, before invoking the runtime call __kmpc_taskred_init.

Currently, it supports task reduction variables passed by value (VAL). A TODO has been added for handling variables passed by reference.

@llvmbot
Copy link
Member

llvmbot commented Oct 10, 2024

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir-openmp

Author: None (harishch4)

Changes

This patch lowers the task_reduction clause operation to LLVM IR, similar to how Clang handles it. It generates task reduction functions (red_init, red_comb, red_fini, etc.) and maps them to an internal structure, kmp_taskred_input_t, before invoking the runtime call __kmpc_taskred_init.

Currently, it supports task reduction variables passed by value (VAL). A TODO has been added for handling variables passed by reference.


Full diff: https://github.com/llvm/llvm-project/pull/111788.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+7)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+223-5)
  • (added) mlir/test/Target/LLVMIR/openmp-task-reduction.mlir (+79)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 886554f66afffc..4a30c35549903b 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1128,6 +1128,13 @@ class OpenMP_TaskReductionClauseSkip<
     unsigned numTaskReductionBlockArgs() {
       return getTaskReductionVars().size();
     }
+    
+    /// Returns the number of reduction variables.
+    unsigned getNumReductionVars() { return getReductionVars().size(); }
+    
+    auto getReductionSyms() {
+      return getTaskReductionSyms();
+    }
   }];
 
   let description = [{
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 19d80fbbd699b0..d1162d6afcc4bb 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1112,9 +1112,8 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   LogicalResult bodyGenStatus = success();
   if (taskOp.getUntiedAttr() || taskOp.getMergeableAttr() ||
-      taskOp.getInReductionSyms() || taskOp.getPriority() ||
-      !taskOp.getAllocateVars().empty() || !taskOp.getPrivateVars().empty() ||
-      taskOp.getPrivateSyms()) {
+      taskOp.getPriority() || !taskOp.getAllocateVars().empty() ||
+      !taskOp.getPrivateVars().empty() || taskOp.getPrivateSyms()) {
     return taskOp.emitError("unhandled clauses for translation to LLVM IR");
   }
   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
@@ -1142,19 +1141,238 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   return bodyGenStatus;
 }
 
+template <typename OP>
+llvm::Value *createTaskReductionFunction(
+    llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
+    LLVM::ModuleTranslation &moduleTranslation,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region &region,
+    OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap) {
+  llvm::LLVMContext &Context = builder.getContext();
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
+  // TODO: by-ref reduction variables are yet to be handled.
+  if (region.empty() || isByRef[Cnt]) {
+    return llvm::Constant::getNullValue(OpaquePtrTy);
+  }
+  llvm::FunctionType *funcType =
+      llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
+  llvm::Function *function =
+      llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
+                             builder.GetInsertBlock()->getModule());
+  function->setDoesNotRecurse();
+  llvm::BasicBlock *entry =
+      llvm::BasicBlock::Create(Context, "entry", function);
+  llvm::IRBuilder<> bbBuilder(entry);
+
+  llvm::Value *arg0 = function->getArg(0);
+  llvm::Value *arg1 = function->getArg(1);
+
+  if (name == "red_init") {
+    function->addParamAttr(0, llvm::Attribute::NoAlias);
+    function->addParamAttr(1, llvm::Attribute::NoAlias);
+    mapInitializationArgs(op, moduleTranslation, reductionDecls,
+                          reductionVariableMap, Cnt);
+  } else if (name == "red_comb") {
+    llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
+    llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
+    moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+    moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+  }
+
+  SmallVector<llvm::Value *, 1> phis;
+  if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
+                                     &phis)))
+    return nullptr;
+  assert(
+      phis.size() == 1 &&
+      "expected one value to be yielded from the reduction declaration region");
+
+  bbBuilder.CreateStore(phis[0], arg0);
+  bbBuilder.CreateRet(arg0); // Return from the function
+  return function;
+}
+
+void emitTaskRedInitCall(
+    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
+    const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
+    llvm::Value *ArrayAlloca) {
+
+  llvm::LLVMContext &Context = builder.getContext();
+  uint32_t SrcLocStrSize;
+  llvm::Constant *SrcLocStr =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc,
+                                                                 SrcLocStrSize);
+  llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(
+      SrcLocStr, SrcLocStrSize);
+  llvm::Value *ThreadID =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
+  llvm::Constant *ConstInt =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize);
+
+  llvm::Function *TaskRedInitFn =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
+          llvm::omp::OMPRTL___kmpc_taskred_init);
+  builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
+}
+
+template <typename OP>
+static LogicalResult allocAndInitializeTaskReductionVars(
+    OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation,
+    llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap,
+    llvm::ArrayRef<bool> isByRef) {
+
+  if (op.getNumReductionVars() == 0)
+    return success();
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::LLVMContext &Context = builder.getContext();
+  SmallVector<DeferredStore> deferredStores;
+
+  // Save the current insertion point
+  auto oldIP = builder.saveIP();
+
+  // Set insertion point after the allocations
+  builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
+
+  // Define the kmp_taskred_input_t structure
+  llvm::StructType *kmp_taskred_input_t =
+      llvm::StructType::create(Context, "kmp_taskred_input_t");
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0); // void*
+  llvm::Type *SizeTy = builder.getInt64Ty(); // size_t (assumed to be i64)
+  llvm::Type *FlagsTy = llvm::Type::getInt32Ty(Context); // flags (i32)
+
+  // Structure members
+  std::vector<llvm::Type *> structMembers = {
+      OpaquePtrTy, // reduce_shar (void*)
+      OpaquePtrTy, // reduce_orig (void*)
+      SizeTy,      // reduce_size (size_t)
+      OpaquePtrTy, // reduce_init (void*)
+      OpaquePtrTy, // reduce_fini (void*)
+      OpaquePtrTy, // reduce_comb (void*)
+      FlagsTy      // flags (i32)
+  };
+
+  kmp_taskred_input_t->setBody(structMembers);
+  int arraySize = op.getNumReductionVars();
+  llvm::ArrayType *ArrayTy =
+      llvm::ArrayType::get(kmp_taskred_input_t, arraySize);
+
+  // Allocate the array for kmp_taskred_input_t
+  llvm::AllocaInst *ArrayAlloca =
+      builder.CreateAlloca(ArrayTy, nullptr, "kmp_taskred_array");
+
+  // Restore the insertion point
+  builder.restoreIP(oldIP);
+  llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();
+
+  for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
+    llvm::Value *shared =
+        moduleTranslation.lookupValue(op.getReductionVars()[Cnt]);
+    // Create a GEP to access the reduction element
+    llvm::Value *StructPtr = builder.CreateGEP(
+        ArrayTy, ArrayAlloca, {builder.getInt32(0), builder.getInt32(Cnt)},
+        "red_element");
+
+    llvm::Value *FieldPtrReduceShar = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 0, "reduce_shar");
+    builder.CreateStore(shared, FieldPtrReduceShar);
+
+    llvm::Value *FieldPtrReduceOrig = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 1, "reduce_orig");
+    builder.CreateStore(shared, FieldPtrReduceOrig);
+
+    // Store size of the reduction variable
+    llvm::Value *FieldPtrReduceSize = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 2, "reduce_size");
+    llvm::Type *redTy;
+    if (auto *alloca = dyn_cast<llvm::AllocaInst>(shared)) {
+      redTy = alloca->getAllocatedType();
+      uint64_t sizeInBytes = DL.getTypeAllocSize(redTy);
+
+      llvm::ConstantInt *sizeConst =
+          llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
+      builder.CreateStore(sizeConst, FieldPtrReduceSize);
+    } else {
+      llvm_unreachable("Non alloca instruction found.");
+    }
+
+    // Initialize reduction variable
+    llvm::Value *FieldPtrReduceInit = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 3, "reduce_init");
+    llvm::Value *initFunction = createTaskReductionFunction(
+        builder, "red_init", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getInitializerRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(initFunction, FieldPtrReduceInit);
+
+    // Create finish and combine functions
+    llvm::Value *FieldPtrReduceFini = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 4, "reduce_fini");
+    llvm::Value *finiFunction = createTaskReductionFunction(
+        builder, "red_fini", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getCleanupRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(finiFunction, FieldPtrReduceFini);
+
+    llvm::Value *FieldPtrReduceComb = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 5, "reduce_comb");
+    llvm::Value *combFunction = createTaskReductionFunction(
+        builder, "red_comb", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getReductionRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(combFunction, FieldPtrReduceComb);
+
+    llvm::Value *FieldPtrFlags =
+        builder.CreateStructGEP(kmp_taskred_input_t, StructPtr, 6, "flags");
+    llvm::ConstantInt *flagVal =
+        llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), 0);
+    builder.CreateStore(flagVal, FieldPtrFlags);
+  }
+
+  // Emit the runtime call
+  emitTaskRedInitCall(builder, moduleTranslation, ompLoc, arraySize,
+                      ArrayAlloca);
+  return success();
+}
+
 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
                       LLVM::ModuleTranslation &moduleTranslation) {
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   LogicalResult bodyGenStatus = success();
-  if (!tgOp.getTaskReductionVars().empty() || !tgOp.getAllocateVars().empty()) {
+  if (!tgOp.getAllocateVars().empty()) {
     return tgOp.emitError("unhandled clauses for translation to LLVM IR");
   }
+
+  llvm::ArrayRef<bool> isByRef = getIsByRef(tgOp.getTaskReductionByref());
+  assert(isByRef.size() == tgOp.getNumReductionVars());
+
+  SmallVector<omp::DeclareReductionOp> reductionDecls;
+  collectReductionDecls(tgOp, reductionDecls);
+  SmallVector<llvm::Value *> privateReductionVariables(
+      tgOp.getNumReductionVars());
+  DenseMap<Value, llvm::Value *> reductionVariableMap;
+
+  MutableArrayRef<BlockArgument> reductionArgs =
+      tgOp.getRegion().getArguments();
+
   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
     builder.restoreIP(codegenIP);
+
+    if (failed(allocAndInitializeTaskReductionVars(
+            tgOp, reductionArgs, builder, moduleTranslation, allocaIP,
+            reductionDecls, privateReductionVariables, reductionVariableMap,
+            isByRef)))
+      bodyGenStatus = failure();
+    SmallVector<llvm::PHINode *> phis;
     convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region", builder,
-                        moduleTranslation, bodyGenStatus);
+                        moduleTranslation, bodyGenStatus, &phis);
   };
   InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
diff --git a/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
new file mode 100644
index 00000000000000..1d4d22d5413c61
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+omp.declare_reduction @add_reduction_i32 : i32 init {
+^bb0(%arg0: i32):
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%0 : i32)
+} combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %0 = llvm.add %arg0, %arg1 : i32
+  omp.yield(%0 : i32)
+}
+llvm.func @_QPtest_task_reduciton() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+  omp.taskgroup task_reduction(@add_reduction_i32 %1 -> %arg0 : !llvm.ptr) {
+      %2 = llvm.load %1 : !llvm.ptr -> i32
+      %3 = llvm.mlir.constant(1 : i32) : i32
+      %4 = llvm.add %2, %3 : i32
+      llvm.store %4, %1 : i32, !llvm.ptr
+      omp.terminator
+  }
+  llvm.return
+}
+
+//CHECK-LABEL: define void @_QPtest_task_reduciton() {
+//CHECK:   %[[VAL1:.*]] = alloca i32, i64 1, align 4
+//CHECK:   %[[RED_ARRY:.*]] = alloca [1 x %kmp_taskred_input_t], align 8
+//CHECK:   br label %entry
+
+//CHECK: entry:
+//CHECK:   %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   call void @__kmpc_taskgroup(ptr @1, i32 %[[TID]])
+//CHECK:   %[[RED_ELEMENT:.*]] = getelementptr [1 x %kmp_taskred_input_t], ptr %[[RED_ARRY]], i32 0, i32 0
+//CHECK:   %[[RED_SHARED:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 0
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_SHARED]], align 8
+//CHECK:   %[[RED_ORIG:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 1
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_ORIG]], align 8
+//CHECK:   %[[RED_SIZE:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 2
+//CHECK:   store i64 4, ptr %[[RED_SIZE]], align 4
+//CHECK:   %[[RED_INIT:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 3
+//CHECK:   store ptr @red_init, ptr %[[RED_INIT]], align 8
+//CHECK:   %[[RED_FINI:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 4
+//CHECK:   store ptr null, ptr %[[RED_FINI]], align 8
+//CHECK:   %[[RED_COMB:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 5
+//CHECK:   store ptr @red_comb, ptr %[[RED_COMB]], align 8
+//CHECK:   %[[FLAGS:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 6
+//CHECK:   store i64 0, ptr %[[FLAGS]], align 4
+//CHECK:   %[[TID1:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   %2 = call ptr @__kmpc_taskred_init(i32 %[[TID1]], i32 1, ptr %[[RED_ARRY]])
+//CHECK:   br label %omp.taskgroup.region
+
+//CHECK: omp.taskgroup.region:
+//CHECK:   %[[VAL3:.*]] = load i32, ptr %[[VAL1]], align 4
+//CHECK:   %4 = add i32 %[[VAL3]], 1
+//CHECK:   store i32 %4, ptr %[[VAL1]], align 4
+//CHECK:   br label %omp.region.cont
+
+//CHECK: omp.region.cont:
+//CHECK:   br label %taskgroup.exit
+
+//CHECK: taskgroup.exit:
+//CHECK:   call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[TID]])
+//CHECK:   ret void
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_init(ptr noalias %0, ptr noalias %1) #2 {
+//CHECK: entry:
+//CHECK:   store i32 0, ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_comb(ptr %0, ptr %1) #2 {
+//CHECK: entry:
+//CHECK:   %[[LD0:.*]] = load i32, ptr %0, align 4
+//CHECK:   %[[LD1:.*]] = load i32, ptr %1, align 4
+//CHECK:   %[[RES:.*]] = add i32 %[[LD0]], %[[LD1]]
+//CHECK:   store i32 %[[RES]], ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }

@llvmbot
Copy link
Member

llvmbot commented Oct 10, 2024

@llvm/pr-subscribers-flang-openmp

Author: None (harishch4)

Changes

This patch lowers the task_reduction clause operation to LLVM IR, similar to how Clang handles it. It generates task reduction functions (red_init, red_comb, red_fini, etc.) and maps them to an internal structure, kmp_taskred_input_t, before invoking the runtime call __kmpc_taskred_init.

Currently, it supports task reduction variables passed by value (VAL). A TODO has been added for handling variables passed by reference.


Full diff: https://github.com/llvm/llvm-project/pull/111788.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+7)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+223-5)
  • (added) mlir/test/Target/LLVMIR/openmp-task-reduction.mlir (+79)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 886554f66afffc..4a30c35549903b 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1128,6 +1128,13 @@ class OpenMP_TaskReductionClauseSkip<
     unsigned numTaskReductionBlockArgs() {
       return getTaskReductionVars().size();
     }
+    
+    /// Returns the number of reduction variables.
+    unsigned getNumReductionVars() { return getReductionVars().size(); }
+    
+    auto getReductionSyms() {
+      return getTaskReductionSyms();
+    }
   }];
 
   let description = [{
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 19d80fbbd699b0..d1162d6afcc4bb 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1112,9 +1112,8 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   LogicalResult bodyGenStatus = success();
   if (taskOp.getUntiedAttr() || taskOp.getMergeableAttr() ||
-      taskOp.getInReductionSyms() || taskOp.getPriority() ||
-      !taskOp.getAllocateVars().empty() || !taskOp.getPrivateVars().empty() ||
-      taskOp.getPrivateSyms()) {
+      taskOp.getPriority() || !taskOp.getAllocateVars().empty() ||
+      !taskOp.getPrivateVars().empty() || taskOp.getPrivateSyms()) {
     return taskOp.emitError("unhandled clauses for translation to LLVM IR");
   }
   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
@@ -1142,19 +1141,238 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   return bodyGenStatus;
 }
 
+template <typename OP>
+llvm::Value *createTaskReductionFunction(
+    llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
+    LLVM::ModuleTranslation &moduleTranslation,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region &region,
+    OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap) {
+  llvm::LLVMContext &Context = builder.getContext();
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
+  // TODO: by-ref reduction variables are yet to be handled.
+  if (region.empty() || isByRef[Cnt]) {
+    return llvm::Constant::getNullValue(OpaquePtrTy);
+  }
+  llvm::FunctionType *funcType =
+      llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
+  llvm::Function *function =
+      llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
+                             builder.GetInsertBlock()->getModule());
+  function->setDoesNotRecurse();
+  llvm::BasicBlock *entry =
+      llvm::BasicBlock::Create(Context, "entry", function);
+  llvm::IRBuilder<> bbBuilder(entry);
+
+  llvm::Value *arg0 = function->getArg(0);
+  llvm::Value *arg1 = function->getArg(1);
+
+  if (name == "red_init") {
+    function->addParamAttr(0, llvm::Attribute::NoAlias);
+    function->addParamAttr(1, llvm::Attribute::NoAlias);
+    mapInitializationArgs(op, moduleTranslation, reductionDecls,
+                          reductionVariableMap, Cnt);
+  } else if (name == "red_comb") {
+    llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
+    llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
+    moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+    moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+  }
+
+  SmallVector<llvm::Value *, 1> phis;
+  if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
+                                     &phis)))
+    return nullptr;
+  assert(
+      phis.size() == 1 &&
+      "expected one value to be yielded from the reduction declaration region");
+
+  bbBuilder.CreateStore(phis[0], arg0);
+  bbBuilder.CreateRet(arg0); // Return from the function
+  return function;
+}
+
+void emitTaskRedInitCall(
+    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
+    const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
+    llvm::Value *ArrayAlloca) {
+
+  llvm::LLVMContext &Context = builder.getContext();
+  uint32_t SrcLocStrSize;
+  llvm::Constant *SrcLocStr =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc,
+                                                                 SrcLocStrSize);
+  llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(
+      SrcLocStr, SrcLocStrSize);
+  llvm::Value *ThreadID =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
+  llvm::Constant *ConstInt =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize);
+
+  llvm::Function *TaskRedInitFn =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
+          llvm::omp::OMPRTL___kmpc_taskred_init);
+  builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
+}
+
+template <typename OP>
+static LogicalResult allocAndInitializeTaskReductionVars(
+    OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation,
+    llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap,
+    llvm::ArrayRef<bool> isByRef) {
+
+  if (op.getNumReductionVars() == 0)
+    return success();
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::LLVMContext &Context = builder.getContext();
+  SmallVector<DeferredStore> deferredStores;
+
+  // Save the current insertion point
+  auto oldIP = builder.saveIP();
+
+  // Set insertion point after the allocations
+  builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
+
+  // Define the kmp_taskred_input_t structure
+  llvm::StructType *kmp_taskred_input_t =
+      llvm::StructType::create(Context, "kmp_taskred_input_t");
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0); // void*
+  llvm::Type *SizeTy = builder.getInt64Ty(); // size_t (assumed to be i64)
+  llvm::Type *FlagsTy = llvm::Type::getInt32Ty(Context); // flags (i32)
+
+  // Structure members
+  std::vector<llvm::Type *> structMembers = {
+      OpaquePtrTy, // reduce_shar (void*)
+      OpaquePtrTy, // reduce_orig (void*)
+      SizeTy,      // reduce_size (size_t)
+      OpaquePtrTy, // reduce_init (void*)
+      OpaquePtrTy, // reduce_fini (void*)
+      OpaquePtrTy, // reduce_comb (void*)
+      FlagsTy      // flags (i32)
+  };
+
+  kmp_taskred_input_t->setBody(structMembers);
+  int arraySize = op.getNumReductionVars();
+  llvm::ArrayType *ArrayTy =
+      llvm::ArrayType::get(kmp_taskred_input_t, arraySize);
+
+  // Allocate the array for kmp_taskred_input_t
+  llvm::AllocaInst *ArrayAlloca =
+      builder.CreateAlloca(ArrayTy, nullptr, "kmp_taskred_array");
+
+  // Restore the insertion point
+  builder.restoreIP(oldIP);
+  llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();
+
+  for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
+    llvm::Value *shared =
+        moduleTranslation.lookupValue(op.getReductionVars()[Cnt]);
+    // Create a GEP to access the reduction element
+    llvm::Value *StructPtr = builder.CreateGEP(
+        ArrayTy, ArrayAlloca, {builder.getInt32(0), builder.getInt32(Cnt)},
+        "red_element");
+
+    llvm::Value *FieldPtrReduceShar = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 0, "reduce_shar");
+    builder.CreateStore(shared, FieldPtrReduceShar);
+
+    llvm::Value *FieldPtrReduceOrig = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 1, "reduce_orig");
+    builder.CreateStore(shared, FieldPtrReduceOrig);
+
+    // Store size of the reduction variable
+    llvm::Value *FieldPtrReduceSize = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 2, "reduce_size");
+    llvm::Type *redTy;
+    if (auto *alloca = dyn_cast<llvm::AllocaInst>(shared)) {
+      redTy = alloca->getAllocatedType();
+      uint64_t sizeInBytes = DL.getTypeAllocSize(redTy);
+
+      llvm::ConstantInt *sizeConst =
+          llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
+      builder.CreateStore(sizeConst, FieldPtrReduceSize);
+    } else {
+      llvm_unreachable("Non alloca instruction found.");
+    }
+
+    // Initialize reduction variable
+    llvm::Value *FieldPtrReduceInit = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 3, "reduce_init");
+    llvm::Value *initFunction = createTaskReductionFunction(
+        builder, "red_init", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getInitializerRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(initFunction, FieldPtrReduceInit);
+
+    // Create finish and combine functions
+    llvm::Value *FieldPtrReduceFini = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 4, "reduce_fini");
+    llvm::Value *finiFunction = createTaskReductionFunction(
+        builder, "red_fini", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getCleanupRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(finiFunction, FieldPtrReduceFini);
+
+    llvm::Value *FieldPtrReduceComb = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 5, "reduce_comb");
+    llvm::Value *combFunction = createTaskReductionFunction(
+        builder, "red_comb", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getReductionRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(combFunction, FieldPtrReduceComb);
+
+    llvm::Value *FieldPtrFlags =
+        builder.CreateStructGEP(kmp_taskred_input_t, StructPtr, 6, "flags");
+    llvm::ConstantInt *flagVal =
+        llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), 0);
+    builder.CreateStore(flagVal, FieldPtrFlags);
+  }
+
+  // Emit the runtime call
+  emitTaskRedInitCall(builder, moduleTranslation, ompLoc, arraySize,
+                      ArrayAlloca);
+  return success();
+}
+
 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
                       LLVM::ModuleTranslation &moduleTranslation) {
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   LogicalResult bodyGenStatus = success();
-  if (!tgOp.getTaskReductionVars().empty() || !tgOp.getAllocateVars().empty()) {
+  if (!tgOp.getAllocateVars().empty()) {
     return tgOp.emitError("unhandled clauses for translation to LLVM IR");
   }
+
+  llvm::ArrayRef<bool> isByRef = getIsByRef(tgOp.getTaskReductionByref());
+  assert(isByRef.size() == tgOp.getNumReductionVars());
+
+  SmallVector<omp::DeclareReductionOp> reductionDecls;
+  collectReductionDecls(tgOp, reductionDecls);
+  SmallVector<llvm::Value *> privateReductionVariables(
+      tgOp.getNumReductionVars());
+  DenseMap<Value, llvm::Value *> reductionVariableMap;
+
+  MutableArrayRef<BlockArgument> reductionArgs =
+      tgOp.getRegion().getArguments();
+
   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
     builder.restoreIP(codegenIP);
+
+    if (failed(allocAndInitializeTaskReductionVars(
+            tgOp, reductionArgs, builder, moduleTranslation, allocaIP,
+            reductionDecls, privateReductionVariables, reductionVariableMap,
+            isByRef)))
+      bodyGenStatus = failure();
+    SmallVector<llvm::PHINode *> phis;
     convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region", builder,
-                        moduleTranslation, bodyGenStatus);
+                        moduleTranslation, bodyGenStatus, &phis);
   };
   InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
diff --git a/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
new file mode 100644
index 00000000000000..1d4d22d5413c61
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+omp.declare_reduction @add_reduction_i32 : i32 init {
+^bb0(%arg0: i32):
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%0 : i32)
+} combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %0 = llvm.add %arg0, %arg1 : i32
+  omp.yield(%0 : i32)
+}
+llvm.func @_QPtest_task_reduciton() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+  omp.taskgroup task_reduction(@add_reduction_i32 %1 -> %arg0 : !llvm.ptr) {
+      %2 = llvm.load %1 : !llvm.ptr -> i32
+      %3 = llvm.mlir.constant(1 : i32) : i32
+      %4 = llvm.add %2, %3 : i32
+      llvm.store %4, %1 : i32, !llvm.ptr
+      omp.terminator
+  }
+  llvm.return
+}
+
+//CHECK-LABEL: define void @_QPtest_task_reduciton() {
+//CHECK:   %[[VAL1:.*]] = alloca i32, i64 1, align 4
+//CHECK:   %[[RED_ARRY:.*]] = alloca [1 x %kmp_taskred_input_t], align 8
+//CHECK:   br label %entry
+
+//CHECK: entry:
+//CHECK:   %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   call void @__kmpc_taskgroup(ptr @1, i32 %[[TID]])
+//CHECK:   %[[RED_ELEMENT:.*]] = getelementptr [1 x %kmp_taskred_input_t], ptr %[[RED_ARRY]], i32 0, i32 0
+//CHECK:   %[[RED_SHARED:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 0
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_SHARED]], align 8
+//CHECK:   %[[RED_ORIG:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 1
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_ORIG]], align 8
+//CHECK:   %[[RED_SIZE:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 2
+//CHECK:   store i64 4, ptr %[[RED_SIZE]], align 4
+//CHECK:   %[[RED_INIT:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 3
+//CHECK:   store ptr @red_init, ptr %[[RED_INIT]], align 8
+//CHECK:   %[[RED_FINI:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 4
+//CHECK:   store ptr null, ptr %[[RED_FINI]], align 8
+//CHECK:   %[[RED_COMB:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 5
+//CHECK:   store ptr @red_comb, ptr %[[RED_COMB]], align 8
+//CHECK:   %[[FLAGS:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 6
+//CHECK:   store i64 0, ptr %[[FLAGS]], align 4
+//CHECK:   %[[TID1:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   %2 = call ptr @__kmpc_taskred_init(i32 %[[TID1]], i32 1, ptr %[[RED_ARRY]])
+//CHECK:   br label %omp.taskgroup.region
+
+//CHECK: omp.taskgroup.region:
+//CHECK:   %[[VAL3:.*]] = load i32, ptr %[[VAL1]], align 4
+//CHECK:   %4 = add i32 %[[VAL3]], 1
+//CHECK:   store i32 %4, ptr %[[VAL1]], align 4
+//CHECK:   br label %omp.region.cont
+
+//CHECK: omp.region.cont:
+//CHECK:   br label %taskgroup.exit
+
+//CHECK: taskgroup.exit:
+//CHECK:   call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[TID]])
+//CHECK:   ret void
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_init(ptr noalias %0, ptr noalias %1) #2 {
+//CHECK: entry:
+//CHECK:   store i32 0, ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_comb(ptr %0, ptr %1) #2 {
+//CHECK: entry:
+//CHECK:   %[[LD0:.*]] = load i32, ptr %0, align 4
+//CHECK:   %[[LD1:.*]] = load i32, ptr %1, align 4
+//CHECK:   %[[RES:.*]] = add i32 %[[LD0]], %[[LD1]]
+//CHECK:   store i32 %[[RES]], ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropped a few drive by nit comments.

@@ -1142,19 +1141,238 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
return bodyGenStatus;
}

template <typename OP>
llvm::Value *createTaskReductionFunction(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This should be static and requires a comment.

Comment on lines +1155 to +1157
if (region.empty() || isByRef[Cnt]) {
return llvm::Constant::getNullValue(OpaquePtrTy);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (region.empty() || isByRef[Cnt]) {
return llvm::Constant::getNullValue(OpaquePtrTy);
}
if (region.empty() || isByRef[Cnt])
return llvm::Constant::getNullValue(OpaquePtrTy);

Comment on lines +1158 to +1159
llvm::FunctionType *funcType =
llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
llvm::FunctionType *funcType =
llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
auto *funcType =
llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);

Nit: We use auto when the type is explicitly given on the RHS already.

@@ -1142,19 +1141,238 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
return bodyGenStatus;
}

template <typename OP>
llvm::Value *createTaskReductionFunction(
llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
llvm::IRBuilderBase &builder, StringRef name, llvm::Type *redTy,

llvm::Value *arg0 = function->getArg(0);
llvm::Value *arg1 = function->getArg(1);

if (name == "red_init") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Please avoid magic strings. Introduce global constexpr StringLiterals for such things.

DenseMap<Value, llvm::Value *> &reductionVariableMap) {
llvm::LLVMContext &Context = builder.getContext();
llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
// TODO: by-ref reduction variables are yet to be handled.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it better to crash/assert than silently swallowing cases like this?

return function;
}

void emitTaskRedInitCall(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: As above, this should be static and requires a comment.

llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
llvm::Value *ArrayAlloca) {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

void emitTaskRedInitCall(
llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
llvm::Value *ArrayAlloca) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
llvm::Value *ArrayAlloca) {
llvm::Value *arrayAlloca) {

Ultra nit: Uppercase beginning variables violate the MLIR style guide.

//CHECK: %[[RED_ARRY:.*]] = alloca [1 x %kmp_taskred_input_t], align 8
//CHECK: br label %entry

//CHECK: entry:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a bit odd. Normally, "entry" is really the first block of a function.

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this. The implementation looks a lot cleaner than the reduction clause.

}

SmallVector<llvm::Value *, 1> phis;
if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Adding a block name makes this a lot easier to debug if the region is translated into multiple basic blocks.

Comment on lines +1213 to +1216
llvm::Function *TaskRedInitFn =
moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
llvm::omp::OMPRTL___kmpc_taskred_init);
builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non-task reductions these sorts of function calls are generated by OpenMPIRBuilder so that we can share code with clang.

Are the clang people happy with us having diverging implementations here? If so I don't mind.

reductionDecls, privateReductionVariables, reductionVariableMap,
isByRef)))
bodyGenStatus = failure();
SmallVector<llvm::PHINode *> phis;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The phis seem unused here. Can the task reduction region omp.yield any values or is omp.terminator the only terminator?

%0 = llvm.mlir.constant(1 : i64) : i64
%1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
omp.taskgroup task_reduction(@add_reduction_i32 %1 -> %arg0 : !llvm.ptr) {
%2 = llvm.load %1 : !llvm.ptr -> i32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is only test code but shouldn't this be using the block argument?

//CHECK: ret void
//CHECK: }

//CHECK-LABEL: define ptr @red_init(ptr noalias %0, ptr noalias %1) #2 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name could collide with symbols in the program. clang uses .red_init. (and likewise .red_comb.).

@tblah
Copy link
Contributor

tblah commented Oct 10, 2024

Are multiple blocks allowed

  1. in the reduction init region?
  2. in the reduction combine region?
  3. in the taskgroup region?

Please could you add test(s) showing that these work correctly or asserts if these are TODO (I think you don't need multi block reduction regions until you support allocatable arrays).

@kiranchandramohan
Copy link
Contributor

@tblah Did we switch all flang reductions to byref? If so, is there value in having task reductions by value?

@tblah
Copy link
Contributor

tblah commented Oct 11, 2024

@tblah Did we switch all flang reductions to byref? If so, is there value in having task reductions by value?

No. We still support by value reductions where we can because there were concerns that the indirection involved in byref would regress performance in those cases

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants