Skip to content

[flang][mlir] Add support for translating task_reduction to LLVMIR #120957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

NimishMishra
Copy link
Contributor

This PR adds support for translating task_reduction to LLVMIR for both pass-by-value and pass-by-reference. Depending on whether the reduction variables are pass-by-val or pass-by-ref, appropriate red_init and red_comb functions are emitted; and components of kmp_taskred_input_t are set. Finally, the runtime call __kmpc_taskred_init is emitted on the populated kmp_taskred_input_t.

@llvmbot
Copy link
Member

llvmbot commented Dec 23, 2024

@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-mlir

Author: None (NimishMishra)

Changes

This PR adds support for translating task_reduction to LLVMIR for both pass-by-value and pass-by-reference. Depending on whether the reduction variables are pass-by-val or pass-by-ref, appropriate red_init and red_comb functions are emitted; and components of kmp_taskred_input_t are set. Finally, the runtime call __kmpc_taskred_init is emitted on the populated kmp_taskred_input_t.


Full diff: https://github.com/llvm/llvm-project/pull/120957.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+6)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+250-10)
  • (added) mlir/test/Target/LLVMIR/openmp-task-reduction.mlir (+79)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-28)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 98d2e80ed2d81d..4b9f85e8baf468 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1207,6 +1207,12 @@ class OpenMP_TaskReductionClauseSkip<
     unsigned numTaskReductionBlockArgs() {
       return getTaskReductionVars().size();
     }
+
+    /// Returns the number of reduction variables.
+    unsigned getNumReductionVars() { return getReductionVars().size(); }
+
+
+   auto getReductionSyms() { return getTaskReductionSyms(); }
   }];
 
   let description = [{
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d591c98a5497f8..7126a0803189a2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -228,11 +228,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getThreadLimit())
       result = todo("thread_limit");
   };
-  auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
-    if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
-        op.getTaskReductionSyms())
-      result = todo("task_reduction");
-  };
   auto checkUntied = [&todo](auto op, LogicalResult &result) {
     if (op.getUntied())
       result = todo("untied");
@@ -259,10 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkInReduction(op, result);
         checkPriority(op, result);
       })
-      .Case([&](omp::TaskgroupOp op) {
-        checkAllocate(op, result);
-        checkTaskReduction(op, result);
-      })
+      .Case([&](omp::TaskgroupOp op) { checkAllocate(op, result); })
       .Case([&](omp::TaskwaitOp op) {
         checkDepend(op, result);
         checkNowait(op, result);
@@ -1787,6 +1779,238 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
+template <typename OP>
+llvm::Value *createTaskReductionFunction(
+    llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
+    LLVM::ModuleTranslation &moduleTranslation,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region &region,
+    OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap) {
+  llvm::LLVMContext &Context = builder.getContext();
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
+  if (region.empty()) {
+    return llvm::Constant::getNullValue(OpaquePtrTy);
+  }
+  llvm::FunctionType *funcType = nullptr;
+  if (isByRef[Cnt])
+    funcType = llvm::FunctionType::get(builder.getVoidTy(),
+                                       {OpaquePtrTy, OpaquePtrTy}, false);
+  else
+    funcType =
+        llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
+  llvm::Function *function =
+      llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
+                             builder.GetInsertBlock()->getModule());
+  function->setDoesNotRecurse();
+  llvm::BasicBlock *entry =
+      llvm::BasicBlock::Create(Context, "entry", function);
+  llvm::IRBuilder<> bbBuilder(entry);
+
+  llvm::Value *arg0 = function->getArg(0);
+  llvm::Value *arg1 = function->getArg(1);
+
+  if (name == "red_init") {
+    function->addParamAttr(0, llvm::Attribute::NoAlias);
+    function->addParamAttr(1, llvm::Attribute::NoAlias);
+    if (isByRef[Cnt]) {
+      // TODO: Handle case where the initializer uses initialization from
+      // declare reduction construct using `arg1Alloca`.
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *LoadVal =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      moduleTranslation.mapValue(reductionDecls[Cnt].getInitializerAllocArg(),
+                                 LoadVal);
+    } else {
+      mapInitializationArgs(op, moduleTranslation, reductionDecls,
+                            reductionVariableMap, Cnt);
+    }
+  } else if (name == "red_comb") {
+    if (isByRef[Cnt]) {
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *arg0L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      llvm::Value *arg1L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg1Alloca);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    } else {
+      llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
+      llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    }
+  }
+
+  SmallVector<llvm::Value *, 1> phis;
+  if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
+                                     &phis)))
+    return nullptr;
+  assert(
+      phis.size() == 1 &&
+      "expected one value to be yielded from the reduction declaration region");
+  if (!isByRef[Cnt]) {
+    bbBuilder.CreateStore(phis[0], arg0);
+    bbBuilder.CreateRet(arg0); // Return from the function
+  } else {
+    bbBuilder.CreateRet(nullptr);
+  }
+  return function;
+}
+
+void emitTaskRedInitCall(
+    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
+    const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
+    llvm::Value *ArrayAlloca) {
+  llvm::LLVMContext &Context = builder.getContext();
+  uint32_t SrcLocStrSize;
+  llvm::Constant *SrcLocStr =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc,
+                                                                 SrcLocStrSize);
+  llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(
+      SrcLocStr, SrcLocStrSize);
+  llvm::Value *ThreadID =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
+  llvm::Constant *ConstInt =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize);
+  llvm::Function *TaskRedInitFn =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
+          llvm::omp::OMPRTL___kmpc_taskred_init);
+  builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
+}
+
+template <typename OP>
+static LogicalResult allocAndInitializeTaskReductionVars(
+    OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation,
+    llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap,
+    llvm::ArrayRef<bool> isByRef) {
+
+  if (op.getNumReductionVars() == 0)
+    return success();
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::LLVMContext &Context = builder.getContext();
+  SmallVector<DeferredStore> deferredStores;
+
+  // Save the current insertion point
+  auto oldIP = builder.saveIP();
+
+  // Set insertion point after the allocations
+  builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
+
+  // Define the kmp_taskred_input_t structure
+  llvm::StructType *kmp_taskred_input_t =
+      llvm::StructType::create(Context, "kmp_taskred_input_t");
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0); // void*
+  llvm::Type *SizeTy = builder.getInt64Ty(); // size_t (assumed to be i64)
+  llvm::Type *FlagsTy = llvm::Type::getInt32Ty(Context); // flags (i32)
+
+  // Structure members
+  std::vector<llvm::Type *> structMembers = {
+      OpaquePtrTy, // reduce_shar (void*)
+      OpaquePtrTy, // reduce_orig (void*)
+      SizeTy,      // reduce_size (size_t)
+      OpaquePtrTy, // reduce_init (void*)
+      OpaquePtrTy, // reduce_fini (void*)
+      OpaquePtrTy, // reduce_comb (void*)
+      FlagsTy      // flags (i32)
+  };
+
+  kmp_taskred_input_t->setBody(structMembers);
+  int arraySize = op.getNumReductionVars();
+  llvm::ArrayType *ArrayTy =
+      llvm::ArrayType::get(kmp_taskred_input_t, arraySize);
+
+  // Allocate the array for kmp_taskred_input_t
+  llvm::AllocaInst *ArrayAlloca =
+      builder.CreateAlloca(ArrayTy, nullptr, "kmp_taskred_array");
+
+  // Restore the insertion point
+  builder.restoreIP(oldIP);
+  llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();
+
+  for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
+    llvm::Value *shared =
+        moduleTranslation.lookupValue(op.getReductionVars()[Cnt]);
+    // Create a GEP to access the reduction element
+    llvm::Value *StructPtr = builder.CreateGEP(
+        ArrayTy, ArrayAlloca, {builder.getInt32(0), builder.getInt32(Cnt)},
+        "red_element");
+    llvm::Value *FieldPtrReduceShar = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 0, "reduce_shar");
+    builder.CreateStore(shared, FieldPtrReduceShar);
+
+    llvm::Value *FieldPtrReduceOrig = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 1, "reduce_orig");
+    builder.CreateStore(shared, FieldPtrReduceOrig);
+
+    // Store size of the reduction variable
+    llvm::Value *FieldPtrReduceSize = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 2, "reduce_size");
+    llvm::Type *redTy;
+    if (auto *alloca = dyn_cast<llvm::AllocaInst>(shared)) {
+      redTy = alloca->getAllocatedType();
+      uint64_t sizeInBytes = DL.getTypeAllocSize(redTy);
+      llvm::ConstantInt *sizeConst =
+          llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
+      builder.CreateStore(sizeConst, FieldPtrReduceSize);
+    } else {
+      llvm_unreachable("Non alloca instruction found.");
+    }
+
+    // Initialize reduction variable
+    llvm::Value *FieldPtrReduceInit = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 3, "reduce_init");
+    llvm::Value *initFunction = createTaskReductionFunction(
+        builder, "red_init", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getInitializerRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(initFunction, FieldPtrReduceInit);
+
+    // Create finish and combine functions
+    llvm::Value *FieldPtrReduceFini = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 4, "reduce_fini");
+    llvm::Value *finiFunction = createTaskReductionFunction(
+        builder, "red_fini", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getCleanupRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(finiFunction, FieldPtrReduceFini);
+
+    llvm::Value *FieldPtrReduceComb = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 5, "reduce_comb");
+    llvm::Value *combFunction = createTaskReductionFunction(
+        builder, "red_comb", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getReductionRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(combFunction, FieldPtrReduceComb);
+
+    llvm::Value *FieldPtrFlags =
+        builder.CreateStructGEP(kmp_taskred_input_t, StructPtr, 6, "flags");
+    llvm::ConstantInt *flagVal =
+        llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), 0);
+    builder.CreateStore(flagVal, FieldPtrFlags);
+  }
+
+  // Emit the runtime call
+  emitTaskRedInitCall(builder, moduleTranslation, ompLoc, arraySize,
+                      ArrayAlloca);
+  return success();
+}
+
 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
