Skip to content

[MLIR][OpenMP] Support target SPMD #127821

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2025
Merged

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Feb 19, 2025

This patch implements MLIR to LLVM IR translation of host-evaluated loop bounds, completing initial support for target teams distribute parallel do [simd] and target teams distribute [simd].

@llvmbot
Copy link
Member

llvmbot commented Feb 19, 2025

@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

Changes

This patch implements MLIR to LLVM IR translation of host-evaluated loop bounds, completing initial support for target teams distribute parallel do [simd] and target teams distribute [simd].


Full diff: https://github.com/llvm/llvm-project/pull/127821.diff

3 Files Affected:

  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+63-20)
  • (added) mlir/test/Target/LLVMIR/openmp-target-spmd.mlir (+96)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-24)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 7e8a9bdb5b133..93a88c89162d6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -176,15 +176,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getHint())
       op.emitWarning("hint clause discarded");
   };
-  auto checkHostEval = [](auto op, LogicalResult &result) {
-    // Host evaluated clauses are supported, except for loop bounds.
-    for (BlockArgument arg :
-         cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs())
-      for (Operation *user : arg.getUsers())
-        if (isa<omp::LoopNestOp>(user))
-          result = op.emitError("not yet implemented: host evaluation of loop "
-                                "bounds in omp.target operation");
-  };
   auto checkInReduction = [&todo](auto op, LogicalResult &result) {
     if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
         op.getInReductionSyms())
@@ -321,7 +312,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkBare(op, result);
         checkDevice(op, result);
         checkHasDeviceAddr(op, result);
-        checkHostEval(op, result);
         checkInReduction(op, result);
         checkIsDevicePtr(op, result);
         checkPrivate(op, result);
@@ -4054,9 +4044,13 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
 ///
 /// Loop bounds and steps are only optionally populated, if output vectors are
 /// provided.
-static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
-                                   Value &numTeamsLower, Value &numTeamsUpper,
-                                   Value &threadLimit) {
+static void
+extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
+                       Value &numTeamsLower, Value &numTeamsUpper,
+                       Value &threadLimit,
+                       llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
+                       llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
+                       llvm::SmallVectorImpl<Value> *steps = nullptr) {
   auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
   for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
                                    blockArgIface.getHostEvalBlockArgs())) {
@@ -4081,11 +4075,26 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::LoopNestOp loopOp) {
-            // TODO: Extract bounds and step values. Currently, this cannot be
-            // reached because translation would have been stopped earlier as a
-            // result of `checkImplementationStatus` detecting and reporting
-            // this situation.
-            llvm_unreachable("unsupported host_eval use");
+            auto processBounds =
+                [&](OperandRange opBounds,
+                    llvm::SmallVectorImpl<Value> *outBounds) -> bool {
+              bool found = false;
+              for (auto [i, lb] : llvm::enumerate(opBounds)) {
+                if (lb == blockArg) {
+                  found = true;
+                  if (outBounds)
+                    (*outBounds)[i] = hostEvalVar;
+                }
+              }
+              return found;
+            };
+            bool found =
+                processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
+            found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
+                    found;
+            found = processBounds(loopOp.getLoopSteps(), steps) || found;
+            if (!found)
+              llvm_unreachable("unsupported host_eval use");
           })
           .Default([](Operation *) {
             llvm_unreachable("unsupported host_eval use");
@@ -4222,6 +4231,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
     combinedMaxThreadsVal = maxThreadsVal;
 
   // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
+  attrs.ExecFlags = targetOp.getKernelExecFlags();
   attrs.MinTeams = minTeamsVal;
   attrs.MaxTeams.front() = maxTeamsVal;
   attrs.MinThreads = 1;
@@ -4239,9 +4249,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
                        LLVM::ModuleTranslation &moduleTranslation,
                        omp::TargetOp targetOp,
                        llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
+  omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
+      targetOp.getInnermostCapturedOmpOp());
+  unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
+
   Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
+  llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
+      steps(numLoops);
   extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
-                         teamsThreadLimit);
+                         teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
 
   // TODO: Handle constant 'if' clauses.
   if (Value targetThreadLimit = targetOp.getThreadLimit())
