Skip to content

Commit c4b7bc7

Browse files
committed
Added customMapper error propagation. Updated test.
1 parent 826bc9d commit c4b7bc7

File tree

5 files changed

+92
-67
lines changed

5 files changed

+92
-67
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8880,17 +8880,17 @@ static void emitOffloadingArraysAndArgs(
88808880
};
88818881

88828882
auto CustomMapperCB = [&](unsigned int I) {
8883-
llvm::Value *MFunc = nullptr;
8883+
llvm::Function *MFunc = nullptr;
88848884
if (CombinedInfo.Mappers[I]) {
88858885
Info.HasMapper = true;
88868886
MFunc = CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc(
88878887
cast<OMPDeclareMapperDecl>(CombinedInfo.Mappers[I]));
88888888
}
88898889
return MFunc;
88908890
};
8891-
OMPBuilder.emitOffloadingArraysAndArgs(
8891+
cantFail(OMPBuilder.emitOffloadingArraysAndArgs(
88928892
AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, CustomMapperCB,
8893-
IsNonContiguous, ForEndCall, DeviceAddrCB);
8893+
IsNonContiguous, ForEndCall, DeviceAddrCB));
88948894
}
88958895

88968896
/// Check for inner distribute directive.
@@ -9083,26 +9083,25 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D,
90839083
return CombinedInfo;
90849084
};
90859085

