@@ -3554,27 +3554,27 @@ emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
3554
3554
LLVM::ModuleTranslation &moduleTranslation);
3555
3555
3556
3556
static llvm::Expected<llvm::Function *>
3557
- getOrCreateUserDefinedMapperFunc(Operation *declMapperOp,
3558
- llvm::IRBuilderBase &builder,
3557
+ getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
3559
3558
LLVM::ModuleTranslation &moduleTranslation) {
3560
- static llvm::DenseMap<const Operation *, llvm::Function *> userDefMapperMap;
3561
- auto iter = userDefMapperMap.find(declMapperOp);
3562
- if (iter != userDefMapperMap.end())
3563
- return iter->second;
3559
+ auto declMapperOp = cast<omp::DeclareMapperOp>(op);
3560
+ std::string mapperFuncName =
3561
+ moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
3562
+ {"omp_mapper", declMapperOp.getSymName()});
3563
+ if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName))
3564
+ return lookupFunc;
3565
+
3564
3566
llvm::Expected<llvm::Function *> mapperFunc =
3565
3567
emitUserDefinedMapper(declMapperOp, builder, moduleTranslation);
3566
3568
if (!mapperFunc)
3567
3569
return mapperFunc.takeError();
3568
- userDefMapperMap.try_emplace(declMapperOp, *mapperFunc);
3569
3570
return mapperFunc;
3570
3571
}
3571
3572
3572
3573
static llvm::Expected<llvm::Function *>
3573
3574
emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
3574
3575
LLVM::ModuleTranslation &moduleTranslation) {
3575
3576
auto declMapperOp = cast<omp::DeclareMapperOp>(op);
3576
- auto declMapperInfoOp =
3577
- *declMapperOp.getOps<omp::DeclareMapperInfoOp>().begin();
3577
+ auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
3578
3578
DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>());
3579
3579
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3580
3580
llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType());
@@ -3590,7 +3590,7 @@ emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
3590
3590
[&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
3591
3591
llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
3592
3592
builder.restoreIP(codeGenIP);
3593
- moduleTranslation.mapValue(declMapperOp.getRegion().getArgument(0 ), ptrPHI);
3593
+ moduleTranslation.mapValue(declMapperOp.getSymVal( ), ptrPHI);
3594
3594
moduleTranslation.mapBlock(&declMapperOp.getRegion().front(),
3595
3595
builder.GetInsertBlock());
3596
3596
if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(),
@@ -3857,10 +3857,11 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
3857
3857
findAllocaInsertPoint(builder, moduleTranslation);
3858
3858
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
3859
3859
if (isa<omp::TargetDataOp>(op))
3860
- return ompBuilder->createTargetData(
3861
- ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID),
3862
- ifCond, info, genMapInfoCB, customMapperCB, nullptr, bodyGenCB,
3863
- /*DeviceAddrCB=*/nullptr);
3860
+ return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
3861
+ builder.getInt64(deviceID), ifCond,
3862
+ info, genMapInfoCB, customMapperCB,
3863
+ /*MapperFunc=*/nullptr, bodyGenCB,
3864
+ /*DeviceAddrCB=*/nullptr);
3864
3865
return ompBuilder->createTargetData(
3865
3866
ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
3866
3867
info, genMapInfoCB, customMapperCB, &RTLFn);
@@ -4546,25 +4547,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
4546
4547
findAllocaInsertPoint(builder, moduleTranslation);
4547
4548
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4548
4549
4549
- llvm::OpenMPIRBuilder::TargetDataInfo info(
4550
- /*RequiresDevicePointerInfo=*/false,
4551
- /*SeparateBeginEndCalls=*/true);
4552
- llvm::Value *ifCond = nullptr;
4553
- if (Value targetIfCond = targetOp.getIfExpr())
4554
- ifCond = moduleTranslation.lookupValue(targetIfCond);
4555
-
4556
- auto customMapperCB = [&](unsigned int i) {
4557
- llvm::Value *mapperFunc = nullptr;
4558
- if (combinedInfos.Mappers[i]) {
4559
- info.HasMapper = true;
4560
- llvm::Expected<llvm::Function *> newFn = getOrCreateUserDefinedMapperFunc(
4561
- combinedInfos.Mappers[i], builder, moduleTranslation);
4562
- assert(newFn && "Expect a valid mapper function is available");
4563
- mapperFunc = *newFn;
4564
- }
4565
- return mapperFunc;
4566
- };
4567
-
4568
4550
llvm::OpenMPIRBuilder::TargetDataInfo info(
4569
4551
/*RequiresDevicePointerInfo=*/false,
4570
4552
/*SeparateBeginEndCalls=*/true);
@@ -4588,8 +4570,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
4588
4570
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4589
4571
moduleTranslation.getOpenMPBuilder()->createTarget(
4590
4572
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
4591
- defaultValTeams, defaultValThreads , kernelInput, genMapInfoCB, bodyCB,
4592
- argAccessorCB, dds, targetOp.getNowait());
4573
+ defaultAttrs, runtimeAttrs, ifCond , kernelInput, genMapInfoCB, bodyCB,
4574
+ argAccessorCB, customMapperCB, dds, targetOp.getNowait());
4593
4575
4594
4576
if (failed(handleError(afterIP, opInst)))
4595
4577
return failure();
0 commit comments