Skip to content

Commit ed1fb2d

Browse files
committed
[MLIR][OpenMP] Introduce overlapped record type map support
This PR introduces a new additional type of map lowering for record types that Clang currently supports, in which a user can map a top-level record type and then individual members with different mapping, effectively creating a sort of "overlapping" mapping that we attempt to cut around. This is currently most predominantly used in Fortran, when mapping descriptors and there data, we map the descriptor and its data with separate map modifiers and "cut around" the pointer data, so that wedo not overwrite it unless the runtime deems it a neccesary action based on its reference counting mechanism. However, it is a mechanism that will come in handy/trigger when a user explitily maps a record type (derived type or structure) and then explicitly maps a member with a different map type. These additions were predominantly in the OpenMPToLLVMIRTranslation.cpp file and phase, however, one Flang test that checks end-to-end IR compilation (as far as we care for now at least) was altered. 2/3 required PRs to enable declare target to mapping, should look at PR 3/3 to check for full green passes (this one will fail a number due to some dependencies). Co-authored-by: Raghu Maddhipatla [email protected]
1 parent d18c299 commit ed1fb2d

File tree

6 files changed

+463
-154
lines changed

6 files changed

+463
-154
lines changed

flang/test/Integration/OpenMP/map-types-and-sizes.f90

Lines changed: 77 additions & 53 deletions
Large diffs are not rendered by default.

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 203 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2979,39 +2979,61 @@ static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
29792979
return std::distance(mapData.MapClause.begin(), res);
29802980
}
29812981

2982-
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
2983-
bool first) {
2984-
ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
2985-
// Only 1 member has been mapped, we can return it.
2986-
if (indexAttr.size() == 1)
2987-
return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
2982+
static void sortMapIndices(llvm::SmallVector<size_t> &indices,
2983+
mlir::omp::MapInfoOp mapInfo,
2984+
bool ascending = true) {
2985+
mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
2986+
if (indexAttr.empty() || indexAttr.size() == 1 || indices.empty() ||
2987+
indices.size() == 1)
2988+
return;
29882989

2989-
llvm::SmallVector<size_t> indices(indexAttr.size());
2990-
std::iota(indices.begin(), indices.end(), 0);
2990+
llvm::sort(
2991+
indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
2992+
auto memberIndicesA = mlir::cast<mlir::ArrayAttr>(indexAttr[a]);
2993+
auto memberIndicesB = mlir::cast<mlir::ArrayAttr>(indexAttr[b]);
2994+
2995+
size_t smallestMember = memberIndicesA.size() < memberIndicesB.size()
2996+
? memberIndicesA.size()
2997+
: memberIndicesB.size();
29912998

2992-
llvm::sort(indices.begin(), indices.end(),
2993-
[&](const size_t a, const size_t b) {
2994-
auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
2995-
auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
2996-
for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
2997-
int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
2998-
int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
2999+
for (size_t i = 0; i < smallestMember; ++i) {
3000+
int64_t aIndex =
3001+
mlir::cast<mlir::IntegerAttr>(memberIndicesA.getValue()[i])
3002+
.getInt();
3003+
int64_t bIndex =
3004+
mlir::cast<mlir::IntegerAttr>(memberIndicesB.getValue()[i])
3005+
.getInt();
29993006

3000-
if (aIndex == bIndex)
3001-
continue;
3007+
if (aIndex == bIndex)
3008+
continue;
30023009

3003-
if (aIndex < bIndex)
3004-
return first;
3010+
if (aIndex < bIndex)
3011+
return ascending;
30053012

3006-
if (aIndex > bIndex)
3007-
return !first;
3008-
}
3013+
if (aIndex > bIndex)
3014+
return !ascending;
3015+
}
30093016

3010-
// Iterated the up until the end of the smallest member and
3011-
// they were found to be equal up to that point, so select
3012-
// the member with the lowest index count, so the "parent"
3013-
return memberIndicesA.size() < memberIndicesB.size();
3014-
});
3017+
// Iterated up until the end of the smallest member and
3018+
// they were found to be equal up to that point, so select
3019+
// the member with the lowest index count, so the "parent"
3020+
return memberIndicesA.size() < memberIndicesB.size();
3021+
});
3022+
}
3023+
3024+
static mlir::omp::MapInfoOp
3025+
getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
3026+
mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
3027+
// Only 1 member has been mapped, we can return it.
3028+
if (indexAttr.size() == 1)
3029+
if (auto mapOp =
3030+
dyn_cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp()))
3031+
return mapOp;
3032+
3033+
llvm::SmallVector<size_t> indices;
3034+
indices.resize(indexAttr.size());
3035+
std::iota(indices.begin(), indices.end(), 0);
3036+
sortMapIndices(indices, mapInfo, first);
30153037

