Skip to content

[flang][mlir] Add support for translating task_reduction to LLVMIR #120957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,12 @@ class OpenMP_TaskReductionClauseSkip<
unsigned numTaskReductionBlockArgs() {
return getTaskReductionVars().size();
}

/// Returns the number of reduction variables.
unsigned getNumReductionVars() { return getReductionVars().size(); }


auto getReductionSyms() { return getTaskReductionSyms(); }
}];

let description = [{
Expand Down
260 changes: 250 additions & 10 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.getThreadLimit())
result = todo("thread_limit");
};
auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
op.getTaskReductionSyms())
result = todo("task_reduction");
};
auto checkUntied = [&todo](auto op, LogicalResult &result) {
if (op.getUntied())
result = todo("untied");
Expand All @@ -259,10 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkInReduction(op, result);
checkPriority(op, result);
})
.Case([&](omp::TaskgroupOp op) {
checkAllocate(op, result);
checkTaskReduction(op, result);
})
.Case([&](omp::TaskgroupOp op) { checkAllocate(op, result); })
.Case([&](omp::TaskwaitOp op) {
checkDepend(op, result);
checkNowait(op, result);
Expand Down Expand Up @@ -1787,16 +1779,264 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
return success();
}

template <typename OP>
llvm::Value *createTaskReductionFunction(
Comment on lines +1782 to +1783
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this can be static

Comment on lines +1782 to +1783
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please could you write a brief documentation comment

llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
LLVM::ModuleTranslation &moduleTranslation,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region &region,
OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
OP &op, unsigned cnt, llvm::ArrayRef<bool> &isByRef,

SmallVectorImpl<llvm::Value *> &privateReductionVariables,
DenseMap<Value, llvm::Value *> &reductionVariableMap) {
llvm::LLVMContext &Context = builder.getContext();
llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
if (region.empty()) {
return llvm::Constant::getNullValue(OpaquePtrTy);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not very familiar with reductions, so to make sure, is it fine to set any of these function pointers to a null value in the final struct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think so. Essentially, in case the region is empty, we do not want to generate the functions.

@tblah Can you give an opinion on this? In case the region is empty, is it better to let an empty function body be present, or should we have an assertion in this place?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice spot Kareem!

It looks like these pointers should not be null because init and combiner are called unconditionally by the runtime. E.g.

template <typename T> void __kmp_call_init(kmp_taskred_data_t &item, size_t j);
.

For normal reductions there will always be combiner and initializer regions. For declare reduction the initializer expression is optional so I think one could write legal code without an init region. The combiner expression is required.

The finalization does seem to be optional: https://github.com/llvm/llvm-project/blob/7e01a322f850e86be9eefde8ae5a30e532d22cfa/openmp/runtime/src/kmp_tasking.cpp#L2753C9-L2753C19

}
Comment on lines +1792 to +1794
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm::FunctionType *funcType = nullptr;
if (isByRef[Cnt])
funcType = llvm::FunctionType::get(builder.getVoidTy(),
{OpaquePtrTy, OpaquePtrTy}, false);
else
funcType =
llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
llvm::Function *function =
llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the name of the generated function, I think it would nice for debuggability to either prefix or suffix the symbol name of the DeclareReductionOp op.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need external linkage for this?

builder.GetInsertBlock()->getModule());
Comment on lines +1803 to +1804
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you consider using OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr? What made you decide to do it manually?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not aware of it frankly; can give it a try.

function->setDoesNotRecurse();
llvm::BasicBlock *entry =
llvm::BasicBlock::Create(Context, "entry", function);
llvm::IRBuilder<> bbBuilder(entry);

llvm::Value *arg0 = function->getArg(0);
llvm::Value *arg1 = function->getArg(1);

if (name == "red_init") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be more consistent with the rest of the code here if you pass a bodyGenCB lambda. This will also enable you to clean up the interface for createTaskReductionFunction (you can get rid of a number of parameters and just capture what you need in the code-gen callback.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks, will do so.

function->addParamAttr(0, llvm::Attribute::NoAlias);
function->addParamAttr(1, llvm::Attribute::NoAlias);
if (isByRef[Cnt]) {
// TODO: Handle case where the initializer uses initialization from
// declare reduction construct using `arg1Alloca`.
llvm::AllocaInst *arg0Alloca =
bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
llvm::AllocaInst *arg1Alloca =
bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
bbBuilder.CreateStore(arg0, arg0Alloca);
bbBuilder.CreateStore(arg1, arg1Alloca);
llvm::Value *LoadVal =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this the same value as arg0? Why the need for loads, stores, and allocas? I must be missing something, just want to understand more. Same applies to the red_comd region.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section of the code is for pass-by-references. Honestly, I tried to keep the logic here as close to Clang as possible:

define void @red_init(ptr noalias %arg0, ptr noalias %arg1) #2 {
  entry:
   %0 = alloca ptr, align 8
   %1 = alloca ptr, align 8
   store ptr %arg0, ptr %0, align 8
   store ptr %arg1, ptr %1, align 8
   %2 = load ptr, ptr %0, align 8
   store i32 0, ptr %2, align 4
   ret void
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For pass-by-value, a simple mapInitializationArgs suffices

bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
moduleTranslation.mapValue(reductionDecls[Cnt].getInitializerAllocArg(),
LoadVal);
} else {
mapInitializationArgs(op, moduleTranslation, reductionDecls,
reductionVariableMap, Cnt);
}
} else if (name == "red_comb") {
if (isByRef[Cnt]) {
llvm::AllocaInst *arg0Alloca =
bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
llvm::AllocaInst *arg1Alloca =
bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
bbBuilder.CreateStore(arg0, arg0Alloca);
bbBuilder.CreateStore(arg1, arg1Alloca);
llvm::Value *arg0L =
bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
llvm::Value *arg1L =
bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg1Alloca);
moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
} else {
llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
}
}

SmallVector<llvm::Value *, 1> phis;
if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
&phis)))
return nullptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to return a LogicalResult and pass an output parameter for the created function value. This is less error prone compared to directly using the return value of this function to store it somewhere for example like we do below.

assert(
phis.size() == 1 &&
"expected one value to be yielded from the reduction declaration region");
if (!isByRef[Cnt]) {
bbBuilder.CreateStore(phis[0], arg0);
bbBuilder.CreateRet(arg0); // Return from the function
} else {
bbBuilder.CreateRet(nullptr);
}
return function;
}

void emitTaskRedInitCall(
llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
llvm::Value *ArrayAlloca) {
llvm::LLVMContext &Context = builder.getContext();
uint32_t SrcLocStrSize;
llvm::Constant *SrcLocStr =
moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc,
SrcLocStrSize);
llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(
SrcLocStr, SrcLocStrSize);
llvm::Value *ThreadID =
moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
llvm::Constant *ConstInt =
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize);
llvm::Function *TaskRedInitFn =
moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
llvm::omp::OMPRTL___kmpc_taskred_init);
builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
}

template <typename OP>
static LogicalResult allocAndInitializeTaskReductionVars(
OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
SmallVectorImpl<llvm::Value *> &privateReductionVariables,
DenseMap<Value, llvm::Value *> &reductionVariableMap,
llvm::ArrayRef<bool> isByRef) {

if (op.getNumReductionVars() == 0)
return success();

llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::LLVMContext &Context = builder.getContext();
SmallVector<DeferredStore> deferredStores;

// Save the current insertion point
auto oldIP = builder.saveIP();

// Set insertion point after the allocations
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());

// Define the kmp_taskred_input_t structure
llvm::StructType *kmp_taskred_input_t =
llvm::StructType::create(Context, "kmp_taskred_input_t");
llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0); // void*
llvm::Type *SizeTy = builder.getInt64Ty(); // size_t (assumed to be i64)
llvm::Type *FlagsTy = llvm::Type::getInt32Ty(Context); // flags (i32)

// Structure members
std::vector<llvm::Type *> structMembers = {
OpaquePtrTy, // reduce_shar (void*)
OpaquePtrTy, // reduce_orig (void*)
SizeTy, // reduce_size (size_t)
OpaquePtrTy, // reduce_init (void*)
OpaquePtrTy, // reduce_fini (void*)
OpaquePtrTy, // reduce_comb (void*)
FlagsTy // flags (i32)
};

kmp_taskred_input_t->setBody(structMembers);
int arraySize = op.getNumReductionVars();
llvm::ArrayType *ArrayTy =
llvm::ArrayType::get(kmp_taskred_input_t, arraySize);

// Allocate the array for kmp_taskred_input_t
llvm::AllocaInst *ArrayAlloca =
builder.CreateAlloca(ArrayTy, nullptr, "kmp_taskred_array");

// Restore the insertion point
builder.restoreIP(oldIP);
llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();

for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
for (int cnt = 0; cnt < arraySize; ++cnt) {

This file is under mlir/ so mlir coding standards apply. https://mlir.llvm.org/getting_started/DeveloperGuide/#style-guide

llvm::Value *shared =
moduleTranslation.lookupValue(op.getReductionVars()[Cnt]);
// Create a GEP to access the reduction element
llvm::Value *StructPtr = builder.CreateGEP(
ArrayTy, ArrayAlloca, {builder.getInt32(0), builder.getInt32(Cnt)},
"red_element");
llvm::Value *FieldPtrReduceShar = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 0, "reduce_shar");
builder.CreateStore(shared, FieldPtrReduceShar);

llvm::Value *FieldPtrReduceOrig = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 1, "reduce_orig");
builder.CreateStore(shared, FieldPtrReduceOrig);

