Skip to content

Commit 7ec3209

Browse files
authored
[MLIR][OpenMP] Named recipe op's block args accessors (NFC) (#112192)
This patch adds extra class declarations to the `omp.declare_reduction` and `omp.private` operations to access the entry block arguments defined by their regions. Some existing accesses to these arguments are updated to use the new named methods to improve code readability.
1 parent 790d986 commit 7ec3209

File tree

2 files changed

+63
-28
lines changed

2 files changed

+63
-28
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,24 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove, RecipeInterface]>
119119
CArg<"TypeAttr">:$type)>
120120
];
121121

122+
let extraClassDeclaration = [{
123+
BlockArgument getAllocMoldArg() {
124+
return getAllocRegion().getArgument(0);
125+
}
126+
BlockArgument getCopyMoldArg() {
127+
auto &region = getCopyRegion();
128+
return region.empty() ? nullptr : region.getArgument(0);
129+
}
130+
BlockArgument getCopyPrivateArg() {
131+
auto &region = getCopyRegion();
132+
return region.empty() ? nullptr : region.getArgument(1);
133+
}
134+
BlockArgument getDeallocMoldArg() {
135+
auto &region = getDeallocRegion();
136+
return region.empty() ? nullptr : region.getArgument(0);
137+
}
138+
}];
139+
122140
let hasVerifier = 1;
123141
}
124142

@@ -1601,22 +1619,41 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove,
16011619
"( `cleanup` $cleanupRegion^ )? ";
16021620