@@ -1794,9 +2018,25 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   if (failed(checkImplementationStatus(*tgOp)))
     return failure();
+  LogicalResult bodyGenStatus = success();
+  // Setup for `task_reduction`
+  llvm::ArrayRef<bool> isByRef = getIsByRef(tgOp.getTaskReductionByref());
+  assert(isByRef.size() == tgOp.getNumReductionVars());
+  SmallVector<omp::DeclareReductionOp> reductionDecls;
+  collectReductionDecls(tgOp, reductionDecls);
+  SmallVector<llvm::Value *> privateReductionVariables(
+      tgOp.getNumReductionVars());
+  DenseMap<Value, llvm::Value *> reductionVariableMap;
+  MutableArrayRef<BlockArgument> reductionArgs =
+      tgOp.getRegion().getArguments();
 
   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
     builder.restoreIP(codegenIP);
+    if (failed(allocAndInitializeTaskReductionVars(
+            tgOp, reductionArgs, builder, moduleTranslation, allocaIP,
+            reductionDecls, privateReductionVariables, reductionVariableMap,
+            isByRef)))
+      bodyGenStatus = failure();
     return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
                                builder, moduleTranslation)
         .takeError();
@@ -1812,7 +2052,7 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
     return failure();
 
   builder.restoreIP(*afterIP);
-  return success();
+  return bodyGenStatus;
 }
 
 static LogicalResult
diff --git a/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
new file mode 100644
index 00000000000000..1d4d22d5413c61
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+omp.declare_reduction @add_reduction_i32 : i32 init {
+^bb0(%arg0: i32):
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%0 : i32)
+} combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %0 = llvm.add %arg0, %arg1 : i32
+  omp.yield(%0 : i32)
+}
+llvm.func @_QPtest_task_reduciton() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+  omp.taskgroup task_reduction(@add_reduction_i32 %1 -> %arg0 : !llvm.ptr) {
+      %2 = llvm.load %1 : !llvm.ptr -> i32
+      %3 = llvm.mlir.constant(1 : i32) : i32
+      %4 = llvm.add %2, %3 : i32
+      llvm.store %4, %1 : i32, !llvm.ptr
+      omp.terminator
+  }
+  llvm.return
+}
+
+//CHECK-LABEL: define void @_QPtest_task_reduciton() {
+//CHECK:   %[[VAL1:.*]] = alloca i32, i64 1, align 4
+//CHECK:   %[[RED_ARRY:.*]] = alloca [1 x %kmp_taskred_input_t], align 8
+//CHECK:   br label %entry
+
+//CHECK: entry:
+//CHECK:   %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   call void @__kmpc_taskgroup(ptr @1, i32 %[[TID]])
+//CHECK:   %[[RED_ELEMENT:.*]] = getelementptr [1 x %kmp_taskred_input_t], ptr %[[RED_ARRY]], i32 0, i32 0
+//CHECK:   %[[RED_SHARED:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 0
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_SHARED]], align 8
+//CHECK:   %[[RED_ORIG:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 1
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_ORIG]], align 8
+//CHECK:   %[[RED_SIZE:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 2
+//CHECK:   store i64 4, ptr %[[RED_SIZE]], align 4
+//CHECK:   %[[RED_INIT:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 3
+//CHECK:   store ptr @red_init, ptr %[[RED_INIT]], align 8
+//CHECK:   %[[RED_FINI:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 4
+//CHECK:   store ptr null, ptr %[[RED_FINI]], align 8
+//CHECK:   %[[RED_COMB:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 5
+//CHECK:   store ptr @red_comb, ptr %[[RED_COMB]], align 8
+//CHECK:   %[[FLAGS:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 6
+//CHECK:   store i64 0, ptr %[[FLAGS]], align 4
+//CHECK:   %[[TID1:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   %2 = call ptr @__kmpc_taskred_init(i32 %[[TID1]], i32 1, ptr %[[RED_ARRY]])
+//CHECK:   br label %omp.taskgroup.region
+
+//CHECK: omp.taskgroup.region:
+//CHECK:   %[[VAL3:.*]] = load i32, ptr %[[VAL1]], align 4
+//CHECK:   %4 = add i32 %[[VAL3]], 1
+//CHECK:   store i32 %4, ptr %[[VAL1]], align 4
+//CHECK:   br label %omp.region.cont
+
+//CHECK: omp.region.cont:
+//CHECK:   br label %taskgroup.exit
+
+//CHECK: taskgroup.exit:
+//CHECK:   call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[TID]])
+//CHECK:   ret void
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_init(ptr noalias %0, ptr noalias %1) #2 {
+//CHECK: entry:
+//CHECK:   store i32 0, ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_comb(ptr %0, ptr %1) #2 {
+//CHECK: entry:
+//CHECK:   %[[LD0:.*]] = load i32, ptr %0, align 4
+//CHECK:   %[[LD1:.*]] = load i32, ptr %1, align 4
+//CHECK:   %[[RES:.*]] = add i32 %[[LD0]], %[[LD1]]
+//CHECK:   store i32 %[[RES]], ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 8ae795ec1ec6b0..a0774d859eecf6 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -451,34 +451,6 @@ llvm.func @taskgroup_allocate(%x : !llvm.ptr) {
 
 // -----
 
