@@ -2979,39 +2979,61 @@ static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
2979
2979
return std::distance (mapData.MapClause .begin (), res);
2980
2980
}
2981
2981
2982
- static omp::MapInfoOp getFirstOrLastMappedMemberPtr (omp::MapInfoOp mapInfo,
2983
- bool first) {
2984
- ArrayAttr indexAttr = mapInfo.getMembersIndexAttr ();
2985
- // Only 1 member has been mapped, we can return it.
2986
- if (indexAttr.size () == 1 )
2987
- return cast<omp::MapInfoOp>(mapInfo.getMembers ()[0 ].getDefiningOp ());
2982
+ static void sortMapIndices (llvm::SmallVector<size_t > &indices,
2983
+ mlir::omp::MapInfoOp mapInfo,
2984
+ bool ascending = true ) {
2985
+ mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr ();
2986
+ if (indexAttr.empty () || indexAttr.size () == 1 || indices.empty () ||
2987
+ indices.size () == 1 )
2988
+ return ;
2988
2989
2989
- llvm::SmallVector<size_t > indices (indexAttr.size ());
2990
- std::iota (indices.begin (), indices.end (), 0 );
2990
+ llvm::sort (
2991
+ indices.begin (), indices.end (), [&](const size_t a, const size_t b) {
2992
+ auto memberIndicesA = mlir::cast<mlir::ArrayAttr>(indexAttr[a]);
2993
+ auto memberIndicesB = mlir::cast<mlir::ArrayAttr>(indexAttr[b]);
2994
+
2995
+ size_t smallestMember = memberIndicesA.size () < memberIndicesB.size ()
2996
+ ? memberIndicesA.size ()
2997
+ : memberIndicesB.size ();
2991
2998
2992
- llvm::sort (indices. begin (), indices. end (),
2993
- [&]( const size_t a, const size_t b) {
2994
- auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
2995
- auto memberIndicesB = cast<ArrayAttr>(indexAttr[b] );
2996
- for ( const auto it : llvm::zip (memberIndicesA, memberIndicesB)) {
2997
- int64_t aIndex = cast<IntegerAttr>(std::get< 0 >(it)). getInt ();
2998
- int64_t bIndex = cast<IntegerAttr>(std::get< 1 >(it)) .getInt ();
2999
+ for ( size_t i = 0 ; i < smallestMember; ++i) {
3000
+ int64_t aIndex =
3001
+ mlir:: cast<mlir::IntegerAttr>(memberIndicesA. getValue ()[i])
3002
+ . getInt ( );
3003
+ int64_t bIndex =
3004
+ mlir:: cast<mlir::IntegerAttr>(memberIndicesB. getValue ()[i])
3005
+ .getInt ();
2999
3006
3000
- if (aIndex == bIndex)
3001
- continue ;
3007
+ if (aIndex == bIndex)
3008
+ continue ;
3002
3009
3003
- if (aIndex < bIndex)
3004
- return first ;
3010
+ if (aIndex < bIndex)
3011
+ return ascending ;
3005
3012
3006
- if (aIndex > bIndex)
3007
- return !first ;
3008
- }
3013
+ if (aIndex > bIndex)
3014
+ return !ascending ;
3015
+ }
3009
3016
3010
- // Iterated the up until the end of the smallest member and
3011
- // they were found to be equal up to that point, so select
3012
- // the member with the lowest index count, so the "parent"
3013
- return memberIndicesA.size () < memberIndicesB.size ();
3014
- });
3017
+ // Iterated up until the end of the smallest member and
3018
+ // they were found to be equal up to that point, so select
3019
+ // the member with the lowest index count, so the "parent"
3020
+ return memberIndicesA.size () < memberIndicesB.size ();
3021
+ });
3022
+ }
3023
+
3024
+ static mlir::omp::MapInfoOp
3025
+ getFirstOrLastMappedMemberPtr (mlir::omp::MapInfoOp mapInfo, bool first) {
3026
+ mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr ();
3027
+ // Only 1 member has been mapped, we can return it.
3028
+ if (indexAttr.size () == 1 )
3029
+ if (auto mapOp =
3030
+ dyn_cast<omp::MapInfoOp>(mapInfo.getMembers ()[0 ].getDefiningOp ()))
3031
+ return mapOp;
3032
+
3033
+ llvm::SmallVector<size_t > indices;
3034
+ indices.resize (indexAttr.size ());
3035
+ std::iota (indices.begin (), indices.end (), 0 );
3036
+ sortMapIndices (indices, mapInfo, first);
3015
3037
3016
3038
return llvm::cast<omp::MapInfoOp>(
3017
3039
mapInfo.getMembers ()[indices.front ()].getDefiningOp ());
@@ -3110,6 +3132,91 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
3110
3132
return idx;
3111
3133
}
3112
3134
3135
+ // Gathers members that are overlapping in the parent, excluding members that
3136
+ // themselves overlap, keeping the top-most (closest to parents level) map.
3137
+ static void getOverlappedMembers (llvm::SmallVector<size_t > &overlapMapDataIdxs,
3138
+ MapInfoData &mapData,
3139
+ omp::MapInfoOp parentOp) {
3140
+ // No members mapped, no overlaps.
3141
+ if (parentOp.getMembers ().empty ())
3142
+ return ;
3143
+
3144
+ // Single member, we can insert and return early.
3145
+ if (parentOp.getMembers ().size () == 1 ) {
3146
+ overlapMapDataIdxs.push_back (0 );
3147
+ return ;
3148
+ }
3149
+
3150
+ // 1) collect list of top-level overlapping members from MemberOp
3151
+ llvm::SmallVector<std::pair<int , mlir::ArrayAttr>> memberByIndex;
3152
+ mlir::ArrayAttr indexAttr = parentOp.getMembersIndexAttr ();
3153
+ for (auto [memIndex, indicesAttr] : llvm::enumerate (indexAttr))
3154
+ memberByIndex.push_back (
3155
+ std::make_pair (memIndex, mlir::cast<mlir::ArrayAttr>(indicesAttr)));
3156
+
3157
+ // Sort the smallest first (higher up the parent -> member chain), so that
3158
+ // when we remove members, we remove as much as we can in the initial
3159
+ // iterations, shortening the number of passes required.
3160
+ llvm::sort (memberByIndex.begin (), memberByIndex.end (),
3161
+ [&](auto a, auto b) { return a.second .size () < b.second .size (); });
3162
+
3163
+ auto getAsIntegers = [](mlir::ArrayAttr values) {
3164
+ llvm::SmallVector<int64_t > ints;
3165
+ ints.reserve (values.size ());
3166
+ llvm::transform (values, std::back_inserter (ints),
3167
+ [](mlir::Attribute value) {
3168
+ return mlir::cast<mlir::IntegerAttr>(value).getInt ();
3169
+ });
3170
+ return ints;
3171
+ };
3172
+
3173
+ // Remove elements from the vector if there is a parent element that
3174
+ // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
3175
+ // [0,2].. etc.
3176
+ for (auto v : make_early_inc_range (memberByIndex)) {
3177
+ auto vArr = getAsIntegers (v.second );
3178
+ memberByIndex.erase (
3179
+ std::remove_if (memberByIndex.begin (), memberByIndex.end (),
3180
+ [&](auto x) {
3181
+ if (v == x)
3182
+ return false ;
3183
+
3184
+ auto xArr = getAsIntegers (x.second );
3185
+ return std::equal (vArr.begin (), vArr.end (),
3186
+ xArr.begin ()) &&
3187
+ xArr.size () >= vArr.size ();
3188
+ }),
3189
+ memberByIndex.end ());
3190
+ }
3191
+
3192
+ // Collect the indices from mapData that we need, as we technically need the
3193
+ // base pointer etc. info, which is stored in there and primarily accessible
3194
+ // via index at the moment.
3195
+ for (auto v : memberByIndex)
3196
+ overlapMapDataIdxs.push_back (v.first );
3197
+ }
3198
+
3199
+ // The intent is to verify if the mapped data being passed is a
3200
+ // pointer -> pointee that requires special handling in certain cases,
3201
+ // e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
3202
+ //
3203
+ // There may be a better way to verify this, but unfortunately with
3204
+ // opaque pointers we lose the ability to easily check if something is
3205
+ // a pointer whilst maintaining access to the underlying type.
3206
+ static bool checkIfPointerMap (omp::MapInfoOp mapOp) {
3207
+ // If we have a varPtrPtr field assigned then the underlying type is a pointer
3208
+ if (mapOp.getVarPtrPtr ())
3209
+ return true ;
3210
+
3211
+ // If the map data is declare target with a link clause, then it's represented
3212
+ // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
3213
+ // no relation to pointers.
3214
+ if (isDeclareTargetLink (mapOp.getVarPtr ()))
3215
+ return true ;
3216
+
3217
+ return false ;
3218
+ }
3219
+
3113
3220
// This creates two insertions into the MapInfosTy data structure for the
3114
3221
// "parent" of a set of members, (usually a container e.g.
3115
3222
// class/structure/derived type) when subsequent members have also been
@@ -3150,7 +3257,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
3150
3257
// runtime information on the dynamically allocated data).
3151
3258
auto parentClause =
3152
3259
llvm::cast<omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
3153
-
3154
3260
llvm::Value *lowAddr, *highAddr;
3155
3261
if (!parentClause.getPartialMap ()) {
3156
3262
lowAddr = builder.CreatePointerCast (mapData.Pointers [mapDataIndex],
@@ -3197,37 +3303,77 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
3197
3303
// what we support as expected.
3198
3304
llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types [mapDataIndex];
3199
3305
ompBuilder.setCorrectMemberOfFlag (mapFlag, memberOfFlag);
3200
- combinedInfo.Types .emplace_back (mapFlag);
3201
- combinedInfo.DevicePointers .emplace_back (
3202
- llvm::OpenMPIRBuilder::DeviceInfoTy::None);
3203
- combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3204
- mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3205
- combinedInfo.BasePointers .emplace_back (mapData.BasePointers [mapDataIndex]);
3206
- combinedInfo.Pointers .emplace_back (mapData.Pointers [mapDataIndex]);
3207
- combinedInfo.Sizes .emplace_back (mapData.Sizes [mapDataIndex]);
3208
- }
3209
- return memberOfFlag;
3210
- }
3211
-
3212
- // The intent is to verify if the mapped data being passed is a
3213
- // pointer -> pointee that requires special handling in certain cases,
3214
- // e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
3215
- //
3216
- // There may be a better way to verify this, but unfortunately with
3217
- // opaque pointers we lose the ability to easily check if something is
3218
- // a pointer whilst maintaining access to the underlying type.
3219
- static bool checkIfPointerMap (omp::MapInfoOp mapOp) {
3220
- // If we have a varPtrPtr field assigned then the underlying type is a pointer
3221
- if (mapOp.getVarPtrPtr ())
3222
- return true ;
3223
3306
3224
- // If the map data is declare target with a link clause, then it's represented
3225
- // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
3226
- // no relation to pointers.
3227
- if (isDeclareTargetLink (mapOp.getVarPtr ()))
3228
- return true ;
3307
+ if (targetDirective == TargetDirective::TargetUpdate) {
3308
+ combinedInfo.Types .emplace_back (mapFlag);
3309
+ combinedInfo.DevicePointers .emplace_back (
3310
+ mapData.DevicePointers [mapDataIndex]);
3311
+ combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3312
+ mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3313
+ combinedInfo.BasePointers .emplace_back (
3314
+ mapData.BasePointers [mapDataIndex]);
3315
+ combinedInfo.Pointers .emplace_back (mapData.Pointers [mapDataIndex]);
3316
+ combinedInfo.Sizes .emplace_back (mapData.Sizes [mapDataIndex]);
3317
+ } else {
3318
+ llvm::SmallVector<size_t > overlapIdxs;
3319
+ // Find all of the members that "overlap", i.e. occlude other members that
3320
+ // were mapped alongside the parent, e.g. member [0], occludes
3321
+ getOverlappedMembers (overlapIdxs, mapData, parentClause);
3322
+ // We need to make sure the overlapped members are sorted in order of
3323
+ // lowest address to highest address
3324
+ sortMapIndices (overlapIdxs, parentClause);
3325
+
3326
+ lowAddr = builder.CreatePointerCast (mapData.Pointers [mapDataIndex],
3327
+ builder.getPtrTy ());
3328
+ highAddr = builder.CreatePointerCast (
3329
+ builder.CreateConstGEP1_32 (mapData.BaseType [mapDataIndex],
3330
+ mapData.Pointers [mapDataIndex], 1 ),
3331
+ builder.getPtrTy ());
3332
+
3333
+ // TODO: We may want to skip arrays/array sections in this as Clang does
3334
+ // so it appears to be an optimisation rather than a neccessity though,
3335
+ // but this requires further investigation. However, we would have to make
3336
+ // sure to not exclude maps with bounds that ARE pointers, as these are
3337
+ // processed as seperate components, i.e. pointer + data.
3338
+ for (auto v : overlapIdxs) {
3339
+ auto mapDataOverlapIdx = getMapDataMemberIdx (
3340
+ mapData,
3341
+ cast<omp::MapInfoOp>(parentClause.getMembers ()[v].getDefiningOp ()));
3342
+ combinedInfo.Types .emplace_back (mapFlag);
3343
+ combinedInfo.DevicePointers .emplace_back (
3344
+ mapData.DevicePointers [mapDataOverlapIdx]);
3345
+ combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3346
+ mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3347
+ combinedInfo.BasePointers .emplace_back (
3348
+ mapData.BasePointers [mapDataIndex]);
3349
+ combinedInfo.Pointers .emplace_back (lowAddr);
3350
+ combinedInfo.Sizes .emplace_back (builder.CreateIntCast (
3351
+ builder.CreatePtrDiff (builder.getInt8Ty (),
3352
+ mapData.OriginalValue [mapDataOverlapIdx],
3353
+ lowAddr),
3354
+ builder.getInt64Ty (), /* isSigned=*/ true ));
3355
+ lowAddr = builder.CreateConstGEP1_32 (
3356
+ checkIfPointerMap (llvm::cast<omp::MapInfoOp>(
3357
+ mapData.MapClause [mapDataOverlapIdx]))
3358
+ ? builder.getPtrTy ()
3359
+ : mapData.BaseType [mapDataOverlapIdx],
3360
+ mapData.BasePointers [mapDataOverlapIdx], 1 );
3361
+ }
3229
3362
3230
- return false ;
3363
+ combinedInfo.Types .emplace_back (mapFlag);
3364
+ combinedInfo.DevicePointers .emplace_back (
3365
+ mapData.DevicePointers [mapDataIndex]);
3366
+ combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3367
+ mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3368
+ combinedInfo.BasePointers .emplace_back (
3369
+ mapData.BasePointers [mapDataIndex]);
3370
+ combinedInfo.Pointers .emplace_back (lowAddr);
3371
+ combinedInfo.Sizes .emplace_back (builder.CreateIntCast (
3372
+ builder.CreatePtrDiff (builder.getInt8Ty (), highAddr, lowAddr),
3373
+ builder.getInt64Ty (), true ));
3374
+ }
3375
+ }
3376
+ return memberOfFlag;
3231
3377
}
3232
3378
3233
3379
// This function is intended to add explicit mappings of members
0 commit comments