@@ -1113,47 +1113,25 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
1113
1113
Function *F = getAssociatedFunction ();
1114
1114
auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
1115
1115
1116
- auto TakeRange = [&](std::pair<unsigned , unsigned > R) {
1117
- auto [Min, Max] = R;
1118
- ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
1119
- IntegerRangeState RangeState (Range);
1120
- clampStateAndIndicateChange (this ->getState (), RangeState);
1121
- indicateOptimisticFixpoint ();
1122
- };
1123
-
1124
- std::pair<unsigned , unsigned > MaxWavesPerEURange{
1125
- 1U , InfoCache.getMaxWavesPerEU (*F)};
1126
-
1127
1116
// If the attribute exists, we will honor it if it is not the default.
1128
1117
if (auto Attr = InfoCache.getWavesPerEUAttr (*F)) {
1118
+ std::pair<unsigned , unsigned > MaxWavesPerEURange{
1119
+ 1U , InfoCache.getMaxWavesPerEU (*F)};
1129
1120
if (*Attr != MaxWavesPerEURange) {
1130
- TakeRange (*Attr);
1121
+ auto [Min, Max] = *Attr;
1122
+ ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
1123
+ IntegerRangeState RangeState (Range);
1124
+ this ->getState () = RangeState;
1125
+ indicateOptimisticFixpoint ();
1131
1126
return ;
1132
1127
}
1133
1128
}
1134
1129
1135
- // Unlike AAAMDFlatWorkGroupSize, it's getting trickier here. Since the
1136
- // calculation of waves per EU involves flat work group size, we can't
1137
- // simply use an assumed flat work group size as a start point, because the
1138
- // update of flat work group size is in an inverse direction of waves per
1139
- // EU. However, we can still do something if it is an entry function. Since
1140
- // an entry function is a terminal node, and flat work group size either
1141
- // from attribute or default will be used anyway, we can take that value and
1142
- // calculate the waves per EU based on it. This result can't be updated by
1143
- // no means, but that could still allow us to propagate it.
1144
- if (AMDGPU::isEntryFunctionCC (F->getCallingConv ())) {
1145
- std::pair<unsigned , unsigned > FlatWorkGroupSize;
1146
- if (auto Attr = InfoCache.getFlatWorkGroupSizeAttr (*F))
1147
- FlatWorkGroupSize = *Attr;
1148
- else
1149
- FlatWorkGroupSize = InfoCache.getDefaultFlatWorkGroupSize (*F);
1150
- TakeRange (InfoCache.getEffectiveWavesPerEU (*F, MaxWavesPerEURange,
1151
- FlatWorkGroupSize));
1152
- }
1130
+ if (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
1131
+ indicatePessimisticFixpoint ();
1153
1132
}
1154
1133
1155
1134
ChangeStatus updateImpl (Attributor &A) override {
1156
- auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
1157
1135
ChangeStatus Change = ChangeStatus::UNCHANGED;
1158
1136
1159
1137
auto CheckCallSite = [&](AbstractCallSite CS) {
@@ -1162,24 +1140,21 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
1162
1140
LLVM_DEBUG (dbgs () << ' [' << getName () << " ] Call " << Caller->getName ()
1163
1141
<< " ->" << Func->getName () << ' \n ' );
1164
1142
1165
- const auto *CallerInfo = A.getAAFor <AAAMDWavesPerEU>(
1143
+ const auto *CallerAA = A.getAAFor <AAAMDWavesPerEU>(
1166
1144
*this , IRPosition::function (*Caller), DepClassTy::REQUIRED);
1167
- const auto *AssumedGroupSize = A.getAAFor <AAAMDFlatWorkGroupSize>(
1168
- *this , IRPosition::function (*Func), DepClassTy::REQUIRED);
1169
- if (!CallerInfo || !AssumedGroupSize || !CallerInfo->isValidState () ||
1170
- !AssumedGroupSize->isValidState ())
1145
+ if (!CallerAA || !CallerAA->isValidState ())
1171
1146
return false ;
1172
1147
1173
- unsigned Min, Max ;
1174
- std::tie (Min, Max) = InfoCache. getEffectiveWavesPerEU (
1175
- *Caller,
1176
- {CallerInfo-> getAssumed (). getLower ().getZExtValue (),
1177
- CallerInfo ->getAssumed ().getUpper ().getZExtValue () - 1 },
1178
- {AssumedGroupSize-> getAssumed (). getLower (). getZExtValue (),
1179
- AssumedGroupSize-> getAssumed (). getUpper (). getZExtValue () - 1 } );
1180
- ConstantRange CallerRange ( APInt ( 32 , Min), APInt ( 32 , Max + 1 )) ;
1181
- IntegerRangeState CallerRangeState (CallerRange);
1182
- Change |= clampStateAndIndicateChange ( this -> getState (), CallerRangeState) ;
1148
+ ConstantRange Assumed = getAssumed () ;
1149
+ unsigned Min = std::max (Assumed. getLower (). getZExtValue (),
1150
+ CallerAA-> getAssumed (). getLower (). getZExtValue ());
1151
+ unsigned Max = std::max (Assumed. getUpper ().getZExtValue (),
1152
+ CallerAA ->getAssumed ().getUpper ().getZExtValue ());
1153
+ ConstantRange Range ( APInt ( 32 , Min), APInt ( 32 , Max));
1154
+ IntegerRangeState RangeState (Range );
1155
+ getState () = RangeState ;
1156
+ Change |= getState () == Assumed ? ChangeStatus::UNCHANGED
1157
+ : ChangeStatus::CHANGED ;
1183
1158
1184
1159
return true ;
1185
1160
};
@@ -1323,6 +1298,74 @@ struct AAAMDGPUNoAGPR
1323
1298
1324
1299
const char AAAMDGPUNoAGPR::ID = 0 ;
1325
1300
1301
+ // / Performs the final check and updates the 'amdgpu-waves-per-eu' attribute
1302
+ // / based on the finalized 'amdgpu-flat-work-group-size' attribute.
1303
+ // / Both attributes start with narrow ranges that expand during iteration.
1304
+ // / However, a narrower flat-workgroup-size leads to a broader waves-per-eu
1305
+ // / range, preventing optimal updates later. Therefore, waves-per-eu can't be
1306
+ // / updated with intermediate values during the attributor run. We defer the
1307
+ // / calculation of waves-per-eu until after the flat-workgroup-size is
1308
+ // / finalized.
1309
+ // / TODO: Remove this and move similar logic back into the attributor run once
1310
+ // / we have a better representation for waves-per-eu.
1311
+ static bool updateWavesPerEU (Module &M, TargetMachine &TM) {
1312
+ bool Changed = false ;
1313
+
1314
+ LLVMContext &Ctx = M.getContext ();
1315
+
1316
+ for (Function &F : M) {
1317
+ if (F.isDeclaration ())
1318
+ continue ;
1319
+
1320
+ const GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
1321
+
1322
+ std::optional<std::pair<unsigned , std::optional<unsigned >>>
1323
+ FlatWgrpSizeAttr =
1324
+ AMDGPU::getIntegerPairAttribute (F, " amdgpu-flat-work-group-size" );
1325
+
1326
+ unsigned MinWavesPerEU = ST.getMinWavesPerEU ();
1327
+ unsigned MaxWavesPerEU = ST.getMaxWavesPerEU ();
1328
+
1329
+ unsigned MinFlatWgrpSize = ST.getMinFlatWorkGroupSize ();
1330
+ unsigned MaxFlatWgrpSize = ST.getMaxFlatWorkGroupSize ();
1331
+ if (FlatWgrpSizeAttr.has_value ()) {
1332
+ MinFlatWgrpSize = FlatWgrpSizeAttr->first ;
1333
+ MaxFlatWgrpSize = *(FlatWgrpSizeAttr->second );
1334
+ }
1335
+
1336
+ // Start with the "best" range.
1337
+ unsigned Min = MinWavesPerEU;
1338
+ unsigned Max = MinWavesPerEU;
1339
+
1340
+ // Compute the range from flat workgroup size. `getWavesPerEU` will also
1341
+ // account for the 'amdgpu-waves-er-eu' attribute.
1342
+ auto [MinFromFlatWgrpSize, MaxFromFlatWgrpSize] =
1343
+ ST.getWavesPerEU (F, {MinFlatWgrpSize, MaxFlatWgrpSize});
1344
+
1345
+ // For the lower bound, we have to "tighten" it.
1346
+ Min = std::max (Min, MinFromFlatWgrpSize);
1347
+ // For the upper bound, we have to "extend" it.
1348
+ Max = std::max (Max, MaxFromFlatWgrpSize);
1349
+
1350
+ // Clamp the range to the max range.
1351
+ Min = std::max (Min, MinWavesPerEU);
1352
+ Max = std::min (Max, MaxWavesPerEU);
1353
+
1354
+ // Update the attribute if it is not the max.
1355
+ if (Min != MinWavesPerEU || Max != MaxWavesPerEU) {
1356
+ SmallString<10 > Buffer;
1357
+ raw_svector_ostream OS (Buffer);
1358
+ OS << Min << ' ,' << Max;
1359
+ Attribute OldAttr = F.getFnAttribute (" amdgpu-waves-per-eu" );
1360
+ Attribute NewAttr = Attribute::get (Ctx, " amdgpu-waves-per-eu" , OS.str ());
1361
+ F.addFnAttr (NewAttr);
1362
+ Changed |= OldAttr == NewAttr;
1363
+ }
1364
+ }
1365
+
1366
+ return Changed;
1367
+ }
1368
+
1326
1369
static bool runImpl (Module &M, AnalysisGetter &AG, TargetMachine &TM,
1327
1370
AMDGPUAttributorOptions Options,
1328
1371
ThinOrFullLTOPhase LTOPhase) {
@@ -1396,8 +1439,11 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
1396
1439
}
1397
1440
}
1398
1441
1399
- ChangeStatus Change = A.run ();
1400
- return Change == ChangeStatus::CHANGED;
1442
+ bool Changed = A.run () == ChangeStatus::CHANGED;
1443
+
1444
+ Changed |= updateWavesPerEU (M, TM);
1445
+
1446
+ return Changed;
1401
1447
}
1402
1448
1403
1449
class AMDGPUAttributorLegacy : public ModulePass {
0 commit comments