@@ -2874,39 +2874,61 @@ static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
2874
2874
return std::distance (mapData.MapClause .begin (), res);
2875
2875
}
2876
2876
2877
- static omp::MapInfoOp getFirstOrLastMappedMemberPtr (omp::MapInfoOp mapInfo,
2878
- bool first) {
2879
- ArrayAttr indexAttr = mapInfo.getMembersIndexAttr ();
2880
- // Only 1 member has been mapped, we can return it.
2881
- if (indexAttr.size () == 1 )
2882
- return cast<omp::MapInfoOp>(mapInfo.getMembers ()[0 ].getDefiningOp ());
2877
+ static void sortMapIndices (llvm::SmallVector<size_t > &indices,
2878
+ mlir::omp::MapInfoOp mapInfo,
2879
+ bool ascending = true ) {
2880
+ mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr ();
2881
+ if (indexAttr.empty () || indexAttr.size () == 1 || indices.empty () ||
2882
+ indices.size () == 1 )
2883
+ return ;
2883
2884
2884
- llvm::SmallVector<size_t > indices (indexAttr.size ());
2885
- std::iota (indices.begin (), indices.end (), 0 );
2885
+ llvm::sort (
2886
+ indices.begin (), indices.end (), [&](const size_t a, const size_t b) {
2887
+ auto memberIndicesA = mlir::cast<mlir::ArrayAttr>(indexAttr[a]);
2888
+ auto memberIndicesB = mlir::cast<mlir::ArrayAttr>(indexAttr[b]);
2889
+
2890
+ size_t smallestMember = memberIndicesA.size () < memberIndicesB.size ()
2891
+ ? memberIndicesA.size ()
2892
+ : memberIndicesB.size ();
2886
2893
2887
- llvm::sort (indices. begin (), indices. end (),
2888
- [&]( const size_t a, const size_t b) {
2889
- auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
2890
- auto memberIndicesB = cast<ArrayAttr>(indexAttr[b] );
2891
- for ( const auto it : llvm::zip (memberIndicesA, memberIndicesB)) {
2892
- int64_t aIndex = cast<IntegerAttr>(std::get< 0 >(it)). getInt ();
2893
- int64_t bIndex = cast<IntegerAttr>(std::get< 1 >(it)) .getInt ();
2894
+ for ( size_t i = 0 ; i < smallestMember; ++i) {
2895
+ int64_t aIndex =
2896
+ mlir:: cast<mlir::IntegerAttr>(memberIndicesA. getValue ()[i])
2897
+ . getInt ( );
2898
+ int64_t bIndex =
2899
+ mlir:: cast<mlir::IntegerAttr>(memberIndicesB. getValue ()[i])
2900
+ .getInt ();
2894
2901
2895
- if (aIndex == bIndex)
2896
- continue ;
2902
+ if (aIndex == bIndex)
2903
+ continue ;
2897
2904
2898
- if (aIndex < bIndex)
2899
- return first ;
2905
+ if (aIndex < bIndex)
2906
+ return ascending ;
2900
2907
2901
- if (aIndex > bIndex)
2902
- return !first ;
2903
- }
2908
+ if (aIndex > bIndex)
2909
+ return !ascending ;
2910
+ }
2904
2911
2905
- // Iterated the up until the end of the smallest member and
2906
- // they were found to be equal up to that point, so select
2907
- // the member with the lowest index count, so the "parent"
2908
- return memberIndicesA.size () < memberIndicesB.size ();
2909
- });
2912
+ // Iterated up until the end of the smallest member and
2913
+ // they were found to be equal up to that point, so select
2914
+ // the member with the lowest index count, so the "parent"
2915
+ return memberIndicesA.size () < memberIndicesB.size ();
2916
+ });
2917
+ }
2918
+
2919
+ static mlir::omp::MapInfoOp
2920
+ getFirstOrLastMappedMemberPtr (mlir::omp::MapInfoOp mapInfo, bool first) {
2921
+ mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr ();
2922
+ // Only 1 member has been mapped, we can return it.
2923
+ if (indexAttr.size () == 1 )
2924
+ if (auto mapOp =
2925
+ dyn_cast<omp::MapInfoOp>(mapInfo.getMembers ()[0 ].getDefiningOp ()))
2926
+ return mapOp;
2927
+
2928
+ llvm::SmallVector<size_t > indices;
2929
+ indices.resize (indexAttr.size ());
2930
+ std::iota (indices.begin (), indices.end (), 0 );
2931
+ sortMapIndices (indices, mapInfo, first);
2910
2932
2911
2933
return llvm::cast<omp::MapInfoOp>(
2912
2934
mapInfo.getMembers ()[indices.front ()].getDefiningOp ());
@@ -3005,6 +3027,91 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
3005
3027
return idx;
3006
3028
}
3007
3029
3030
+ // Gathers members that are overlapping in the parent, excluding members that
3031
+ // themselves overlap, keeping the top-most (closest to parents level) map.
3032
+ static void getOverlappedMembers (llvm::SmallVector<size_t > &overlapMapDataIdxs,
3033
+ MapInfoData &mapData,
3034
+ omp::MapInfoOp parentOp) {
3035
+ // No members mapped, no overlaps.
3036
+ if (parentOp.getMembers ().empty ())
3037
+ return ;
3038
+
3039
+ // Single member, we can insert and return early.
3040
+ if (parentOp.getMembers ().size () == 1 ) {
3041
+ overlapMapDataIdxs.push_back (0 );
3042
+ return ;
3043
+ }
3044
+
3045
+ // 1) collect list of top-level overlapping members from MemberOp
3046
+ llvm::SmallVector<std::pair<int , mlir::ArrayAttr>> memberByIndex;
3047
+ mlir::ArrayAttr indexAttr = parentOp.getMembersIndexAttr ();
3048
+ for (auto [memIndex, indicesAttr] : llvm::enumerate (indexAttr))
3049
+ memberByIndex.push_back (
3050
+ std::make_pair (memIndex, mlir::cast<mlir::ArrayAttr>(indicesAttr)));
3051
+
3052
+ // Sort the smallest first (higher up the parent -> member chain), so that
3053
+ // when we remove members, we remove as much as we can in the initial
3054
+ // iterations, shortening the number of passes required.
3055
+ llvm::sort (memberByIndex.begin (), memberByIndex.end (),
3056
+ [&](auto a, auto b) { return a.second .size () < b.second .size (); });
3057
+
3058
+ auto getAsIntegers = [](mlir::ArrayAttr values) {
3059
+ llvm::SmallVector<int64_t > ints;
3060
+ ints.reserve (values.size ());
3061
+ llvm::transform (values, std::back_inserter (ints),
3062
+ [](mlir::Attribute value) {
3063
+ return mlir::cast<mlir::IntegerAttr>(value).getInt ();
3064
+ });
3065
+ return ints;
3066
+ };
3067
+
3068
+ // Remove elements from the vector if there is a parent element that
3069
+ // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
3070
+ // [0,2].. etc.
3071
+ for (auto v : make_early_inc_range (memberByIndex)) {
3072
+ auto vArr = getAsIntegers (v.second );
3073
+ memberByIndex.erase (
3074
+ std::remove_if (memberByIndex.begin (), memberByIndex.end (),
3075
+ [&](auto x) {
3076
+ if (v == x)
3077
+ return false ;
3078
+
3079
+ auto xArr = getAsIntegers (x.second );
3080
+ return std::equal (vArr.begin (), vArr.end (),
3081
+ xArr.begin ()) &&
3082
+ xArr.size () >= vArr.size ();
3083
+ }),
3084
+ memberByIndex.end ());
3085
+ }
3086
+
3087
+ // Collect the indices from mapData that we need, as we technically need the
3088
+ // base pointer etc. info, which is stored in there and primarily accessible
3089
+ // via index at the moment.
3090
+ for (auto v : memberByIndex)
3091
+ overlapMapDataIdxs.push_back (v.first );
3092
+ }
3093
+
3094
+ // The intent is to verify if the mapped data being passed is a
3095
+ // pointer -> pointee that requires special handling in certain cases,
3096
+ // e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
3097
+ //
3098
+ // There may be a better way to verify this, but unfortunately with
3099
+ // opaque pointers we lose the ability to easily check if something is
3100
+ // a pointer whilst maintaining access to the underlying type.
3101
+ static bool checkIfPointerMap (omp::MapInfoOp mapOp) {
3102
+ // If we have a varPtrPtr field assigned then the underlying type is a pointer
3103
+ if (mapOp.getVarPtrPtr ())
3104
+ return true ;
3105
+
3106
+ // If the map data is declare target with a link clause, then it's represented
3107
+ // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
3108
+ // no relation to pointers.
3109
+ if (isDeclareTargetLink (mapOp.getVarPtr ()))
3110
+ return true ;
3111
+
3112
+ return false ;
3113
+ }
3114
+
3008
3115
// This creates two insertions into the MapInfosTy data structure for the
3009
3116
// "parent" of a set of members, (usually a container e.g.
3010
3117
// class/structure/derived type) when subsequent members have also been
@@ -3045,7 +3152,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
3045
3152
// runtime information on the dynamically allocated data).
3046
3153
auto parentClause =
3047
3154
llvm::cast<omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
3048
-
3049
3155
llvm::Value *lowAddr, *highAddr;
3050
3156
if (!parentClause.getPartialMap ()) {
3051
3157
lowAddr = builder.CreatePointerCast (mapData.Pointers [mapDataIndex],
@@ -3092,37 +3198,77 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
3092
3198
// what we support as expected.
3093
3199
llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types [mapDataIndex];
3094
3200
ompBuilder.setCorrectMemberOfFlag (mapFlag, memberOfFlag);
3095
- combinedInfo.Types .emplace_back (mapFlag);
3096
- combinedInfo.DevicePointers .emplace_back (
3097
- llvm::OpenMPIRBuilder::DeviceInfoTy::None);
3098
- combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3099
- mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3100
- combinedInfo.BasePointers .emplace_back (mapData.BasePointers [mapDataIndex]);
3101
- combinedInfo.Pointers .emplace_back (mapData.Pointers [mapDataIndex]);
3102
- combinedInfo.Sizes .emplace_back (mapData.Sizes [mapDataIndex]);
3103
- }
3104
- return memberOfFlag;
3105
- }
3106
-
3107
- // The intent is to verify if the mapped data being passed is a
3108
- // pointer -> pointee that requires special handling in certain cases,
3109
- // e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
3110
- //
3111
- // There may be a better way to verify this, but unfortunately with
3112
- // opaque pointers we lose the ability to easily check if something is
3113
- // a pointer whilst maintaining access to the underlying type.
3114
- static bool checkIfPointerMap (omp::MapInfoOp mapOp) {
3115
- // If we have a varPtrPtr field assigned then the underlying type is a pointer
3116
- if (mapOp.getVarPtrPtr ())
3117
- return true ;
3118
3201
3119
- // If the map data is declare target with a link clause, then it's represented
3120
- // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
3121
- // no relation to pointers.
3122
- if (isDeclareTargetLink (mapOp.getVarPtr ()))
3123
- return true ;
3202
+ if (targetDirective == TargetDirective::TargetUpdate) {
3203
+ combinedInfo.Types .emplace_back (mapFlag);
3204
+ combinedInfo.DevicePointers .emplace_back (
3205
+ mapData.DevicePointers [mapDataIndex]);
3206
+ combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3207
+ mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3208
+ combinedInfo.BasePointers .emplace_back (
3209
+ mapData.BasePointers [mapDataIndex]);
3210
+ combinedInfo.Pointers .emplace_back (mapData.Pointers [mapDataIndex]);
3211
+ combinedInfo.Sizes .emplace_back (mapData.Sizes [mapDataIndex]);
3212
+ } else {
3213
+ llvm::SmallVector<size_t > overlapIdxs;
3214
+ // Find all of the members that "overlap", i.e. occlude other members that
3215
+ // were mapped alongside the parent, e.g. member [0], occludes
3216
+ getOverlappedMembers (overlapIdxs, mapData, parentClause);
3217
+ // We need to make sure the overlapped members are sorted in order of
3218
+ // lowest address to highest address
3219
+ sortMapIndices (overlapIdxs, parentClause);
3220
+
3221
+ lowAddr = builder.CreatePointerCast (mapData.Pointers [mapDataIndex],
3222
+ builder.getPtrTy ());
3223
+ highAddr = builder.CreatePointerCast (
3224
+ builder.CreateConstGEP1_32 (mapData.BaseType [mapDataIndex],
3225
+ mapData.Pointers [mapDataIndex], 1 ),
3226
+ builder.getPtrTy ());
3227
+
3228
+ // TODO: We may want to skip arrays/array sections in this as Clang does
3229
+ // so it appears to be an optimisation rather than a neccessity though,
3230
+ // but this requires further investigation. However, we would have to make
3231
+ // sure to not exclude maps with bounds that ARE pointers, as these are
3232
+ // processed as seperate components, i.e. pointer + data.
3233
+ for (auto v : overlapIdxs) {
3234
+ auto mapDataOverlapIdx = getMapDataMemberIdx (
3235
+ mapData,
3236
+ cast<omp::MapInfoOp>(parentClause.getMembers ()[v].getDefiningOp ()));
3237
+ combinedInfo.Types .emplace_back (mapFlag);
3238
+ combinedInfo.DevicePointers .emplace_back (
3239
+ mapData.DevicePointers [mapDataOverlapIdx]);
3240
+ combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3241
+ mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3242
+ combinedInfo.BasePointers .emplace_back (
3243
+ mapData.BasePointers [mapDataIndex]);
3244
+ combinedInfo.Pointers .emplace_back (lowAddr);
3245
+ combinedInfo.Sizes .emplace_back (builder.CreateIntCast (
3246
+ builder.CreatePtrDiff (builder.getInt8Ty (),
3247
+ mapData.OriginalValue [mapDataOverlapIdx],
3248
+ lowAddr),
3249
+ builder.getInt64Ty (), /* isSigned=*/ true ));
3250
+ lowAddr = builder.CreateConstGEP1_32 (
3251
+ checkIfPointerMap (llvm::cast<omp::MapInfoOp>(
3252
+ mapData.MapClause [mapDataOverlapIdx]))
3253
+ ? builder.getPtrTy ()
3254
+ : mapData.BaseType [mapDataOverlapIdx],
3255
+ mapData.BasePointers [mapDataOverlapIdx], 1 );
3256
+ }
3124
3257
3125
- return false ;
3258
+ combinedInfo.Types .emplace_back (mapFlag);
3259
+ combinedInfo.DevicePointers .emplace_back (
3260
+ mapData.DevicePointers [mapDataIndex]);
3261
+ combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3262
+ mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3263
+ combinedInfo.BasePointers .emplace_back (
3264
+ mapData.BasePointers [mapDataIndex]);
3265
+ combinedInfo.Pointers .emplace_back (lowAddr);
3266
+ combinedInfo.Sizes .emplace_back (builder.CreateIntCast (
3267
+ builder.CreatePtrDiff (builder.getInt8Ty (), highAddr, lowAddr),
3268
+ builder.getInt64Ty (), true ));
3269
+ }
3270
+ }
3271
+ return memberOfFlag;
3126
3272
}
3127
3273
3128
3274
// This function is intended to add explicit mappings of members
0 commit comments