@@ -173,15 +173,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
173
173
if (op.getHint())
174
174
op.emitWarning("hint clause discarded");
175
175
};
176
- auto checkHostEval = [](auto op, LogicalResult &result) {
177
- // Host evaluated clauses are supported, except for loop bounds.
178
- for (BlockArgument arg :
179
- cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs())
180
- for (Operation *user : arg.getUsers())
181
- if (isa<omp::LoopNestOp>(user))
182
- result = op.emitError("not yet implemented: host evaluation of loop "
183
- "bounds in omp.target operation");
184
- };
185
176
auto checkInReduction = [&todo](auto op, LogicalResult &result) {
186
177
if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
187
178
op.getInReductionSyms())
@@ -318,7 +309,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
318
309
checkBare(op, result);
319
310
checkDevice(op, result);
320
311
checkHasDeviceAddr(op, result);
321
- checkHostEval(op, result);
322
312
checkInReduction(op, result);
323
313
checkIsDevicePtr(op, result);
324
314
checkPrivate(op, result);
@@ -4158,9 +4148,13 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
4158
4148
///
4159
4149
/// Loop bounds and steps are only optionally populated, if output vectors are
4160
4150
/// provided.
4161
- static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
4162
- Value &numTeamsLower, Value &numTeamsUpper,
4163
- Value &threadLimit) {
4151
+ static void
4152
+ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
4153
+ Value &numTeamsLower, Value &numTeamsUpper,
4154
+ Value &threadLimit,
4155
+ llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
4156
+ llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
4157
+ llvm::SmallVectorImpl<Value> *steps = nullptr) {
4164
4158
auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
4165
4159
for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
4166
4160
blockArgIface.getHostEvalBlockArgs())) {
@@ -4185,11 +4179,26 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
4185
4179
llvm_unreachable("unsupported host_eval use");
4186
4180
})
4187
4181
.Case([&](omp::LoopNestOp loopOp) {
4188
- // TODO: Extract bounds and step values. Currently, this cannot be
4189
- // reached because translation would have been stopped earlier as a
4190
- // result of `checkImplementationStatus` detecting and reporting
4191
- // this situation.
4192
- llvm_unreachable("unsupported host_eval use");
4182
+ auto processBounds =
4183
+ [&](OperandRange opBounds,
4184
+ llvm::SmallVectorImpl<Value> *outBounds) -> bool {
4185
+ bool found = false;
4186
+ for (auto [i, lb] : llvm::enumerate(opBounds)) {
4187
+ if (lb == blockArg) {
4188
+ found = true;
4189
+ if (outBounds)
4190
+ (*outBounds)[i] = hostEvalVar;
4191
+ }
4192
+ }
4193
+ return found;
4194
+ };
4195
+ bool found =
4196
+ processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
4197
+ found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
4198
+ found;
4199
+ found = processBounds(loopOp.getLoopSteps(), steps) || found;
4200
+ (void)found;
4201
+ assert(found && "unsupported host_eval use");
4193
4202
})
4194
4203
.Default([](Operation *) {
4195
4204
llvm_unreachable("unsupported host_eval use");
@@ -4326,6 +4335,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
4326
4335
combinedMaxThreadsVal = maxThreadsVal;
4327
4336
4328
4337
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4338
+ attrs.ExecFlags = targetOp.getKernelExecFlags();
4329
4339
attrs.MinTeams = minTeamsVal;
4330
4340
attrs.MaxTeams.front() = maxTeamsVal;
4331
4341
attrs.MinThreads = 1;
@@ -4343,9 +4353,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
4343
4353
LLVM::ModuleTranslation &moduleTranslation,
4344
4354
omp::TargetOp targetOp,
4345
4355
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4356
+ omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
4357
+ targetOp.getInnermostCapturedOmpOp());
4358
+ unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
4359
+
4346
4360
Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
4361
+ llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
4362
+ steps(numLoops);
4347
4363
extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
4348
- teamsThreadLimit);
4364
+ teamsThreadLimit, &lowerBounds, &upperBounds, &steps );
4349
4365
4350
4366
// TODO: Handle constant 'if' clauses.
4351
4367
if (Value targetThreadLimit = targetOp.getThreadLimit())
@@ -4365,7 +4381,34 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
4365
4381
if (numThreads)
4366
4382
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
4367
4383
4368
- // TODO: Populate attrs.LoopTripCount if it is target SPMD.
4384
+ if (targetOp.getKernelExecFlags() != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
4385
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4386
+ attrs.LoopTripCount = nullptr;
4387
+
4388
+ // To calculate the trip count, we multiply together the trip counts of
4389
+ // every collapsed canonical loop. We don't need to create the loop nests
4390
+ // here, since we're only interested in the trip count.
4391
+ for (auto [loopLower, loopUpper, loopStep] :
4392
+ llvm::zip_equal(lowerBounds, upperBounds, steps)) {
4393
+ llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
4394
+ llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
4395
+ llvm::Value *step = moduleTranslation.lookupValue(loopStep);
4396
+
4397
+ llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4398
+ llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
4399
+ loc, lowerBound, upperBound, step, /*IsSigned=*/true,
4400
+ loopOp.getLoopInclusive());
4401
+
4402
+ if (!attrs.LoopTripCount) {
4403
+ attrs.LoopTripCount = tripCount;
4404
+ continue;
4405
+ }
4406
+
4407
+ // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
4408
+ attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
4409
+ {}, /*HasNUW=*/true);
4410
+ }
4411
+ }
4369
4412
}
4370
4413
4371
4414
static LogicalResult
0 commit comments