// Store size of the reduction variable
llvm::Value *FieldPtrReduceSize = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 2, "reduce_size");
llvm::Type *redTy;
if (auto *alloca = dyn_cast<llvm::AllocaInst>(shared)) {
redTy = alloca->getAllocatedType();
uint64_t sizeInBytes = DL.getTypeAllocSize(redTy);
llvm::ConstantInt *sizeConst =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
builder.CreateStore(sizeConst, FieldPtrReduceSize);
} else {
llvm_unreachable("Non alloca instruction found.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this temporary (a todo) or permenant? Can we simply use moduleTranslation.convertType(op.getReductionVars()[Cnt].getType()) (or something to the same effect)?

}
Comment on lines +1965 to +1973
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels error-prone to me. There is currently nothing in the MLIR dialect requiring all SSA values passed into a reduction to be stack allocated.

The MLIR OpenMP dialect has uses outside of flang.

Aside from this, some other questions:

  • What if the allocation size is not constant?
  • What if the allocation is for an array?


// Initialize reduction variable
llvm::Value *FieldPtrReduceInit = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 3, "reduce_init");
llvm::Value *initFunction = createTaskReductionFunction(
builder, "red_init", redTy, moduleTranslation, reductionDecls,
reductionDecls[Cnt].getInitializerRegion(), op, Cnt, isByRef,
privateReductionVariables, reductionVariableMap);
builder.CreateStore(initFunction, FieldPtrReduceInit);

// Create finish and combine functions
llvm::Value *FieldPtrReduceFini = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 4, "reduce_fini");
llvm::Value *finiFunction = createTaskReductionFunction(
builder, "red_fini", redTy, moduleTranslation, reductionDecls,
reductionDecls[Cnt].getCleanupRegion(), op, Cnt, isByRef,
privateReductionVariables, reductionVariableMap);
builder.CreateStore(finiFunction, FieldPtrReduceFini);

llvm::Value *FieldPtrReduceComb = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 5, "reduce_comb");
llvm::Value *combFunction = createTaskReductionFunction(
builder, "red_comb", redTy, moduleTranslation, reductionDecls,
reductionDecls[Cnt].getReductionRegion(), op, Cnt, isByRef,
privateReductionVariables, reductionVariableMap);
builder.CreateStore(combFunction, FieldPtrReduceComb);

llvm::Value *FieldPtrFlags =
builder.CreateStructGEP(kmp_taskred_input_t, StructPtr, 6, "flags");
llvm::ConstantInt *flagVal =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), 0);
builder.CreateStore(flagVal, FieldPtrFlags);
}

// Emit the runtime call
emitTaskRedInitCall(builder, moduleTranslation, ompLoc, arraySize,
ArrayAlloca);
return success();
}

/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
if (failed(checkImplementationStatus(*tgOp)))
return failure();
LogicalResult bodyGenStatus = success();
// Setup for `task_reduction`
llvm::ArrayRef<bool> isByRef = getIsByRef(tgOp.getTaskReductionByref());
assert(isByRef.size() == tgOp.getNumReductionVars());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the code above, we take different branches based on isByRef for each individual reduction; for example:

  if (isByRef[Cnt])
    funcType = llvm::FunctionType::get(builder.getVoidTy(),
                                       {OpaquePtrTy, OpaquePtrTy}, false);
  else
    funcType =
        llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);

Does this assert mean that the PR does not contain any examples where isByRef == false for at least some of the reductions?

If I understood that correctly (i.e. that we only test the isByRef[i] == true in this PR, I think it would be better to move the isByRef[i] == false code paths and add additional testing for by-value in a follow-up PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that is indeed the case.

I had a conversation with Tom. With concerns about pointer lifetimes, it is best that we go ahead with by-val implementation for now, and wait for delayed privatization for task to get merged before merging a PR about by-ref.

SmallVector<omp::DeclareReductionOp> reductionDecls;
collectReductionDecls(tgOp, reductionDecls);
SmallVector<llvm::Value *> privateReductionVariables(
tgOp.getNumReductionVars());
DenseMap<Value, llvm::Value *> reductionVariableMap;
MutableArrayRef<BlockArgument> reductionArgs =
tgOp.getRegion().getArguments();

auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
builder.restoreIP(codegenIP);
if (failed(allocAndInitializeTaskReductionVars(
tgOp, reductionArgs, builder, moduleTranslation, allocaIP,
reductionDecls, privateReductionVariables, reductionVariableMap,
isByRef)))
bodyGenStatus = failure();
Comment on lines +2035 to +2039
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at allocAndInitializeTaskReductionVars's implementation, looks like we can use llvm::cantFail(...) to wrap this call.

return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
builder, moduleTranslation)
.takeError();
Expand All @@ -1812,7 +2052,7 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
return failure();

builder.restoreIP(*afterIP);
return success();
return bodyGenStatus;
}

static LogicalResult
Expand Down
Loading
Loading