Skip to content

Commit 173a950

Browse files
committed
[MLIR][OpenMP] Introduce overlapped record type map support
This PR introduces a new additional type of map lowering for record types that Clang currently supports, in which a user can map a top-level record type and then individual members with different mapping, effectively creating a sort of "overlapping" mapping that we attempt to cut around. This is currently most predominantly used in Fortran, when mapping descriptors and there data, we map the descriptor and its data with separate map modifiers and "cut around" the pointer data, so that wedo not overwrite it unless the runtime deems it a neccesary action based on its reference counting mechanism. However, it is a mechanism that will come in handy/trigger when a user explitily maps a record type (derived type or structure) and then explicitly maps a member with a different map type. These additions were predominantly in the OpenMPToLLVMIRTranslation.cpp file and phase, however, one Flang test that checks end-to-end IR compilation (as far as we care for now at least) was altered. 2/3 required PRs to enable declare target to mapping, should look at PR 3/3 to check for full green passes (this one will fail a number due to some dependencies). Co-authored-by: Raghu Maddhipatla [email protected]
1 parent 2757285 commit 173a950

File tree

6 files changed

+463
-154
lines changed

6 files changed

+463
-154
lines changed

flang/test/Integration/OpenMP/map-types-and-sizes.f90

Lines changed: 77 additions & 53 deletions
Large diffs are not rendered by default.

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 203 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2874,39 +2874,61 @@ static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
28742874
return std::distance(mapData.MapClause.begin(), res);
28752875
}
28762876

2877-
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
2878-
bool first) {
2879-
ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
2880-
// Only 1 member has been mapped, we can return it.
2881-
if (indexAttr.size() == 1)
2882-
return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
2877+
static void sortMapIndices(llvm::SmallVector<size_t> &indices,
2878+
mlir::omp::MapInfoOp mapInfo,
2879+
bool ascending = true) {
2880+
mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
2881+
if (indexAttr.empty() || indexAttr.size() == 1 || indices.empty() ||
2882+
indices.size() == 1)
2883+
return;
28832884

2884-
llvm::SmallVector<size_t> indices(indexAttr.size());
2885-
std::iota(indices.begin(), indices.end(), 0);
2885+
llvm::sort(
2886+
indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
2887+
auto memberIndicesA = mlir::cast<mlir::ArrayAttr>(indexAttr[a]);
2888+
auto memberIndicesB = mlir::cast<mlir::ArrayAttr>(indexAttr[b]);
2889+
2890+
size_t smallestMember = memberIndicesA.size() < memberIndicesB.size()
2891+
? memberIndicesA.size()
2892+
: memberIndicesB.size();
28862893

2887-
llvm::sort(indices.begin(), indices.end(),
2888-
[&](const size_t a, const size_t b) {
2889-
auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
2890-
auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
2891-
for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
2892-
int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
2893-
int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
2894+
for (size_t i = 0; i < smallestMember; ++i) {
2895+
int64_t aIndex =
2896+
mlir::cast<mlir::IntegerAttr>(memberIndicesA.getValue()[i])
2897+
.getInt();
2898+
int64_t bIndex =
2899+
mlir::cast<mlir::IntegerAttr>(memberIndicesB.getValue()[i])
2900+
.getInt();
28942901

2895-
if (aIndex == bIndex)
2896-
continue;
2902+
if (aIndex == bIndex)
2903+
continue;
28972904

2898-
if (aIndex < bIndex)
2899-
return first;
2905+
if (aIndex < bIndex)
2906+
return ascending;
29002907

2901-
if (aIndex > bIndex)
2902-
return !first;
2903-
}
2908+
if (aIndex > bIndex)
2909+
return !ascending;
2910+
}
29042911

2905-
// Iterated the up until the end of the smallest member and
2906-
// they were found to be equal up to that point, so select
2907-
// the member with the lowest index count, so the "parent"
2908-
return memberIndicesA.size() < memberIndicesB.size();
2909-
});
2912+
// Iterated up until the end of the smallest member and
2913+
// they were found to be equal up to that point, so select
2914+
// the member with the lowest index count, so the "parent"
2915+
return memberIndicesA.size() < memberIndicesB.size();
2916+
});
2917+
}
2918+
2919+
static mlir::omp::MapInfoOp
2920+
getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
2921+
mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
2922+
// Only 1 member has been mapped, we can return it.
2923+
if (indexAttr.size() == 1)
2924+
if (auto mapOp =
2925+
dyn_cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp()))
2926+
return mapOp;
2927+
2928+
llvm::SmallVector<size_t> indices;
2929+
indices.resize(indexAttr.size());
2930+
std::iota(indices.begin(), indices.end(), 0);
2931+
sortMapIndices(indices, mapInfo, first);
29102932