@@ -4261,7 +4277,34 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   if (numThreads)
     attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
 
-  // TODO: Populate attrs.LoopTripCount if it is target SPMD.
+  if (targetOp.getKernelExecFlags() != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
+    llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+    attrs.LoopTripCount = nullptr;
+
+    // To calculate the trip count, we multiply together the trip counts of
+    // every collapsed canonical loop. We don't need to create the loop nests
+    // here, since we're only interested in the trip count.
+    for (auto [loopLower, loopUpper, loopStep] :
+         llvm::zip_equal(lowerBounds, upperBounds, steps)) {
+      llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
+      llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
+      llvm::Value *step = moduleTranslation.lookupValue(loopStep);
+
+      llvm::OpenMPIRBuilder::LocationDescription loc(builder);
+      llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
+          loc, lowerBound, upperBound, step, /*IsSigned=*/true,
+          loopOp.getLoopInclusive());
+
+      if (!attrs.LoopTripCount) {
+        attrs.LoopTripCount = tripCount;
+        continue;
+      }
+
+      // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
+      attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
+                                              {}, /*HasNUW=*/true);
+    }
+  }
 }
 
 static LogicalResult
diff --git a/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir
new file mode 100644
index 0000000000000..7930554cbe11a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir
@@ -0,0 +1,96 @@
+// RUN: split-file %s %t
+// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
+// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
+
+//--- host.mlir
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+  llvm.func @main(%x : i32) {
+    omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.teams {
+        omp.parallel {
+          omp.distribute {
+            omp.wsloop {
+              omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+                omp.yield
+              }
+            } {omp.composite}
+          } {omp.composite}
+          omp.terminator
+        } {omp.composite}
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// HOST-LABEL: define void @main
+// HOST:         %omp_loop.tripcount = {{.*}}
+// HOST-NEXT:    br label %[[ENTRY:.*]]
+// HOST:       [[ENTRY]]:
+// HOST-NEXT:    %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
+// HOST:         %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
+// HOST-NEXT:    store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
+// HOST:         %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
+// HOST-NEXT:    %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
+// HOST-NEXT:    br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
+// HOST:       [[OFFLOAD_FAILED]]:
+// HOST:         call void @[[TARGET_OUTLINE:.*]]({{.*}})
+
+// HOST:       define internal void @[[TARGET_OUTLINE]]
+// HOST:         call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
+
+// HOST:       define internal void @[[TEAMS_OUTLINE]]
+// HOST:         call void{{.*}}@__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
+
+// HOST:       define internal void @[[PARALLEL_OUTLINE]]
+// HOST:         call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+
+// HOST:       define internal void @[[DISTRIBUTE_OUTLINE]]
+// HOST:         call void @__kmpc_dist_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+
+//--- device.mlir
+
+module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
+  llvm.func @main(%x : i32) {
+    omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.teams {
+        omp.parallel {
+          omp.distribute {
+            omp.wsloop {
+              omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+                omp.yield
+              }
+            } {omp.composite}
+          } {omp.composite}
+          omp.terminator
+        } {omp.composite}
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// DEVICE:      @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 2
+// DEVICE:      @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
+// DEVICE:      @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
+// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 0, i8 1, i8 [[EXEC_MODE:2]], {{.*}}},
+// DEVICE-SAME: ptr @{{.*}}, ptr @{{.*}} }
+
+// DEVICE:      define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
+// DEVICE:        %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
+// DEVICE:        call void @[[TARGET_OUTLINE:.*]]({{.*}})
+// DEVICE:        call void @__kmpc_target_deinit()
+
+// DEVICE:      define internal void @[[TARGET_OUTLINE]]({{.*}})
+// DEVICE:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+
+// DEVICE:      define internal void @[[PARALLEL_OUTLINE]]({{.*}})
+// DEVICE:        call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+
+// DEVICE:      define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
+// DEVICE:        call void @__kmpc_distribute_for_static_loop{{.*}}({{.*}})
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index d1c745af9bff5..f907bb3f94a2a 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -319,30 +319,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) {
 
 // -----
 
-llvm.func @target_host_eval(%x : i32) {
-  // expected-error@below {{not yet implemented: host evaluation of loop bounds in omp.target operation}}
-  // expected-error@below {{LLVM Translation failed for operation: omp.target}}
-  omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
-    omp.teams {
-      omp.parallel {
-        omp.distribute {
-          omp.wsloop {
-            omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
-              omp.yield
-            }
-          } {omp.composite}
-        } {omp.composite}
-        omp.terminator
-      } {omp.composite}
-      omp.terminator
-    }
-    omp.terminator
-  }
-  llvm.return
-}
-
-// -----
-
 omp.declare_reduction @add_f32 : f32
 init {
 ^bb0(%arg: f32):

@llvmbot
Copy link
Member

llvmbot commented Feb 19, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Sergio Afonso (skatrak)

Changes

This patch implements MLIR to LLVM IR translation of host-evaluated loop bounds, completing initial support for target teams distribute parallel do [simd] and target teams distribute [simd].


Full diff: https://github.com/llvm/llvm-project/pull/127821.diff

3 Files Affected:

  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+63-20)
  • (added) mlir/test/Target/LLVMIR/openmp-target-spmd.mlir (+96)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-24)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 7e8a9bdb5b133..93a88c89162d6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -176,15 +176,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getHint())
       op.emitWarning("hint clause discarded");
   };