30163038
return llvm::cast<omp::MapInfoOp>(
30173039
mapInfo.getMembers()[indices.front()].getDefiningOp());
@@ -3110,6 +3132,91 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
31103132
return idx;
31113133
}
31123134

3135+
// Gathers members that are overlapping in the parent, excluding members that
3136+
// themselves overlap, keeping the top-most (closest to parents level) map.
3137+
static void getOverlappedMembers(llvm::SmallVector<size_t> &overlapMapDataIdxs,
3138+
MapInfoData &mapData,
3139+
omp::MapInfoOp parentOp) {
3140+
// No members mapped, no overlaps.
3141+
if (parentOp.getMembers().empty())
3142+
return;
3143+
3144+
// Single member, we can insert and return early.
3145+
if (parentOp.getMembers().size() == 1) {
3146+
overlapMapDataIdxs.push_back(0);
3147+
return;
3148+
}
3149+
3150+
// 1) collect list of top-level overlapping members from MemberOp
3151+
llvm::SmallVector<std::pair<int, mlir::ArrayAttr>> memberByIndex;
3152+
mlir::ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
3153+
for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
3154+
memberByIndex.push_back(
3155+
std::make_pair(memIndex, mlir::cast<mlir::ArrayAttr>(indicesAttr)));
3156+
3157+
// Sort the smallest first (higher up the parent -> member chain), so that
3158+
// when we remove members, we remove as much as we can in the initial
3159+
// iterations, shortening the number of passes required.
3160+
llvm::sort(memberByIndex.begin(), memberByIndex.end(),
3161+
[&](auto a, auto b) { return a.second.size() < b.second.size(); });
3162+
3163+
auto getAsIntegers = [](mlir::ArrayAttr values) {
3164+
llvm::SmallVector<int64_t> ints;
3165+
ints.reserve(values.size());
3166+
llvm::transform(values, std::back_inserter(ints),
3167+
[](mlir::Attribute value) {
3168+
return mlir::cast<mlir::IntegerAttr>(value).getInt();
3169+
});
3170+
return ints;
3171+
};
3172+
3173+
// Remove elements from the vector if there is a parent element that
3174+
// supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
3175+
// [0,2].. etc.
3176+
for (auto v : make_early_inc_range(memberByIndex)) {
3177+
auto vArr = getAsIntegers(v.second);
3178+
memberByIndex.erase(
3179+
std::remove_if(memberByIndex.begin(), memberByIndex.end(),
3180+
[&](auto x) {
3181+
if (v == x)
3182+
return false;
3183+
3184+
auto xArr = getAsIntegers(x.second);
3185+
return std::equal(vArr.begin(), vArr.end(),
3186+
xArr.begin()) &&
3187+
xArr.size() >= vArr.size();
3188+
}),
3189+
memberByIndex.end());
3190+
}
3191+
3192+
// Collect the indices from mapData that we need, as we technically need the
3193+
// base pointer etc. info, which is stored in there and primarily accessible
3194+
// via index at the moment.
3195+
for (auto v : memberByIndex)
3196+
overlapMapDataIdxs.push_back(v.first);
3197+
}
3198+
3199+
// The intent is to verify if the mapped data being passed is a
3200+
// pointer -> pointee that requires special handling in certain cases,
3201+
// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
3202+
//
3203+
// There may be a better way to verify this, but unfortunately with
3204+
// opaque pointers we lose the ability to easily check if something is
3205+
// a pointer whilst maintaining access to the underlying type.
3206+
static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
3207+
// If we have a varPtrPtr field assigned then the underlying type is a pointer
3208+
if (mapOp.getVarPtrPtr())
3209+
return true;
3210+
3211+
// If the map data is declare target with a link clause, then it's represented
3212+
// as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
3213+
// no relation to pointers.
3214+
if (isDeclareTargetLink(mapOp.getVarPtr()))
3215+
return true;
3216+
3217+
return false;
3218+
}
3219+
31133220
// This creates two insertions into the MapInfosTy data structure for the
31143221
// "parent" of a set of members, (usually a container e.g.
31153222
// class/structure/derived type) when subsequent members have also been
@@ -3150,7 +3257,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
31503257
// runtime information on the dynamically allocated data).
31513258
auto parentClause =
31523259
llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3153-
31543260
llvm::Value *lowAddr, *highAddr;
31553261
if (!parentClause.getPartialMap()) {
31563262
lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
@@ -3197,37 +3303,77 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
31973303
// what we support as expected.
31983304
llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
31993305
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3200-
combinedInfo.Types.emplace_back(mapFlag);
3201-
combinedInfo.DevicePointers.emplace_back(
3202-
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
3203-
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
3204-
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3205-
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3206-
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3207-
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
3208-
}
3209-
return memberOfFlag;
3210-
}
3211-
3212-
// The intent is to verify if the mapped data being passed is a
3213-
// pointer -> pointee that requires special handling in certain cases,
3214-
// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
3215-
//
3216-
// There may be a better way to verify this, but unfortunately with
3217-
// opaque pointers we lose the ability to easily check if something is
3218-
// a pointer whilst maintaining access to the underlying type.
3219-
static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
3220-
// If we have a varPtrPtr field assigned then the underlying type is a pointer
3221-
if (mapOp.getVarPtrPtr())
3222-
return true;
32233306

