Skip to content

Commit d9ca701

Browse files
committed
[MLIR][OpenMP] Improve Generic-SPMD kernel detection
The previous implementation assumed that, for a target region to be tagged as Generic-SPMD, it would need to contain a single `teams distribute` loop with a single `parallel wsloop` nested in it. However, this was an overly restrictive set of conditions which resulted in a range of kernels behaving incorrectly. This patch updates the kernel execution flags identification logic to accept any number of `parallel` regions inside of a single `teams distribute` loop (possibly as part of conditional or loop control flow) as Generic-SPMD.
1 parent 26da887 commit d9ca701

File tree

2 files changed

+134
-97
lines changed

2 files changed

+134
-97
lines changed

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 37 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,7 +1954,7 @@ LogicalResult TargetOp::verifyRegions() {
19541954
}
19551955

19561956
static Operation *
1957-
findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
1957+
findCapturedOmpOp(Operation *rootOp,
19581958
llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
19591959
assert(rootOp && "expected valid operation");
19601960

@@ -1982,19 +1982,17 @@ findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
19821982
// (i.e. its block's successors can reach it) or if it's not guaranteed to
19831983
// be executed before all exits of the region (i.e. it doesn't dominate all
19841984
// blocks with no successors reachable from the entry block).
1985-
if (checkSingleMandatoryExec) {
1986-
Region *parentRegion = op->getParentRegion();
1987-
Block *parentBlock = op->getBlock();
1988-
1989-
for (Block *successor : parentBlock->getSuccessors())
1990-
if (successor->isReachable(parentBlock))
1991-
return WalkResult::interrupt();
1992-
1993-
for (Block &block : *parentRegion)
1994-
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
1995-
!domInfo.dominates(parentBlock, &block))
1996-
return WalkResult::interrupt();
1997-
}
1985+
Region *parentRegion = op->getParentRegion();
1986+
Block *parentBlock = op->getBlock();
1987+
1988+
for (Block *successor : parentBlock->getSuccessors())
1989+
if (successor->isReachable(parentBlock))
1990+
return WalkResult::interrupt();
1991+
1992+
for (Block &block : *parentRegion)
1993+
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
1994+
!domInfo.dominates(parentBlock, &block))
1995+
return WalkResult::interrupt();
19981996

19991997
// Don't capture this op if it has a not-allowed sibling, and stop recursing
20001998
// into nested operations.
@@ -2017,27 +2015,25 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
20172015

20182016
// Only allow OpenMP terminators and non-OpenMP ops that have known memory
20192017
// effects, but don't include a memory write effect.
2020-
return findCapturedOmpOp(
2021-
*this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2022-
if (!sibling)
2023-
return false;
2024-
2025-
if (ompDialect == sibling->getDialect())
2026-
return sibling->hasTrait<OpTrait::IsTerminator>();
2027-
2028-
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2029-
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
2030-
effects;
2031-
memOp.getEffects(effects);
2032-
return !llvm::any_of(
2033-
effects, [&](MemoryEffects::EffectInstance &effect) {
2034-
return isa<MemoryEffects::Write>(effect.getEffect()) &&
2035-
isa<SideEffects::AutomaticAllocationScopeResource>(
2036-
effect.getResource());
2037-
});
2038-
}
2039-
return true;
2018+
return findCapturedOmpOp(*this, [&](Operation *sibling) {
2019+
if (!sibling)
2020+
return false;
2021+
2022+
if (ompDialect == sibling->getDialect())
2023+
return sibling->hasTrait<OpTrait::IsTerminator>();
2024+
2025+
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2026+
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
2027+
effects;
2028+
memOp.getEffects(effects);
2029+
return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
2030+
return isa<MemoryEffects::Write>(effect.getEffect()) &&
2031+
isa<SideEffects::AutomaticAllocationScopeResource>(
2032+
effect.getResource());
20402033
});
2034+
}
2035+
return true;
2036+
});
20412037
}
20422038

20432039
TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
@@ -2098,33 +2094,23 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
20982094
if (isa<LoopOp>(innermostWrapper))
20992095
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
21002096

2101-
// Find single immediately nested captured omp.parallel and add spmd flag
2102-
// (generic-spmd case).
2097+
// Add spmd flag if there's a nested omp.parallel (generic-spmd case).
21032098
//
21042099
// TODO: This shouldn't have to be done here, as it is too easy to break.
21052100
// The openmp-opt pass should be updated to be able to promote kernels like
21062101
// this from "Generic" to "Generic-SPMD". However, the use of the
21072102
// `kmpc_distribute_static_loop` family of functions produced by the
21082103
// OMPIRBuilder for these kernels prevents that from working.
2109-
Dialect *ompDialect = targetOp->getDialect();
2110-
Operation *nestedCapture = findCapturedOmpOp(
2111-
capturedOp, /*checkSingleMandatoryExec=*/false,
2112-
[&](Operation *sibling) {
2113-
return sibling && (ompDialect != sibling->getDialect() ||
2114-
sibling->hasTrait<OpTrait::IsTerminator>());
2115-
});
2104+
bool hasParallel = capturedOp
2105+
->walk<WalkOrder::PreOrder>([](ParallelOp) {
2106+
return WalkResult::interrupt();
2107+
})
2108+
.wasInterrupted();
21162109

21172110
TargetRegionFlags result =
21182111
TargetRegionFlags::generic | TargetRegionFlags::trip_count;
21192112

2120-
if (!nestedCapture)
2121-
return result;
2122-
2123-
while (nestedCapture->getParentOp() != capturedOp)
2124-
nestedCapture = nestedCapture->getParentOp();
2125-
2126-
return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2127-
: result;
2113+
return hasParallel ? result | TargetRegionFlags::spmd : result;
21282114
}
21292115
// Detect target-parallel-wsloop[-simd].
21302116
else if (isa<WsloopOp>(innermostWrapper)) {
Lines changed: 97 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
// RUN: split-file %s %t
2-
// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
3-
// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
4-
5-
//--- host.mlir
1+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
62

73
module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
8-
llvm.func @main(%arg0 : !llvm.ptr) {
4+
llvm.func @host(%arg0 : !llvm.ptr) {
95
%x = llvm.load %arg0 : !llvm.ptr -> i32
106
%0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
117
omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) map_entries(%0 -> %ptr : !llvm.ptr) {
@@ -32,36 +28,36 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
3228
}
3329
}
3430

35-
// HOST-LABEL: define void @main
36-
// HOST: %omp_loop.tripcount = {{.*}}
37-
// HOST-NEXT: br label %[[ENTRY:.*]]
38-
// HOST: [[ENTRY]]:
39-
// HOST: %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
40-
// HOST: %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
41-
// HOST-NEXT: store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
42-
// HOST: %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
43-
// HOST-NEXT: %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
44-
// HOST-NEXT: br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
45-
// HOST: [[OFFLOAD_FAILED]]:
46-
// HOST: call void @[[TARGET_OUTLINE:.*]]({{.*}})
31+
// CHECK-LABEL: define void @host
32+
// CHECK: %omp_loop.tripcount = {{.*}}
33+
// CHECK-NEXT: br label %[[ENTRY:.*]]
34+
// CHECK: [[ENTRY]]:
35+
// CHECK: %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
36+
// CHECK: %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
37+
// CHECK-NEXT: store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
38+
// CHECK: %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
39+
// CHECK-NEXT: %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
40+
// CHECK-NEXT: br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
41+
// CHECK: [[OFFLOAD_FAILED]]:
42+
// CHECK: call void @[[TARGET_OUTLINE:.*]]({{.*}})
4743

48-
// HOST: define internal void @[[TARGET_OUTLINE]]
49-
// HOST: call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
44+
// CHECK: define internal void @[[TARGET_OUTLINE]]
45+
// CHECK: call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
5046

51-
// HOST: define internal void @[[TEAMS_OUTLINE]]
52-
// HOST: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
47+
// CHECK: define internal void @[[TEAMS_OUTLINE]]
48+
// CHECK: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
5349

54-
// HOST: define internal void @[[DISTRIBUTE_OUTLINE]]
55-
// HOST: call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 92, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
56-
// HOST: call void (ptr, i32, ptr, ...) @__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
50+
// CHECK: define internal void @[[DISTRIBUTE_OUTLINE]]
51+
// CHECK: call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 92, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
52+
// CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
5753

58-
// HOST: define internal void @[[PARALLEL_OUTLINE]]
59-
// HOST: call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
54+
// CHECK: define internal void @[[PARALLEL_OUTLINE]]
55+
// CHECK: call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
6056

61-
//--- device.mlir
57+
// -----
6258

6359
module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
64-
llvm.func @main(%arg0 : !llvm.ptr) {
60+
llvm.func @device(%arg0 : !llvm.ptr) {
6561
%0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
6662
omp.target map_entries(%0 -> %ptr : !llvm.ptr) {
6763
%x = llvm.load %ptr : !llvm.ptr -> i32
@@ -87,25 +83,80 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
8783
}
8884
}
8985

90-
// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
91-
// DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
92-
// DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
93-
// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
94-
// DEVICE-SAME: ptr @{{.*}}, ptr @{{.*}} }
86+
// CHECK: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
87+
// CHECK: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
88+
// CHECK: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
89+
// CHECK-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
90+
// CHECK-SAME: ptr @{{.*}}, ptr @{{.*}} }
91+
92+
// CHECK: define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
93+
// CHECK: %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
94+
// CHECK: call void @[[TARGET_OUTLINE:.*]]({{.*}})
95+
// CHECK: call void @__kmpc_target_deinit()
96+
97+
// CHECK: define internal void @[[TARGET_OUTLINE]]({{.*}})
98+
// CHECK: call void @[[TEAMS_OUTLINE:.*]]({{.*}})
99+
100+
// CHECK: define internal void @[[TEAMS_OUTLINE]]({{.*}})
101+
// CHECK: call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}})
102+
103+
// CHECK: define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
104+
// CHECK: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
105+
106+
// CHECK: define internal void @[[PARALLEL_OUTLINE]]({{.*}})
107+
// CHECK: call void @__kmpc_for_static_loop{{.*}}({{.*}})
108+
109+
// -----
110+
111+
module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
112+
llvm.func @device2(%arg0 : !llvm.ptr) {
113+
%0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
114+
omp.target map_entries(%0 -> %ptr : !llvm.ptr) {
115+
%x = llvm.load %ptr : !llvm.ptr -> i32
116+
omp.teams {
117+
omp.distribute {
118+
omp.loop_nest (%iv1) : i32 = (%x) to (%x) step (%x) {
119+
omp.parallel {
120+
omp.terminator
121+
}
122+
llvm.br ^bb2
123+
^bb1:
124+
omp.parallel {
125+
omp.terminator
126+
}
127+
omp.yield
128+
^bb2:
129+
llvm.br ^bb1
130+
}
131+
}
132+
omp.terminator
133+
}
134+
omp.terminator
135+
}
136+
llvm.return
137+
}
138+
}
139+
140+
// CHECK: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
141+
// CHECK: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
142+
// CHECK: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
143+
// CHECK-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
144+
// CHECK-SAME: ptr @{{.*}}, ptr @{{.*}} }
95145

96-
// DEVICE: define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
97-
// DEVICE: %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
98-
// DEVICE: call void @[[TARGET_OUTLINE:.*]]({{.*}})
99-
// DEVICE: call void @__kmpc_target_deinit()
146+
// CHECK: define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
147+
// CHECK: %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
148+
// CHECK: call void @[[TARGET_OUTLINE:.*]]({{.*}})
149+
// CHECK: call void @__kmpc_target_deinit()
100150

101-
// DEVICE: define internal void @[[TARGET_OUTLINE]]({{.*}})
102-
// DEVICE: call void @[[TEAMS_OUTLINE:.*]]({{.*}})
151+
// CHECK: define internal void @[[TARGET_OUTLINE]]({{.*}})
152+
// CHECK: call void @[[TEAMS_OUTLINE:.*]]({{.*}})
103153

104-
// DEVICE: define internal void @[[TEAMS_OUTLINE]]({{.*}})
105-
// DEVICE: call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}})
154+
// CHECK: define internal void @[[TEAMS_OUTLINE]]({{.*}})
155+
// CHECK: call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}})
106156

107-
// DEVICE: define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
108-
// DEVICE: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
157+
// CHECK: define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
158+
// CHECK: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE0:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
159+
// CHECK: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE1:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
109160

110-
// DEVICE: define internal void @[[PARALLEL_OUTLINE]]({{.*}})
111-
// DEVICE: call void @__kmpc_for_static_loop{{.*}}({{.*}})
161+
// CHECK: define internal void @[[PARALLEL_OUTLINE1]]({{.*}})
162+
// CHECK: define internal void @[[PARALLEL_OUTLINE0]]({{.*}})

0 commit comments

Comments
 (0)