-omp.declare_reduction @add_f32 : f32
-init {
-^bb0(%arg: f32):
-  %0 = llvm.mlir.constant(0.0 : f32) : f32
-  omp.yield (%0 : f32)
-}
-combiner {
-^bb1(%arg0: f32, %arg1: f32):
-  %1 = llvm.fadd %arg0, %arg1 : f32
-  omp.yield (%1 : f32)
-}
-atomic {
-^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
-  %2 = llvm.load %arg3 : !llvm.ptr -> f32
-  llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
-  omp.yield
-}
-llvm.func @taskgroup_task_reduction(%x : !llvm.ptr) {
-  // expected-error@below {{not yet implemented: Unhandled clause task_reduction in omp.taskgroup operation}}
-  // expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
-  omp.taskgroup task_reduction(@add_f32 %x -> %prv : !llvm.ptr) {
-    omp.terminator
-  }
-  llvm.return
-}
-
-// -----
-
 llvm.func @taskloop(%lb : i32, %ub : i32, %step : i32) {
   // expected-error@below {{not yet implemented: omp.taskloop}}
   // expected-error@below {{LLVM Translation failed for operation: omp.taskloop}}

@llvmbot
Copy link
Member

llvmbot commented Dec 23, 2024

@llvm/pr-subscribers-mlir-llvm

Author: None (NimishMishra)

Changes

This PR adds support for translating task_reduction to LLVMIR for both pass-by-value and pass-by-reference. Depending on whether the reduction variables are pass-by-val or pass-by-ref, appropriate red_init and red_comb functions are emitted; and components of kmp_taskred_input_t are set. Finally, the runtime call __kmpc_taskred_init is emitted on the populated kmp_taskred_input_t.


Full diff: https://github.com/llvm/llvm-project/pull/120957.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+6)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+250-10)
  • (added) mlir/test/Target/LLVMIR/openmp-task-reduction.mlir (+79)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-28)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 98d2e80ed2d81d..4b9f85e8baf468 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1207,6 +1207,12 @@ class OpenMP_TaskReductionClauseSkip<
     unsigned numTaskReductionBlockArgs() {
       return getTaskReductionVars().size();
     }
+
+    /// Returns the number of reduction variables.
+    unsigned getNumReductionVars() { return getReductionVars().size(); }
+
+
+   auto getReductionSyms() { return getTaskReductionSyms(); }
   }];
 
   let description = [{
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d591c98a5497f8..7126a0803189a2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -228,11 +228,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getThreadLimit())
       result = todo("thread_limit");
   };
-  auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
-    if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
-        op.getTaskReductionSyms())
-      result = todo("task_reduction");
-  };
   auto checkUntied = [&todo](auto op, LogicalResult &result) {
     if (op.getUntied())
       result = todo("untied");
@@ -259,10 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkInReduction(op, result);
         checkPriority(op, result);
       })