3224-
// If the map data is declare target with a link clause, then it's represented
3225-
// as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
3226-
// no relation to pointers.
3227-
if (isDeclareTargetLink(mapOp.getVarPtr()))
3228-
return true;
3307+
if (targetDirective == TargetDirective::TargetUpdate) {
3308+
combinedInfo.Types.emplace_back(mapFlag);
3309+
combinedInfo.DevicePointers.emplace_back(
3310+
mapData.DevicePointers[mapDataIndex]);
3311+
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
3312+
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3313+
combinedInfo.BasePointers.emplace_back(
3314+
mapData.BasePointers[mapDataIndex]);
3315+
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3316+
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
3317+
} else {
3318+
llvm::SmallVector<size_t> overlapIdxs;
3319+
// Find all of the members that "overlap", i.e. occlude other members that
3320+
// were mapped alongside the parent, e.g. member [0], occludes
3321+
getOverlappedMembers(overlapIdxs, mapData, parentClause);
3322+
// We need to make sure the overlapped members are sorted in order of
3323+
// lowest address to highest address
3324+
sortMapIndices(overlapIdxs, parentClause);
3325+
3326+
lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
3327+
builder.getPtrTy());
3328+
highAddr = builder.CreatePointerCast(
3329+
builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
3330+
mapData.Pointers[mapDataIndex], 1),
3331+
builder.getPtrTy());
3332+
3333+
// TODO: We may want to skip arrays/array sections in this as Clang does
3334+
// so it appears to be an optimisation rather than a neccessity though,
3335+
// but this requires further investigation. However, we would have to make
3336+
// sure to not exclude maps with bounds that ARE pointers, as these are
3337+
// processed as seperate components, i.e. pointer + data.
3338+
for (auto v : overlapIdxs) {
3339+
auto mapDataOverlapIdx = getMapDataMemberIdx(
3340+
mapData,
3341+
cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
3342+
combinedInfo.Types.emplace_back(mapFlag);
3343+
combinedInfo.DevicePointers.emplace_back(
3344+
mapData.DevicePointers[mapDataOverlapIdx]);
3345+
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
3346+
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3347+
combinedInfo.BasePointers.emplace_back(
3348+
mapData.BasePointers[mapDataIndex]);
3349+
combinedInfo.Pointers.emplace_back(lowAddr);
3350+
combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
3351+
builder.CreatePtrDiff(builder.getInt8Ty(),
3352+
mapData.OriginalValue[mapDataOverlapIdx],
3353+
lowAddr),
3354+
builder.getInt64Ty(), /*isSigned=*/true));
3355+
lowAddr = builder.CreateConstGEP1_32(
3356+
checkIfPointerMap(llvm::cast<omp::MapInfoOp>(
3357+
mapData.MapClause[mapDataOverlapIdx]))
3358+
? builder.getPtrTy()
3359+
: mapData.BaseType[mapDataOverlapIdx],
3360+
mapData.BasePointers[mapDataOverlapIdx], 1);
3361+
}
32293362

3230-
return false;
3363+
combinedInfo.Types.emplace_back(mapFlag);
3364+
combinedInfo.DevicePointers.emplace_back(
3365+
mapData.DevicePointers[mapDataIndex]);
3366+
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
3367+
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3368+
combinedInfo.BasePointers.emplace_back(
3369+
mapData.BasePointers[mapDataIndex]);
3370+
combinedInfo.Pointers.emplace_back(lowAddr);
3371+
combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
3372+
builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
3373+
builder.getInt64Ty(), true));
3374+
}
3375+
}
3376+
return memberOfFlag;
32313377
}
32323378

32333379
// This function is intended to add explicit mappings of members

mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,18 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
6767