-  auto checkHostEval = [](auto op, LogicalResult &result) {
-    // Host evaluated clauses are supported, except for loop bounds.
-    for (BlockArgument arg :
-         cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs())
-      for (Operation *user : arg.getUsers())
-        if (isa<omp::LoopNestOp>(user))
-          result = op.emitError("not yet implemented: host evaluation of loop "
-                                "bounds in omp.target operation");
-  };
   auto checkInReduction = [&todo](auto op, LogicalResult &result) {
     if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
         op.getInReductionSyms())
@@ -321,7 +312,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkBare(op, result);
         checkDevice(op, result);
         checkHasDeviceAddr(op, result);
-        checkHostEval(op, result);
         checkInReduction(op, result);
         checkIsDevicePtr(op, result);
         checkPrivate(op, result);
@@ -4054,9 +4044,13 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
 ///
 /// Loop bounds and steps are only optionally populated, if output vectors are
 /// provided.
-static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
-                                   Value &numTeamsLower, Value &numTeamsUpper,
-                                   Value &threadLimit) {
+static void
+extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
+                       Value &numTeamsLower, Value &numTeamsUpper,
+                       Value &threadLimit,
+                       llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
+                       llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
+                       llvm::SmallVectorImpl<Value> *steps = nullptr) {
   auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
   for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
                                    blockArgIface.getHostEvalBlockArgs())) {
@@ -4081,11 +4075,26 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::LoopNestOp loopOp) {
-            // TODO: Extract bounds and step values. Currently, this cannot be
-            // reached because translation would have been stopped earlier as a
-            // result of `checkImplementationStatus` detecting and reporting
-            // this situation.
-            llvm_unreachable("unsupported host_eval use");
+            auto processBounds =
+                [&](OperandRange opBounds,
+                    llvm::SmallVectorImpl<Value> *outBounds) -> bool {
+              bool found = false;
+              for (auto [i, lb] : llvm::enumerate(opBounds)) {
+                if (lb == blockArg) {
+                  found = true;
+                  if (outBounds)
+                    (*outBounds)[i] = hostEvalVar;
+                }
+              }
+              return found;
+            };
+            bool found =
+                processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
+            found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
+                    found;
+            found = processBounds(loopOp.getLoopSteps(), steps) || found;
+            if (!found)
+              llvm_unreachable("unsupported host_eval use");
           })
           .Default([](Operation *) {
             llvm_unreachable("unsupported host_eval use");
@@ -4222,6 +4231,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
     combinedMaxThreadsVal = maxThreadsVal;
 
   // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
+  attrs.ExecFlags = targetOp.getKernelExecFlags();
   attrs.MinTeams = minTeamsVal;
   attrs.MaxTeams.front() = maxTeamsVal;
   attrs.MinThreads = 1;
@@ -4239,9 +4249,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
                        LLVM::ModuleTranslation &moduleTranslation,
                        omp::TargetOp targetOp,
                        llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
+  omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
+      targetOp.getInnermostCapturedOmpOp());
+  unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
+
   Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
+  llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
+      steps(numLoops);
   extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
-                         teamsThreadLimit);
+                         teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
 
   // TODO: Handle constant 'if' clauses.
   if (Value targetThreadLimit = targetOp.getThreadLimit())
