-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[flang][mlir] Add support for translating task_reduction to LLVMIR #120957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -228,11 +228,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { | |||||
if (op.getThreadLimit()) | ||||||
result = todo("thread_limit"); | ||||||
}; | ||||||
auto checkTaskReduction = [&todo](auto op, LogicalResult &result) { | ||||||
if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() || | ||||||
op.getTaskReductionSyms()) | ||||||
result = todo("task_reduction"); | ||||||
}; | ||||||
auto checkUntied = [&todo](auto op, LogicalResult &result) { | ||||||
if (op.getUntied()) | ||||||
result = todo("untied"); | ||||||
|
@@ -259,10 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { | |||||
checkInReduction(op, result); | ||||||
checkPriority(op, result); | ||||||
}) | ||||||
.Case([&](omp::TaskgroupOp op) { | ||||||
checkAllocate(op, result); | ||||||
checkTaskReduction(op, result); | ||||||
}) | ||||||
.Case([&](omp::TaskgroupOp op) { checkAllocate(op, result); }) | ||||||
.Case([&](omp::TaskwaitOp op) { | ||||||
checkDepend(op, result); | ||||||
checkNowait(op, result); | ||||||
|
@@ -1787,16 +1779,264 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, | |||||
return success(); | ||||||
} | ||||||
|
||||||
template <typename OP> | ||||||
llvm::Value *createTaskReductionFunction( | ||||||
Comment on lines
+1782
to
+1783
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please could you write a brief documentation comment |
||||||
llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy, | ||||||
LLVM::ModuleTranslation &moduleTranslation, | ||||||
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region ®ion, | ||||||
OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit
Suggested change
|
||||||
SmallVectorImpl<llvm::Value *> &privateReductionVariables, | ||||||
DenseMap<Value, llvm::Value *> &reductionVariableMap) { | ||||||
llvm::LLVMContext &Context = builder.getContext(); | ||||||
llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0); | ||||||
if (region.empty()) { | ||||||
return llvm::Constant::getNullValue(OpaquePtrTy); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not very familiar with reductions, so to make sure, is it fine to set any of these function pointers to a null value in the final struct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do think so. Essentially, in case the @tblah Can you give an opinion on this? In case the region is empty, is it better to let an empty function body be present, or should we have an assertion in this place? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice spot Kareem! It looks like these pointers should not be null because init and combiner are called unconditionally by the runtime. E.g. llvm-project/openmp/runtime/src/kmp_tasking.cpp Line 2511 in 7e01a32
For normal reductions there will always be combiner and initializer regions. For The finalization does seem to be optional: https://github.com/llvm/llvm-project/blob/7e01a322f850e86be9eefde8ae5a30e532d22cfa/openmp/runtime/src/kmp_tasking.cpp#L2753C9-L2753C19 |
||||||
} | ||||||
Comment on lines
+1792
to
+1794
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
llvm::FunctionType *funcType = nullptr; | ||||||
if (isByRef[Cnt]) | ||||||
funcType = llvm::FunctionType::get(builder.getVoidTy(), | ||||||
{OpaquePtrTy, OpaquePtrTy}, false); | ||||||
else | ||||||
funcType = | ||||||
llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false); | ||||||
llvm::Function *function = | ||||||
llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the name of the generated function, I think it would nice for debuggability to either prefix or suffix the symbol name of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need external linkage for this? |
||||||
builder.GetInsertBlock()->getModule()); | ||||||
Comment on lines
+1803
to
+1804
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you consider using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was not aware of it frankly; can give it a try. |
||||||
function->setDoesNotRecurse(); | ||||||
llvm::BasicBlock *entry = | ||||||
llvm::BasicBlock::Create(Context, "entry", function); | ||||||
llvm::IRBuilder<> bbBuilder(entry); | ||||||
|
||||||
llvm::Value *arg0 = function->getArg(0); | ||||||
llvm::Value *arg1 = function->getArg(1); | ||||||
|
||||||
if (name == "red_init") { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be more consistent with the rest of the code here if you pass a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Thanks, will do so. |
||||||
function->addParamAttr(0, llvm::Attribute::NoAlias); | ||||||
function->addParamAttr(1, llvm::Attribute::NoAlias); | ||||||
if (isByRef[Cnt]) { | ||||||
// TODO: Handle case where the initializer uses initialization from | ||||||
// declare reduction construct using `arg1Alloca`. | ||||||
llvm::AllocaInst *arg0Alloca = | ||||||
bbBuilder.CreateAlloca(bbBuilder.getPtrTy()); | ||||||
llvm::AllocaInst *arg1Alloca = | ||||||
bbBuilder.CreateAlloca(bbBuilder.getPtrTy()); | ||||||
bbBuilder.CreateStore(arg0, arg0Alloca); | ||||||
bbBuilder.CreateStore(arg1, arg1Alloca); | ||||||
llvm::Value *LoadVal = | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this the same value as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This section of the code is for pass-by-references. Honestly, I tried to keep the logic here as close to Clang as possible:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For pass-by-value, a simple |
||||||
bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca); | ||||||
moduleTranslation.mapValue(reductionDecls[Cnt].getInitializerAllocArg(), | ||||||
LoadVal); | ||||||
} else { | ||||||
mapInitializationArgs(op, moduleTranslation, reductionDecls, | ||||||
reductionVariableMap, Cnt); | ||||||
} | ||||||
} else if (name == "red_comb") { | ||||||
if (isByRef[Cnt]) { | ||||||
llvm::AllocaInst *arg0Alloca = | ||||||
bbBuilder.CreateAlloca(bbBuilder.getPtrTy()); | ||||||
llvm::AllocaInst *arg1Alloca = | ||||||
bbBuilder.CreateAlloca(bbBuilder.getPtrTy()); | ||||||
bbBuilder.CreateStore(arg0, arg0Alloca); | ||||||
bbBuilder.CreateStore(arg1, arg1Alloca); | ||||||
llvm::Value *arg0L = | ||||||
bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca); | ||||||
llvm::Value *arg1L = | ||||||
bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg1Alloca); | ||||||
moduleTranslation.mapValue(region.front().getArgument(0), arg0L); | ||||||
moduleTranslation.mapValue(region.front().getArgument(1), arg1L); | ||||||
} else { | ||||||
llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0); | ||||||
llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1); | ||||||
moduleTranslation.mapValue(region.front().getArgument(0), arg0L); | ||||||
moduleTranslation.mapValue(region.front().getArgument(1), arg1L); | ||||||
} | ||||||
} | ||||||
|
||||||
SmallVector<llvm::Value *, 1> phis; | ||||||
if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation, | ||||||
&phis))) | ||||||
return nullptr; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be better to return a |
||||||
assert( | ||||||
phis.size() == 1 && | ||||||
"expected one value to be yielded from the reduction declaration region"); | ||||||
if (!isByRef[Cnt]) { | ||||||
bbBuilder.CreateStore(phis[0], arg0); | ||||||
bbBuilder.CreateRet(arg0); // Return from the function | ||||||
} else { | ||||||
bbBuilder.CreateRet(nullptr); | ||||||
} | ||||||
return function; | ||||||
} | ||||||
|
||||||
void emitTaskRedInitCall( | ||||||
llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, | ||||||
const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize, | ||||||
llvm::Value *ArrayAlloca) { | ||||||
llvm::LLVMContext &Context = builder.getContext(); | ||||||
uint32_t SrcLocStrSize; | ||||||
llvm::Constant *SrcLocStr = | ||||||
moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc, | ||||||
SrcLocStrSize); | ||||||
llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent( | ||||||
SrcLocStr, SrcLocStrSize); | ||||||
llvm::Value *ThreadID = | ||||||
moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident); | ||||||
llvm::Constant *ConstInt = | ||||||
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize); | ||||||
llvm::Function *TaskRedInitFn = | ||||||
moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr( | ||||||
llvm::omp::OMPRTL___kmpc_taskred_init); | ||||||
builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca}); | ||||||
} | ||||||
|
||||||
template <typename OP> | ||||||
static LogicalResult allocAndInitializeTaskReductionVars( | ||||||
OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder, | ||||||
LLVM::ModuleTranslation &moduleTranslation, | ||||||
llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, | ||||||
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, | ||||||
SmallVectorImpl<llvm::Value *> &privateReductionVariables, | ||||||
DenseMap<Value, llvm::Value *> &reductionVariableMap, | ||||||
llvm::ArrayRef<bool> isByRef) { | ||||||
|
||||||
if (op.getNumReductionVars() == 0) | ||||||
return success(); | ||||||
|
||||||
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); | ||||||
llvm::LLVMContext &Context = builder.getContext(); | ||||||
SmallVector<DeferredStore> deferredStores; | ||||||
|
||||||
// Save the current insertion point | ||||||
auto oldIP = builder.saveIP(); | ||||||
|
||||||
// Set insertion point after the allocations | ||||||
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator()); | ||||||
|
||||||
// Define the kmp_taskred_input_t structure | ||||||
llvm::StructType *kmp_taskred_input_t = | ||||||
llvm::StructType::create(Context, "kmp_taskred_input_t"); | ||||||
llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0); // void* | ||||||
llvm::Type *SizeTy = builder.getInt64Ty(); // size_t (assumed to be i64) | ||||||
llvm::Type *FlagsTy = llvm::Type::getInt32Ty(Context); // flags (i32) | ||||||
|
||||||
// Structure members | ||||||
std::vector<llvm::Type *> structMembers = { | ||||||
OpaquePtrTy, // reduce_shar (void*) | ||||||
OpaquePtrTy, // reduce_orig (void*) | ||||||
SizeTy, // reduce_size (size_t) | ||||||
OpaquePtrTy, // reduce_init (void*) | ||||||
OpaquePtrTy, // reduce_fini (void*) | ||||||
OpaquePtrTy, // reduce_comb (void*) | ||||||
FlagsTy // flags (i32) | ||||||
}; | ||||||
|
||||||
kmp_taskred_input_t->setBody(structMembers); | ||||||
int arraySize = op.getNumReductionVars(); | ||||||
llvm::ArrayType *ArrayTy = | ||||||
llvm::ArrayType::get(kmp_taskred_input_t, arraySize); | ||||||
|
||||||
// Allocate the array for kmp_taskred_input_t | ||||||
llvm::AllocaInst *ArrayAlloca = | ||||||
builder.CreateAlloca(ArrayTy, nullptr, "kmp_taskred_array"); | ||||||
|
||||||
// Restore the insertion point | ||||||
builder.restoreIP(oldIP); | ||||||
llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout(); | ||||||
|
||||||
for (int Cnt = 0; Cnt < arraySize; ++Cnt) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit
Suggested change
This file is under mlir/ so mlir coding standards apply. https://mlir.llvm.org/getting_started/DeveloperGuide/#style-guide |
||||||
llvm::Value *shared = | ||||||
moduleTranslation.lookupValue(op.getReductionVars()[Cnt]); | ||||||
// Create a GEP to access the reduction element | ||||||
llvm::Value *StructPtr = builder.CreateGEP( | ||||||
ArrayTy, ArrayAlloca, {builder.getInt32(0), builder.getInt32(Cnt)}, | ||||||
"red_element"); | ||||||
llvm::Value *FieldPtrReduceShar = builder.CreateStructGEP( | ||||||
kmp_taskred_input_t, StructPtr, 0, "reduce_shar"); | ||||||
builder.CreateStore(shared, FieldPtrReduceShar); | ||||||
|
||||||
llvm::Value *FieldPtrReduceOrig = builder.CreateStructGEP( | ||||||
kmp_taskred_input_t, StructPtr, 1, "reduce_orig"); | ||||||
builder.CreateStore(shared, FieldPtrReduceOrig); | ||||||
|
||||||
// Store size of the reduction variable | ||||||
llvm::Value *FieldPtrReduceSize = builder.CreateStructGEP( | ||||||
kmp_taskred_input_t, StructPtr, 2, "reduce_size"); | ||||||
llvm::Type *redTy; | ||||||
if (auto *alloca = dyn_cast<llvm::AllocaInst>(shared)) { | ||||||
redTy = alloca->getAllocatedType(); | ||||||
uint64_t sizeInBytes = DL.getTypeAllocSize(redTy); | ||||||
llvm::ConstantInt *sizeConst = | ||||||
llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes); | ||||||
builder.CreateStore(sizeConst, FieldPtrReduceSize); | ||||||
} else { | ||||||
llvm_unreachable("Non alloca instruction found."); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this temporary (a todo) or permenant? Can we simply use |
||||||
} | ||||||
Comment on lines
+1965
to
+1973
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This feels error-prone to me. There is currently nothing in the MLIR dialect requiring all SSA values passed into a reduction to be stack allocated. The MLIR OpenMP dialect has uses outside of flang. Aside from this, some other questions:
|
||||||
|
||||||
// Initialize reduction variable | ||||||
llvm::Value *FieldPtrReduceInit = builder.CreateStructGEP( | ||||||
kmp_taskred_input_t, StructPtr, 3, "reduce_init"); | ||||||
llvm::Value *initFunction = createTaskReductionFunction( | ||||||
builder, "red_init", redTy, moduleTranslation, reductionDecls, | ||||||
reductionDecls[Cnt].getInitializerRegion(), op, Cnt, isByRef, | ||||||
privateReductionVariables, reductionVariableMap); | ||||||
builder.CreateStore(initFunction, FieldPtrReduceInit); | ||||||
|
||||||
// Create finish and combine functions | ||||||
llvm::Value *FieldPtrReduceFini = builder.CreateStructGEP( | ||||||
kmp_taskred_input_t, StructPtr, 4, "reduce_fini"); | ||||||
llvm::Value *finiFunction = createTaskReductionFunction( | ||||||
builder, "red_fini", redTy, moduleTranslation, reductionDecls, | ||||||
reductionDecls[Cnt].getCleanupRegion(), op, Cnt, isByRef, | ||||||
privateReductionVariables, reductionVariableMap); | ||||||
builder.CreateStore(finiFunction, FieldPtrReduceFini); | ||||||
|
||||||
llvm::Value *FieldPtrReduceComb = builder.CreateStructGEP( | ||||||
kmp_taskred_input_t, StructPtr, 5, "reduce_comb"); | ||||||
llvm::Value *combFunction = createTaskReductionFunction( | ||||||
builder, "red_comb", redTy, moduleTranslation, reductionDecls, | ||||||
reductionDecls[Cnt].getReductionRegion(), op, Cnt, isByRef, | ||||||
privateReductionVariables, reductionVariableMap); | ||||||
builder.CreateStore(combFunction, FieldPtrReduceComb); | ||||||
|
||||||
llvm::Value *FieldPtrFlags = | ||||||
builder.CreateStructGEP(kmp_taskred_input_t, StructPtr, 6, "flags"); | ||||||
llvm::ConstantInt *flagVal = | ||||||
llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), 0); | ||||||
builder.CreateStore(flagVal, FieldPtrFlags); | ||||||
} | ||||||
|
||||||
// Emit the runtime call | ||||||
emitTaskRedInitCall(builder, moduleTranslation, ompLoc, arraySize, | ||||||
ArrayAlloca); | ||||||
return success(); | ||||||
} | ||||||
|
||||||
/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder. | ||||||
static LogicalResult | ||||||
convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, | ||||||
LLVM::ModuleTranslation &moduleTranslation) { | ||||||
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; | ||||||
if (failed(checkImplementationStatus(*tgOp))) | ||||||
return failure(); | ||||||
LogicalResult bodyGenStatus = success(); | ||||||
// Setup for `task_reduction` | ||||||
llvm::ArrayRef<bool> isByRef = getIsByRef(tgOp.getTaskReductionByref()); | ||||||
assert(isByRef.size() == tgOp.getNumReductionVars()); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the code above, we take different branches based on
Does this If I understood that correctly (i.e. that we only test the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, that is indeed the case. I had a conversation with Tom. With concerns about pointer lifetimes, it is best that we go ahead with by-val implementation for now, and wait for delayed privatization for task to get merged before merging a PR about by-ref. |
||||||
SmallVector<omp::DeclareReductionOp> reductionDecls; | ||||||
collectReductionDecls(tgOp, reductionDecls); | ||||||
SmallVector<llvm::Value *> privateReductionVariables( | ||||||
tgOp.getNumReductionVars()); | ||||||
DenseMap<Value, llvm::Value *> reductionVariableMap; | ||||||
MutableArrayRef<BlockArgument> reductionArgs = | ||||||
tgOp.getRegion().getArguments(); | ||||||
|
||||||
auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) { | ||||||
builder.restoreIP(codegenIP); | ||||||
if (failed(allocAndInitializeTaskReductionVars( | ||||||
tgOp, reductionArgs, builder, moduleTranslation, allocaIP, | ||||||
reductionDecls, privateReductionVariables, reductionVariableMap, | ||||||
isByRef))) | ||||||
bodyGenStatus = failure(); | ||||||
Comment on lines
+2035
to
+2039
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at |
||||||
return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region", | ||||||
builder, moduleTranslation) | ||||||
.takeError(); | ||||||
|
@@ -1812,7 +2052,7 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, | |||||
return failure(); | ||||||
|
||||||
builder.restoreIP(*afterIP); | ||||||
return success(); | ||||||
return bodyGenStatus; | ||||||
} | ||||||
|
||||||
static LogicalResult | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this can be
static