Skip to content

[MLIR][OpenMP] Improve Generic-SPMD kernel detection #137307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Apr 25, 2025

The previous implementation assumed that, for a target region to be tagged as Generic-SPMD, it would need to contain a single teams distribute loop with a single parallel wsloop nested in it. However, this was an overly restrictive set of conditions which resulted in a range of kernels behaving incorrectly.

This patch updates the kernel execution flags identification logic to accept any number of parallel regions inside of a single teams distribute loop (possibly as part of conditional or loop control flow) as Generic-SPMD.

@llvmbot
Copy link
Member

llvmbot commented Apr 25, 2025

@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir-openmp

Author: Sergio Afonso (skatrak)

Changes

The previous implementation assumed that, for a target region to be tagged as Generic-SPMD, it would need to contain a single teams distribute loop with a single parallel wsloop nested in it. However, this was an overly restrictive set of conditions which resulted in a range of kernels behaving incorrectly.

This patch updates the kernel execution flags identification logic to accept any number of parallel regions inside of a single teams distribute loop (possibly as part of conditional or loop control flow) as Generic-SPMD.


Full diff: https://github.com/llvm/llvm-project/pull/137307.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+37-51)
  • (modified) mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir (+97-46)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index dd701da507fc6..3afb374381bdf 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1954,7 +1954,7 @@ LogicalResult TargetOp::verifyRegions() {
 }
 
 static Operation *
-findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
+findCapturedOmpOp(Operation *rootOp,
                   llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
   assert(rootOp && "expected valid operation");
 
@@ -1982,19 +1982,17 @@ findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
     // (i.e. its block's successors can reach it) or if it's not guaranteed to
     // be executed before all exits of the region (i.e. it doesn't dominate all
     // blocks with no successors reachable from the entry block).
-    if (checkSingleMandatoryExec) {
-      Region *parentRegion = op->getParentRegion();
-      Block *parentBlock = op->getBlock();
-
-      for (Block *successor : parentBlock->getSuccessors())
-        if (successor->isReachable(parentBlock))
-          return WalkResult::interrupt();
-
-      for (Block &block : *parentRegion)
-        if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
-            !domInfo.dominates(parentBlock, &block))
-          return WalkResult::interrupt();
-    }
+    Region *parentRegion = op->getParentRegion();
+    Block *parentBlock = op->getBlock();
+
+    for (Block *successor : parentBlock->getSuccessors())
+      if (successor->isReachable(parentBlock))
+        return WalkResult::interrupt();
+
+    for (Block &block : *parentRegion)
+      if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
+          !domInfo.dominates(parentBlock, &block))
+        return WalkResult::interrupt();
 
     // Don't capture this op if it has a not-allowed sibling, and stop recursing
     // into nested operations.
@@ -2017,27 +2015,25 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
 
   // Only allow OpenMP terminators and non-OpenMP ops that have known memory
   // effects, but don't include a memory write effect.
-  return findCapturedOmpOp(
-      *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
-        if (!sibling)
-          return false;
-
-        if (ompDialect == sibling->getDialect())
-          return sibling->hasTrait<OpTrait::IsTerminator>();
-
-        if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
-          SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
-              effects;
-          memOp.getEffects(effects);
-          return !llvm::any_of(
-              effects, [&](MemoryEffects::EffectInstance &effect) {
-                return isa<MemoryEffects::Write>(effect.getEffect()) &&
-                       isa<SideEffects::AutomaticAllocationScopeResource>(
-                           effect.getResource());
-              });
-        }
-        return true;
+  return findCapturedOmpOp(*this, [&](Operation *sibling) {
+    if (!sibling)
+      return false;
+
+    if (ompDialect == sibling->getDialect())
+      return sibling->hasTrait<OpTrait::IsTerminator>();
+
+    if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
+      SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
+          effects;
+      memOp.getEffects(effects);
+      return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
+        return isa<MemoryEffects::Write>(effect.getEffect()) &&
+               isa<SideEffects::AutomaticAllocationScopeResource>(
+                   effect.getResource());
       });
+    }
+    return true;
+  });
 }
 
 TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
@@ -2098,33 +2094,23 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
     if (isa<LoopOp>(innermostWrapper))
       return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
 
-    // Find single immediately nested captured omp.parallel and add spmd flag
-    // (generic-spmd case).
+    // Add spmd flag if there's a nested omp.parallel (generic-spmd case).
     //
     // TODO: This shouldn't have to be done here, as it is too easy to break.
     // The openmp-opt pass should be updated to be able to promote kernels like
     // this from "Generic" to "Generic-SPMD". However, the use of the
     // `kmpc_distribute_static_loop` family of functions produced by the
     // OMPIRBuilder for these kernels prevents that from working.