@@ -4261,7 +4277,34 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   if (numThreads)
     attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
 
-  // TODO: Populate attrs.LoopTripCount if it is target SPMD.
+  if (targetOp.getKernelExecFlags() != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
+    llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+    attrs.LoopTripCount = nullptr;
+
+    // To calculate the trip count, we multiply together the trip counts of
+    // every collapsed canonical loop. We don't need to create the loop nests
+    // here, since we're only interested in the trip count.
+    for (auto [loopLower, loopUpper, loopStep] :
+         llvm::zip_equal(lowerBounds, upperBounds, steps)) {
+      llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
+      llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
+      llvm::Value *step = moduleTranslation.lookupValue(loopStep);
+
+      llvm::OpenMPIRBuilder::LocationDescription loc(builder);
+      llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
+          loc, lowerBound, upperBound, step, /*IsSigned=*/true,
+          loopOp.getLoopInclusive());
+
+      if (!attrs.LoopTripCount) {
+        attrs.LoopTripCount = tripCount;
+        continue;
+      }
+
+      // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
+      attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
+                                              {}, /*HasNUW=*/true);
+    }
+  }
 }
 
 static LogicalResult
diff --git a/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir
new file mode 100644
index 0000000000000..7930554cbe11a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir
@@ -0,0 +1,96 @@
+// RUN: split-file %s %t
+// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
+// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
+
+//--- host.mlir
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+  llvm.func @main(%x : i32) {
+    omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.teams {
+        omp.parallel {
+          omp.distribute {
+            omp.wsloop {
+              omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+                omp.yield
+              }
+            } {omp.composite}
+          } {omp.composite}
+          omp.terminator
+        } {omp.composite}
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// HOST-LABEL: define void @main
+// HOST:         %omp_loop.tripcount = {{.*}}
+// HOST-NEXT:    br label %[[ENTRY:.*]]
+// HOST:       [[ENTRY]]:
+// HOST-NEXT:    %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
+// HOST:         %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
+// HOST-NEXT:    store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
+// HOST:         %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
+// HOST-NEXT:    %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
+// HOST-NEXT:    br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
+// HOST:       [[OFFLOAD_FAILED]]:
+// HOST:         call void @[[TARGET_OUTLINE:.*]]({{.*}})
+
+// HOST:       define internal void @[[TARGET_OUTLINE]]
+// HOST:         call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
+
+// HOST:       define internal void @[[TEAMS_OUTLINE]]
+// HOST:         call void{{.*}}@__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
+
+// HOST:       define internal void @[[PARALLEL_OUTLINE]]
+// HOST:         call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+
+// HOST:       define internal void @[[DISTRIBUTE_OUTLINE]]
+// HOST:         call void @__kmpc_dist_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+
+//--- device.mlir
+
+module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
+  llvm.func @main(%x : i32) {
+    omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.teams {
+        omp.parallel {
+          omp.distribute {
+            omp.wsloop {
+              omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+                omp.yield
+              }
+            } {omp.composite}
+          } {omp.composite}
+          omp.terminator
+        } {omp.composite}
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// DEVICE:      @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 2
+// DEVICE:      @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
+// DEVICE:      @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
+// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 0, i8 1, i8 [[EXEC_MODE:2]], {{.*}}},
+// DEVICE-SAME: ptr @{{.*}}, ptr @{{.*}} }
+
+// DEVICE:      define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
+// DEVICE:        %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
+// DEVICE:        call void @[[TARGET_OUTLINE:.*]]({{.*}})
+// DEVICE:        call void @__kmpc_target_deinit()
+
+// DEVICE:      define internal void @[[TARGET_OUTLINE]]({{.*}})
+// DEVICE:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+
+// DEVICE:      define internal void @[[PARALLEL_OUTLINE]]({{.*}})
+// DEVICE:        call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+
+// DEVICE:      define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
+// DEVICE:        call void @__kmpc_distribute_for_static_loop{{.*}}({{.*}})
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index d1c745af9bff5..f907bb3f94a2a 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -319,30 +319,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) {
 
 // -----
 
-llvm.func @target_host_eval(%x : i32) {
-  // expected-error@below {{not yet implemented: host evaluation of loop bounds in omp.target operation}}
-  // expected-error@below {{LLVM Translation failed for operation: omp.target}}
-  omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
-    omp.teams {
-      omp.parallel {
-        omp.distribute {
-          omp.wsloop {
-            omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
-              omp.yield
-            }
-          } {omp.composite}
-        } {omp.composite}
-        omp.terminator
-      } {omp.composite}
-      omp.terminator
-    }
-    omp.terminator
-  }
-  llvm.return
-}
-
-// -----
-
 omp.declare_reduction @add_f32 : f32
 init {
 ^bb0(%arg: f32):

Copy link
Member

@Meinersbur Meinersbur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@skatrak skatrak force-pushed the users/skatrak/target-spmd-06-canonical-refactor branch from 5153e0d to 082d8e1 Compare February 21, 2025 10:08
@skatrak skatrak force-pushed the users/skatrak/target-spmd-07-mlir-spmd branch from 33409d2 to d5cd2ca Compare February 21, 2025 10:10
@skatrak skatrak force-pushed the users/skatrak/target-spmd-06-canonical-refactor branch from 082d8e1 to 033091e Compare February 21, 2025 10:21
@skatrak skatrak force-pushed the users/skatrak/target-spmd-07-mlir-spmd branch from d5cd2ca to b074390 Compare February 21, 2025 10:21
@skatrak skatrak force-pushed the users/skatrak/target-spmd-06-canonical-refactor branch from 033091e to f14d964 Compare February 21, 2025 10:51
@skatrak skatrak force-pushed the users/skatrak/target-spmd-07-mlir-spmd branch from b074390 to 36e1b5f Compare February 21, 2025 10:52
@skatrak skatrak force-pushed the users/skatrak/target-spmd-06-canonical-refactor branch from f14d964 to aa6182c Compare February 24, 2025 12:55
@skatrak skatrak force-pushed the users/skatrak/target-spmd-07-mlir-spmd branch from 36e1b5f to 0365152 Compare February 24, 2025 12:56
@skatrak skatrak force-pushed the users/skatrak/target-spmd-06-canonical-refactor branch from aa6182c to 35078fd Compare February 25, 2025 10:32
Base automatically changed from users/skatrak/target-spmd-06-canonical-refactor to main February 25, 2025 10:32
This patch implements MLIR to LLVM IR translation of host-evaluated loop
bounds, completing initial support for `target teams distribute parallel do
[simd]` and `target teams distribute [simd]`.
@skatrak skatrak force-pushed the users/skatrak/target-spmd-07-mlir-spmd branch from 0365152 to 146d4ba Compare February 25, 2025 10:33
@skatrak skatrak merged commit 29e1495 into main Feb 25, 2025
6 of 10 checks passed
@skatrak skatrak deleted the users/skatrak/target-spmd-07-mlir-spmd branch February 25, 2025 10:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants