Skip to content

Commit 29e1495

Browse files
authored
[MLIR][OpenMP] Support target SPMD (#127821)
This patch implements MLIR to LLVM IR translation of host-evaluated loop bounds, completing initial support for `target teams distribute parallel do [simd]` and `target teams distribute [simd]`.
1 parent 56975b4 commit 29e1495

File tree

3 files changed

+159
-44
lines changed

3 files changed

+159
-44
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 63 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
173173
if (op.getHint())
174174
op.emitWarning("hint clause discarded");
175175
};
176-
auto checkHostEval = [](auto op, LogicalResult &result) {
177-
// Host evaluated clauses are supported, except for loop bounds.
178-
for (BlockArgument arg :
179-
cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs())
180-
for (Operation *user : arg.getUsers())
181-
if (isa<omp::LoopNestOp>(user))
182-
result = op.emitError("not yet implemented: host evaluation of loop "
183-
"bounds in omp.target operation");
184-
};
185176
auto checkInReduction = [&todo](auto op, LogicalResult &result) {
186177
if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
187178
op.getInReductionSyms())
@@ -318,7 +309,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
318309
checkBare(op, result);
319310
checkDevice(op, result);
320311
checkHasDeviceAddr(op, result);
321-
checkHostEval(op, result);
322312
checkInReduction(op, result);
323313
checkIsDevicePtr(op, result);
324314
checkPrivate(op, result);
@@ -4158,9 +4148,13 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
41584148
///
41594149
/// Loop bounds and steps are only optionally populated, if output vectors are
41604150
/// provided.
4161-
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
4162-
Value &numTeamsLower, Value &numTeamsUpper,
4163-
Value &threadLimit) {
4151+
static void
4152+
extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
4153+
Value &numTeamsLower, Value &numTeamsUpper,
4154+
Value &threadLimit,
4155+
llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
4156+
llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
4157+
llvm::SmallVectorImpl<Value> *steps = nullptr) {
41644158
auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
41654159
for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
41664160
blockArgIface.getHostEvalBlockArgs())) {
@@ -4185,11 +4179,26 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
41854179
llvm_unreachable("unsupported host_eval use");
41864180
})
41874181
.Case([&](omp::LoopNestOp loopOp) {
4188-
// TODO: Extract bounds and step values. Currently, this cannot be
4189-
// reached because translation would have been stopped earlier as a
4190-
// result of `checkImplementationStatus` detecting and reporting
4191-
// this situation.
4192-
llvm_unreachable("unsupported host_eval use");
4182+
auto processBounds =
4183+
[&](OperandRange opBounds,
4184+
llvm::SmallVectorImpl<Value> *outBounds) -> bool {
4185+
bool found = false;
4186+
for (auto [i, lb] : llvm::enumerate(opBounds)) {
4187+
if (lb == blockArg) {
4188+
found = true;
4189+
if (outBounds)
4190+
(*outBounds)[i] = hostEvalVar;
4191+
}
4192+
}
4193+
return found;
4194+
};
4195+
bool found =
4196+
processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
4197+
found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
4198+
found;
4199+
found = processBounds(loopOp.getLoopSteps(), steps) || found;
4200+
(void)found;
4201+
assert(found && "unsupported host_eval use");
41934202
})
41944203
.Default([](Operation *) {
41954204
llvm_unreachable("unsupported host_eval use");
@@ -4326,6 +4335,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
43264335
combinedMaxThreadsVal = maxThreadsVal;
43274336

43284337
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4338+
attrs.ExecFlags = targetOp.getKernelExecFlags();
43294339
attrs.MinTeams = minTeamsVal;
43304340
attrs.MaxTeams.front() = maxTeamsVal;
43314341
attrs.MinThreads = 1;
@@ -4343,9 +4353,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
43434353
LLVM::ModuleTranslation &moduleTranslation,
43444354
omp::TargetOp targetOp,
43454355
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4356+
omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
4357+
targetOp.getInnermostCapturedOmpOp());
4358+
unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
4359+
43464360
Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
4361+
llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
4362+
steps(numLoops);
43474363
extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
4348-
teamsThreadLimit);
4364+
teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
43494365

43504366
// TODO: Handle constant 'if' clauses.
43514367
if (Value targetThreadLimit = targetOp.getThreadLimit())
@@ -4365,7 +4381,34 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
43654381
if (numThreads)
43664382
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
43674383

4368-
// TODO: Populate attrs.LoopTripCount if it is target SPMD.
4384+
if (targetOp.getKernelExecFlags() != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
4385+
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4386+
attrs.LoopTripCount = nullptr;
4387+
4388+
// To calculate the trip count, we multiply together the trip counts of
4389+
// every collapsed canonical loop. We don't need to create the loop nests
4390+
// here, since we're only interested in the trip count.
4391+
for (auto [loopLower, loopUpper, loopStep] :
4392+
llvm::zip_equal(lowerBounds, upperBounds, steps)) {
4393+
llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
4394+
llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
4395+
llvm::Value *step = moduleTranslation.lookupValue(loopStep);
4396+
4397+
llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4398+
llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
4399+
loc, lowerBound, upperBound, step, /*IsSigned=*/true,
4400+
loopOp.getLoopInclusive());
4401+
4402+
if (!attrs.LoopTripCount) {
4403+
attrs.LoopTripCount = tripCount;
4404+
continue;
4405+
}
4406+
4407+
// TODO: Enable UndefinedSanitizer to diagnose an overflow here.
4408+
attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
4409+
{}, /*HasNUW=*/true);
4410+
}
4411+
}
43694412
}
43704413

43714414
static LogicalResult
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// RUN: split-file %s %t
2+
// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
3+
// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
4+
5+
//--- host.mlir
6+
7+
module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
8+
llvm.func @main(%x : i32) {
9+
omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
10+
omp.teams {
11+
omp.parallel {
12+
omp.distribute {
13+
omp.wsloop {
14+
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
15+
omp.yield
16+
}
17+
} {omp.composite}
18+
} {omp.composite}
19+
omp.terminator
20+
} {omp.composite}
21+
omp.terminator
22+
}
23+
omp.terminator
24+
}
25+
llvm.return
26+
}
27+
}
28+
29+
// HOST-LABEL: define void @main
30+
// HOST: %omp_loop.tripcount = {{.*}}
31+
// HOST-NEXT: br label %[[ENTRY:.*]]
32+
// HOST: [[ENTRY]]:
33+
// HOST-NEXT: %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
34+
// HOST: %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
35+
// HOST-NEXT: store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
36+
// HOST: %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
37+
// HOST-NEXT: %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
38+
// HOST-NEXT: br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
39+
// HOST: [[OFFLOAD_FAILED]]:
40+
// HOST: call void @[[TARGET_OUTLINE:.*]]({{.*}})
41+
42+
// HOST: define internal void @[[TARGET_OUTLINE]]
43+
// HOST: call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
44+
45+
// HOST: define internal void @[[TEAMS_OUTLINE]]
46+
// HOST: call void{{.*}}@__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
47+
48+
// HOST: define internal void @[[PARALLEL_OUTLINE]]
49+
// HOST: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
50+
51+
// HOST: define internal void @[[DISTRIBUTE_OUTLINE]]
52+
// HOST: call void @__kmpc_dist_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
53+
54+
//--- device.mlir
55+
56+
module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
57+
llvm.func @main(%x : i32) {
58+
omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
59+
omp.teams {
60+
omp.parallel {
61+
omp.distribute {
62+
omp.wsloop {
63+
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
64+
omp.yield
65+
}
66+
} {omp.composite}
67+
} {omp.composite}
68+
omp.terminator
69+
} {omp.composite}
70+
omp.terminator
71+
}
72+
omp.terminator
73+
}
74+
llvm.return
75+
}
76+
}
77+
78+
// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 2
79+
// DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
80+
// DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
81+
// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 0, i8 1, i8 [[EXEC_MODE:2]], {{.*}}},
82+
// DEVICE-SAME: ptr @{{.*}}, ptr @{{.*}} }
83+
84+
// DEVICE: define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
85+
// DEVICE: %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
86+
// DEVICE: call void @[[TARGET_OUTLINE:.*]]({{.*}})
87+
// DEVICE: call void @__kmpc_target_deinit()
88+
89+
// DEVICE: define internal void @[[TARGET_OUTLINE]]({{.*}})
90+
// DEVICE: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
91+
92+
// DEVICE: define internal void @[[PARALLEL_OUTLINE]]({{.*}})
93+
// DEVICE: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
94+
95+
// DEVICE: define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
96+
// DEVICE: call void @__kmpc_distribute_for_static_loop{{.*}}({{.*}})

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -319,30 +319,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) {
319319

320320
// -----
321321

322-
llvm.func @target_host_eval(%x : i32) {
323-
// expected-error@below {{not yet implemented: host evaluation of loop bounds in omp.target operation}}
324-
// expected-error@below {{LLVM Translation failed for operation: omp.target}}
325-
omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
326-
omp.teams {
327-
omp.parallel {
328-
omp.distribute {
329-
omp.wsloop {
330-
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
331-
omp.yield
332-
}
333-
} {omp.composite}
334-
} {omp.composite}
335-
omp.terminator
336-
} {omp.composite}
337-
omp.terminator
338-
}
339-
omp.terminator
340-
}
341-
llvm.return
342-
}
343-
344-
// -----
345-
346322
omp.declare_reduction @add_f32 : f32
347323
init {
348324
^bb0(%arg: f32):

0 commit comments

Comments
 (0)