-      .Case([&](omp::TaskgroupOp op) {
-        checkAllocate(op, result);
-        checkTaskReduction(op, result);
-      })
+      .Case([&](omp::TaskgroupOp op) { checkAllocate(op, result); })
       .Case([&](omp::TaskwaitOp op) {
         checkDepend(op, result);
         checkNowait(op, result);
@@ -1787,6 +1779,238 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
+template <typename OP>
+llvm::Value *createTaskReductionFunction(
+    llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
+    LLVM::ModuleTranslation &moduleTranslation,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region &region,
+    OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap) {
+  llvm::LLVMContext &Context = builder.getContext();
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
+  if (region.empty()) {
+    return llvm::Constant::getNullValue(OpaquePtrTy);
+  }
+  llvm::FunctionType *funcType = nullptr;
+  if (isByRef[Cnt])
+    funcType = llvm::FunctionType::get(builder.getVoidTy(),
+                                       {OpaquePtrTy, OpaquePtrTy}, false);
+  else
+    funcType =
+        llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
+  llvm::Function *function =
+      llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
+                             builder.GetInsertBlock()->getModule());
+  function->setDoesNotRecurse();
+  llvm::BasicBlock *entry =
+      llvm::BasicBlock::Create(Context, "entry", function);
+  llvm::IRBuilder<> bbBuilder(entry);
+
+  llvm::Value *arg0 = function->getArg(0);
+  llvm::Value *arg1 = function->getArg(1);
+
+  if (name == "red_init") {
+    function->addParamAttr(0, llvm::Attribute::NoAlias);
+    function->addParamAttr(1, llvm::Attribute::NoAlias);
+    if (isByRef[Cnt]) {
+      // TODO: Handle case where the initializer uses initialization from
+      // declare reduction construct using `arg1Alloca`.
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *LoadVal =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      moduleTranslation.mapValue(reductionDecls[Cnt].getInitializerAllocArg(),
+                                 LoadVal);
+    } else {
+      mapInitializationArgs(op, moduleTranslation, reductionDecls,
+                            reductionVariableMap, Cnt);
+    }
+  } else if (name == "red_comb") {
+    if (isByRef[Cnt]) {
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *arg0L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      llvm::Value *arg1L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg1Alloca);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    } else {
+      llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
+      llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    }
+  }
+
+  SmallVector<llvm::Value *, 1> phis;
+  if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
+                                     &phis)))
+    return nullptr;
+  assert(
+      phis.size() == 1 &&
+      "expected one value to be yielded from the reduction declaration region");
+  if (!isByRef[Cnt]) {
+    bbBuilder.CreateStore(phis[0], arg0);
+    bbBuilder.CreateRet(arg0); // Return from the function
+  } else {
+    bbBuilder.CreateRet(nullptr);
+  }
+  return function;
+}
+
+void emitTaskRedInitCall(
+    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
+    const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
+    llvm::Value *ArrayAlloca) {
+  llvm::LLVMContext &Context = builder.getContext();
+  uint32_t SrcLocStrSize;
+  llvm::Constant *SrcLocStr =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc,
+                                                                 SrcLocStrSize);
+  llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(
+      SrcLocStr, SrcLocStrSize);
+  llvm::Value *ThreadID =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
+  llvm::Constant *ConstInt =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize);
+  llvm::Function *TaskRedInitFn =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
+          llvm::omp::OMPRTL___kmpc_taskred_init);
+  builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
+}
+
+template <typename OP>
+static LogicalResult allocAndInitializeTaskReductionVars(
+    OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation,
+    llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap,
+    llvm::ArrayRef<bool> isByRef) {
+
+  if (op.getNumReductionVars() == 0)
+    return success();
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::LLVMContext &Context = builder.getContext();
+  SmallVector<DeferredStore> deferredStores;
+
+  // Save the current insertion point
+  auto oldIP = builder.saveIP();
+
+  // Set insertion point after the allocations
+  builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
+
+  // Define the kmp_taskred_input_t structure
+  llvm::StructType *kmp_taskred_input_t =
+      llvm::StructType::create(Context, "kmp_taskred_input_t");
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0); // void*
+  llvm::Type *SizeTy = builder.getInt64Ty(); // size_t (assumed to be i64)
+  llvm::Type *FlagsTy = llvm::Type::getInt32Ty(Context); // flags (i32)
+
+  // Structure members
+  std::vector<llvm::Type *> structMembers = {
+      OpaquePtrTy, // reduce_shar (void*)
+      OpaquePtrTy, // reduce_orig (void*)
+      SizeTy,      // reduce_size (size_t)
+      OpaquePtrTy, // reduce_init (void*)
+      OpaquePtrTy, // reduce_fini (void*)
+      OpaquePtrTy, // reduce_comb (void*)
+      FlagsTy      // flags (i32)
+  };
+
+  kmp_taskred_input_t->setBody(structMembers);
+  int arraySize = op.getNumReductionVars();
+  llvm::ArrayType *ArrayTy =
+      llvm::ArrayType::get(kmp_taskred_input_t, arraySize);
+
+  // Allocate the array for kmp_taskred_input_t
+  llvm::AllocaInst *ArrayAlloca =
+      builder.CreateAlloca(ArrayTy, nullptr, "kmp_taskred_array");
+
+  // Restore the insertion point
+  builder.restoreIP(oldIP);
+  llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();
+
+  for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
+    llvm::Value *shared =
+        moduleTranslation.lookupValue(op.getReductionVars()[Cnt]);
+    // Create a GEP to access the reduction element
+    llvm::Value *StructPtr = builder.CreateGEP(
+        ArrayTy, ArrayAlloca, {builder.getInt32(0), builder.getInt32(Cnt)},
+        "red_element");
+    llvm::Value *FieldPtrReduceShar = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 0, "reduce_shar");
+    builder.CreateStore(shared, FieldPtrReduceShar);
+
+    llvm::Value *FieldPtrReduceOrig = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 1, "reduce_orig");
+    builder.CreateStore(shared, FieldPtrReduceOrig);
+
+    // Store size of the reduction variable
+    llvm::Value *FieldPtrReduceSize = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 2, "reduce_size");
+    llvm::Type *redTy;
+    if (auto *alloca = dyn_cast<llvm::AllocaInst>(shared)) {
+      redTy = alloca->getAllocatedType();
+      uint64_t sizeInBytes = DL.getTypeAllocSize(redTy);
+      llvm::ConstantInt *sizeConst =
+          llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
+      builder.CreateStore(sizeConst, FieldPtrReduceSize);
+    } else {
+      llvm_unreachable("Non alloca instruction found.");
+    }
+
+    // Initialize reduction variable
+    llvm::Value *FieldPtrReduceInit = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 3, "reduce_init");
+    llvm::Value *initFunction = createTaskReductionFunction(
+        builder, "red_init", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getInitializerRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(initFunction, FieldPtrReduceInit);
+
+    // Create finish and combine functions
+    llvm::Value *FieldPtrReduceFini = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 4, "reduce_fini");
+    llvm::Value *finiFunction = createTaskReductionFunction(
+        builder, "red_fini", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getCleanupRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(finiFunction, FieldPtrReduceFini);
+
+    llvm::Value *FieldPtrReduceComb = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 5, "reduce_comb");
+    llvm::Value *combFunction = createTaskReductionFunction(
+        builder, "red_comb", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getReductionRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(combFunction, FieldPtrReduceComb);
+
+    llvm::Value *FieldPtrFlags =
+        builder.CreateStructGEP(kmp_taskred_input_t, StructPtr, 6, "flags");
+    llvm::ConstantInt *flagVal =
+        llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), 0);
+    builder.CreateStore(flagVal, FieldPtrFlags);
+  }
+
+  // Emit the runtime call
+  emitTaskRedInitCall(builder, moduleTranslation, ompLoc, arraySize,
+                      ArrayAlloca);
+  return success();
+}
+
 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
@@ -1794,9 +2018,25 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   if (failed(checkImplementationStatus(*tgOp)))
     return failure();
+  LogicalResult bodyGenStatus = success();
+  // Setup for `task_reduction`
+  llvm::ArrayRef<bool> isByRef = getIsByRef(tgOp.getTaskReductionByref());
+  assert(isByRef.size() == tgOp.getNumReductionVars());
+  SmallVector<omp::DeclareReductionOp> reductionDecls;
+  collectReductionDecls(tgOp, reductionDecls);
+  SmallVector<llvm::Value *> privateReductionVariables(
+      tgOp.getNumReductionVars());
+  DenseMap<Value, llvm::Value *> reductionVariableMap;
+  MutableArrayRef<BlockArgument> reductionArgs =
+      tgOp.getRegion().getArguments();
 
   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
     builder.restoreIP(codegenIP);
+    if (failed(allocAndInitializeTaskReductionVars(
+            tgOp, reductionArgs, builder, moduleTranslation, allocaIP,
+            reductionDecls, privateReductionVariables, reductionVariableMap,
+            isByRef)))
+      bodyGenStatus = failure();
     return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
                                builder, moduleTranslation)
         .takeError();
@@ -1812,7 +2052,7 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
     return failure();
 
   builder.restoreIP(*afterIP);
-  return success();
+  return bodyGenStatus;
 }
 
 static LogicalResult
