Skip to content

Commit 75fde1a

Browse files
committed
[llvm][mlir][OMPIRBuilder] Translate omp.single's copyprivate
Use the new copyprivate list from omp.single to emit calls to __kmpc_copyprivate, during the creation of the single operation in OMPIRBuilder. This is patch 4 of 4, to add support for COPYPRIVATE in Flang. Original PR: llvm#73128
1 parent c1e9883 commit 75fde1a

File tree

5 files changed

+187
-5
lines changed

5 files changed

+187
-5
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1830,12 +1830,16 @@ class OpenMPIRBuilder {
18301830
/// \param FiniCB Callback to finalize variable copies.
18311831
/// \param IsNowait If false, a barrier is emitted.
18321832
/// \param DidIt Local variable used as a flag to indicate 'single' thread
1833+
/// \param CPVars copyprivate variables.
1834+
/// \param CPFuncs copy functions to use for each copyprivate variable.
18331835
///
18341836
/// \returns The insertion position *after* the single call.
18351837
InsertPointTy createSingle(const LocationDescription &Loc,
18361838
BodyGenCallbackTy BodyGenCB,
18371839
FinalizeCallbackTy FiniCB, bool IsNowait,
1838-
llvm::Value *DidIt);
1840+
llvm::Value *DidIt,
1841+
ArrayRef<llvm::Value *> CPVars = {},
1842+
ArrayRef<llvm::Function *> CPFuncs = {});
18391843

18401844
/// Generator for '#omp master'
18411845
///

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4002,7 +4002,8 @@ OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
40024002

40034003
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
40044004
const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
4005-
FinalizeCallbackTy FiniCB, bool IsNowait, llvm::Value *DidIt) {
4005+
FinalizeCallbackTy FiniCB, bool IsNowait, llvm::Value *DidIt,
4006+
ArrayRef<llvm::Value *> CPVars, ArrayRef<llvm::Function *> CPFuncs) {
40064007

40074008
if (!updateToLocation(Loc))
40084009
return Loc.IP;
@@ -4025,17 +4026,33 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
40254026
Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_single);
40264027
Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
40274028

4029+
auto FiniCBWrapper = [&](InsertPointTy IP) {
4030+
FiniCB(IP);
4031+
4032+
if (DidIt)
4033+
Builder.CreateStore(Builder.getInt32(1), DidIt);
4034+
};
4035+
40284036
// generates the following:
40294037
// if (__kmpc_single()) {
40304038
// .... single region ...
40314039
// __kmpc_end_single
40324040
// }
4041+
// __kmpc_copyprivate
40334042
// __kmpc_barrier
40344043

4035-
EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4044+
EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCBWrapper,
40364045
/*Conditional*/ true,
40374046
/*hasFinalize*/ true);
4038-
if (!IsNowait)
4047+
4048+
if (DidIt) {
4049+
for (size_t I = 0, E = CPVars.size(); I < E; ++I)
4050+
// NOTE BufSize is currently unused, so just pass 0.
4051+
createCopyPrivate(LocationDescription(Builder.saveIP(), Loc.DL),
4052+
/*BufSize=*/ConstantInt::get(Int64, 0), CPVars[I],
4053+
CPFuncs[I], DidIt);
4054+
// NOTE __kmpc_copyprivate already inserts a barrier
4055+
} else if (!IsNowait)
40394056
createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
40404057
omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
40414058
/* CheckCancelFlag */ false);

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3464,6 +3464,117 @@ TEST_F(OpenMPIRBuilderTest, SingleDirectiveNowait) {
34643464
EXPECT_EQ(ExitBarrier, nullptr);
34653465
}
34663466