29112933
return llvm::cast<omp::MapInfoOp>(
29122934
mapInfo.getMembers()[indices.front()].getDefiningOp());
@@ -3005,6 +3027,91 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
30053027
return idx;
30063028
}
30073029

3030+
// Gathers members that are overlapping in the parent, excluding members that
3031+
// themselves overlap, keeping the top-most (closest to parents level) map.
3032+
static void getOverlappedMembers(llvm::SmallVector<size_t> &overlapMapDataIdxs,
3033+
MapInfoData &mapData,
3034+
omp::MapInfoOp parentOp) {
3035+
// No members mapped, no overlaps.
3036+
if (parentOp.getMembers().empty())
3037+
return;
3038+
3039+
// Single member, we can insert and return early.
3040+
if (parentOp.getMembers().size() == 1) {
3041+
overlapMapDataIdxs.push_back(0);
3042+
return;
3043+
}
3044+
3045+
// 1) collect list of top-level overlapping members from MemberOp
3046+
llvm::SmallVector<std::pair<int, mlir::ArrayAttr>> memberByIndex;
3047+
mlir::ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
3048+
for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
3049+
memberByIndex.push_back(
3050+
std::make_pair(memIndex, mlir::cast<mlir::ArrayAttr>(indicesAttr)));
3051+
3052+
// Sort the smallest first (higher up the parent -> member chain), so that
3053+
// when we remove members, we remove as much as we can in the initial
3054+
// iterations, shortening the number of passes required.
3055+
llvm::sort(memberByIndex.begin(), memberByIndex.end(),
3056+
[&](auto a, auto b) { return a.second.size() < b.second.size(); });
3057+
3058+
auto getAsIntegers = [](mlir::ArrayAttr values) {
3059+
llvm::SmallVector<int64_t> ints;
3060+
ints.reserve(values.size());
3061+
llvm::transform(values, std::back_inserter(ints),
3062+
[](mlir::Attribute value) {
3063+
return mlir::cast<mlir::IntegerAttr>(value).getInt();
3064+
});
3065+
return ints;
3066+
};
3067+
3068+
// Remove elements from the vector if there is a parent element that
3069+
// supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
3070+
// [0,2].. etc.
3071+
for (auto v : make_early_inc_range(memberByIndex)) {
3072+
auto vArr = getAsIntegers(v.second);
3073+
memberByIndex.erase(
3074+
std::remove_if(memberByIndex.begin(), memberByIndex.end(),
3075+
[&](auto x) {
3076+
if (v == x)
3077+
return false;
3078+
3079+
auto xArr = getAsIntegers(x.second);
3080+
return std::equal(vArr.begin(), vArr.end(),
3081+
xArr.begin()) &&
3082+
xArr.size() >= vArr.size();
3083+
}),
3084+
memberByIndex.end());
3085+
}
3086+
3087+
// Collect the indices from mapData that we need, as we technically need the
3088+
// base pointer etc. info, which is stored in there and primarily accessible
3089+
// via index at the moment.
3090+
for (auto v : memberByIndex)
3091+
overlapMapDataIdxs.push_back(v.first);
3092+
}
3093+
3094+
// The intent is to verify if the mapped data being passed is a
3095+
// pointer -> pointee that requires special handling in certain cases,
3096+
// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
3097+
//
3098+
// There may be a better way to verify this, but unfortunately with
3099+
// opaque pointers we lose the ability to easily check if something is
3100+
// a pointer whilst maintaining access to the underlying type.
3101+
static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
3102+
// If we have a varPtrPtr field assigned then the underlying type is a pointer
3103+
if (mapOp.getVarPtrPtr())
3104+
return true;
3105+
3106+
// If the map data is declare target with a link clause, then it's represented
3107+
// as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
3108+
// no relation to pointers.
3109+
if (isDeclareTargetLink(mapOp.getVarPtr()))
3110+
return true;
3111+
3112+
return false;
3113+
}
3114+
30083115
// This creates two insertions into the MapInfosTy data structure for the
30093116
// "parent" of a set of members, (usually a container e.g.
30103117
// class/structure/derived type) when subsequent members have also been
@@ -3045,7 +3152,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
30453152
// runtime information on the dynamically allocated data).
30463153
auto parentClause =
30473154
llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3048-
30493155
llvm::Value *lowAddr, *highAddr;
30503156
if (!parentClause.getPartialMap()) {
30513157
lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
@@ -3092,37 +3198,77 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
30923198
// what we support as expected.
30933199
llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
30943200
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3095-
combinedInfo.Types.emplace_back(mapFlag);
3096-
combinedInfo.DevicePointers.emplace_back(
3097-
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
3098-
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
3099-
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3100-
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3101-
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3102-
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
3103-
}
3104-
return memberOfFlag;
3105-
}
3106-
3107-
// The intent is to verify if the mapped data being passed is a
3108-
// pointer -> pointee that requires special handling in certain cases,
3109-
// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
3110-
//
3111-
// There may be a better way to verify this, but unfortunately with
3112-
// opaque pointers we lose the ability to easily check if something is
3113-
// a pointer whilst maintaining access to the underlying type.
3114-
static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
3115-
// If we have a varPtrPtr field assigned then the underlying type is a pointer
3116-
if (mapOp.getVarPtrPtr())
3117-
return true;
31183201

3119-
// If the map data is declare target with a link clause, then it's represented
3120-
// as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
3121-
// no relation to pointers.
3122-
if (isDeclareTargetLink(mapOp.getVarPtr()))
3123-
return true;
3202+
if (targetDirective == TargetDirective::TargetUpdate) {
3203+
combinedInfo.Types.emplace_back(mapFlag);
3204+
combinedInfo.DevicePointers.emplace_back(
3205+
mapData.DevicePointers[mapDataIndex]);
3206+
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
3207+
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3208+
combinedInfo.BasePointers.emplace_back(
3209+
mapData.BasePointers[mapDataIndex]);
3210+
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3211+
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
3212+
} else {
3213+
llvm::SmallVector<size_t> overlapIdxs;
3214+
// Find all of the members that "overlap", i.e. occlude other members that
3215+
// were mapped alongside the parent, e.g. member [0], occludes
3216+
getOverlappedMembers(overlapIdxs, mapData, parentClause);
3217+
// We need to make sure the overlapped members are sorted in order of
3218+
// lowest address to highest address
3219+
sortMapIndices(overlapIdxs, parentClause);
3220+
3221+
lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
3222+
builder.getPtrTy());
3223+
highAddr = builder.CreatePointerCast(
3224+
builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
3225+
mapData.Pointers[mapDataIndex], 1),
3226+
builder.getPtrTy());
3227+
3228+
// TODO: We may want to skip arrays/array sections in this as Clang does
3229+
// so it appears to be an optimisation rather than a neccessity though,
3230+
// but this requires further investigation. However, we would have to make
3231+
// sure to not exclude maps with bounds that ARE pointers, as these are
3232+
// processed as seperate components, i.e. pointer + data.
3233+
for (auto v : overlapIdxs) {
3234+
auto mapDataOverlapIdx = getMapDataMemberIdx(
3235+
mapData,
3236+
cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
3237+
combinedInfo.Types.emplace_back(mapFlag);
3238+
combinedInfo.DevicePointers.emplace_back(
3239+
mapData.DevicePointers[mapDataOverlapIdx]);
3240+
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
3241+
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3242+
combinedInfo.BasePointers.emplace_back(
3243+
mapData.BasePointers[mapDataIndex]);
3244+
combinedInfo.Pointers.emplace_back(lowAddr);
3245+
combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
3246+
builder.CreatePtrDiff(builder.getInt8Ty(),
3247+
mapData.OriginalValue[mapDataOverlapIdx],
3248+
lowAddr),
3249+
builder.getInt64Ty(), /*isSigned=*/true));
3250+
lowAddr = builder.CreateConstGEP1_32(
3251+
checkIfPointerMap(llvm::cast<omp::MapInfoOp>(
3252+
mapData.MapClause[mapDataOverlapIdx]))
3253+
? builder.getPtrTy()
3254+
: mapData.BaseType[mapDataOverlapIdx],
3255+
mapData.BasePointers[mapDataOverlapIdx], 1);
3256+
}
31243257

3125-
return false;
3258+
combinedInfo.Types.emplace_back(mapFlag);
3259+
combinedInfo.DevicePointers.emplace_back(
3260+
mapData.DevicePointers[mapDataIndex]);
3261+
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
3262+
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3263+
combinedInfo.BasePointers.emplace_back(
3264+
mapData.BasePointers[mapDataIndex]);
3265+
combinedInfo.Pointers.emplace_back(lowAddr);
3266+
combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
3267+
builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
3268+
builder.getInt64Ty(), true));
3269+
}
3270+
}
3271+
return memberOfFlag;
31263272
}
31273273

31283274
// This function is intended to add explicit mappings of members

mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,18 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
6767

