Skip to content

Commit 9b52d9e

Browse files
authored
[MLIR][OpenMP] Prevent loop wrapper translation crashes (#115475)
This patch updates the `convertOmpOpRegions` translation function to prevent calling it for a loop wrapper region from causing a compiler crash due to a lack of terminator operations. This problem is currently not triggered because there are no cases for which the region of a loop wrapper is passed to that function. This will have to change in order to support composite construct translation to LLVM IR.
1 parent db40592 commit 9b52d9e

File tree

1 file changed

+33
-20
lines changed

1 file changed

+33
-20
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,8 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
381381
Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
382382
LLVM::ModuleTranslation &moduleTranslation,
383383
SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
384+
bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
385+
384386
llvm::BasicBlock *continuationBlock =
385387
splitBB(builder, true, "omp.region.cont");
386388
llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
@@ -397,30 +399,34 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
397399

398400
// Terminators (namely YieldOp) may be forwarding values to the region that
399401
// need to be available in the continuation block. Collect the types of these
400-
// operands in preparation of creating PHI nodes.
402+
// operands in preparation of creating PHI nodes. This is skipped for loop
403+
// wrapper operations, for which we know in advance they have no terminators.
401404
SmallVector<llvm::Type *> continuationBlockPHITypes;
402-
bool operandsProcessed = false;
403405
unsigned numYields = 0;
404-
for (Block &bb : region.getBlocks()) {
405-
if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
406-
if (!operandsProcessed) {
407-
for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
408-
continuationBlockPHITypes.push_back(
409-
moduleTranslation.convertType(yield->getOperand(i).getType()));
410-
}
411-
operandsProcessed = true;
412-
} else {
413-
assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
414-
"mismatching number of values yielded from the region");
415-
for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
416-
llvm::Type *operandType =
417-
moduleTranslation.convertType(yield->getOperand(i).getType());
418-
(void)operandType;
419-
assert(continuationBlockPHITypes[i] == operandType &&
420-
"values of mismatching types yielded from the region");
406+
407+
if (!isLoopWrapper) {
408+
bool operandsProcessed = false;
409+
for (Block &bb : region.getBlocks()) {
410+
if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
411+
if (!operandsProcessed) {
412+
for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
413+
continuationBlockPHITypes.push_back(
414+
moduleTranslation.convertType(yield->getOperand(i).getType()));
415+
}
416+
operandsProcessed = true;
417+
} else {
418+
assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
419+
"mismatching number of values yielded from the region");
420+
for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
421+
llvm::Type *operandType =
422+
moduleTranslation.convertType(yield->getOperand(i).getType());
423+
(void)operandType;
424+
assert(continuationBlockPHITypes[i] == operandType &&
425+
"values of mismatching types yielded from the region");
426+
}
421427
}
428+
numYields++;
422429
}
423-
numYields++;
424430
}
425431
}
426432

@@ -458,6 +464,13 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
458464
moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
459465
return llvm::make_error<PreviouslyReportedError>();
460466

467+
// Create a direct branch here for loop wrappers to prevent their lack of a
468+
// terminator from causing a crash below.
469+
if (isLoopWrapper) {
470+
builder.CreateBr(continuationBlock);
471+
continue;
472+
}
473+
461474
// Special handling for `omp.yield` and `omp.terminator` (we may have more
462475
// than one): they return the control to the parent OpenMP dialect operation
463476
// so replace them with the branch to the continuation block. We handle this

0 commit comments

Comments
 (0)