Skip to content

[OpenMP][MLIR] Lowering task_reduction clause to LLVMIR #111788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,13 @@ class OpenMP_TaskReductionClauseSkip<
unsigned numTaskReductionBlockArgs() {
return getTaskReductionVars().size();
}

/// Returns the number of reduction variables.
unsigned getNumReductionVars() { return getReductionVars().size(); }

auto getReductionSyms() {
return getTaskReductionSyms();
}
}];

let description = [{
Expand Down
228 changes: 223 additions & 5 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1112,9 +1112,8 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
LogicalResult bodyGenStatus = success();
if (taskOp.getUntiedAttr() || taskOp.getMergeableAttr() ||
taskOp.getInReductionSyms() || taskOp.getPriority() ||
!taskOp.getAllocateVars().empty() || !taskOp.getPrivateVars().empty() ||
taskOp.getPrivateSyms()) {
taskOp.getPriority() || !taskOp.getAllocateVars().empty() ||
!taskOp.getPrivateVars().empty() || taskOp.getPrivateSyms()) {
return taskOp.emitError("unhandled clauses for translation to LLVM IR");
}
auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
Expand Down Expand Up @@ -1142,19 +1141,238 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
return bodyGenStatus;
}

template <typename OP>
llvm::Value *createTaskReductionFunction(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This should be static and requires a comment.

llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
llvm::IRBuilderBase &builder, StringRef name, llvm::Type *redTy,

LLVM::ModuleTranslation &moduleTranslation,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region &region,
OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
SmallVectorImpl<llvm::Value *> &privateReductionVariables,
DenseMap<Value, llvm::Value *> &reductionVariableMap) {
llvm::LLVMContext &Context = builder.getContext();
llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
// TODO: by-ref reduction variables are yet to be handled.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it better to crash/assert than silently swallowing cases like this?

if (region.empty() || isByRef[Cnt]) {
return llvm::Constant::getNullValue(OpaquePtrTy);
}
Comment on lines +1155 to +1157
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (region.empty() || isByRef[Cnt]) {
return llvm::Constant::getNullValue(OpaquePtrTy);
}
if (region.empty() || isByRef[Cnt])
return llvm::Constant::getNullValue(OpaquePtrTy);

llvm::FunctionType *funcType =
llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
Comment on lines +1158 to +1159
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
llvm::FunctionType *funcType =
llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
auto *funcType =
llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);

Nit: We use auto when the type is explicitly given on the RHS already.

llvm::Function *function =
llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
builder.GetInsertBlock()->getModule());
function->setDoesNotRecurse();
llvm::BasicBlock *entry =
llvm::BasicBlock::Create(Context, "entry", function);
llvm::IRBuilder<> bbBuilder(entry);

llvm::Value *arg0 = function->getArg(0);
llvm::Value *arg1 = function->getArg(1);

if (name == "red_init") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Please avoid magic strings. Introduce global constexpr StringLiterals for such things.

function->addParamAttr(0, llvm::Attribute::NoAlias);
function->addParamAttr(1, llvm::Attribute::NoAlias);
mapInitializationArgs(op, moduleTranslation, reductionDecls,
reductionVariableMap, Cnt);
} else if (name == "red_comb") {
llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
}

SmallVector<llvm::Value *, 1> phis;
if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Adding a block name makes this a lot easier to debug if the region is translated into multiple basic blocks.

&phis)))
return nullptr;
assert(
phis.size() == 1 &&
"expected one value to be yielded from the reduction declaration region");

bbBuilder.CreateStore(phis[0], arg0);
bbBuilder.CreateRet(arg0); // Return from the function
return function;
}

void emitTaskRedInitCall(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: As above, this should be static and requires a comment.

llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
llvm::Value *ArrayAlloca) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
llvm::Value *ArrayAlloca) {
llvm::Value *arrayAlloca) {

Ultra nit: Uppercase beginning variables violate the MLIR style guide.


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

llvm::LLVMContext &Context = builder.getContext();
uint32_t SrcLocStrSize;
llvm::Constant *SrcLocStr =
moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc,
SrcLocStrSize);
llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(
SrcLocStr, SrcLocStrSize);
llvm::Value *ThreadID =
moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
llvm::Constant *ConstInt =
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize);

llvm::Function *TaskRedInitFn =
moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
llvm::omp::OMPRTL___kmpc_taskred_init);
builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
Comment on lines +1213 to +1216
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non-task reductions these sorts of function calls are generated by OpenMPIRBuilder so that we can share code with clang.

Are the clang people happy with us having diverging implementations here? If so I don't mind.

}

template <typename OP>
static LogicalResult allocAndInitializeTaskReductionVars(
OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
SmallVectorImpl<llvm::Value *> &privateReductionVariables,
DenseMap<Value, llvm::Value *> &reductionVariableMap,
llvm::ArrayRef<bool> isByRef) {

if (op.getNumReductionVars() == 0)
return success();

llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::LLVMContext &Context = builder.getContext();
SmallVector<DeferredStore> deferredStores;

// Save the current insertion point
auto oldIP = builder.saveIP();

// Set insertion point after the allocations
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());