3467+
TEST_F(OpenMPIRBuilderTest, SingleDirectiveCopyPrivate) {
3468+
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3469+
OpenMPIRBuilder OMPBuilder(*M);
3470+
OMPBuilder.initialize();
3471+
F->setName("func");
3472+
IRBuilder<> Builder(BB);
3473+
3474+
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
3475+
3476+
AllocaInst *PrivAI = nullptr;
3477+
3478+
BasicBlock *EntryBB = nullptr;
3479+
BasicBlock *ThenBB = nullptr;
3480+
3481+
Value *CPVar = Builder.CreateAlloca(F->arg_begin()->getType());
3482+
Builder.CreateStore(F->arg_begin(), CPVar);
3483+
3484+
FunctionType *CopyFuncTy = FunctionType::get(
3485+
Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getPtrTy()}, false);
3486+
Function *CopyFunc =
3487+
Function::Create(CopyFuncTy, Function::PrivateLinkage, "copy_var", *M);
3488+
3489+
Value *DidIt = Builder.CreateAlloca(Type::getInt32Ty(Builder.getContext()));
3490+
3491+
auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
3492+
if (AllocaIP.isSet())
3493+
Builder.restoreIP(AllocaIP);
3494+
else
3495+
Builder.SetInsertPoint(&*(F->getEntryBlock().getFirstInsertionPt()));
3496+
PrivAI = Builder.CreateAlloca(F->arg_begin()->getType());
3497+
Builder.CreateStore(F->arg_begin(), PrivAI);
3498+
3499+
llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock();
3500+
llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint();
3501+
EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst);
3502+
3503+
Builder.restoreIP(CodeGenIP);
3504+
3505+
// collect some info for checks later
3506+
ThenBB = Builder.GetInsertBlock();
3507+
EntryBB = ThenBB->getUniquePredecessor();
3508+
3509+
// simple instructions for body
3510+
Value *PrivLoad =
3511+
Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
3512+
Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
3513+
};
3514+
3515+
auto FiniCB = [&](InsertPointTy IP) {
3516+
BasicBlock *IPBB = IP.getBlock();
3517+
EXPECT_NE(IPBB->end(), IP.getPoint());
3518+
};
3519+
3520+
Builder.restoreIP(OMPBuilder.createSingle(Builder, BodyGenCB, FiniCB,
3521+
/*IsNowait*/ false, DidIt, {CPVar},
3522+
{CopyFunc}));
3523+
Value *EntryBBTI = EntryBB->getTerminator();
3524+
EXPECT_NE(EntryBBTI, nullptr);
3525+
EXPECT_TRUE(isa<BranchInst>(EntryBBTI));
3526+
BranchInst *EntryBr = cast<BranchInst>(EntryBB->getTerminator());
3527+
EXPECT_TRUE(EntryBr->isConditional());
3528+
EXPECT_EQ(EntryBr->getSuccessor(0), ThenBB);
3529+
BasicBlock *ExitBB = ThenBB->getUniqueSuccessor();
3530+
EXPECT_EQ(EntryBr->getSuccessor(1), ExitBB);
3531+
3532+
CmpInst *CondInst = cast<CmpInst>(EntryBr->getCondition());
3533+
EXPECT_TRUE(isa<CallInst>(CondInst->getOperand(0)));
3534+
3535+
CallInst *SingleEntryCI = cast<CallInst>(CondInst->getOperand(0));
3536+
EXPECT_EQ(SingleEntryCI->arg_size(), 2U);
3537+
EXPECT_EQ(SingleEntryCI->getCalledFunction()->getName(), "__kmpc_single");
3538+
EXPECT_TRUE(isa<GlobalVariable>(SingleEntryCI->getArgOperand(0)));
3539+
3540+
CallInst *SingleEndCI = nullptr;
3541+
for (auto &FI : *ThenBB) {
3542+
Instruction *Cur = &FI;
3543+
if (isa<CallInst>(Cur)) {
3544+
SingleEndCI = cast<CallInst>(Cur);
3545+
if (SingleEndCI->getCalledFunction()->getName() == "__kmpc_end_single")
3546+
break;
3547+
SingleEndCI = nullptr;
3548+
}
3549+
}
3550+
EXPECT_NE(SingleEndCI, nullptr);
3551+
EXPECT_EQ(SingleEndCI->arg_size(), 2U);
3552+
EXPECT_TRUE(isa<GlobalVariable>(SingleEndCI->getArgOperand(0)));
3553+
EXPECT_EQ(SingleEndCI->getArgOperand(1), SingleEntryCI->getArgOperand(1));
3554+
3555+
CallInst *CopyPrivateCI = nullptr;
3556+
bool FoundBarrier = false;
3557+
for (auto &FI : *ExitBB) {
3558+
Instruction *Cur = &FI;
3559+
if (auto *CI = dyn_cast<CallInst>(Cur)) {
3560+
if (CI->getCalledFunction()->getName() == "__kmpc_barrier")
3561+
FoundBarrier = true;
3562+
else if (CI->getCalledFunction()->getName() == "__kmpc_copyprivate")
3563+
CopyPrivateCI = CI;
3564+
}
3565+
}
3566+
EXPECT_FALSE(FoundBarrier);
3567+
EXPECT_NE(CopyPrivateCI, nullptr);
3568+
EXPECT_EQ(CopyPrivateCI->arg_size(), 6U);
3569+
EXPECT_TRUE(isa<AllocaInst>(CopyPrivateCI->getArgOperand(3)));
3570+
EXPECT_EQ(CopyPrivateCI->getArgOperand(3), CPVar);
3571+
EXPECT_TRUE(isa<Function>(CopyPrivateCI->getArgOperand(4)));
3572+
EXPECT_EQ(CopyPrivateCI->getArgOperand(4), CopyFunc);
3573+
EXPECT_TRUE(isa<LoadInst>(CopyPrivateCI->getArgOperand(5)));
3574+
LoadInst *DidItLI = cast<LoadInst>(CopyPrivateCI->getArgOperand(5));
3575+
EXPECT_EQ(DidItLI->getOperand(0), DidIt);
3576+
}
3577+
34673578
TEST_F(OpenMPIRBuilderTest, OMPAtomicReadFlt) {
34683579
OpenMPIRBuilder OMPBuilder(*M);
34693580
OMPBuilder.initialize();

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,8 +656,26 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
656656
moduleTranslation, bodyGenStatus);
657657
};
658658
auto finiCB = [&](InsertPointTy codeGenIP) {};
659+
660+
// Handle copyprivate
661+
Operation::operand_range cpVars = singleOp.getCopyprivateVars();
662+
std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateFuncs();
663+
llvm::SmallVector<llvm::Value *> llvmCPVars;
664+
llvm::SmallVector<llvm::Function *> llvmCPFuncs;
665+
for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
666+
llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
667+
auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
668+
singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
669+
llvmCPFuncs.push_back(
670+
moduleTranslation.lookupFunction(llvmFuncOp.getName()));
671+
}
672+
llvm::Value *didIt = nullptr;
673+
if (!llvmCPVars.empty())
674+
didIt = builder.CreateAlloca(llvm::Type::getInt32Ty(builder.getContext()));
675+
659676
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createSingle(
660-
ompLoc, bodyCB, finiCB, singleOp.getNowait(), /*DidIt=*/nullptr));
677+
ompLoc, bodyCB, finiCB, singleOp.getNowait(), didIt, llvmCPVars,
678+
llvmCPFuncs));
661679
return bodyGenStatus;
662680
}
663681

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,6 +2186,38 @@ llvm.func @single_nowait(%x: i32, %y: i32, %zaddr: !llvm.ptr) {
21862186

21872187
// -----
21882188

2189+
llvm.func @copy_i32(!llvm.ptr, !llvm.ptr)
2190+
llvm.func @copy_f32(!llvm.ptr, !llvm.ptr)
2191+
2192+
// CHECK-LABEL: @single_copyprivate
2193+
// CHECK-SAME: (ptr %[[ip:.*]], ptr %[[fp:.*]])
2194+
llvm.func @single_copyprivate(%ip: !llvm.ptr, %fp: !llvm.ptr) {
2195+
// CHECK: call i32 @__kmpc_single
2196+
omp.single copyprivate(%ip -> @copy_i32 : !llvm.ptr, %fp -> @copy_f32 : !llvm.ptr) {
2197+
// CHECK: %[[i:.*]] = load i32, ptr %[[ip]]
2198+
%i = llvm.load %ip : !llvm.ptr -> i32
2199+
// CHECK: %[[i2:.*]] = add i32 %[[i]], %[[i]]
2200+
%i2 = llvm.add %i, %i : i32
2201+
// CHECK: store i32 %[[i2]], ptr %[[ip]]
2202+
llvm.store %i2, %ip : i32, !llvm.ptr
2203+
// CHECK: %[[f:.*]] = load float, ptr %[[fp]]
2204+
%f = llvm.load %fp : !llvm.ptr -> f32
2205+
// CHECK: %[[f2:.*]] = fadd float %[[f]], %[[f]]
2206+
%f2 = llvm.fadd %f, %f : f32
2207+
// CHECK: store float %[[f2]], ptr %[[fp]]
2208+
llvm.store %f2, %fp : f32, !llvm.ptr
2209+
// CHECK: call void @__kmpc_end_single
2210+
// CHECK: call void @__kmpc_copyprivate({{.*}}, ptr %[[ip]], ptr @copy_i32, {{.*}})
2211+
// CHECK: call void @__kmpc_copyprivate({{.*}}, ptr %[[fp]], ptr @copy_f32, {{.*}})
2212+
// CHECK-NOT: call void @__kmpc_barrier
2213+
omp.terminator
2214+
}
2215+
// CHECK: ret void
2216+
llvm.return
2217+
}
2218+
2219+
// -----
2220+
21892221
// CHECK: @_QFsubEx = internal global i32 undef
21902222
// CHECK: @_QFsubEx.cache = common global ptr null
21912223

0 commit comments

Comments
 (0)