diff --git a/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
new file mode 100644
index 00000000000000..1d4d22d5413c61
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+omp.declare_reduction @add_reduction_i32 : i32 init {
+^bb0(%arg0: i32):
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%0 : i32)
+} combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %0 = llvm.add %arg0, %arg1 : i32
+  omp.yield(%0 : i32)
+}
+llvm.func @_QPtest_task_reduciton() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+  omp.taskgroup task_reduction(@add_reduction_i32 %1 -> %arg0 : !llvm.ptr) {
+      %2 = llvm.load %1 : !llvm.ptr -> i32
+      %3 = llvm.mlir.constant(1 : i32) : i32
+      %4 = llvm.add %2, %3 : i32
+      llvm.store %4, %1 : i32, !llvm.ptr
+      omp.terminator
+  }
+  llvm.return
+}
+
+//CHECK-LABEL: define void @_QPtest_task_reduciton() {
+//CHECK:   %[[VAL1:.*]] = alloca i32, i64 1, align 4
+//CHECK:   %[[RED_ARRY:.*]] = alloca [1 x %kmp_taskred_input_t], align 8
+//CHECK:   br label %entry
+
+//CHECK: entry:
+//CHECK:   %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   call void @__kmpc_taskgroup(ptr @1, i32 %[[TID]])
+//CHECK:   %[[RED_ELEMENT:.*]] = getelementptr [1 x %kmp_taskred_input_t], ptr %[[RED_ARRY]], i32 0, i32 0
+//CHECK:   %[[RED_SHARED:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 0
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_SHARED]], align 8
+//CHECK:   %[[RED_ORIG:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 1
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_ORIG]], align 8
+//CHECK:   %[[RED_SIZE:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 2
+//CHECK:   store i64 4, ptr %[[RED_SIZE]], align 4
+//CHECK:   %[[RED_INIT:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 3
+//CHECK:   store ptr @red_init, ptr %[[RED_INIT]], align 8
+//CHECK:   %[[RED_FINI:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 4
+//CHECK:   store ptr null, ptr %[[RED_FINI]], align 8
+//CHECK:   %[[RED_COMB:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 5
+//CHECK:   store ptr @red_comb, ptr %[[RED_COMB]], align 8
+//CHECK:   %[[FLAGS:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 6
+//CHECK:   store i64 0, ptr %[[FLAGS]], align 4
+//CHECK:   %[[TID1:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   %2 = call ptr @__kmpc_taskred_init(i32 %[[TID1]], i32 1, ptr %[[RED_ARRY]])
+//CHECK:   br label %omp.taskgroup.region
+
+//CHECK: omp.taskgroup.region:
+//CHECK:   %[[VAL3:.*]] = load i32, ptr %[[VAL1]], align 4
+//CHECK:   %4 = add i32 %[[VAL3]], 1
+//CHECK:   store i32 %4, ptr %[[VAL1]], align 4
+//CHECK:   br label %omp.region.cont
+
+//CHECK: omp.region.cont:
+//CHECK:   br label %taskgroup.exit
+
+//CHECK: taskgroup.exit:
+//CHECK:   call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[TID]])
+//CHECK:   ret void
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_init(ptr noalias %0, ptr noalias %1) #2 {
+//CHECK: entry:
+//CHECK:   store i32 0, ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_comb(ptr %0, ptr %1) #2 {
+//CHECK: entry:
+//CHECK:   %[[LD0:.*]] = load i32, ptr %0, align 4
+//CHECK:   %[[LD1:.*]] = load i32, ptr %1, align 4
+//CHECK:   %[[RES:.*]] = add i32 %[[LD0]], %[[LD1]]
+//CHECK:   store i32 %[[RES]], ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 8ae795ec1ec6b0..a0774d859eecf6 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -451,34 +451,6 @@ llvm.func @taskgroup_allocate(%x : !llvm.ptr) {
 
 // -----
 
-omp.declare_reduction @add_f32 : f32
-init {
-^bb0(%arg: f32):
-  %0 = llvm.mlir.constant(0.0 : f32) : f32
-  omp.yield (%0 : f32)
-}
-combiner {
-^bb1(%arg0: f32, %arg1: f32):
-  %1 = llvm.fadd %arg0, %arg1 : f32
-  omp.yield (%1 : f32)
-}
-atomic {
-^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
-  %2 = llvm.load %arg3 : !llvm.ptr -> f32
-  llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
-  omp.yield
-}
-llvm.func @taskgroup_task_reduction(%x : !llvm.ptr) {
-  // expected-error@below {{not yet implemented: Unhandled clause task_reduction in omp.taskgroup operation}}
-  // expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
-  omp.taskgroup task_reduction(@add_f32 %x -> %prv : !llvm.ptr) {
-    omp.terminator
-  }
-  llvm.return
-}
-
-// -----
-
 llvm.func @taskloop(%lb : i32, %ub : i32, %step : i32) {
   // expected-error@below {{not yet implemented: omp.taskloop}}
   // expected-error@below {{LLVM Translation failed for operation: omp.taskloop}}

@llvmbot
Copy link
Member

llvmbot commented Dec 23, 2024

@llvm/pr-subscribers-flang-openmp

Author: None (NimishMishra)

Changes

This PR adds support for translating task_reduction to LLVMIR for both pass-by-value and pass-by-reference. Depending on whether the reduction variables are pass-by-val or pass-by-ref, appropriate red_init and red_comb functions are emitted; and components of kmp_taskred_input_t are set. Finally, the runtime call __kmpc_taskred_init is emitted on the populated kmp_taskred_input_t.


Full diff: https://github.com/llvm/llvm-project/pull/120957.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+6)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+250-10)
  • (added) mlir/test/Target/LLVMIR/openmp-task-reduction.mlir (+79)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-28)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 98d2e80ed2d81d..4b9f85e8baf468 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1207,6 +1207,12 @@ class OpenMP_TaskReductionClauseSkip<
     unsigned numTaskReductionBlockArgs() {
       return getTaskReductionVars().size();
     }
+
+    /// Returns the number of reduction variables.
+    unsigned getNumReductionVars() { return getReductionVars().size(); }
+
+
+   auto getReductionSyms() { return getTaskReductionSyms(); }
   }];
 
   let description = [{
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d591c98a5497f8..7126a0803189a2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -228,11 +228,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getThreadLimit())
       result = todo("thread_limit");
   };
-  auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
-    if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
-        op.getTaskReductionSyms())
-      result = todo("task_reduction");
-  };
   auto checkUntied = [&todo](auto op, LogicalResult &result) {
     if (op.getUntied())
       result = todo("untied");
@@ -259,10 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkInReduction(op, result);
         checkPriority(op, result);
       })