// Define the kmp_taskred_input_t structure
llvm::StructType *kmp_taskred_input_t =
llvm::StructType::create(Context, "kmp_taskred_input_t");
llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0); // void*
llvm::Type *SizeTy = builder.getInt64Ty(); // size_t (assumed to be i64)
llvm::Type *FlagsTy = llvm::Type::getInt32Ty(Context); // flags (i32)

// Structure members
std::vector<llvm::Type *> structMembers = {
OpaquePtrTy, // reduce_shar (void*)
OpaquePtrTy, // reduce_orig (void*)
SizeTy, // reduce_size (size_t)
OpaquePtrTy, // reduce_init (void*)
OpaquePtrTy, // reduce_fini (void*)
OpaquePtrTy, // reduce_comb (void*)
FlagsTy // flags (i32)
};

kmp_taskred_input_t->setBody(structMembers);
int arraySize = op.getNumReductionVars();
llvm::ArrayType *ArrayTy =
llvm::ArrayType::get(kmp_taskred_input_t, arraySize);

// Allocate the array for kmp_taskred_input_t
llvm::AllocaInst *ArrayAlloca =
builder.CreateAlloca(ArrayTy, nullptr, "kmp_taskred_array");

// Restore the insertion point
builder.restoreIP(oldIP);
llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();

for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
llvm::Value *shared =
moduleTranslation.lookupValue(op.getReductionVars()[Cnt]);
// Create a GEP to access the reduction element
llvm::Value *StructPtr = builder.CreateGEP(
ArrayTy, ArrayAlloca, {builder.getInt32(0), builder.getInt32(Cnt)},
"red_element");

llvm::Value *FieldPtrReduceShar = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 0, "reduce_shar");
builder.CreateStore(shared, FieldPtrReduceShar);

llvm::Value *FieldPtrReduceOrig = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 1, "reduce_orig");
builder.CreateStore(shared, FieldPtrReduceOrig);

// Store size of the reduction variable
llvm::Value *FieldPtrReduceSize = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 2, "reduce_size");
llvm::Type *redTy;
if (auto *alloca = dyn_cast<llvm::AllocaInst>(shared)) {
redTy = alloca->getAllocatedType();
uint64_t sizeInBytes = DL.getTypeAllocSize(redTy);

llvm::ConstantInt *sizeConst =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
builder.CreateStore(sizeConst, FieldPtrReduceSize);
} else {
llvm_unreachable("Non alloca instruction found.");
}

// Initialize reduction variable
llvm::Value *FieldPtrReduceInit = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 3, "reduce_init");
llvm::Value *initFunction = createTaskReductionFunction(
builder, "red_init", redTy, moduleTranslation, reductionDecls,
reductionDecls[Cnt].getInitializerRegion(), op, Cnt, isByRef,
privateReductionVariables, reductionVariableMap);
builder.CreateStore(initFunction, FieldPtrReduceInit);

// Create finish and combine functions
llvm::Value *FieldPtrReduceFini = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 4, "reduce_fini");
llvm::Value *finiFunction = createTaskReductionFunction(
builder, "red_fini", redTy, moduleTranslation, reductionDecls,
reductionDecls[Cnt].getCleanupRegion(), op, Cnt, isByRef,
privateReductionVariables, reductionVariableMap);
builder.CreateStore(finiFunction, FieldPtrReduceFini);

llvm::Value *FieldPtrReduceComb = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 5, "reduce_comb");
llvm::Value *combFunction = createTaskReductionFunction(
builder, "red_comb", redTy, moduleTranslation, reductionDecls,
reductionDecls[Cnt].getReductionRegion(), op, Cnt, isByRef,
privateReductionVariables, reductionVariableMap);
builder.CreateStore(combFunction, FieldPtrReduceComb);

llvm::Value *FieldPtrFlags =
builder.CreateStructGEP(kmp_taskred_input_t, StructPtr, 6, "flags");
llvm::ConstantInt *flagVal =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), 0);
builder.CreateStore(flagVal, FieldPtrFlags);
}

// Emit the runtime call
emitTaskRedInitCall(builder, moduleTranslation, ompLoc, arraySize,
ArrayAlloca);
return success();
}