6868
// CHECK: define void @mix_use_device_ptr_and_addr_and_map_(ptr %[[ARG_0:.*]], ptr %[[ARG_1:.*]], ptr %[[ARG_2:.*]], ptr %[[ARG_3:.*]], ptr %[[ARG_4:.*]], ptr %[[ARG_5:.*]], ptr %[[ARG_6:.*]], ptr %[[ARG_7:.*]]) {
6969
// CHECK: %[[ALLOCA:.*]] = alloca ptr, align 8
70-
// CHECK: %[[BASEPTR_0_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
70+
// CHECK: %[[BASEPTR_0_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
7171
// CHECK: store ptr %[[ARG_0]], ptr %[[BASEPTR_0_GEP]], align 8
72-
// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 2
72+
// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 4
7373
// CHECK: store ptr %[[ARG_2]], ptr %[[BASEPTR_2_GEP]], align 8
74-
// CHECK: %[[BASEPTR_6_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 6
75-
// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_6_GEP]], align 8
74+
// CHECK: %[[BASEPTR_3_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 9
75+
// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_3_GEP]], align 8
7676

7777
// CHECK: call void @__tgt_target_data_begin_mapper({{.*}})
7878
// CHECK: %[[LOAD_BASEPTR_0:.*]] = load ptr, ptr %[[BASEPTR_0_GEP]], align 8
7979
// store ptr %[[LOAD_BASEPTR_0]], ptr %[[ALLOCA]], align 8
8080
// CHECK: %[[LOAD_BASEPTR_2:.*]] = load ptr, ptr %[[BASEPTR_2_GEP]], align 8
81-
// CHECK: %[[LOAD_BASEPTR_6:.*]] = load ptr, ptr %[[BASEPTR_6_GEP]], align 8
81+
// CHECK: %[[LOAD_BASEPTR_3:.*]] = load ptr, ptr %[[BASEPTR_3_GEP]], align 8
8282
// CHECK: %[[GEP_A4:.*]] = getelementptr { i64 }, ptr %[[ARG_4]], i32 0, i32 0
8383
// CHECK: %[[GEP_A7:.*]] = getelementptr { i64 }, ptr %[[ARG_7]], i32 0, i32 0
8484
// CHECK: %[[LOAD_A4:.*]] = load i64, ptr %[[GEP_A4]], align 4
@@ -93,17 +93,17 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
9393

9494
// CHECK: define void @mix_use_device_ptr_and_addr_and_map_2(ptr %[[ARG_0:.*]], ptr %[[ARG_1:.*]], ptr %[[ARG_2:.*]], ptr %[[ARG_3:.*]], ptr %[[ARG_4:.*]], ptr %[[ARG_5:.*]], ptr %[[ARG_6:.*]], ptr %[[ARG_7:.*]]) {
9595
// CHECK: %[[ALLOCA:.*]] = alloca ptr, align 8
96-
// CHECK: %[[BASEPTR_1_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
96+
// CHECK: %[[BASEPTR_1_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
9797
// CHECK: store ptr %[[ARG_0]], ptr %[[BASEPTR_1_GEP]], align 8
98-
// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 2
98+
// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 4
9999
// CHECK: store ptr %[[ARG_2]], ptr %[[BASEPTR_2_GEP]], align 8
100-
// CHECK: %[[BASEPTR_6_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 6
101-
// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_6_GEP]], align 8
100+
// CHECK: %[[BASEPTR_3_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 9
101+
// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_3_GEP]], align 8
102102
// CHECK: call void @__tgt_target_data_begin_mapper({{.*}})
103103
// CHECK: %[[LOAD_BASEPTR_1:.*]] = load ptr, ptr %[[BASEPTR_1_GEP]], align 8
104104
// store ptr %[[LOAD_BASEPTR_1]], ptr %[[ALLOCA]], align 8
105105
// CHECK: %[[LOAD_BASEPTR_2:.*]] = load ptr, ptr %[[BASEPTR_2_GEP]], align 8
106-
// CHECK: %[[LOAD_BASEPTR_6:.*]] = load ptr, ptr %[[BASEPTR_6_GEP]], align 8
106+
// CHECK: %[[LOAD_BASEPTR_3:.*]] = load ptr, ptr %[[BASEPTR_3_GEP]], align 8
107107
// CHECK: %[[GEP_A4:.*]] = getelementptr { i64 }, ptr %[[ARG_4]], i32 0, i32 0
108108
// CHECK: %[[GEP_A7:.*]] = getelementptr { i64 }, ptr %[[ARG_7]], i32 0, i32 0
109109
// CHECK: %[[LOAD_A4:.*]] = load i64, ptr %[[GEP_A4]], align 4

0 commit comments

Comments
 (0)