16031621
let extraClassDeclaration = [{
1622+
BlockArgument getAllocMoldArg() {
1623+
auto &region = getAllocRegion();
1624+
return region.empty() ? nullptr : region.getArgument(0);
1625+
}
1626+
BlockArgument getInitializerMoldArg() {
1627+
return getInitializerRegion().getArgument(0);
1628+
}
1629+
BlockArgument getInitializerAllocArg() {
1630+
return getAllocRegion().empty() ?
1631+
nullptr : getInitializerRegion().getArgument(1);
1632+
}
1633+
BlockArgument getReductionLhsArg() {
1634+
return getReductionRegion().getArgument(0);
1635+
}
1636+
BlockArgument getReductionRhsArg() {
1637+
return getReductionRegion().getArgument(1);
1638+
}
1639+
BlockArgument getAtomicReductionLhsArg() {
1640+
auto &region = getAtomicReductionRegion();
1641+
return region.empty() ? nullptr : region.getArgument(0);
1642+
}
1643+
BlockArgument getAtomicReductionRhsArg() {
1644+
auto &region = getAtomicReductionRegion();
1645+
return region.empty() ? nullptr : region.getArgument(1);
1646+
}
1647+
BlockArgument getCleanupAllocArg() {
1648+
auto &region = getCleanupRegion();
1649+
return region.empty() ? nullptr : region.getArgument(0);
1650+
}
1651+
16041652
PointerLikeType getAccumulatorType() {
16051653
if (getAtomicReductionRegion().empty())
16061654
return {};
16071655

1608-
return cast<PointerLikeType>(getAtomicReductionRegion().front().getArgument(0).getType());
1609-
}
1610-
1611-
Value getInitializerMoldArg() {
1612-
return getInitializerRegion().front().getArgument(0);
1613-
}
1614-
1615-
Value getInitializerAllocArg() {
1616-
if (getAllocRegion().empty() ||
1617-
getInitializerRegion().front().getNumArguments() != 2)
1618-
return {nullptr};
1619-
return getInitializerRegion().front().getArgument(1);
1656+
return cast<PointerLikeType>(getAtomicReductionLhsArg().getType());
16201657
}
16211658
}];
16221659
let hasRegionVerifier = 1;

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -480,12 +480,11 @@ makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
480480
[&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
481481
llvm::Value *lhs, llvm::Value *rhs,
482482
llvm::Value *&result) mutable {
483-
Region &reductionRegion = decl.getReductionRegion();
484-
moduleTranslation.mapValue(reductionRegion.front().getArgument(0), lhs);
485-
moduleTranslation.mapValue(reductionRegion.front().getArgument(1), rhs);
483+
moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
484+
moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
486485
builder.restoreIP(insertPoint);
487486
SmallVector<llvm::Value *> phis;
488-
if (failed(inlineConvertOmpRegions(reductionRegion,
487+
if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
489488
"omp.reduction.nonatomic.body",
490489
builder, moduleTranslation, &phis)))
491490
return llvm::OpenMPIRBuilder::InsertPointTy();
@@ -513,12 +512,11 @@ makeAtomicReductionGen(omp::DeclareReductionOp decl,
513512
OwningAtomicReductionGen atomicGen =
514513
[&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
515514
llvm::Value *lhs, llvm::Value *rhs) mutable {
516-
Region &atomicRegion = decl.getAtomicReductionRegion();
517-
moduleTranslation.mapValue(atomicRegion.front().getArgument(0), lhs);
518-
moduleTranslation.mapValue(atomicRegion.front().getArgument(1), rhs);
515+
moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
516+
moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
519517
builder.restoreIP(insertPoint);
520518
SmallVector<llvm::Value *> phis;
521-
if (failed(inlineConvertOmpRegions(atomicRegion,
519+
if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(),
522520
"omp.reduction.atomic.body", builder,
523521
moduleTranslation, &phis)))
524522
return llvm::OpenMPIRBuilder::InsertPointTy();
@@ -1674,9 +1672,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
16741672
// argument of the `alloc` region and the second argument of the `copy`
16751673
// region to be the yielded value of the `alloc` region (this is the
16761674
// private clone of the privatized value).
1677-
copyCloneBuilder.mergeBlocks(
1678-
&*newCopyRegionFrontBlock, &*oldAllocBackBlock,
1679-
{allocRegion.getArgument(0), oldAllocYieldOp.getOperand(0)});
1675+
copyCloneBuilder.mergeBlocks(&*newCopyRegionFrontBlock,
1676+
&*oldAllocBackBlock,
1677+
{mlirPrivatizerClone.getAllocMoldArg(),
1678+
oldAllocYieldOp.getOperand(0)});
16801679

16811680
// 4. The old terminator of the `alloc` region is not needed anymore, so
16821681
// delete it.
@@ -1686,8 +1685,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
16861685
// Replace the privatizer block argument with mlir value being privatized.
16871686
// This way, the body of the privatizer will be changed from using the
16881687
// region/block argument to the value being privatized.
1689-
auto allocRegionArg = allocRegion.getArgument(0);
1690-
replaceAllUsesInRegionWith(allocRegionArg, mlirPrivVar, allocRegion);
1688+
replaceAllUsesInRegionWith(mlirPrivatizerClone.getAllocMoldArg(),
1689+
mlirPrivVar, allocRegion);
16911690

16921691
auto oldIP = builder.saveIP();
16931692
builder.restoreIP(allocaIP);
@@ -3480,10 +3479,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
34803479
" private allocatables is not supported yet");
34813480
bodyGenStatus = failure();
34823481
} else {
3483-
Region &allocRegion = privatizer.getAllocRegion();
3484-
BlockArgument allocRegionArg = allocRegion.getArgument(0);
3485-
moduleTranslation.mapValue(allocRegionArg,
3482+
moduleTranslation.mapValue(privatizer.getAllocMoldArg(),
34863483
moduleTranslation.lookupValue(privVar));
3484+
Region &allocRegion = privatizer.getAllocRegion();
34873485
SmallVector<llvm::Value *, 1> yieldedValues;
34883486
if (failed(inlineConvertOmpRegions(
34893487
allocRegion, "omp.targetop.privatizer", builder,

0 commit comments

Comments
 (0)