Skip to content

Commit efd1216

Browse files
committed
More refactoring
- Added task counter suffix to each data structure avoiding duplicates - Simplifiy code for the outline function terminator, since tasks are single entry single exit - Tweak CodeExtractor to avoid doing anything with lifetimes TODO Closes llvm#3
1 parent 78ed96d commit efd1216

File tree

4 files changed

+143
-64
lines changed

4 files changed

+143
-64
lines changed

llvm/include/llvm/Transforms/Utils/CodeExtractor.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ class Value;
7676
BasicBlock *newHeader,
7777
Function *oldFunction,
7878
Module *M,
79-
const SetVector<BasicBlock *> &Blocks)> constructOmpSsFunctions;
79+
const SetVector<BasicBlock *> &Blocks)> rewriteOutToInTaskBrAndGetOmpSsUnpackFunc;
8080
std::function<CallInst*(Function *newFunction,
8181
BasicBlock *codeReplacer,
82-
const SetVector<BasicBlock *> &Blocks)> emitCaptureAndCall;
82+
const SetVector<BasicBlock *> &Blocks)> emitOmpSsCaptureAndSubmitTask;
8383

8484
public:
8585
/// Create a code extractor for a sequence of blocks.
@@ -113,10 +113,10 @@ class Value;
113113
BasicBlock *newHeader,
114114
Function *oldFunction,
115115
Module *M,
116-
const SetVector<BasicBlock *> &Blocks)> constructOmpSsFunctions,
116+
const SetVector<BasicBlock *> &Blocks)> rewriteOutToInTaskBrAndGetOmpSsUnpackFunc,
117117
std::function<CallInst*(Function *newFunction,
118118
BasicBlock *codeReplacer,
119-
const SetVector<BasicBlock *> &Blocks)> emitCaptureAndCall);
119+
const SetVector<BasicBlock *> &Blocks)> emitOmpSsCaptureAndSubmitTask);
120120

121121
/// Perform the extraction, returning the new function.
122122
///

llvm/lib/Transforms/OmpSs/OmpSsTransform.cpp