-    Dialect *ompDialect = targetOp->getDialect();
-    Operation *nestedCapture = findCapturedOmpOp(
-        capturedOp, /*checkSingleMandatoryExec=*/false,
-        [&](Operation *sibling) {
-          return sibling && (ompDialect != sibling->getDialect() ||
-                             sibling->hasTrait<OpTrait::IsTerminator>());
-        });
+    bool hasParallel = capturedOp
+                           ->walk<WalkOrder::PreOrder>([](ParallelOp) {
+                             return WalkResult::interrupt();
+                           })
+                           .wasInterrupted();
 
     TargetRegionFlags result =
         TargetRegionFlags::generic | TargetRegionFlags::trip_count;
 
-    if (!nestedCapture)
-      return result;
-
-    while (nestedCapture->getParentOp() != capturedOp)
-      nestedCapture = nestedCapture->getParentOp();
-
-    return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
-                                          : result;
+    return hasParallel ? result | TargetRegionFlags::spmd : result;
   }
   // Detect target-parallel-wsloop[-simd].
   else if (isa<WsloopOp>(innermostWrapper)) {
diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
index 8101660e571e4..3273de0c26d27 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
@@ -1,11 +1,7 @@
-// RUN: split-file %s %t
-// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
-// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
-
-//--- host.mlir
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
 
 module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
-  llvm.func @main(%arg0 : !llvm.ptr) {
+  llvm.func @host(%arg0 : !llvm.ptr) {
     %x = llvm.load %arg0 : !llvm.ptr -> i32
     %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
     omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) map_entries(%0 -> %ptr : !llvm.ptr) {
@@ -32,36 +28,36 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
   }
 }
 
-// HOST-LABEL: define void @main
-// HOST:         %omp_loop.tripcount = {{.*}}
-// HOST-NEXT:    br label %[[ENTRY:.*]]
-// HOST:       [[ENTRY]]:
-// HOST:         %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
-// HOST:         %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
-// HOST-NEXT:    store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
-// HOST:         %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
-// HOST-NEXT:    %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
-// HOST-NEXT:    br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
-// HOST:       [[OFFLOAD_FAILED]]:
-// HOST:         call void @[[TARGET_OUTLINE:.*]]({{.*}})
+// CHECK-LABEL: define void @host
+// CHECK:         %omp_loop.tripcount = {{.*}}
+// CHECK-NEXT:    br label %[[ENTRY:.*]]
+// CHECK:       [[ENTRY]]:
+// CHECK:         %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
+// CHECK:         %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
+// CHECK-NEXT:    store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
+// CHECK:         %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
+// CHECK-NEXT:    %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
+// CHECK-NEXT:    br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
+// CHECK:       [[OFFLOAD_FAILED]]:
+// CHECK:         call void @[[TARGET_OUTLINE:.*]]({{.*}})
 
-// HOST:       define internal void @[[TARGET_OUTLINE]]
-// HOST:         call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
+// CHECK:       define internal void @[[TARGET_OUTLINE]]
+// CHECK:         call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
 
-// HOST:       define internal void @[[TEAMS_OUTLINE]]
-// HOST:         call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+// CHECK:       define internal void @[[TEAMS_OUTLINE]]
+// CHECK:         call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
 
-// HOST:       define internal void @[[DISTRIBUTE_OUTLINE]]
-// HOST:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 92, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
-// HOST:         call void (ptr, i32, ptr, ...) @__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
+// CHECK:       define internal void @[[DISTRIBUTE_OUTLINE]]
+// CHECK:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 92, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+// CHECK:         call void (ptr, i32, ptr, ...) @__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
 
-// HOST:       define internal void @[[PARALLEL_OUTLINE]]
-// HOST:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+// CHECK:       define internal void @[[PARALLEL_OUTLINE]]
+// CHECK:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
 
-//--- device.mlir
+// -----
 
 module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
-  llvm.func @main(%arg0 : !llvm.ptr) {
+  llvm.func @device(%arg0 : !llvm.ptr) {
     %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
     omp.target map_entries(%0 -> %ptr : !llvm.ptr) {
       %x = llvm.load %ptr : !llvm.ptr -> i32
@@ -87,25 +83,80 @@ module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_devic
   }
 }
 
-// DEVICE:      @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
-// DEVICE:      @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
-// DEVICE:      @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
-// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
-// DEVICE-SAME: ptr @{{.*}}, ptr @{{.*}} }
+// CHECK:      @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
+// CHECK:      @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
+// CHECK:      @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
+// CHECK-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
+// CHECK-SAME: ptr @{{.*}}, ptr @{{.*}} }
+
+// CHECK:      define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
+// CHECK:        %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
+// CHECK:        call void @[[TARGET_OUTLINE:.*]]({{.*}})
+// CHECK:        call void @__kmpc_target_deinit()
+
+// CHECK:      define internal void @[[TARGET_OUTLINE]]({{.*}})
+// CHECK:        call void @[[TEAMS_OUTLINE:.*]]({{.*}})
+
+// CHECK:      define internal void @[[TEAMS_OUTLINE]]({{.*}})
+// CHECK:        call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}})
+
+// CHECK:      define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
+// CHECK:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+
+// CHECK:      define internal void @[[PARALLEL_OUTLINE]]({{.*}})
+// CHECK:        call void @__kmpc_for_static_loop{{.*}}({{.*}})
+
+// -----
+
+module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
+  llvm.func @device2(%arg0 : !llvm.ptr) {
+    %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
+    omp.target map_entries(%0 -> %ptr : !llvm.ptr) {
+      %x = llvm.load %ptr : !llvm.ptr -> i32
+      omp.teams {
+        omp.distribute {
+          omp.loop_nest (%iv1) : i32 = (%x) to (%x) step (%x) {
+            omp.parallel {
+              omp.terminator
+            }
+            llvm.br ^bb2
+          ^bb1:
+            omp.parallel {
+              omp.terminator
+            }
+            omp.yield
+          ^bb2:
+            llvm.br ^bb1
+          }
+        }
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// CHECK:      @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
+// CHECK:      @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
+// CHECK:      @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
+// CHECK-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
+// CHECK-SAME: ptr @{{.*}}, ptr @{{.*}} }
 
-// DEVICE:      define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
-// DEVICE:        %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
-// DEVICE:        call void @[[TARGET_OUTLINE:.*]]({{.*}})
-// DEVICE:        call void @__kmpc_target_deinit()
+// CHECK:      define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
+// CHECK:        %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
+// CHECK:        call void @[[TARGET_OUTLINE:.*]]({{.*}})
+// CHECK:        call void @__kmpc_target_deinit()
 
-// DEVICE:      define internal void @[[TARGET_OUTLINE]]({{.*}})
-// DEVICE:        call void @[[TEAMS_OUTLINE:.*]]({{.*}})
+// CHECK:      define internal void @[[TARGET_OUTLINE]]({{.*}})
+// CHECK:        call void @[[TEAMS_OUTLINE:.*]]({{.*}})
 
-// DEVICE:      define internal void @[[TEAMS_OUTLINE]]({{.*}})
-// DEVICE:        call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}})
+// CHECK:      define internal void @[[TEAMS_OUTLINE]]({{.*}})
+// CHECK:        call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}})
 
-// DEVICE:      define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
-// DEVICE:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+// CHECK:      define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
+// CHECK:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE0:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+// CHECK:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE1:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
 
-// DEVICE:      define internal void @[[PARALLEL_OUTLINE]]({{.*}})
-// DEVICE:        call void @__kmpc_for_static_loop{{.*}}({{.*}})
+// CHECK:      define internal void @[[PARALLEL_OUTLINE1]]({{.*}})
+// CHECK:      define internal void @[[PARALLEL_OUTLINE0]]({{.*}})

@llvmbot
Copy link
Member

llvmbot commented Apr 25, 2025

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

Changes

The previous implementation assumed that, for a target region to be tagged as Generic-SPMD, it would need to contain a single teams distribute loop with a single parallel wsloop nested in it. However, this was an overly restrictive set of conditions which resulted in a range of kernels behaving incorrectly.

This patch updates the kernel execution flags identification logic to accept any number of parallel regions inside of a single teams distribute loop (possibly as part of conditional or loop control flow) as Generic-SPMD.


Full diff: https://github.com/llvm/llvm-project/pull/137307.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+37-51)
  • (modified) mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir (+97-46)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index dd701da507fc6..3afb374381bdf 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1954,7 +1954,7 @@ LogicalResult TargetOp::verifyRegions() {
 }
 
 static Operation *
-findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
+findCapturedOmpOp(Operation *rootOp,
                   llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
   assert(rootOp && "expected valid operation");
 
@@ -1982,19 +1982,17 @@ findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
     // (i.e. its block's successors can reach it) or if it's not guaranteed to
     // be executed before all exits of the region (i.e. it doesn't dominate all
     // blocks with no successors reachable from the entry block).
-    if (checkSingleMandatoryExec) {
-      Region *parentRegion = op->getParentRegion();
-      Block *parentBlock = op->getBlock();
-
-      for (Block *successor : parentBlock->getSuccessors())
-        if (successor->isReachable(parentBlock))
-          return WalkResult::interrupt();
-
-      for (Block &block : *parentRegion)
-        if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
-            !domInfo.dominates(parentBlock, &block))
-          return WalkResult::interrupt();
-    }
+    Region *parentRegion = op->getParentRegion();
+    Block *parentBlock = op->getBlock();
+
+    for (Block *successor : parentBlock->getSuccessors())
+      if (successor->isReachable(parentBlock))
+        return WalkResult::interrupt();
+
+    for (Block &block : *parentRegion)
+      if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
+          !domInfo.dominates(parentBlock, &block))
+        return WalkResult::interrupt();
 
     // Don't capture this op if it has a not-allowed sibling, and stop recursing
     // into nested operations.
@@ -2017,27 +2015,25 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
 
   // Only allow OpenMP terminators and non-OpenMP ops that have known memory
   // effects, but don't include a memory write effect.
-  return findCapturedOmpOp(
-      *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
-        if (!sibling)
-          return false;
-
-        if (ompDialect == sibling->getDialect())
-          return sibling->hasTrait<OpTrait::IsTerminator>();
-
-        if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
-          SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
-              effects;
-          memOp.getEffects(effects);
-          return !llvm::any_of(
-              effects, [&](MemoryEffects::EffectInstance &effect) {
-                return isa<MemoryEffects::Write>(effect.getEffect()) &&
-                       isa<SideEffects::AutomaticAllocationScopeResource>(
-                           effect.getResource());
-              });
-        }
-        return true;
+  return findCapturedOmpOp(*this, [&](Operation *sibling) {
+    if (!sibling)
+      return false;
+
+    if (ompDialect == sibling->getDialect())
+      return sibling->hasTrait<OpTrait::IsTerminator>();
+
+    if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
+      SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
+          effects;
+      memOp.getEffects(effects);
+      return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
+        return isa<MemoryEffects::Write>(effect.getEffect()) &&
+               isa<SideEffects::AutomaticAllocationScopeResource>(
+                   effect.getResource());
       });
+    }
+    return true;
+  });
 }
 
 TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
