-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[OpenMPIRBuilder] Remove wrapper function in createTask
, createTeams
#67723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
6aabc3c
ac20181
a1a9438
eb506a7
4b71558
7c95d29
81bcd15
1bef65f
1ef1690
e550cd4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -340,6 +340,44 @@ BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch, | |
return splitBB(Builder, CreateBranch, Old->getName() + Suffix); | ||
} | ||
|
||
// This function creates a fake integer value and a fake use for the integer | ||
// value. It returns the fake value created. This is useful in modeling the | ||
// extra arguments to the outlined functions. | ||
Value *createFakeIntVal(IRBuilder<> &Builder, | ||
OpenMPIRBuilder::InsertPointTy OuterAllocaIP, | ||
std::stack<Instruction *> &ToBeDeleted, | ||
OpenMPIRBuilder::InsertPointTy InnerAllocaIP, | ||
const Twine &Name = "", bool AsPtr = true) { | ||
Builder.restoreIP(OuterAllocaIP); | ||
Instruction *FakeVal; | ||
AllocaInst *FakeValAddr = | ||
Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, Name + ".addr"); | ||
ToBeDeleted.push(FakeValAddr); | ||
|
||
if (AsPtr) | ||
FakeVal = FakeValAddr; | ||
else { | ||
FakeVal = | ||
Builder.CreateLoad(Builder.getInt32Ty(), FakeValAddr, Name + ".val"); | ||
ToBeDeleted.push(FakeVal); | ||
shraiysh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
// We only need TIDAddr and ZeroAddr for modeling purposes to get the | ||
// associated arguments in the outlined function, so we delete them later. | ||
|
||
// Fake use of TID | ||
Builder.restoreIP(InnerAllocaIP); | ||
Instruction *UseFakeVal; | ||
if (AsPtr) | ||
UseFakeVal = | ||
Builder.CreateLoad(Builder.getInt32Ty(), FakeVal, Name + ".use"); | ||
else | ||
UseFakeVal = | ||
cast<BinaryOperator>(Builder.CreateAdd(FakeVal, Builder.getInt32(10))); | ||
ToBeDeleted.push(UseFakeVal); | ||
return FakeVal; | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// OpenMPIRBuilderConfig | ||
//===----------------------------------------------------------------------===// | ||
|
@@ -1496,6 +1534,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, | |
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, | ||
bool Tied, Value *Final, Value *IfCondition, | ||
SmallVector<DependData> Dependencies) { | ||
|
||
if (!updateToLocation(Loc)) | ||
return InsertPointTy(); | ||
|
||
|
@@ -1523,41 +1562,31 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, | |
BasicBlock *TaskAllocaBB = | ||
splitBB(Builder, /*CreateBranch=*/true, "task.alloca"); | ||
|
||
InsertPointTy TaskAllocaIP = | ||
InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin()); | ||
InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin()); | ||
BodyGenCB(TaskAllocaIP, TaskBodyIP); | ||
|
||
OutlineInfo OI; | ||
OI.EntryBB = TaskAllocaBB; | ||
OI.OuterAllocaBB = AllocaIP.getBlock(); | ||
OI.ExitBB = TaskExitBB; | ||
OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, | ||
Dependencies](Function &OutlinedFn) { | ||
// The input IR here looks like the following- | ||
// ``` | ||
// func @current_fn() { | ||
// outlined_fn(%args) | ||
// } | ||
// func @outlined_fn(%args) { ... } | ||
// ``` | ||
// | ||
// This is changed to the following- | ||
// | ||
// ``` | ||
// func @current_fn() { | ||
// runtime_call(..., wrapper_fn, ...) | ||
// } | ||
// func @wrapper_fn(..., %args) { | ||
// outlined_fn(%args) | ||
// } | ||
// func @outlined_fn(%args) { ... } | ||
// ``` | ||
|
||
// The stale call instruction will be replaced with a new call instruction | ||
// for runtime call with a wrapper function. | ||
// Add the thread ID argument. | ||
std::stack<Instruction *> ToBeDeleted; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not a SmallVector, like we use everywhere else? We can reasonably guess the size to avoid dynamic allocations. |
||
OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal( | ||
Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false)); | ||
|
||
OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies, | ||
TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable { | ||
// Replace the Stale CI by appropriate RTL function call. | ||
assert(OutlinedFn.getNumUses() == 1 && | ||
"there must be a single user for the outlined function"); | ||
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back()); | ||
|
||
// HasShareds is true if any variables are captured in the outlined region, | ||
// false otherwise. | ||
bool HasShareds = StaleCI->arg_size() > 0; | ||
bool HasShareds = StaleCI->arg_size() > 1; | ||
Builder.SetInsertPoint(StaleCI); | ||
|
||
// Gather the arguments for emitting the runtime call for | ||
|
@@ -1595,7 +1624,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, | |
Value *SharedsSize = Builder.getInt64(0); | ||
if (HasShareds) { | ||
AllocaInst *ArgStructAlloca = | ||
dyn_cast<AllocaInst>(StaleCI->getArgOperand(0)); | ||
dyn_cast<AllocaInst>(StaleCI->getArgOperand(1)); | ||
assert(ArgStructAlloca && | ||
"Unable to find the alloca instruction corresponding to arguments " | ||
"for extracted function"); | ||
|
@@ -1606,31 +1635,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, | |
SharedsSize = | ||
Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType)); | ||
} | ||
|
||
// Argument - task_entry (the wrapper function) | ||
// If the outlined function has some captured variables (i.e. HasShareds is | ||
// true), then the wrapper function will have an additional argument (the | ||
// struct containing captured variables). Otherwise, no such argument will | ||
// be present. | ||
SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty()}; | ||
if (HasShareds) | ||
WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType()); | ||
FunctionCallee WrapperFuncVal = M.getOrInsertFunction( | ||
(Twine(OutlinedFn.getName()) + ".wrapper").str(), | ||
FunctionType::get(Builder.getInt32Ty(), WrapperArgTys, false)); | ||
Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee()); | ||
|
||
// Emit the @__kmpc_omp_task_alloc runtime call | ||
// The runtime call returns a pointer to an area where the task captured | ||
// variables must be copied before the task is run (TaskData) | ||
CallInst *TaskData = Builder.CreateCall( | ||
TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags, | ||
/*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize, | ||
/*task_func=*/WrapperFunc}); | ||
/*task_func=*/&OutlinedFn}); | ||
|
||
// Copy the arguments for outlined function | ||
if (HasShareds) { | ||
Value *Shareds = StaleCI->getArgOperand(0); | ||
Value *Shareds = StaleCI->getArgOperand(1); | ||
Align Alignment = TaskData->getPointerAlignment(M.getDataLayout()); | ||
Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData); | ||
Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment, | ||
|
@@ -1689,18 +1704,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, | |
// br label %exit | ||
// else: | ||
// call @__kmpc_omp_task_begin_if0(...) | ||
// call @wrapper_fn(...) | ||
// call @outlined_fn(...) | ||
// call @__kmpc_omp_task_complete_if0(...) | ||
// br label %exit | ||
// exit: | ||
// ... | ||
if (IfCondition) { | ||
// `SplitBlockAndInsertIfThenElse` requires the block to have a | ||
// terminator. | ||
BasicBlock *NewBasicBlock = | ||
splitBB(Builder, /*CreateBranch=*/true, "if.end"); | ||
splitBB(Builder, /*CreateBranch=*/true, "if.end"); | ||
Instruction *IfTerminator = | ||
NewBasicBlock->getSinglePredecessor()->getTerminator(); | ||
Builder.GetInsertPoint()->getParent()->getTerminator(); | ||
Instruction *ThenTI = IfTerminator, *ElseTI = nullptr; | ||
Builder.SetInsertPoint(IfTerminator); | ||
SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI, | ||
|
@@ -1711,10 +1725,12 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, | |
Function *TaskCompleteFn = | ||
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0); | ||
Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData}); | ||
CallInst *CI = nullptr; | ||
if (HasShareds) | ||
Builder.CreateCall(WrapperFunc, {ThreadID, TaskData}); | ||
CI = Builder.CreateCall(&OutlinedFn, {ThreadID, TaskData}); | ||
else | ||
Builder.CreateCall(WrapperFunc, {ThreadID}); | ||
CI = Builder.CreateCall(&OutlinedFn, {ThreadID}); | ||
CI->setDebugLoc(StaleCI->getDebugLoc()); | ||
Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData}); | ||
Builder.SetInsertPoint(ThenTI); | ||
} | ||
|
@@ -1736,26 +1752,20 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, | |
|
||
StaleCI->eraseFromParent(); | ||
|
||
// Emit the body for wrapper function | ||
BasicBlock *WrapperEntryBB = | ||
BasicBlock::Create(M.getContext(), "", WrapperFunc); | ||
Builder.SetInsertPoint(WrapperEntryBB); | ||
Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin()); | ||
if (HasShareds) { | ||
llvm::Value *Shareds = | ||
Builder.CreateLoad(VoidPtr, WrapperFunc->getArg(1)); | ||
Builder.CreateCall(&OutlinedFn, {Shareds}); | ||
} else { | ||
Builder.CreateCall(&OutlinedFn); | ||
LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1)); | ||
OutlinedFn.getArg(1)->replaceUsesWithIf( | ||
Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; }); | ||
} | ||
|
||
while (!ToBeDeleted.empty()) { | ||
ToBeDeleted.top()->eraseFromParent(); | ||
ToBeDeleted.pop(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no need to pop anything, we have this code in other places already, why do we need to come up with new and exciting ways to do the same thing?
|
||
} | ||
Builder.CreateRet(Builder.getInt32(0)); | ||
}; | ||
|
||
addOutlineInfo(std::move(OI)); | ||
|
||
InsertPointTy TaskAllocaIP = | ||
InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin()); | ||
InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin()); | ||
BodyGenCB(TaskAllocaIP, TaskBodyIP); | ||
Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin()); | ||
|
||
return Builder.saveIP(); | ||
|
@@ -5748,6 +5758,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc, | |
BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry"); | ||
Builder.SetInsertPoint(BodyBB, BodyBB->begin()); | ||
} | ||
InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin()); | ||
shraiysh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// The current basic block is split into four basic blocks. After outlining, | ||
// they will be mapped as follows: | ||
|
@@ -5771,84 +5782,62 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc, | |
BasicBlock *AllocaBB = | ||
splitBB(Builder, /*CreateBranch=*/true, "teams.alloca"); | ||
|
||
// Generate the body of teams. | ||
InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin()); | ||
InsertPointTy CodeGenIP(BodyBB, BodyBB->begin()); | ||
BodyGenCB(AllocaIP, CodeGenIP); | ||
|
||
OutlineInfo OI; | ||
OI.EntryBB = AllocaBB; | ||
OI.ExitBB = ExitBB; | ||
OI.OuterAllocaBB = &OuterAllocaBB; | ||
OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) { | ||
// The input IR here looks like the following- | ||
// ``` | ||
// func @current_fn() { | ||
// outlined_fn(%args) | ||
// } | ||
// func @outlined_fn(%args) { ... } | ||
// ``` | ||
// | ||
// This is changed to the following- | ||
// | ||
// ``` | ||
// func @current_fn() { | ||
// runtime_call(..., wrapper_fn, ...) | ||
// } | ||
// func @wrapper_fn(..., %args) { | ||
// outlined_fn(%args) | ||
// } | ||
// func @outlined_fn(%args) { ... } | ||
// ``` | ||
|
||
// Insert fake values for global tid and bound tid. | ||
std::stack<Instruction *> ToBeDeleted; | ||
OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal( | ||
Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true)); | ||
OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal( | ||
Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true)); | ||
|
||
OI.PostOutlineCB = [this, Ident, ToBeDeleted](Function &OutlinedFn) mutable { | ||
// The stale call instruction will be replaced with a new call instruction | ||
// for runtime call with a wrapper function. | ||
// for runtime call with the outlined function. | ||
|
||
assert(OutlinedFn.getNumUses() == 1 && | ||
"there must be a single user for the outlined function"); | ||
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back()); | ||
ToBeDeleted.push(StaleCI); | ||
|
||
assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) && | ||
"Outlined function must have two or three arguments only"); | ||
|
||
bool HasShared = OutlinedFn.arg_size() == 3; | ||
|
||
// Create the wrapper function. | ||
SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()}; | ||
for (auto &Arg : OutlinedFn.args()) | ||
WrapperArgTys.push_back(Arg.getType()); | ||
FunctionCallee WrapperFuncVal = M.getOrInsertFunction( | ||
(Twine(OutlinedFn.getName()) + ".teams").str(), | ||
FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false)); | ||
Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee()); | ||
WrapperFunc->getArg(0)->setName("global_tid"); | ||
WrapperFunc->getArg(1)->setName("bound_tid"); | ||
if (WrapperFunc->arg_size() > 2) | ||
WrapperFunc->getArg(2)->setName("data"); | ||
|
||
// Emit the body of the wrapper function - just a call to outlined function | ||
// and return statement. | ||
BasicBlock *WrapperEntryBB = | ||
BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc); | ||
Builder.SetInsertPoint(WrapperEntryBB); | ||
SmallVector<Value *> Args; | ||
for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++) | ||
Args.push_back(WrapperFunc->getArg(ArgIndex)); | ||
Builder.CreateCall(&OutlinedFn, Args); | ||
Builder.CreateRetVoid(); | ||
|
||
OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline); | ||
OutlinedFn.getArg(0)->setName("global.tid.ptr"); | ||
OutlinedFn.getArg(1)->setName("bound.tid.ptr"); | ||
if (HasShared) | ||
OutlinedFn.getArg(2)->setName("data"); | ||
|
||
// Call to the runtime function for teams in the current function. | ||
assert(StaleCI && "Error while outlining - no CallInst user found for the " | ||
"outlined function."); | ||
Builder.SetInsertPoint(StaleCI); | ||
Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc}; | ||
for (Use &Arg : StaleCI->args()) | ||
Args.push_back(Arg); | ||
SmallVector<Value *> Args = { | ||
Ident, Builder.getInt32(StaleCI->arg_size() - 2), &OutlinedFn}; | ||
if (HasShared) | ||
Args.push_back(StaleCI->getArgOperand(2)); | ||
Builder.CreateCall(getOrCreateRuntimeFunctionPtr( | ||
omp::RuntimeFunction::OMPRTL___kmpc_fork_teams), | ||
Args); | ||
StaleCI->eraseFromParent(); | ||
|
||
while (!ToBeDeleted.empty()) { | ||
ToBeDeleted.top()->eraseFromParent(); | ||
ToBeDeleted.pop(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
} | ||
}; | ||
|
||
addOutlineInfo(std::move(OI)); | ||
|
||
// Generate the body of teams. | ||
InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin()); | ||
InsertPointTy CodeGenIP(BodyBB, BodyBB->begin()); | ||
BodyGenCB(AllocaIP, CodeGenIP); | ||
|
||
Builder.SetInsertPoint(ExitBB, ExitBB->begin()); | ||
|
||
return Builder.saveIP(); | ||
|
Uh oh!
There was an error while loading. Please reload this page.