Lines changed: 49 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ struct OmpSs : public ModulePass {
4545
StructType *Ty = nullptr;
4646
};
4747
TaskConstraintsTy TskConstraintsTy;
48+
4849
struct TaskInvInfoTy {
4950
struct Members {
5051
Type *InvSourceTy;
@@ -53,6 +54,7 @@ struct OmpSs : public ModulePass {
5354
Members Mmbers;
5455
};
5556
TaskInvInfoTy TskInvInfoTy;
57+
5658
struct TaskImplInfoTy {
5759
struct Members {
5860
Type *DeviceTypeIdTy;
@@ -66,6 +68,7 @@ struct OmpSs : public ModulePass {
6668
Members Mmbers;
6769
};
6870
TaskImplInfoTy TskImplInfoTy;
71+
6972
struct TaskInfoTy {
7073
struct Members {
7174
Type *NumSymbolsTy;
@@ -92,7 +95,7 @@ struct OmpSs : public ModulePass {
9295
"nanos6_taskwait", IRB.getVoidTy(), IRB.getInt8PtrTy()));
9396
// 2. Build String
9497
// TODO: add debug info (line:col)
95-
Constant *Nanos6TaskwaitStr = IRB.CreateGlobalStringPtr(M.getModuleIdentifier());
98+
Constant *Nanos6TaskwaitStr = IRB.CreateGlobalStringPtr(M.getSourceFileName());
9699

97100
// 3. Insert the call
98101
IRB.CreateCall(Func, {Nanos6TaskwaitStr});
@@ -102,6 +105,7 @@ struct OmpSs : public ModulePass {
102105

103106
void lowerTask(const TaskInfo &TI,
104107
Function &F,
108+
size_t taskNum,
105109
Module &M) {
106110
// 1. Split BB
107111
BasicBlock *EntryBB = TI.Entry->getParent();
@@ -135,7 +139,7 @@ struct OmpSs : public ModulePass {
135139
for (Value *V : TI.DSAInfo.Firstprivate) {
136140
TaskArgsMemberTy.push_back(V->getType()->getPointerElementType());
137141
}
138-
StructType *TaskArgsTy = StructType::create(M.getContext(), TaskArgsMemberTy, ("nanos6_task_args_" + F.getName()).str());
142+
StructType *TaskArgsTy = StructType::create(M.getContext(), TaskArgsMemberTy, ("nanos6_task_args_" + F.getName() + Twine(taskNum)).str());
139143
// Create nanos6_task_args_* END
140144

141145
// nanos6_unpacked_task_region_* START
@@ -155,7 +159,7 @@ struct OmpSs : public ModulePass {
155159

156160
Function *unpackFuncVar = Function::Create(
157161
unpackFuncType, GlobalValue::InternalLinkage, F.getAddressSpace(),
158-
"nanos6_unpacked_task_region_" + F.getName(), &M);
162+
"nanos6_unpacked_task_region_" + F.getName() + Twine(taskNum), &M);
159163

160164
// Create an iterator to name all of the arguments we inserted.
161165
Function::arg_iterator AI = unpackFuncVar->arg_begin();
@@ -187,7 +191,7 @@ struct OmpSs : public ModulePass {
187191

188192
Function *outlineFuncVar = Function::Create(
189193
outlineFuncType, GlobalValue::InternalLinkage, F.getAddressSpace(),
190-
"nanos6_ol_task_region_" + F.getName(), &M);
194+
"nanos6_ol_task_region_" + F.getName() + Twine(taskNum), &M);
191195

192196
BasicBlock *outlineEntryBB = BasicBlock::Create(M.getContext(), "entry", outlineFuncVar);
193197

@@ -225,27 +229,28 @@ struct OmpSs : public ModulePass {
225229
}
226230
TaskUnpackParams.push_back(&*AI++);
227231
TaskUnpackParams.push_back(&*AI++);
228-
Instruction *TaskUnpackCall =
229-
BBBuilder.CreateCall(unpackFuncVar, TaskUnpackParams);
230-
ReturnInst *TaskOlRet = BBBuilder.CreateRetVoid();
232+
// Build TaskUnpackCall
233+
BBBuilder.CreateCall(unpackFuncVar, TaskUnpackParams);
234+
// Make BB legal with a terminator to task outline function
235+
BBBuilder.CreateRetVoid();
231236

232237
// nanos6_ol_task_region_* END
233238

234-
// 0.1 Create Nanos6 task data structures info
235-
Constant *TaskInvInfoVar = M.getOrInsertGlobal(("task_invocation_info_" + F.getName()).str(),
239+
// 3. Create Nanos6 task data structures info
240+
Constant *TaskInvInfoVar = M.getOrInsertGlobal(("task_invocation_info_" + F.getName() + Twine(taskNum)).str(),
236241
TskInvInfoTy.Ty,
237242
[&] {
238243
GlobalVariable *GV = new GlobalVariable(M, TskInvInfoTy.Ty,
239244
false,
240245
GlobalVariable::InternalLinkage,
241246
ConstantStruct::get(TskInvInfoTy.Ty,
242247
ConstantPointerNull::get(Type::getInt8PtrTy(M.getContext()))),
243-
("task_invocation_info_" + F.getName()).str());
248+
("task_invocation_info_" + F.getName() + Twine(taskNum)).str());
244249
GV->setAlignment(64);
245250
return GV;
246251
});
247252

248-
Constant *TaskImplInfoVar = M.getOrInsertGlobal(("implementations_var_" + F.getName()).str(),
253+
Constant *TaskImplInfoVar = M.getOrInsertGlobal(("implementations_var_" + F.getName() + Twine(taskNum)).str(),
249254
ArrayType::get(TskImplInfoTy.Ty, 1),
250255
[&] {
251256
auto *outlineFuncCastTy = FunctionType::get(Type::getVoidTy(M.getContext()),
@@ -264,13 +269,13 @@ struct OmpSs : public ModulePass {
264269
ConstantPointerNull::get(TskImplInfoTy.Mmbers.TaskLabelTy->getPointerTo()),
265270
ConstantPointerNull::get(TskImplInfoTy.Mmbers.DeclSourceTy->getPointerTo()),
266271
ConstantPointerNull::get(TskImplInfoTy.Mmbers.RunWrapperFuncTy->getPointerTo()))),
267-
("implementations_var_" + F.getName()).str());
272+
("implementations_var_" + F.getName() + Twine(taskNum)).str());
268273

269274
GV->setAlignment(64);
270275
return GV;
271276
});
272277

273-
Constant *TaskInfoVar = M.getOrInsertGlobal(("task_info_var_" + F.getName()).str(),
278+
Constant *TaskInfoVar = M.getOrInsertGlobal(("task_info_var_" + F.getName() + Twine(taskNum)).str(),
274279
TskInfoTy.Ty,
275280
[&] {
276281
GlobalVariable *GV = new GlobalVariable(M, TskInfoTy.Ty,
@@ -287,12 +292,13 @@ struct OmpSs : public ModulePass {
287292
ConstantPointerNull::get(TskInfoTy.Mmbers.DuplicateArgsBlockFuncTy->getPointerTo()),
288293
ConstantPointerNull::get(TskInfoTy.Mmbers.ReductInitsFuncTy->getPointerTo()),
289294
ConstantPointerNull::get(TskInfoTy.Mmbers.ReductCombsFuncTy->getPointerTo())),
290-
("task_info_var_" + F.getName()).str());
295+
("task_info_var_" + F.getName() + Twine(taskNum)).str());
291296

292297
GV->setAlignment(64);
293298
return GV;
294299
});
295300

301+
// 4. Create nanos6_create_task nanos6_submit_task function types
296302
Function *CreateTaskFuncTy = cast<Function>(M.getOrInsertFunction("nanos6_create_task",
297303
Type::getVoidTy(M.getContext()),
298304
TskInfoTy.Ty->getPointerTo(),
@@ -307,7 +313,7 @@ struct OmpSs : public ModulePass {
307313
Type::getVoidTy(M.getContext()),
308314
Type::getInt8PtrTy(M.getContext())));
309315

310-
auto constructOmpSsFunctions = [&](BasicBlock *header,
316+
auto rewriteOutToInTaskBrAndGetOmpSsUnpackFunc = [&](BasicBlock *header,
311317
BasicBlock *newRootNode,
312318
BasicBlock *newHeader,
313319
Function *oldFunction,
@@ -316,11 +322,9 @@ struct OmpSs : public ModulePass {
316322

317323
unpackFuncVar->getBasicBlockList().push_back(newRootNode);
318324

319-
// Rewrite branches to basic blocks outside of the loop to new dummy blocks
320-
// within the new function. This must be done before we lose track of which
321-
// blocks were originally in the code region.
322-
// ?? FIXME: Parece que esto se usa para cambiar los branches al codigo que movemos
323-
// Por ej. br label %codeRepl
325+
// Rewrite branches from basic blocks outside of the task region to blocks
326+
// inside the region to use the new label (newHeader) since the task region
327+
// will be outlined
324328
std::vector<User *> Users(header->user_begin(), header->user_end());
325329
for (unsigned i = 0, e = Users.size(); i != e; ++i)
326330
// The BasicBlock which contains the branch is not in the region
@@ -330,19 +334,20 @@ struct OmpSs : public ModulePass {
330334
I->getParent()->getParent() == oldFunction)
331335
I->replaceUsesOfWith(header, newHeader);
332336

333-
334337
return unpackFuncVar;
335338
};
336-
auto emitCaptureAndCall = [&](Function *newFunction,
339+
auto emitOmpSsCaptureAndSubmitTask = [&](Function *newFunction,
337340
BasicBlock *codeReplacer,
338341
const SetVector<BasicBlock *> &Blocks) {
339342

340343
IRBuilder<> IRB(codeReplacer);
341344
Value *TaskArgsVar = IRB.CreateAlloca(TaskArgsTy->getPointerTo());
342345
Value *TaskArgsVarCast = IRB.CreateBitCast(TaskArgsVar, IRB.getInt8PtrTy()->getPointerTo());
346+
// TODO: For now TaskFlagsVar is hardcoded
343347
// Value *TaskFlagsVar = IRB.CreateAlloca(IRB.getInt64Ty());
344348
// IRB.CreateStore(ConstantInt::get(IRB.getInt64Ty(), 0), TaskFlagsVar);
345349
Value *TaskPtrVar = IRB.CreateAlloca(IRB.getInt8PtrTy());
350+
// TODO: For now TaskNumDepsVar is hardcoded
346351
// Value *TaskNumDepsVar = IRB.CreateAlloca(IRB.getInt64Ty());
347352
// IRB.CreateStore(ConstantInt::get(IRB.getInt64Ty(), 0), TaskNumDepsVar);
348353
uint64_t TaskArgsSizeOf = M.getDataLayout().getTypeAllocSize(TaskArgsTy);
@@ -363,7 +368,7 @@ struct OmpSs : public ModulePass {
363368
Idx[1] = ConstantInt::get(IRB.getInt32Ty(), TaskArgsIdx);
364369
Value *GEP = IRB.CreateGEP(
365370
TaskArgsVarL, Idx, "gep_" + TI.DSAInfo.Shared[i]->getName());
366-
Value *CaptureDSA = IRB.CreateStore(TI.DSAInfo.Shared[i], GEP);
371+
IRB.CreateStore(TI.DSAInfo.Shared[i], GEP);
367372
}
368373
TaskArgsIdx += TI.DSAInfo.Private.size();
369374
for (unsigned i = 0; i < TI.DSAInfo.Firstprivate.size(); ++i, ++TaskArgsIdx) {
@@ -373,49 +378,42 @@ struct OmpSs : public ModulePass {
373378
Value *GEP = IRB.CreateGEP(
374379
TaskArgsVarL, Idx, "gep_" + TI.DSAInfo.Firstprivate[i]->getName());
375380
Value *FPValue = IRB.CreateLoad(TI.DSAInfo.Firstprivate[i]);
376-
Value *CaptureDSA = IRB.CreateStore(FPValue, GEP);
381+
IRB.CreateStore(FPValue, GEP);
377382
}
378383

379384
Value *TaskPtrVarL = IRB.CreateLoad(TaskPtrVar);
380-
IRB.CreateCall(TaskSubmitFuncTy, TaskPtrVarL);
381-
382-
// Since there may be multiple exits from the original region, make the new
383-
// function return an unsigned, switch on that number. This loop iterates
384-
// over all of the blocks in the extracted region, updating any terminator
385-
// instructions in the to-be-extracted region that branch to blocks that are
386-
// not in the region to be extracted.
387-
std::map<BasicBlock *, BasicBlock *> ExitBlockMap;
385+
CallInst *TaskSubmitFuncCall = IRB.CreateCall(TaskSubmitFuncTy, TaskPtrVarL);
388386

389-
unsigned switchVal = 0;
387+
// Add a branch to the next basic block after the task region
388+
// and replace the terminator that exits the task region
389+
// Since this is a single entry single exit region this should
390+
// be done once.
391+
bool DoneOnce = false;
390392
for (BasicBlock *Block : Blocks) {
391393
Instruction *TI = Block->getTerminator();
392394
for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
393395
if (!Blocks.count(TI->getSuccessor(i))) {
396+
assert(!DoneOnce && "More than one exit in task code");
397+
DoneOnce = true;
398+
394399
BasicBlock *OldTarget = TI->getSuccessor(i);
395400

401+
// Create branch to next BB after the task region
396402
IRB.CreateBr(OldTarget);
397403

398-
// add a new basic block which returns the appropriate value
399-
BasicBlock *&NewTarget = ExitBlockMap[OldTarget];
400-
if (!NewTarget) {
401-
// If we don't already have an exit stub for this non-extracted
402-
// destination, create one now!
403-
NewTarget = BasicBlock::Create(M.getContext(),
404-
OldTarget->getName() + ".exitStub",
405-
newFunction);
406-
407-
ReturnInst::Create(M.getContext(), nullptr, NewTarget);
408-
}
409-
// rewrite the original branch instruction with this new target
410-
TI->setSuccessor(i, NewTarget);
404+
IRBuilder<> BNewTerminatorI(TI);
405+
BNewTerminatorI.CreateRetVoid();
411406
}
407+
if (DoneOnce)
408+
TI->eraseFromParent();
412409
}
413410

414-
return nullptr;
411+
return TaskSubmitFuncCall;
415412
};
416-
CodeExtractor CE(TaskBBs.getArrayRef(), constructOmpSsFunctions, emitCaptureAndCall);
417413

418-
Function *OutF = CE.extractCodeRegion();
414+
// 4. Extract region the way we want
415+
CodeExtractor CE(TaskBBs.getArrayRef(), rewriteOutToInTaskBrAndGetOmpSsUnpackFunc, emitOmpSsCaptureAndSubmitTask);
416+
CE.extractCodeRegion();
419417
}
420418

421419
bool runOnModule(Module &M) override {
@@ -521,8 +519,9 @@ struct OmpSs : public ModulePass {
521519
for (TaskwaitInfo& TwI : TwFI.PostOrder) {
522520
lowerTaskwait(TwI, M);
523521
}
522+
size_t taskNum = 0;
524523
for (TaskInfo TI : TFI.PostOrder) {
525-
lowerTask(TI, F, M);
524+
lowerTask(TI, F, taskNum++, M);
526525
}
527526

528527
}

llvm/lib/Transforms/Utils/CodeExtractor.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -249,15 +249,15 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs,
249249
BasicBlock *newHeader,
250250
Function *oldFunction,
251251
Module *M,
252-
const SetVector<BasicBlock *> &Blocks)> constructOmpSsFunctions,
252+
const SetVector<BasicBlock *> &Blocks)> rewriteOutToInTaskBrAndGetOmpSsUnpackFunc,
253253
std::function<CallInst*(Function *newFunction,
254254
BasicBlock *codeReplacer,
255-
const SetVector<BasicBlock *> &Blocks)> emitCaptureAndCall)
255+
const SetVector<BasicBlock *> &Blocks)> emitOmpSsCaptureAndSubmitTask)
256256
: DT(nullptr), AggregateArgs(false), BFI(nullptr),
257257
BPI(nullptr), AllowVarArgs(false),
258258
Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, /* AllowAlloca */ true)),
259-
constructOmpSsFunctions(constructOmpSsFunctions),
260-
emitCaptureAndCall(emitCaptureAndCall) {}
259+
rewriteOutToInTaskBrAndGetOmpSsUnpackFunc(rewriteOutToInTaskBrAndGetOmpSsUnpackFunc),
260+
emitOmpSsCaptureAndSubmitTask(emitOmpSsCaptureAndSubmitTask) {}
261261

262262
CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
263263
BlockFrequencyInfo *BFI,
@@ -1385,8 +1385,8 @@ Function *CodeExtractor::extractCodeRegion() {
13851385

13861386
// Construct new function based on inputs/outputs & add allocas for all defs.
13871387
Function *newFunction;
1388-
if (constructOmpSsFunctions) {
1389-
newFunction = constructOmpSsFunctions(header, newFuncRoot, codeReplacer,
1388+
if (rewriteOutToInTaskBrAndGetOmpSsUnpackFunc) {
1389+
newFunction = rewriteOutToInTaskBrAndGetOmpSsUnpackFunc(header, newFuncRoot, codeReplacer,
13901390
oldFunction, oldFunction->getParent(), Blocks);
13911391
} else {
13921392
newFunction = constructFunction(inputs, outputs, header, newFuncRoot, codeReplacer,
@@ -1403,8 +1403,8 @@ Function *CodeExtractor::extractCodeRegion() {
14031403
}
14041404

14051405
CallInst *TheCall;
1406-
if (emitCaptureAndCall) {
1407-
TheCall = emitCaptureAndCall(newFunction, codeReplacer, Blocks);
1406+
if (emitOmpSsCaptureAndSubmitTask) {
1407+
TheCall = emitOmpSsCaptureAndSubmitTask(newFunction, codeReplacer, Blocks);
14081408
} else {
14091409
TheCall = emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
14101410
}
@@ -1413,8 +1413,10 @@ Function *CodeExtractor::extractCodeRegion() {
14131413

14141414
// Replicate the effects of any lifetime start/end markers which referenced
14151415
// input objects in the extraction region by placing markers around the call.
1416-
insertLifetimeMarkersSurroundingCall(oldFunction->getParent(),
1417-
InputObjectsWithLifetime, TheCall);
1416+
// FIXME OmpSs: For now lets ignore lifetimes
1417+
if (!rewriteOutToInTaskBrAndGetOmpSsUnpackFunc && !emitOmpSsCaptureAndSubmitTask)
1418+
insertLifetimeMarkersSurroundingCall(oldFunction->getParent(),
1419+
InputObjectsWithLifetime, TheCall);
14181420

14191421
// Propagate personality info to the new function if there is one.
14201422
if (oldFunction->hasPersonalityFn())

0 commit comments

Comments
 (0)