@@ -2098,33 +2094,23 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
     if (isa<LoopOp>(innermostWrapper))
       return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
 
-    // Find single immediately nested captured omp.parallel and add spmd flag
-    // (generic-spmd case).
+    // Add spmd flag if there's a nested omp.parallel (generic-spmd case).
     //
     // TODO: This shouldn't have to be done here, as it is too easy to break.
     // The openmp-opt pass should be updated to be able to promote kernels like
     // this from "Generic" to "Generic-SPMD". However, the use of the
     // `kmpc_distribute_static_loop` family of functions produced by the
     // OMPIRBuilder for these kernels prevents that from working.
-    Dialect *ompDialect = targetOp->getDialect();
-    Operation *nestedCapture = findCapturedOmpOp(
-        capturedOp, /*checkSingleMandatoryExec=*/false,
-        [&](Operation *sibling) {
-          return sibling && (ompDialect != sibling->getDialect() ||
-                             sibling->hasTrait<OpTrait::IsTerminator>());
-        });
+    bool hasParallel = capturedOp
+                           ->walk<WalkOrder::PreOrder>([](ParallelOp) {
+                             return WalkResult::interrupt();
+                           })
+                           .wasInterrupted();
 
     TargetRegionFlags result =
         TargetRegionFlags::generic | TargetRegionFlags::trip_count;
 