6868
// CHECK: define void @mix_use_device_ptr_and_addr_and_map_(ptr %[[ARG_0:.*]], ptr %[[ARG_1:.*]], ptr %[[ARG_2:.*]], ptr %[[ARG_3:.*]], ptr %[[ARG_4:.*]], ptr %[[ARG_5:.*]], ptr %[[ARG_6:.*]], ptr %[[ARG_7:.*]]) {
6969
// CHECK: %[[ALLOCA:.*]] = alloca ptr, align 8
70-
// CHECK: %[[BASEPTR_0_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
70+
// CHECK: %[[BASEPTR_0_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
7171
// CHECK: store ptr %[[ARG_0]], ptr %[[BASEPTR_0_GEP]], align 8
72-
// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 2
72+
// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 4
7373
// CHECK: store ptr %[[ARG_2]], ptr %[[BASEPTR_2_GEP]], align 8
74-
// CHECK: %[[BASEPTR_6_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 6
75-
// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_6_GEP]], align 8
74+
// CHECK: %[[BASEPTR_3_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 9
75+
// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_3_GEP]], align 8
7676

7777
// CHECK: call void @__tgt_target_data_begin_mapper({{.*}})
7878
// CHECK: %[[LOAD_BASEPTR_0:.*]] = load ptr, ptr %[[BASEPTR_0_GEP]], align 8
7979
// store ptr %[[LOAD_BASEPTR_0]], ptr %[[ALLOCA]], align 8
8080
// CHECK: %[[LOAD_BASEPTR_2:.*]] = load ptr, ptr %[[BASEPTR_2_GEP]], align 8
81-
// CHECK: %[[LOAD_BASEPTR_6:.*]] = load ptr, ptr %[[BASEPTR_6_GEP]], align 8
81+
// CHECK: %[[LOAD_BASEPTR_3:.*]] = load ptr, ptr %[[BASEPTR_3_GEP]], align 8
8282
// CHECK: %[[GEP_A4:.*]] = getelementptr { i64 }, ptr %[[ARG_4]], i32 0, i32 0
8383
// CHECK: %[[GEP_A7:.*]] = getelementptr { i64 }, ptr %[[ARG_7]], i32 0, i32 0
8484
// CHECK: %[[LOAD_A4:.*]] = load i64, ptr %[[GEP_A4]], align 4
@@ -93,17 +93,17 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
9393

9494
// CHECK: define void @mix_use_device_ptr_and_addr_and_map_2(ptr %[[ARG_0:.*]], ptr %[[ARG_1:.*]], ptr %[[ARG_2:.*]], ptr %[[ARG_3:.*]], ptr %[[ARG_4:.*]], ptr %[[ARG_5:.*]], ptr %[[ARG_6:.*]], ptr %[[ARG_7:.*]]) {
9595
// CHECK: %[[ALLOCA:.*]] = alloca ptr, align 8
96-
// CHECK: %[[BASEPTR_1_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
96+
// CHECK: %[[BASEPTR_1_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
9797
// CHECK: store ptr %[[ARG_0]], ptr %[[BASEPTR_1_GEP]], align 8
98-
// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 2
98+
// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 4
9999
// CHECK: store ptr %[[ARG_2]], ptr %[[BASEPTR_2_GEP]], align 8
100-
// CHECK: %[[BASEPTR_6_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 6
101-
// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_6_GEP]], align 8
100+
// CHECK: %[[BASEPTR_3_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 9
101+
// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_3_GEP]], align 8
102102
// CHECK: call void @__tgt_target_data_begin_mapper({{.*}})
103103
// CHECK: %[[LOAD_BASEPTR_1:.*]] = load ptr, ptr %[[BASEPTR_1_GEP]], align 8
104104
// store ptr %[[LOAD_BASEPTR_1]], ptr %[[ALLOCA]], align 8
105105
// CHECK: %[[LOAD_BASEPTR_2:.*]] = load ptr, ptr %[[BASEPTR_2_GEP]], align 8
106-
// CHECK: %[[LOAD_BASEPTR_6:.*]] = load ptr, ptr %[[BASEPTR_6_GEP]], align 8
106+
// CHECK: %[[LOAD_BASEPTR_3:.*]] = load ptr, ptr %[[BASEPTR_3_GEP]], align 8
107107
// CHECK: %[[GEP_A4:.*]] = getelementptr { i64 }, ptr %[[ARG_4]], i32 0, i32 0
108108
// CHECK: %[[GEP_A7:.*]] = getelementptr { i64 }, ptr %[[ARG_7]], i32 0, i32 0
109109
// CHECK: %[[LOAD_A4:.*]] = load i64, ptr %[[GEP_A4]], align 4

0 commit comments

Comments
 (0)