-      .Case([&](omp::TaskgroupOp op) {
-        checkAllocate(op, result);
-        checkTaskReduction(op, result);
-      })
+      .Case([&](omp::TaskgroupOp op) { checkAllocate(op, result); })
       .Case([&](omp::TaskwaitOp op) {
         checkDepend(op, result);
         checkNowait(op, result);
@@ -1787,6 +1779,238 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
+template <typename OP>
+llvm::Value *createTaskReductionFunction(
+    llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
+    LLVM::ModuleTranslation &moduleTranslation,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region &region,
+    OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap) {
+  llvm::LLVMContext &Context = builder.getContext();
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
+  if (region.empty()) {
+    return llvm::Constant::getNullValue(OpaquePtrTy);
+  }
+  llvm::FunctionType *funcType = nullptr;
+  if (isByRef[Cnt])
+    funcType = llvm::FunctionType::get(builder.getVoidTy(),
+                                       {OpaquePtrTy, OpaquePtrTy}, false);
+  else
+    funcType =
+        llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
+  llvm::Function *function =
+      llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
+                             builder.GetInsertBlock()->getModule());
+  function->setDoesNotRecurse();
+  llvm::BasicBlock *entry =
+      llvm::BasicBlock::Create(Context, "entry", function);
+  llvm::IRBuilder<> bbBuilder(entry);
+
+  llvm::Value *arg0 = function->getArg(0);
+  llvm::Value *arg1 = function->getArg(1);
+
+  if (name == "red_init") {
+    function->addParamAttr(0, llvm::Attribute::NoAlias);
+    function->addParamAttr(1, llvm::Attribute::NoAlias);
+    if (isByRef[Cnt]) {
+      // TODO: Handle case where the initializer uses initialization from
+      // declare reduction construct using `arg1Alloca`.
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *LoadVal =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      moduleTranslation.mapValue(reductionDecls[Cnt].getInitializerAllocArg(),
+                                 LoadVal);
+    } else {
+      mapInitializationArgs(op, moduleTranslation, reductionDecls,
+                            reductionVariableMap, Cnt);
+    }
+  } else if (name == "red_comb") {
+    if (isByRef[Cnt]) {
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *arg0L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      llvm::Value *arg1L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg1Alloca);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    } else {
+      llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
+      llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    }
+  }
+
+  SmallVector<llvm::Value *, 1> phis;
+  if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
+                                     &phis)))
+    return nullptr;
+  assert(
+      phis.size() == 1 &&
+      "expected one value to be yielded from the reduction declaration region");
+  if (!isByRef[Cnt]) {
+    bbBuilder.CreateStore(phis[0], arg0);
+    bbBuilder.CreateRet(arg0); // Return from the function
+  } else {
+    bbBuilder.CreateRet(nullptr);
+  }
+  return function;
+}
+
+void emitTaskRedInitCall(
+    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
+    const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
+    llvm::Value *ArrayAlloca) {
+  llvm::LLVMContext &Context = builder.getContext();
+  uint32_t SrcLocStrSize;
+  llvm::Constant *SrcLocStr =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc,
+                                                                 SrcLocStrSize);
+  llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(
+      SrcLocStr, SrcLocStrSize);
+  llvm::Value *ThreadID =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
+  llvm::Constant *ConstInt =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize);
+  llvm::Function *TaskRedInitFn =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
+          llvm::omp::OMPRTL___kmpc_taskred_init);
+  builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
+}
+
+template <typename OP>
+static LogicalResult allocAndInitializeTaskReductionVars(
+    OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation,
+    llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap,
+    llvm::ArrayRef<bool> isByRef) {
+
+  if (op.getNumReductionVars() == 0)
+    return success();
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::LLVMContext &Context = builder.getContext();
+  SmallVector<DeferredStore> deferredStores;
+
+  // Save the current insertion point
+  auto oldIP = builder.saveIP();
+
+  // Set insertion point after the allocations
+  builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
+
+  // Define the kmp_taskred_input_t structure
+  llvm::StructType *kmp_taskred_input_t =
+      llvm::StructType::create(Context, "kmp_taskred_input_t");
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0); // void*
+  llvm::Type *SizeTy = builder.getInt64Ty(); // size_t (assumed to be i64)
+  llvm::Type *FlagsTy = llvm::Type::getInt32Ty(Context); // flags (i32)
+
+  // Structure members
+  std::vector<llvm::Type *> structMembers = {
+      OpaquePtrTy, // reduce_shar (void*)
+      OpaquePtrTy, // reduce_orig (void*)
+      SizeTy,      // reduce_size (size_t)
+      OpaquePtrTy, // reduce_init (void*)
+      OpaquePtrTy, // reduce_fini (void*)
+      OpaquePtrTy, // reduce_comb (void*)
+      FlagsTy      // flags (i32)
+  };
+
+  kmp_taskred_input_t->setBody(structMembers);
+  int arraySize = op.getNumReductionVars();
+  llvm::ArrayType *ArrayTy =
+      llvm::ArrayType::get(kmp_taskred_input_t, arraySize);
+
+  // Allocate the array for kmp_taskred_input_t
+  llvm::AllocaInst *ArrayAlloca =
+      builder.CreateAlloca(ArrayTy, nullptr, "kmp_taskred_array");
+
+  // Restore the insertion point
+  builder.restoreIP(oldIP);
+  llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();
+
+  for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
+    llvm::Value *shared =
+        moduleTranslation.lookupValue(op.getReductionVars()[Cnt]);
+    // Create a GEP to access the reduction element
+    llvm::Value *StructPtr = builder.CreateGEP(
+        ArrayTy, ArrayAlloca, {builder.getInt32(0), builder.getInt32(Cnt)},
+        "red_element");
+    llvm::Value *FieldPtrReduceShar = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 0, "reduce_shar");
+    builder.CreateStore(shared, FieldPtrReduceShar);
+
+    llvm::Value *FieldPtrReduceOrig = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 1, "reduce_orig");
+    builder.CreateStore(shared, FieldPtrReduceOrig);
+
+    // Store size of the reduction variable
+    llvm::Value *FieldPtrReduceSize = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 2, "reduce_size");
+    llvm::Type *redTy;
+    if (auto *alloca = dyn_cast<llvm::AllocaInst>(shared)) {
+      redTy = alloca->getAllocatedType();
+      uint64_t sizeInBytes = DL.getTypeAllocSize(redTy);
+      llvm::ConstantInt *sizeConst =
+          llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
+      builder.CreateStore(sizeConst, FieldPtrReduceSize);
+    } else {
+      llvm_unreachable("Non alloca instruction found.");
+    }
+
+    // Initialize reduction variable
+    llvm::Value *FieldPtrReduceInit = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 3, "reduce_init");
+    llvm::Value *initFunction = createTaskReductionFunction(
+        builder, "red_init", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getInitializerRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(initFunction, FieldPtrReduceInit);
+
+    // Create finish and combine functions
+    llvm::Value *FieldPtrReduceFini = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 4, "reduce_fini");
+    llvm::Value *finiFunction = createTaskReductionFunction(
+        builder, "red_fini", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getCleanupRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(finiFunction, FieldPtrReduceFini);
+
+    llvm::Value *FieldPtrReduceComb = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 5, "reduce_comb");
+    llvm::Value *combFunction = createTaskReductionFunction(
+        builder, "red_comb", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getReductionRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(combFunction, FieldPtrReduceComb);
+
+    llvm::Value *FieldPtrFlags =
+        builder.CreateStructGEP(kmp_taskred_input_t, StructPtr, 6, "flags");
+    llvm::ConstantInt *flagVal =
+        llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), 0);
+    builder.CreateStore(flagVal, FieldPtrFlags);
+  }
+
+  // Emit the runtime call
+  emitTaskRedInitCall(builder, moduleTranslation, ompLoc, arraySize,
+                      ArrayAlloca);
+  return success();
+}
+
 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