-    if (!nestedCapture)
-      return result;
-
-    while (nestedCapture->getParentOp() != capturedOp)
-      nestedCapture = nestedCapture->getParentOp();
-
-    return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
-                                          : result;
+    return hasParallel ? result | TargetRegionFlags::spmd : result;
   }
   // Detect target-parallel-wsloop[-simd].
   else if (isa<WsloopOp>(innermostWrapper)) {
diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
index 8101660e571e4..3273de0c26d27 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
@@ -1,11 +1,7 @@
-// RUN: split-file %s %t
-// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
-// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
-
-//--- host.mlir
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
 
 module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
-  llvm.func @main(%arg0 : !llvm.ptr) {
+  llvm.func @host(%arg0 : !llvm.ptr) {
     %x = llvm.load %arg0 : !llvm.ptr -> i32
     %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
     omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) map_entries(%0 -> %ptr : !llvm.ptr) {
@@ -32,36 +28,36 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
   }
 }
 
-// HOST-LABEL: define void @main
-// HOST:         %omp_loop.tripcount = {{.*}}
-// HOST-NEXT:    br label %[[ENTRY:.*]]
-// HOST:       [[ENTRY]]:
-// HOST:         %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
-// HOST:         %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
-// HOST-NEXT:    store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
-// HOST:         %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
-// HOST-NEXT:    %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
-// HOST-NEXT:    br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
-// HOST:       [[OFFLOAD_FAILED]]:
-// HOST:         call void @[[TARGET_OUTLINE:.*]]({{.*}})
+// CHECK-LABEL: define void @host
+// CHECK:         %omp_loop.tripcount = {{.*}}
+// CHECK-NEXT:    br label %[[ENTRY:.*]]
+// CHECK:       [[ENTRY]]:
+// CHECK:         %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
+// CHECK:         %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
+// CHECK-NEXT:    store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
+// CHECK:         %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
+// CHECK-NEXT:    %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
+// CHECK-NEXT:    br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
+// CHECK:       [[OFFLOAD_FAILED]]:
+// CHECK:         call void @[[TARGET_OUTLINE:.*]]({{.*}})
 
-// HOST:       define internal void @[[TARGET_OUTLINE]]
-// HOST:         call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
+// CHECK:       define internal void @[[TARGET_OUTLINE]]
+// CHECK:         call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
 
-// HOST:       define internal void @[[TEAMS_OUTLINE]]
-// HOST:         call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+// CHECK:       define internal void @[[TEAMS_OUTLINE]]
+// CHECK:         call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
 
-// HOST:       define internal void @[[DISTRIBUTE_OUTLINE]]
-// HOST:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 92, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
-// HOST:         call void (ptr, i32, ptr, ...) @__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
+// CHECK:       define internal void @[[DISTRIBUTE_OUTLINE]]
+// CHECK:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 92, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+// CHECK:         call void (ptr, i32, ptr, ...) @__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
 
-// HOST:       define internal void @[[PARALLEL_OUTLINE]]
-// HOST:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+// CHECK:       define internal void @[[PARALLEL_OUTLINE]]
+// CHECK:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
 
-//--- device.mlir
+// -----
 
 module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
-  llvm.func @main(%arg0 : !llvm.ptr) {
+  llvm.func @device(%arg0 : !llvm.ptr) {
     %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
     omp.target map_entries(%0 -> %ptr : !llvm.ptr) {
       %x = llvm.load %ptr : !llvm.ptr -> i32
@@ -87,25 +83,80 @@ module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_devic
   }
 }
 
-// DEVICE:      @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
-// DEVICE:      @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
-// DEVICE:      @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
-// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
-// DEVICE-SAME: ptr @{{.*}}, ptr @{{.*}} }
+// CHECK:      @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
+// CHECK:      @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
+// CHECK:      @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
+// CHECK-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
+// CHECK-SAME: ptr @{{.*}}, ptr @{{.*}} }
+
+// CHECK:      define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
+// CHECK:        %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
+// CHECK:        call void @[[TARGET_OUTLINE:.*]]({{.*}})
+// CHECK:        call void @__kmpc_target_deinit()
+
+// CHECK:      define internal void @[[TARGET_OUTLINE]]({{.*}})
+// CHECK:        call void @[[TEAMS_OUTLINE:.*]]({{.*}})
+
+// CHECK:      define internal void @[[TEAMS_OUTLINE]]({{.*}})
+// CHECK:        call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}})
+
+// CHECK:      define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
+// CHECK:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+
+// CHECK:      define internal void @[[PARALLEL_OUTLINE]]({{.*}})
+// CHECK:        call void @__kmpc_for_static_loop{{.*}}({{.*}})
+
+// -----
+
+module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
+  llvm.func @device2(%arg0 : !llvm.ptr) {
+    %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
+    omp.target map_entries(%0 -> %ptr : !llvm.ptr) {
+      %x = llvm.load %ptr : !llvm.ptr -> i32
+      omp.teams {
+        omp.distribute {
+          omp.loop_nest (%iv1) : i32 = (%x) to (%x) step (%x) {
+            omp.parallel {
+              omp.terminator
+            }
+            llvm.br ^bb2
+          ^bb1:
+            omp.parallel {
+              omp.terminator
+            }
+            omp.yield
+          ^bb2:
+            llvm.br ^bb1
+          }
+        }
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// CHECK:      @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
+// CHECK:      @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
+// CHECK:      @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
+// CHECK-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
+// CHECK-SAME: ptr @{{.*}}, ptr @{{.*}} }
 
-// DEVICE:      define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
-// DEVICE:        %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
-// DEVICE:        call void @[[TARGET_OUTLINE:.*]]({{.*}})
-// DEVICE:        call void @__kmpc_target_deinit()
+// CHECK:      define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
+// CHECK:        %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
+// CHECK:        call void @[[TARGET_OUTLINE:.*]]({{.*}})
+// CHECK:        call void @__kmpc_target_deinit()
 
-// DEVICE:      define internal void @[[TARGET_OUTLINE]]({{.*}})
-// DEVICE:        call void @[[TEAMS_OUTLINE:.*]]({{.*}})
+// CHECK:      define internal void @[[TARGET_OUTLINE]]({{.*}})
+// CHECK:        call void @[[TEAMS_OUTLINE:.*]]({{.*}})
 
-// DEVICE:      define internal void @[[TEAMS_OUTLINE]]({{.*}})
-// DEVICE:        call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}})
+// CHECK:      define internal void @[[TEAMS_OUTLINE]]({{.*}})
+// CHECK:        call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}})
 
-// DEVICE:      define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
-// DEVICE:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+// CHECK:      define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
+// CHECK:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE0:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+// CHECK:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE1:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
 
-// DEVICE:      define internal void @[[PARALLEL_OUTLINE]]({{.*}})
-// DEVICE:        call void @__kmpc_for_static_loop{{.*}}({{.*}})
+// CHECK:      define internal void @[[PARALLEL_OUTLINE1]]({{.*}})
+// CHECK:      define internal void @[[PARALLEL_OUTLINE0]]({{.*}})

Copy link
Member

@Meinersbur Meinersbur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is "a range of kernels behaving incorrectly"?

Why do you think this is legal? If you have control flow in the teams distribute construct (including a sequence of omp parallel constructs), the teams must synchronize with the initial thread. The absence of such synchronization is what allows SPMD.

@skatrak
Copy link
Member Author

skatrak commented Apr 30, 2025

I think I probably didn't explain very well the cases this patch covers, since having a parallel construct inside of a loop or conditional block is already accepted as part of the Generic-SPMD pattern (checkSingleMandatoryExec=false on that call to findCapturedOmpOp). What this does is also allow multiple consecutive parallel constructs or, more generally, multiple OpenMP constructs within the region.

I absolutely agree with you that this seems counterintuitive. But I think the main point here is that we're not tagging these kernels as "SPMD", but rather "Generic-SPMD". This is in contrast to just "Generic", which is what we're currently doing. The reason is that, in practice, if a parallel region appears inside of a Generic kernel, it doesn't seem to run properly. The following tests show some cases that don't work without this patch:

! This works (it's already tagged Generic-SPMD):
! condition=true: 1 1 1 1 1 1
! condition=false: 2 2 2 2 2 2
subroutine if_cond_single(condition)
  implicit none
  logical, intent(in) :: condition
  integer, parameter :: M = 2, N = 3
  integer :: i, j
  integer :: v(M,N)

  v(:,:) = 0

  !$omp target teams distribute
  do i=1, M
    if (condition) then
      !$omp parallel do
      do j=1, N
        v(i, j) = v(i, j) + 1
      end do
    else
      do j=1, N
        v(i, j) = v(i, j) + 2
      end do
    end if
  end do

  print *, v(:,:)
end subroutine

! This doesn't work without this patch:
! condition=true: 0 0 0 0 0 0
! condition=false: 0 0 0 0 0 0
subroutine if_cond_multiple(condition)
  implicit none
  logical, intent(in) :: condition
  integer, parameter :: M = 2, N = 3
  integer :: i, j
  integer :: v(M,N)

  v(:,:) = 0

  !$omp target teams distribute
  do i=1, M
    if (condition) then
      !$omp parallel do
      do j=1, N
        v(i, j) = v(i, j) + 1
      end do
    else
      !$omp parallel do
      do j=1, N
        v(i, j) = v(i, j) + 2
      end do
    end if
  end do

  print *, v(:,:)
end subroutine

! This works (it's already tagged Generic-SPMD):
! 3 3 2 2 2 2
subroutine single_parallel()
  implicit none
  integer, parameter :: M = 2, N = 3
  integer :: i, j
  integer :: v(M,N)

  v(:,:) = 0

  !$omp target teams distribute
  do i=1, M
    !$omp parallel do
    do j=1, N
      v(i, j) = v(i, j) + 1
    end do

    v(i, 1) = v(i, 1) + 1

    do j=1, N
      v(i, j) = v(i, j) + 1
    end do
  end do

  print *, v(:,:)
end subroutine

! This doesn't work without this patch:
! 1 1 0 0 0 0
subroutine multi_parallel()
  implicit none
  integer, parameter :: M = 2, N = 3
  integer :: i, j
  integer :: v(M,N)

  v(:,:) = 0

  !$omp target teams distribute
  do i=1, M
    !$omp parallel do
    do j=1, N
      v(i, j) = v(i, j) + 1
    end do

    v(i, 1) = v(i, 1) + 1

    !$omp parallel do
    do j=1, N
      v(i, j) = v(i, j) + 1
    end do
  end do

  print *, v(:,:)
end subroutine

I'm no expert on the exact uses of Generic-SPMD, but making it mean roughly "a target teams distribute construct with at least one parallel region inside" is so far what makes most applications and tests we're looking at work. SPMD is only used for target teams distribute parallel do composite constructs and Generic is everything else. I'm sure we'll have to tune this detection further, but I believe this change doesn't break anything we knew to be already working and it does make other cases work.

@Meinersbur
Copy link
Member

Meinersbur commented Apr 30, 2025

IIUC, Generic-SPMD is when the openmp-opt pass converts a Generic kernel to an SPMD kernel. Because openmp-opt only sees the GPU kernel, it cannot modify the host-side kernel invocation, so you end up with a mix of both. One consequence is that at kernel invocation, it does not pass the number of iterations because it is not known.

With this background, I don't see why a frontend would ever use Generic-SPMD mode, since it has control over kernel code and host-side invocation. Clang does not know about it.

Independent of that, Generic -- as the name implies -- is supposed to always work. If it does not, it is a bug.

@skatrak
Copy link
Member Author

skatrak commented Apr 30, 2025

IIUC, Generic-SPMD is when the openmp-opt pass converts a Generic kernel to an SPMD kernel. Because openmp-opt only sees the GPU kernel, it cannot modify the host-side kernel invocation, so you end up with a mix of both. One consequence is that at kernel invocation, it does not pass the number of threads because it is not known.

Yes, that's how it works for clang. This does not work for flang because we use different DeviceRTL functions for distribute than clang does, and that optimization looks at certain specific function calls. I did try at one point adding support for this, but I wasn't able to (seemingly related to the fact that the DeviceRTL functions we use in flang take a function pointer to the distribute body instead of updating the loop bounds passed and having the distribute body inline).

With this background, I don't see why a frontend would ever use Generic-SPMD mode, since it has control over kernel code and host-side invocation.

Independent of that, Generic -- as the name implies -- is supposed to always work. If it does not, it is a bug.

That's the thing I also struggle to understand. There must be a bug in Generic mode if it doesn't always produce correct results, performance considerations apart. But it appears that these tests only work if tagged as Generic-SPMD, not Generic or SPMD. Considering the OpenMPOpt pass can't currently make the promotion from Generic on its own, we are temporarily handling it in codegen. There's a TODO comment documenting this.

@skatrak skatrak force-pushed the generic-spmd-fix branch from d9ca701 to d1a45ad Compare May 5, 2025 15:06
The previous implementation assumed that, for a target region to be tagged as
Generic-SPMD, it would need to contain a single `teams distribute` loop with a
single `parallel wsloop` nested in it. However, this was an overly restrictive
set of conditions which resulted in a range of kernels behaving incorrectly.

This patch updates the kernel execution flags identification logic to accept
any number of `parallel` regions inside of a single `teams distribute` loop
(possibly as part of conditional or loop control flow) as Generic-SPMD.
@skatrak skatrak force-pushed the generic-spmd-fix branch from d1a45ad to c7813e6 Compare May 15, 2025 15:19
@jdoerfert
Copy link
Member

jdoerfert commented May 16, 2025

I tried to replicate the issue in C, but that doesn't seem to work. Maybe we should compare the IR. What I did is:

#include <stdio.h>

void multi_parallel() {
  int M = 2, N = 3;
  int v[M][N];
  for (int i = 0; i < M; ++i)
    for (int j = 0; j < N; ++j)
      v[i][j] = 0;

  #pragma omp target teams distribute
  for (int i = 0; i < M; ++i) {
    #pragma omp parallel for
    for (int j = 0; j < N; ++j) {
      v[i][j] = v[i][j] + 1;
    }
    v[i][0] = v[i][0] + 1;
    #pragma omp parallel for
    for (int j = 0; j < N; ++j) {
      v[i][j] = v[i][j] + 1;
    }
  }

  for (int i = 0; i < M; ++i)
    for (int j = 0; j < N; ++j)
      printf("v[%i][%i] = %i\n", i, j, v[i][j]);
}

int main() {
  multi_parallel();
}
clang -fopenmp --offload-arch=native -L /p/vast1/doerfert/build/llvm/lib generic.c -o generic.O0 -O0
clang -fopenmp --offload-arch=native -L /p/vast1/doerfert/build/llvm/lib generic.c -o generic.O1 -O1
clang -fopenmp --offload-arch=native -L /p/vast1/doerfert/build/llvm/lib generic.c -o generic.O2 -O2
clang -fopenmp --offload-arch=native -L /p/vast1/doerfert/build/llvm/lib generic.c -o generic.O3 -O3

All resulting in

v[0][0] = 3
v[0][1] = 2
v[0][2] = 2
v[1][0] = 3
v[1][1] = 2
v[1][2] = 2

Note that O0 is executed in Generic mode

"PluginInterface" device 0 info: Launching kernel __omp_offloading_86fafab6_8700211b_multi_parallel_l10 with [2,1,1] blocks and [256,1,1] threads in Generic mode
AMDGPU device 0 info: #Args: 6 Teams x Thrds:    2x 256 (MaxFlatWorkGroupSize: 256) LDS Usage: 2104B #SGPRs/VGPRs: 106/47 #SGPR/VGPR Spills: 11/0 Tripcount: 2

while the rest is optimized to generic-spmd.

@skatrak
Copy link
Member Author

skatrak commented May 20, 2025

I tried to replicate the issue in C, but that doesn't seem to work. Maybe we should compare the IR.

Thank you for checking this. It looks like that test has been fixed sometime after I created this PR, since I was able to reproduce failures with clang until I updated to the latest main branch. This other test that @Meinersbur made, however, does show another case where running in Generic-SPMD mode is currently required in order to get the expected results:

#include <stdio.h>
#include <omp.h>

int main() {
  int i, j, a = 0, b = 0, c = 0, g = 21;

  #pragma omp target teams distribute thread_limit(10) private(i,j) reduction(+:a,b,c,g)
  for (i = 1; i <= 10; ++i) {
    j = i;
    if (j == 5) {
      g += 10 * omp_get_team_num() + omp_get_thread_num();
      ++c;
      j = 11;
    }
    if (j == 11) {
      #pragma omp parallel num_threads(10) reduction(+:a)
      {
        ++a;
      }
    } else {
      #pragma omp parallel num_threads(10) reduction(+:b)
      {
        ++b;
      }
    }
  }

  printf("a: %d\nb: %d\nc: %d\ng: %d", a, b, c, g);
  return 0;
}

On this, we get the following (same output for -O1, -O2 and -O3, since they all use Generic-SPMD):

clang -fopenmp --offload-arch=native test.c -O0 -o generic.O0 && ./generic.O0
a: 1
b: 9
c: 1
g: 61

"PluginInterface" device 0 info: Launching kernel __omp_offloading_10307_d124924_main_l29 with [10,1,1] blocks and [10,1,1] threads in Generic mode
AMDGPU device 0 info: #Args: 5 Teams x Thrds:   10x  10 (MaxFlatWorkGroupSize: 10) LDS Usage: 2280B #SGPRs/VGPRs: 76/59 #SGPR/VGPR Spills: 20/13 Tripcount: 10

clang -fopenmp --offload-arch=native test.c -O1 -o generic.O1 && ./generic.O1
a: 10
b: 90
c: 1
g: 61

"PluginInterface" device 0 info: Launching kernel __omp_offloading_10307_d124924_main_l29 with [10,1,1] blocks and [10,1,1] threads in Generic-SPMD mode
AMDGPU device 0 info: #Args: 5 Teams x Thrds:   10x  10 (MaxFlatWorkGroupSize: 10) LDS Usage: 1768B #SGPRs/VGPRs: 37/32 #SGPR/VGPR Spills: 0/0 Tripcount: 10

@jdoerfert
Copy link
Member

This other test that @Meinersbur made, however, does show another case where running in Generic-SPMD mode is currently required in order to get the expected results:

Long story short, the result of O0 is correct and expected.
The result you see with O1-O3 is also correct and expected.

What's going on:

First, note that #pragma omp parallel can choose less threads than the user requested, so 1 is always a valid option (up to the strict modifier introduction which we don't implement yet).
Next a few design choices:

  • When we run in Generic-mode we need an extra warp for the main thread. We count that warp against the thread limit, which one could argue we shouldn't. As we do, a thread limit of 10 doesn't allow for 2 warps, hence we end up with the single "main thread warp" and no workers at all. We can lift the thread limit to 2x WarpSize but that alone won't make a difference.
  • We choose not to run generic-mode parallel regions with partial warps. Honestly, I don't remember why, nor if I introduced this originally. My best guess is that it makes the Generic-mode barriers simpler/doable. One could look into this choice.

So, for those 2 reasons we see 1 thread parallel regions in the example. Again, that's perfectly valid OpenMP. If you add omp for statements you can see we properly "workshare" among the available threads (1 or more):

#include <omp.h>
#include <stdio.h>

int main() {
  int i, j, a = 0, b = 0, c = 0, g = 21;

#pragma omp target teams distribute thread_limit(128) private(i, j)            \
    reduction(+ : a, b, c, g)
  for (i = 1; i <= 10; ++i) {
    j = i;
    if (j == 5) {
      g += 10 * omp_get_team_num() + omp_get_thread_num();
      ++c;
      j = 11;
    }
    if (j == 11) {
#pragma omp parallel num_threads(64) reduction(+ : a)
#ifdef WS1
#pragma omp for
      for (int k = 0; k < 10; ++k)
#endif
        ++a;
    } else {
#pragma omp parallel num_threads(10) reduction(+ : b)
#ifdef WS2
#pragma omp for
      for (int k = 0; k < 10; ++k)
#endif
        ++b;
    }
  }

  printf("a: %d\nb: %d\nc: %d\ng: %d", a, b, c, g);
  return 0;
}

The fun part is, I run into a hang if you enable WS2; see #140786 for more information.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants