Skip to content

[MLIR][OpenMP] Support target SPMD #127821

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 63 additions & 20 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.getHint())
op.emitWarning("hint clause discarded");
};
auto checkHostEval = [](auto op, LogicalResult &result) {
// Host evaluated clauses are supported, except for loop bounds.
for (BlockArgument arg :
cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs())
for (Operation *user : arg.getUsers())
if (isa<omp::LoopNestOp>(user))
result = op.emitError("not yet implemented: host evaluation of loop "
"bounds in omp.target operation");
};
auto checkInReduction = [&todo](auto op, LogicalResult &result) {
if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
op.getInReductionSyms())
Expand Down Expand Up @@ -318,7 +309,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkBare(op, result);
checkDevice(op, result);
checkHasDeviceAddr(op, result);
checkHostEval(op, result);
checkInReduction(op, result);
checkIsDevicePtr(op, result);
checkPrivate(op, result);
Expand Down Expand Up @@ -4158,9 +4148,13 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
///
/// Loop bounds and steps are only optionally populated, if output vectors are
/// provided.
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
Value &numTeamsLower, Value &numTeamsUpper,
Value &threadLimit) {
static void
extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
Value &numTeamsLower, Value &numTeamsUpper,
Value &threadLimit,
llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
llvm::SmallVectorImpl<Value> *steps = nullptr) {
auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
blockArgIface.getHostEvalBlockArgs())) {
Expand All @@ -4185,11 +4179,26 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::LoopNestOp loopOp) {
// TODO: Extract bounds and step values. Currently, this cannot be
// reached because translation would have been stopped earlier as a
// result of `checkImplementationStatus` detecting and reporting
// this situation.
llvm_unreachable("unsupported host_eval use");
auto processBounds =
[&](OperandRange opBounds,
llvm::SmallVectorImpl<Value> *outBounds) -> bool {
bool found = false;
for (auto [i, lb] : llvm::enumerate(opBounds)) {
if (lb == blockArg) {
found = true;
if (outBounds)
(*outBounds)[i] = hostEvalVar;
}
}
return found;
};
bool found =
processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
found;
found = processBounds(loopOp.getLoopSteps(), steps) || found;
(void)found;
assert(found && "unsupported host_eval use");
})
.Default([](Operation *) {
llvm_unreachable("unsupported host_eval use");
Expand Down Expand Up @@ -4326,6 +4335,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
combinedMaxThreadsVal = maxThreadsVal;

// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
attrs.ExecFlags = targetOp.getKernelExecFlags();
attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;
Expand All @@ -4343,9 +4353,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
omp::TargetOp targetOp,
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
targetOp.getInnermostCapturedOmpOp());
unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;

Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
steps(numLoops);
extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
teamsThreadLimit);
teamsThreadLimit, &lowerBounds, &upperBounds, &steps);

// TODO: Handle constant 'if' clauses.
if (Value targetThreadLimit = targetOp.getThreadLimit())
Expand All @@ -4365,7 +4381,34 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
if (numThreads)
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);

// TODO: Populate attrs.LoopTripCount if it is target SPMD.
if (targetOp.getKernelExecFlags() != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
attrs.LoopTripCount = nullptr;

// To calculate the trip count, we multiply together the trip counts of
// every collapsed canonical loop. We don't need to create the loop nests
// here, since we're only interested in the trip count.
for (auto [loopLower, loopUpper, loopStep] :
llvm::zip_equal(lowerBounds, upperBounds, steps)) {
llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
llvm::Value *step = moduleTranslation.lookupValue(loopStep);

llvm::OpenMPIRBuilder::LocationDescription loc(builder);
llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
loc, lowerBound, upperBound, step, /*IsSigned=*/true,
loopOp.getLoopInclusive());

if (!attrs.LoopTripCount) {
attrs.LoopTripCount = tripCount;
continue;
}

// TODO: Enable UndefinedSanitizer to diagnose an overflow here.
attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
{}, /*HasNUW=*/true);
}
}
}

static LogicalResult
Expand Down
96 changes: 96 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-target-spmd.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// RUN: split-file %s %t
// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE

//--- host.mlir

module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
llvm.func @main(%x : i32) {
omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
omp.teams {
omp.parallel {
omp.distribute {
omp.wsloop {
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
omp.yield
}
} {omp.composite}
} {omp.composite}
omp.terminator
} {omp.composite}
omp.terminator
}
omp.terminator
}
llvm.return
}
}

// HOST-LABEL: define void @main
// HOST: %omp_loop.tripcount = {{.*}}
// HOST-NEXT: br label %[[ENTRY:.*]]
// HOST: [[ENTRY]]:
// HOST-NEXT: %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
// HOST: %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
// HOST-NEXT: store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
// HOST: %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
// HOST-NEXT: %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
// HOST-NEXT: br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
// HOST: [[OFFLOAD_FAILED]]:
// HOST: call void @[[TARGET_OUTLINE:.*]]({{.*}})

// HOST: define internal void @[[TARGET_OUTLINE]]
// HOST: call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})

// HOST: define internal void @[[TEAMS_OUTLINE]]
// HOST: call void{{.*}}@__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})

// HOST: define internal void @[[PARALLEL_OUTLINE]]
// HOST: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})

// HOST: define internal void @[[DISTRIBUTE_OUTLINE]]
// HOST: call void @__kmpc_dist_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})

//--- device.mlir

module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
llvm.func @main(%x : i32) {
omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
omp.teams {
omp.parallel {
omp.distribute {
omp.wsloop {
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
omp.yield
}
} {omp.composite}
} {omp.composite}
omp.terminator
} {omp.composite}
omp.terminator
}
omp.terminator
}
llvm.return
}
}

// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 2
// DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
// DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 0, i8 1, i8 [[EXEC_MODE:2]], {{.*}}},
// DEVICE-SAME: ptr @{{.*}}, ptr @{{.*}} }

// DEVICE: define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
// DEVICE: %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
// DEVICE: call void @[[TARGET_OUTLINE:.*]]({{.*}})
// DEVICE: call void @__kmpc_target_deinit()

// DEVICE: define internal void @[[TARGET_OUTLINE]]({{.*}})
// DEVICE: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})

// DEVICE: define internal void @[[PARALLEL_OUTLINE]]({{.*}})
// DEVICE: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})

// DEVICE: define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
// DEVICE: call void @__kmpc_distribute_for_static_loop{{.*}}({{.*}})
24 changes: 0 additions & 24 deletions mlir/test/Target/LLVMIR/openmp-todo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -319,30 +319,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) {

// -----

llvm.func @target_host_eval(%x : i32) {
// expected-error@below {{not yet implemented: host evaluation of loop bounds in omp.target operation}}
// expected-error@below {{LLVM Translation failed for operation: omp.target}}
omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
omp.teams {
omp.parallel {
omp.distribute {
omp.wsloop {
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
omp.yield
}
} {omp.composite}
} {omp.composite}
omp.terminator
} {omp.composite}
omp.terminator
}
omp.terminator
}
llvm.return
}

// -----

omp.declare_reduction @add_f32 : f32
init {
^bb0(%arg: f32):
Expand Down
Loading