/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
LogicalResult bodyGenStatus = success();
if (!tgOp.getTaskReductionVars().empty() || !tgOp.getAllocateVars().empty()) {
if (!tgOp.getAllocateVars().empty()) {
return tgOp.emitError("unhandled clauses for translation to LLVM IR");
}

llvm::ArrayRef<bool> isByRef = getIsByRef(tgOp.getTaskReductionByref());
assert(isByRef.size() == tgOp.getNumReductionVars());

SmallVector<omp::DeclareReductionOp> reductionDecls;
collectReductionDecls(tgOp, reductionDecls);
SmallVector<llvm::Value *> privateReductionVariables(
tgOp.getNumReductionVars());
DenseMap<Value, llvm::Value *> reductionVariableMap;

MutableArrayRef<BlockArgument> reductionArgs =
tgOp.getRegion().getArguments();

auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
builder.restoreIP(codegenIP);

if (failed(allocAndInitializeTaskReductionVars(
tgOp, reductionArgs, builder, moduleTranslation, allocaIP,
reductionDecls, privateReductionVariables, reductionVariableMap,
isByRef)))
bodyGenStatus = failure();
SmallVector<llvm::PHINode *> phis;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The phis seem unused here. Can the task reduction region omp.yield any values or is omp.terminator the only terminator?

convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region", builder,
moduleTranslation, bodyGenStatus);
moduleTranslation, bodyGenStatus, &phis);
};
InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
Expand Down
79 changes: 79 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

omp.declare_reduction @add_reduction_i32 : i32 init {
^bb0(%arg0: i32):
%0 = llvm.mlir.constant(0 : i32) : i32
omp.yield(%0 : i32)
} combiner {
^bb0(%arg0: i32, %arg1: i32):
%0 = llvm.add %arg0, %arg1 : i32
omp.yield(%0 : i32)
}
llvm.func @_QPtest_task_reduciton() {
%0 = llvm.mlir.constant(1 : i64) : i64
%1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
omp.taskgroup task_reduction(@add_reduction_i32 %1 -> %arg0 : !llvm.ptr) {
%2 = llvm.load %1 : !llvm.ptr -> i32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is only test code but shouldn't this be using the block argument?

%3 = llvm.mlir.constant(1 : i32) : i32
%4 = llvm.add %2, %3 : i32
llvm.store %4, %1 : i32, !llvm.ptr
omp.terminator
}
llvm.return
}

//CHECK-LABEL: define void @_QPtest_task_reduciton() {
//CHECK: %[[VAL1:.*]] = alloca i32, i64 1, align 4
//CHECK: %[[RED_ARRY:.*]] = alloca [1 x %kmp_taskred_input_t], align 8
//CHECK: br label %entry

//CHECK: entry:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a bit odd. Normally, "entry" is really the first block of a function.

//CHECK: %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
//CHECK: call void @__kmpc_taskgroup(ptr @1, i32 %[[TID]])
//CHECK: %[[RED_ELEMENT:.*]] = getelementptr [1 x %kmp_taskred_input_t], ptr %[[RED_ARRY]], i32 0, i32 0
//CHECK: %[[RED_SHARED:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 0
//CHECK: store ptr %[[VAL1]], ptr %[[RED_SHARED]], align 8
//CHECK: %[[RED_ORIG:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 1
//CHECK: store ptr %[[VAL1]], ptr %[[RED_ORIG]], align 8
//CHECK: %[[RED_SIZE:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 2
//CHECK: store i64 4, ptr %[[RED_SIZE]], align 4
//CHECK: %[[RED_INIT:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 3
//CHECK: store ptr @red_init, ptr %[[RED_INIT]], align 8
//CHECK: %[[RED_FINI:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 4
//CHECK: store ptr null, ptr %[[RED_FINI]], align 8
//CHECK: %[[RED_COMB:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 5
//CHECK: store ptr @red_comb, ptr %[[RED_COMB]], align 8
//CHECK: %[[FLAGS:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 6
//CHECK: store i64 0, ptr %[[FLAGS]], align 4
//CHECK: %[[TID1:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
//CHECK: %2 = call ptr @__kmpc_taskred_init(i32 %[[TID1]], i32 1, ptr %[[RED_ARRY]])
//CHECK: br label %omp.taskgroup.region

//CHECK: omp.taskgroup.region:
//CHECK: %[[VAL3:.*]] = load i32, ptr %[[VAL1]], align 4
//CHECK: %4 = add i32 %[[VAL3]], 1
//CHECK: store i32 %4, ptr %[[VAL1]], align 4
//CHECK: br label %omp.region.cont

//CHECK: omp.region.cont:
//CHECK: br label %taskgroup.exit

//CHECK: taskgroup.exit:
//CHECK: call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[TID]])
//CHECK: ret void
//CHECK: }

//CHECK-LABEL: define ptr @red_init(ptr noalias %0, ptr noalias %1) #2 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name could collide with symbols in the program. clang uses .red_init. (and likewise .red_comb.).

//CHECK: entry:
//CHECK: store i32 0, ptr %0, align 4
//CHECK: ret ptr %0
//CHECK: }

//CHECK-LABEL: define ptr @red_comb(ptr %0, ptr %1) #2 {
//CHECK: entry:
//CHECK: %[[LD0:.*]] = load i32, ptr %0, align 4
//CHECK: %[[LD1:.*]] = load i32, ptr %1, align 4
//CHECK: %[[RES:.*]] = add i32 %[[LD0]], %[[LD1]]
//CHECK: store i32 %[[RES]], ptr %0, align 4
//CHECK: ret ptr %0
//CHECK: }
Loading