@@ -1794,9 +2018,25 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   if (failed(checkImplementationStatus(*tgOp)))
     return failure();
+  LogicalResult bodyGenStatus = success();
+  // Setup for `task_reduction`
+  llvm::ArrayRef<bool> isByRef = getIsByRef(tgOp.getTaskReductionByref());
+  assert(isByRef.size() == tgOp.getNumReductionVars());
+  SmallVector<omp::DeclareReductionOp> reductionDecls;
+  collectReductionDecls(tgOp, reductionDecls);
+  SmallVector<llvm::Value *> privateReductionVariables(
+      tgOp.getNumReductionVars());
+  DenseMap<Value, llvm::Value *> reductionVariableMap;
+  MutableArrayRef<BlockArgument> reductionArgs =
+      tgOp.getRegion().getArguments();
 
   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
     builder.restoreIP(codegenIP);
+    if (failed(allocAndInitializeTaskReductionVars(
+            tgOp, reductionArgs, builder, moduleTranslation, allocaIP,
+            reductionDecls, privateReductionVariables, reductionVariableMap,
+            isByRef)))
+      bodyGenStatus = failure();
     return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
                                builder, moduleTranslation)
         .takeError();
@@ -1812,7 +2052,7 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
     return failure();
 
   builder.restoreIP(*afterIP);
-  return success();
+  return bodyGenStatus;
 }
 
 static LogicalResult
diff --git a/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
new file mode 100644
index 00000000000000..1d4d22d5413c61
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+omp.declare_reduction @add_reduction_i32 : i32 init {
+^bb0(%arg0: i32):
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%0 : i32)
+} combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %0 = llvm.add %arg0, %arg1 : i32
+  omp.yield(%0 : i32)
+}
+llvm.func @_QPtest_task_reduciton() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+  omp.taskgroup task_reduction(@add_reduction_i32 %1 -> %arg0 : !llvm.ptr) {
+      %2 = llvm.load %1 : !llvm.ptr -> i32
+      %3 = llvm.mlir.constant(1 : i32) : i32
+      %4 = llvm.add %2, %3 : i32
+      llvm.store %4, %1 : i32, !llvm.ptr
+      omp.terminator
+  }
+  llvm.return
+}
+
+//CHECK-LABEL: define void @_QPtest_task_reduciton() {
+//CHECK:   %[[VAL1:.*]] = alloca i32, i64 1, align 4
+//CHECK:   %[[RED_ARRY:.*]] = alloca [1 x %kmp_taskred_input_t], align 8
+//CHECK:   br label %entry
+
+//CHECK: entry:
+//CHECK:   %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   call void @__kmpc_taskgroup(ptr @1, i32 %[[TID]])
+//CHECK:   %[[RED_ELEMENT:.*]] = getelementptr [1 x %kmp_taskred_input_t], ptr %[[RED_ARRY]], i32 0, i32 0
+//CHECK:   %[[RED_SHARED:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 0
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_SHARED]], align 8
+//CHECK:   %[[RED_ORIG:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 1
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_ORIG]], align 8
+//CHECK:   %[[RED_SIZE:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 2
+//CHECK:   store i64 4, ptr %[[RED_SIZE]], align 4
+//CHECK:   %[[RED_INIT:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 3
+//CHECK:   store ptr @red_init, ptr %[[RED_INIT]], align 8
+//CHECK:   %[[RED_FINI:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 4
+//CHECK:   store ptr null, ptr %[[RED_FINI]], align 8
+//CHECK:   %[[RED_COMB:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 5
+//CHECK:   store ptr @red_comb, ptr %[[RED_COMB]], align 8
+//CHECK:   %[[FLAGS:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 6
+//CHECK:   store i64 0, ptr %[[FLAGS]], align 4
+//CHECK:   %[[TID1:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   %2 = call ptr @__kmpc_taskred_init(i32 %[[TID1]], i32 1, ptr %[[RED_ARRY]])
+//CHECK:   br label %omp.taskgroup.region
+
+//CHECK: omp.taskgroup.region:
+//CHECK:   %[[VAL3:.*]] = load i32, ptr %[[VAL1]], align 4
+//CHECK:   %4 = add i32 %[[VAL3]], 1
+//CHECK:   store i32 %4, ptr %[[VAL1]], align 4
+//CHECK:   br label %omp.region.cont
+
+//CHECK: omp.region.cont:
+//CHECK:   br label %taskgroup.exit
+
+//CHECK: taskgroup.exit:
+//CHECK:   call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[TID]])
+//CHECK:   ret void
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_init(ptr noalias %0, ptr noalias %1) #2 {
+//CHECK: entry:
+//CHECK:   store i32 0, ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_comb(ptr %0, ptr %1) #2 {
+//CHECK: entry:
+//CHECK:   %[[LD0:.*]] = load i32, ptr %0, align 4
+//CHECK:   %[[LD1:.*]] = load i32, ptr %1, align 4
+//CHECK:   %[[RES:.*]] = add i32 %[[LD0]], %[[LD1]]
+//CHECK:   store i32 %[[RES]], ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 8ae795ec1ec6b0..a0774d859eecf6 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -451,34 +451,6 @@ llvm.func @taskgroup_allocate(%x : !llvm.ptr) {
 
 // -----
 