9086-
auto CustomMapperCB = [&](unsigned I, llvm::Function **MapperFunc) {
9086+
auto CustomMapperCB = [&](unsigned I) {
9087+
llvm::Function *MapperFunc = nullptr;
90879088
if (CombinedInfo.Mappers[I]) {
90889089
// Call the corresponding mapper function.
9089-
*MapperFunc = getOrCreateUserDefinedMapperFunc(
9090+
MapperFunc = getOrCreateUserDefinedMapperFunc(
90909091
cast<OMPDeclareMapperDecl>(CombinedInfo.Mappers[I]));
9091-
assert(*MapperFunc && "Expect a valid mapper function is available.");
9092-
return true;
9092+
assert(MapperFunc && "Expect a valid mapper function is available.");
90939093
}
9094-
return false;
9094+
return MapperFunc;
90959095
};
90969096

90979097
SmallString<64> TyStr;
90989098
llvm::raw_svector_ostream Out(TyStr);
90999099
CGM.getCXXABI().getMangleContext().mangleCanonicalTypeName(Ty, Out);
91009100
std::string Name = getName({"omp_mapper", TyStr, D->getName()});
91019101

9102-
llvm::Expected<llvm::Function *> NewFn = OMPBuilder.emitUserDefinedMapper(
9103-
PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB);
9104-
assert(NewFn && "Unexpected error in emitUserDefinedMapper");
9105-
UDMMap.try_emplace(D, *NewFn);
9102+
llvm::Function *NewFn = cantFail(OMPBuilder.emitUserDefinedMapper(
9103+
PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB));
9104+
UDMMap.try_emplace(D, NewFn);
91069105
if (CGF)
91079106
FunctionUDMMap[CGF->CurFn].push_back(D);
91089107
}
@@ -10075,7 +10074,7 @@ void CGOpenMPRuntime::emitTargetDataCalls(
1007510074
};
1007610075

1007710076
auto CustomMapperCB = [&](unsigned int I) {
10078-
llvm::Value *MFunc = nullptr;
10077+
llvm::Function *MFunc = nullptr;
1007910078
if (CombinedInfo.Mappers[I]) {
1008010079
Info.HasMapper = true;
1008110080
MFunc = CGF.CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc(

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2408,6 +2408,11 @@ class OpenMPIRBuilder {
24082408
using EmitFallbackCallbackTy =
24092409
function_ref<InsertPointOrErrorTy(InsertPointTy)>;
24102410

2411+
// Callback function type for emitting and fetching user defined custom
2412+
// mappers.
2413+
using CustomMapperCallbackTy =
2414+
function_ref<Expected<Function *>(unsigned int)>;
2415+
24112416
/// Generate a target region entry call and host fallback call.
24122417
///
24132418
/// \param Loc The location at which the request originated and is fulfilled.
@@ -2474,9 +2479,9 @@ class OpenMPIRBuilder {
24742479
/// return nullptr by reference. Accepts a reference to a MapInfosTy object
24752480
/// that contains information generated for mappable clauses,
24762481
/// including base pointers, pointers, sizes, map types, user-defined mappers.
2477-
void emitOffloadingArrays(
2482+
Error emitOffloadingArrays(
24782483
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
2479-
TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
2484+
TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB,
24802485
bool IsNonContiguous = false,
24812486
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
24822487

@@ -2486,11 +2491,11 @@ class OpenMPIRBuilder {
24862491
/// library. In essence, this function is a combination of
24872492
/// emitOffloadingArrays and emitOffloadingArraysArgument and should arguably
24882493
/// be preferred by clients of OpenMPIRBuilder.
2489-
void emitOffloadingArraysAndArgs(
2494+
Error emitOffloadingArraysAndArgs(
24902495
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
24912496
TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
2492-
function_ref<Value *(unsigned int)> CustomMapperCB,
2493-
bool IsNonContiguous = false, bool ForEndCall = false,
2497+
CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous = false,
2498+
bool ForEndCall = false,
24942499
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
24952500

24962501
/// Creates offloading entry for the provided entry ID \a ID, address \a
@@ -2956,7 +2961,7 @@ class OpenMPIRBuilder {
29562961
InsertPointTy CodeGenIP, llvm::Value *PtrPHI, llvm::Value *BeginArg)>
29572962
PrivAndGenMapInfoCB,
29582963
llvm::Type *ElemTy, StringRef FuncName,
2959-
function_ref<bool(unsigned int, Function **)> CustomMapperCB);
2964+
CustomMapperCallbackTy CustomMapperCB);
29602965

29612966
/// Generator for '#omp target data'
29622967
///
@@ -2979,7 +2984,7 @@ class OpenMPIRBuilder {
29792984
const LocationDescription &Loc, InsertPointTy AllocaIP,
29802985
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
29812986
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
2982-
function_ref<Value *(unsigned int)> CustomMapperCB,
2987+
CustomMapperCallbackTy CustomMapperCB,
29832988
omp::RuntimeFunction *MapperFunc = nullptr,
29842989
function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
29852990
BodyGenTy BodyGenType)>
@@ -3028,7 +3033,7 @@ class OpenMPIRBuilder {
30283033
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
30293034
TargetBodyGenCallbackTy BodyGenCB,
30303035
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
3031-
function_ref<Value *(unsigned int)> CustomMapperCB,
3036+
CustomMapperCallbackTy CustomMapperCB,
30323037
SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
30333038

30343039
/// Returns __kmpc_for_static_init_* runtime function for the specified

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6549,8 +6549,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
65496549
const LocationDescription &Loc, InsertPointTy AllocaIP,
65506550
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
65516551
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
6552-
function_ref<Value *(unsigned int)> CustomMapperCB,
6553-
omp::RuntimeFunction *MapperFunc,
6552+
CustomMapperCallbackTy CustomMapperCB, omp::RuntimeFunction *MapperFunc,
65546553
function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
65556554
BodyGenTy BodyGenType)>
65566555
BodyGenCB,
@@ -6579,9 +6578,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
65796578
auto BeginThenGen = [&](InsertPointTy AllocaIP,
65806579
InsertPointTy CodeGenIP) -> Error {
65816580
MapInfo = &GenMapInfoCB(Builder.saveIP());
6582-
emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info,
6583-
CustomMapperCB,
6584-
/*IsNonContiguous=*/true, DeviceAddrCB);
6581+
if (Error Err = emitOffloadingArrays(
6582+
AllocaIP, Builder.saveIP(), *MapInfo, Info, CustomMapperCB,
6583+
/*IsNonContiguous=*/true, DeviceAddrCB))
6584+
return Err;
65856585

65866586
TargetDataRTArgs RTArgs;
65876587
emitOffloadingArraysArgument(Builder, RTArgs, Info);
@@ -7392,14 +7392,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
73927392
return Builder.saveIP();
73937393
}
73947394

7395-
void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
7395+
Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
73967396
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
73977397
TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
7398-
function_ref<Value *(unsigned int)> CustomMapperCB, bool IsNonContiguous,
7398+
CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous,
73997399
bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
7400-
emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, CustomMapperCB,
7401-
IsNonContiguous, DeviceAddrCB);
7400+
if (Error Err =
7401+
emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info,
7402+
CustomMapperCB, IsNonContiguous, DeviceAddrCB))
7403+
return Err;
74027404
emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
7405+
return Error::success();
74037406
}
74047407

74057408
static void
@@ -7411,7 +7414,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74117414
Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
74127415
SmallVectorImpl<Value *> &Args,
74137416
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7414-
function_ref<Value *(unsigned int)> CustomMapperCB,
7417+
OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB,
74157418
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies,
74167419
bool HasNoWait) {
74177420
// Generate a function call to the host fallback implementation of the target
@@ -7486,10 +7489,11 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74867489
OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
74877490
OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
74887491
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7489-
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
7490-
RTArgs, MapInfo, CustomMapperCB,
7491-
/*IsNonContiguous=*/true,
7492-
/*ForEndCall=*/false);
7492+
if (Error Err = OMPBuilder.emitOffloadingArraysAndArgs(
7493+
AllocaIP, Builder.saveIP(), Info, RTArgs, MapInfo, CustomMapperCB,
7494+
/*IsNonContiguous=*/true,
7495+
/*ForEndCall=*/false))
7496+
return Err;
74937497

74947498
SmallVector<Value *, 3> NumTeamsC;
74957499
for (auto [DefaultVal, RuntimeVal] :
@@ -7598,8 +7602,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
75987602
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
75997603
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
76007604
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7601-
function_ref<Value *(unsigned int)> CustomMapperCB,
7602-
SmallVector<DependData> Dependencies, bool HasNowait) {
7605+
CustomMapperCallbackTy CustomMapperCB, SmallVector<DependData> Dependencies,
7606+
bool HasNowait) {
76037607

76047608
if (!updateToLocation(Loc))
76057609
return InsertPointTy();
@@ -7951,8 +7955,7 @@ Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
79517955
function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
79527956
llvm::Value *BeginArg)>
79537957
GenMapInfoCB,
7954-
Type *ElemTy, StringRef FuncName,
7955-
function_ref<bool(unsigned int, Function **)> CustomMapperCB) {
7958+
Type *ElemTy, StringRef FuncName, CustomMapperCallbackTy CustomMapperCB) {
79567959
SmallVector<Type *> Params;
79577960
Params.emplace_back(Builder.getPtrTy());
79587961
Params.emplace_back(Builder.getPtrTy());
@@ -8132,17 +8135,19 @@ Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
81328135

81338136
Value *OffloadingArgs[] = {MapperHandle, CurBaseArg, CurBeginArg,
81348137
CurSizeArg, CurMapType, CurNameArg};
8135-
Function *ChildMapperFn = nullptr;
8136-
if (CustomMapperCB && CustomMapperCB(I, &ChildMapperFn)) {
8138+
8139+
auto ChildMapperFn = CustomMapperCB(I);
8140+
if (!ChildMapperFn)
8141+
return ChildMapperFn.takeError();
8142+
if (*ChildMapperFn)
81378143
// Call the corresponding mapper function.
8138-
Builder.CreateCall(ChildMapperFn, OffloadingArgs)->setDoesNotThrow();
8139-
} else {
8144+
Builder.CreateCall(*ChildMapperFn, OffloadingArgs)->setDoesNotThrow();
8145+
else
81408146
// Call the runtime API __tgt_push_mapper_component to fill up the runtime
81418147
// data structure.
81428148
Builder.CreateCall(
81438149
getOrCreateRuntimeFunction(M, OMPRTL___tgt_push_mapper_component),
81448150
OffloadingArgs);
8145-
}
81468151
}
81478152

81488153
// Update the pointer to point to the next element that needs to be mapped,
@@ -8169,9 +8174,9 @@ Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
81698174
return MapperFn;
81708175
}
81718176

8172-
void OpenMPIRBuilder::emitOffloadingArrays(
8177+
Error OpenMPIRBuilder::emitOffloadingArrays(
81738178
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
8174-
TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
8179+
TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB,
81758180
bool IsNonContiguous,
81768181
function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
81778182

@@ -8180,7 +8185,7 @@ void OpenMPIRBuilder::emitOffloadingArrays(
81808185
Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
81818186

81828187
if (Info.NumberOfPtrs == 0)
8183-
return;
8188+
return Error::success();
81848189

81858190
Builder.restoreIP(AllocaIP);
81868191
// Detect if we have any capture size requiring runtime evaluation of the
@@ -8344,9 +8349,13 @@ void OpenMPIRBuilder::emitOffloadingArrays(
83448349
// Fill up the mapper array.
83458350
unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
83468351
Value *MFunc = ConstantPointerNull::get(PtrTy);
8347-
if (CustomMapperCB)
8348-
if (Value *CustomMFunc = CustomMapperCB(I))
8349-
MFunc = Builder.CreatePointerCast(CustomMFunc, PtrTy);
8352+
8353+
auto CustomMFunc = CustomMapperCB(I);
8354+
if (!CustomMFunc)
8355+
return CustomMFunc.takeError();
8356+
if (*CustomMFunc)
8357+
MFunc = Builder.CreatePointerCast(*CustomMFunc, PtrTy);
8358+
83508359
Value *MAddr = Builder.CreateInBoundsGEP(
83518360
MappersArray->getAllocatedType(), MappersArray,
83528361
{Builder.getIntN(IndexSize, 0), Builder.getIntN(IndexSize, I)});
@@ -8356,8 +8365,9 @@ void OpenMPIRBuilder::emitOffloadingArrays(
83568365

83578366
if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
83588367
Info.NumberOfPtrs == 0)
8359-
return;
8368+
return Error::success();
83608369
emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
8370+
return Error::success();
83618371
}
83628372

83638373
void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3608,16 +3608,17 @@ emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
36083608
return combinedInfo;
36093609
};
36103610

3611-
auto customMapperCB = [&](unsigned i, llvm::Function **mapperFunc) {
3611+
auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> {
3612+
llvm::Function *mapperFunc = nullptr;
36123613
if (combinedInfo.Mappers[i]) {
36133614
// Call the corresponding mapper function.
36143615
llvm::Expected<llvm::Function *> newFn = getOrCreateUserDefinedMapperFunc(
36153616
combinedInfo.Mappers[i], builder, moduleTranslation);
3616-
assert(newFn && "Expect a valid mapper function is available");
3617-
*mapperFunc = *newFn;
3618-
return true;
3617+
if (!newFn)
3618+
return newFn.takeError();
3619+
mapperFunc = *newFn;
36193620
}
3620-
return false;
3621+
return mapperFunc;
36213622
};
36223623

36233624
llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
@@ -3840,13 +3841,15 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
38403841
return builder.saveIP();
38413842
};
38423843

3843-
auto customMapperCB = [&](unsigned int i) {
3844+
auto customMapperCB =
3845+
[&](unsigned int i) -> llvm::Expected<llvm::Function *> {
38443846
llvm::Function *mapperFunc = nullptr;
38453847
if (combinedInfo.Mappers[i]) {
38463848
info.HasMapper = true;
38473849
llvm::Expected<llvm::Function *> newFn = getOrCreateUserDefinedMapperFunc(
38483850
combinedInfo.Mappers[i], builder, moduleTranslation);
3849-
assert(newFn && "Expect a valid mapper function is available");
3851+
if (!newFn)
3852+
return newFn.takeError();
38503853
mapperFunc = *newFn;
38513854
}
38523855
return mapperFunc;
@@ -4551,13 +4554,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
45514554
/*RequiresDevicePointerInfo=*/false,
45524555
/*SeparateBeginEndCalls=*/true);
45534556

4554-
auto customMapperCB = [&](unsigned int i) {
4555-
llvm::Value *mapperFunc = nullptr;
4557+
auto customMapperCB =
4558+
[&](unsigned int i) -> llvm::Expected<llvm::Function *> {
4559+
llvm::Function *mapperFunc = nullptr;
45564560
if (combinedInfos.Mappers[i]) {
45574561
info.HasMapper = true;
45584562
llvm::Expected<llvm::Function *> newFn = getOrCreateUserDefinedMapperFunc(
45594563
combinedInfos.Mappers[i], builder, moduleTranslation);
4560-
assert(newFn && "Expect a valid mapper function is available");
4564+
if (!newFn)
4565+
return newFn.takeError();
45614566
mapperFunc = *newFn;
45624567
}
45634568
return mapperFunc;

offload/test/offloading/fortran/target-custom-mapper.f90

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,32 @@ program test_openmp_mapper
1010
integer :: data(n)
1111
end type mytype
1212

13+
type :: mytype2
14+
type(mytype) :: my_data
15+
end type mytype2
16+
1317
! Declare custom mappers for the derived type `mytype`
14-
!$omp declare mapper(my_mapper1 : mytype :: t) map(to: t%data)
15-
!$omp declare mapper(my_mapper2 : mytype :: t) map(mapper(my_mapper1): t%data)
18+
!$omp declare mapper(my_mapper1 : mytype :: t) map(to: t%data(1 : n))
19+
20+
! Declare custom mappers for the derived type `mytype2`
21+
!$omp declare mapper(my_mapper2 : mytype2 :: t) map(mapper(my_mapper1): t%my_data)
1622

17-
type(mytype) :: obj
23+
type(mytype2) :: obj
1824
integer :: i, sum_host, sum_device
1925

2026
! Initialize the host data
2127
do i = 1, n
22-
obj%data(i) = 1
28+
obj%my_data%data(i) = 1
2329
end do
2430

2531
! Compute the sum on the host for verification
26-
sum_host = sum(obj%data)
32+
sum_host = sum(obj%my_data%data)
2733

2834
! Offload computation to the device using the named mapper `my_mapper2`
2935
sum_device = 0
3036
!$omp target map(tofrom: sum_device) map(mapper(my_mapper2) : obj)
3137
do i = 1, n
32-
sum_device = sum_device + obj%data(i)
38+
sum_device = sum_device + obj%my_data%data(i)
3339
end do
3440
!$omp end target
3541

0 commit comments

Comments
 (0)