-omp.declare_reduction @add_f32 : f32
-init {
-^bb0(%arg: f32):
-  %0 = llvm.mlir.constant(0.0 : f32) : f32
-  omp.yield (%0 : f32)
-}
-combiner {
-^bb1(%arg0: f32, %arg1: f32):
-  %1 = llvm.fadd %arg0, %arg1 : f32
-  omp.yield (%1 : f32)
-}
-atomic {
-^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
-  %2 = llvm.load %arg3 : !llvm.ptr -> f32
-  llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
-  omp.yield
-}
-llvm.func @taskgroup_task_reduction(%x : !llvm.ptr) {
-  // expected-error@below {{not yet implemented: Unhandled clause task_reduction in omp.taskgroup operation}}
-  // expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
-  omp.taskgroup task_reduction(@add_f32 %x -> %prv : !llvm.ptr) {
-    omp.terminator
-  }
-  llvm.return
-}
-
-// -----
-
 llvm.func @taskloop(%lb : i32, %ub : i32, %step : i32) {
   // expected-error@below {{not yet implemented: omp.taskloop}}
   // expected-error@below {{LLVM Translation failed for operation: omp.taskloop}}

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making a start on this. I think the pointer lifetimes need to be properly understood before this can proceed. If a solution to that ends up depending upon my work fixing task privatization, then it would be acceptable to me if this was merged with support only for the simple by value cases that will work correctly, with a "not yet implemented" error in the cases that are not safe yet.

Comment on lines +1792 to +1794
if (region.empty()) {
return llvm::Constant::getNullValue(OpaquePtrTy);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +1782 to +1783
template <typename OP>
llvm::Value *createTaskReductionFunction(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this can be static

Comment on lines +1782 to +1783
template <typename OP>
llvm::Value *createTaskReductionFunction(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please could you write a brief documentation comment

llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
LLVM::ModuleTranslation &moduleTranslation,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region &region,
OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
OP &op, unsigned cnt, llvm::ArrayRef<bool> &isByRef,

Comment on lines +1803 to +1804
llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
builder.GetInsertBlock()->getModule());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you consider using OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr? What made you decide to do it manually?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not aware of it frankly; can give it a try.

builder.restoreIP(oldIP);
llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();

for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
for (int cnt = 0; cnt < arraySize; ++cnt) {

This file is under mlir/ so mlir coding standards apply. https://mlir.llvm.org/getting_started/DeveloperGuide/#style-guide

Comment on lines +1965 to +1973
if (auto *alloca = dyn_cast<llvm::AllocaInst>(shared)) {
redTy = alloca->getAllocatedType();
uint64_t sizeInBytes = DL.getTypeAllocSize(redTy);
llvm::ConstantInt *sizeConst =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
builder.CreateStore(sizeConst, FieldPtrReduceSize);
} else {
llvm_unreachable("Non alloca instruction found.");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels error-prone to me. There is currently nothing in the MLIR dialect requiring all SSA values passed into a reduction to be stack allocated.

The MLIR OpenMP dialect has uses outside of flang.

Aside from this, some other questions:

  • What if the allocation size is not constant?
  • What if the allocation is for an array?

//CHECK: call void @__kmpc_taskgroup(ptr @1, i32 %[[TID]])
//CHECK: %[[RED_ELEMENT:.*]] = getelementptr [1 x %kmp_taskred_input_t], ptr %[[RED_ARRY]], i32 0, i32 0
//CHECK: %[[RED_SHARED:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 0
//CHECK: store ptr %[[VAL1]], ptr %[[RED_SHARED]], align 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned this could run into the same issue as we have for privatization (https://discourse.llvm.org/t/rfc-openmp-fix-issue-in-mlir-to-llvmir-translation-for-delayed-privatisation/81225)

A pointer to a stack allocation is stored into the task context. How can you guarantee that the pointer remains live at the point it is used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right about pointer lifetimes, and we will have to deal with it (which this PR does not, at the point). Is the PR submitted as a result of this RFC a good point to begin looking into it?

I think we could do the following:

(1) go ahead with the by-value reduction for the time-being
(2) I go over your solution and try to incorporate that for by-reference. That could be another PR.

Does it work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can go ahead for by-value reduction so long as you produce a helpful error message when by-reference is used. I think it is better to refuse to build the code than to miss-compile it.

Then yes I will be happy to help once my solution for privatization is done. I'll do my best not to take so long, but the preparatory step redefining omp.private is a moving target and so has taken some time.

llvm::Value *arg0 = function->getArg(0);
llvm::Value *arg1 = function->getArg(1);

if (name == "red_init") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be more consistent with the rest of the code here if you pass a bodyGenCB lambda. This will also enable you to clean up the interface for createTaskReductionFunction (you can get rid of a number of parameters and just capture what you need in the code-gen callback.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks, will do so.

funcType =
llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
llvm::Function *function =
llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the name of the generated function, I think it would nice for debuggability to either prefix or suffix the symbol name of the DeclareReductionOp op.

funcType =
llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
llvm::Function *function =
llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need external linkage for this?

bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
bbBuilder.CreateStore(arg0, arg0Alloca);
bbBuilder.CreateStore(arg1, arg1Alloca);
llvm::Value *LoadVal =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this the same value as arg0? Why the need for loads, stores, and allocas? I must be missing something, just want to understand more. Same applies to the red_comd region.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section of the code is for pass-by-references. Honestly, I tried to keep the logic here as close to Clang as possible:

define void @red_init(ptr noalias %arg0, ptr noalias %arg1) #2 {
  entry:
   %0 = alloca ptr, align 8
   %1 = alloca ptr, align 8
   store ptr %arg0, ptr %0, align 8
   store ptr %arg1, ptr %1, align 8
   %2 = load ptr, ptr %0, align 8
   store i32 0, ptr %2, align 4
   ret void
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For pass-by-value, a simple mapInitializationArgs suffices

llvm::LLVMContext &Context = builder.getContext();
llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
if (region.empty()) {
return llvm::Constant::getNullValue(OpaquePtrTy);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not very familiar with reductions, so to make sure, is it fine to set any of these function pointers to a null value in the final struct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think so. Essentially, in case the region is empty, we do not want to generate the functions.

@tblah Can you give an opinion on this? In case the region is empty, is it better to let an empty function body be present, or should we have an assertion in this place?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice spot Kareem!

It looks like these pointers should not be null because init and combiner are called unconditionally by the runtime. E.g.

template <typename T> void __kmp_call_init(kmp_taskred_data_t &item, size_t j);
.

For normal reductions there will always be combiner and initializer regions. For declare reduction the initializer expression is optional so I think one could write legal code without an init region. The combiner expression is required.

The finalization does seem to be optional: https://github.com/llvm/llvm-project/blob/7e01a322f850e86be9eefde8ae5a30e532d22cfa/openmp/runtime/src/kmp_tasking.cpp#L2753C9-L2753C19

SmallVector<llvm::Value *, 1> phis;
if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
&phis)))
return nullptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to return a LogicalResult and pass an output parameter for the created function value. This is less error prone compared to directly using the return value of this function to store it somewhere for example like we do below.

llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
builder.CreateStore(sizeConst, FieldPtrReduceSize);
} else {
llvm_unreachable("Non alloca instruction found.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this temporary (a todo) or permenant? Can we simply use moduleTranslation.convertType(op.getReductionVars()[Cnt].getType()) (or something to the same effect)?

LogicalResult bodyGenStatus = success();
// Setup for `task_reduction`
llvm::ArrayRef<bool> isByRef = getIsByRef(tgOp.getTaskReductionByref());
assert(isByRef.size() == tgOp.getNumReductionVars());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the code above, we take different branches based on isByRef for each individual reduction; for example:

  if (isByRef[Cnt])
    funcType = llvm::FunctionType::get(builder.getVoidTy(),
                                       {OpaquePtrTy, OpaquePtrTy}, false);
  else
    funcType =
        llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);

Does this assert mean that the PR does not contain any examples where isByRef == false for at least some of the reductions?

If I understood that correctly (i.e. that we only test the isByRef[i] == true in this PR, I think it would be better to move the isByRef[i] == false code paths and add additional testing for by-value in a follow-up PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that is indeed the case.

I had a conversation with Tom. With concerns about pointer lifetimes, it is best that we go ahead with by-val implementation for now, and wait for delayed privatization for task to get merged before merging a PR about by-ref.

Comment on lines +2035 to +2039
if (failed(allocAndInitializeTaskReductionVars(
tgOp, reductionArgs, builder, moduleTranslation, allocaIP,
reductionDecls, privateReductionVariables, reductionVariableMap,
isByRef)))
bodyGenStatus = failure();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at allocAndInitializeTaskReductionVars's implementation, looks like we can use llvm::cantFail(...) to wrap this call.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants