@@ -198,6 +198,17 @@ class AMDGPUInformationCache : public InformationCache {
198
198
return ST.getWavesPerEU (F, FlatWorkGroupSize);
199
199
}
200
200
201
+ std::optional<std::pair<unsigned , unsigned >>
202
+ getWavesPerEUAttr (const Function &F) {
203
+ auto Val = AMDGPU::getIntegerPairAttribute (F, " amdgpu-waves-per-eu" ,
204
+ /* OnlyFirstRequired=*/ true );
205
+ if (Val && Val->second == 0 ) {
206
+ const GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
207
+ Val->second = ST.getMaxWavesPerEU ();
208
+ }
209
+ return Val;
210
+ }
211
+
201
212
std::pair<unsigned , unsigned >
202
213
getEffectiveWavesPerEU (const Function &F,
203
214
std::pair<unsigned , unsigned > WavesPerEU,
@@ -768,22 +779,6 @@ struct AAAMDSizeRangeAttribute
768
779
/* ForceReplace=*/ true );
769
780
}
770
781
771
- ChangeStatus emitAttributeIfNotDefault (Attributor &A, unsigned Min,
772
- unsigned Max) {
773
- // Don't add the attribute if it's the implied default.
774
- if (getAssumed ().getLower () == Min && getAssumed ().getUpper () - 1 == Max)
775
- return ChangeStatus::UNCHANGED;
776
-
777
- Function *F = getAssociatedFunction ();
778
- LLVMContext &Ctx = F->getContext ();
779
- SmallString<10 > Buffer;
780
- raw_svector_ostream OS (Buffer);
781
- OS << getAssumed ().getLower () << ' ,' << getAssumed ().getUpper () - 1 ;
782
- return A.manifestAttrs (getIRPosition (),
783
- {Attribute::get (Ctx, AttrName, OS.str ())},
784
- /* ForceReplace=*/ true );
785
- }
786
-
787
782
const std::string getAsStr (Attributor *) const override {
788
783
std::string Str;
789
784
raw_string_ostream OS (Str);
@@ -873,29 +868,47 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
873
868
AAAMDWavesPerEU (const IRPosition &IRP, Attributor &A)
874
869
: AAAMDSizeRangeAttribute(IRP, A, " amdgpu-waves-per-eu" ) {}
875
870
876
- bool isValidState () const override {
877
- return !Assumed.isEmptySet () && IntegerRangeState::isValidState ();
878
- }
879
-
880
871
void initialize (Attributor &A) override {
881
872
Function *F = getAssociatedFunction ();
882
873
auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
883
874
884
- if (const auto *AssumedGroupSize = A.getAAFor <AAAMDFlatWorkGroupSize>(
885
- *this , IRPosition::function (*F), DepClassTy::REQUIRED);
886
- AssumedGroupSize->isValidState ()) {
875
+ auto TakeRange = [&](std::pair<unsigned , unsigned > R) {
876
+ auto [Min, Max] = R;
877
+ ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
878
+ IntegerRangeState RangeState (Range);
879
+ clampStateAndIndicateChange (this ->getState (), RangeState);
880
+ indicateOptimisticFixpoint ();
881
+ };
887
882
888
- unsigned Min, Max;
889
- std::tie (Min, Max) = InfoCache.getWavesPerEU (
890
- *F, {AssumedGroupSize->getAssumed ().getLower ().getZExtValue (),
891
- AssumedGroupSize->getAssumed ().getUpper ().getZExtValue () - 1 });
883
+ std::pair<unsigned , unsigned > MaxWavesPerEURange{
884
+ 1U , InfoCache.getMaxWavesPerEU (*F)};
892
885
893
- ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
894
- intersectKnown (Range);
886
+ // If the attribute exists, we will honor it if it is not the default.
887
+ if (auto Attr = InfoCache.getWavesPerEUAttr (*F)) {
888
+ if (*Attr != MaxWavesPerEURange) {
889
+ TakeRange (*Attr);
890
+ return ;
891
+ }
895
892
}
896
893
897
- if (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
898
- indicatePessimisticFixpoint ();
894
+ // Unlike AAAMDFlatWorkGroupSize, it's getting trickier here. Since the
895
+ // calculation of waves per EU involves flat work group size, we can't
896
+ // simply use an assumed flat work group size as a start point, because the
897
+ // update of flat work group size is in an inverse direction of waves per
898
+ // EU. However, we can still do something if it is an entry function. Since
899
+ // an entry function is a terminal node, and flat work group size either
900
+ // from attribute or default will be used anyway, we can take that value and
901
+ // calculate the waves per EU based on it. This result can't be updated by
902
+ // no means, but that could still allow us to propagate it.
903
+ if (AMDGPU::isEntryFunctionCC (F->getCallingConv ())) {
904
+ std::pair<unsigned , unsigned > FlatWorkGroupSize;
905
+ if (auto Attr = InfoCache.getFlatWorkGroupSizeAttr (*F))
906
+ FlatWorkGroupSize = *Attr;
907
+ else
908
+ FlatWorkGroupSize = InfoCache.getDefaultFlatWorkGroupSize (*F);
909
+ TakeRange (InfoCache.getEffectiveWavesPerEU (*F, MaxWavesPerEURange,
910
+ FlatWorkGroupSize));
911
+ }
899
912
}
900
913
901
914
ChangeStatus updateImpl (Attributor &A) override {
@@ -944,8 +957,8 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
944
957
ChangeStatus manifest (Attributor &A) override {
945
958
Function *F = getAssociatedFunction ();
946
959
auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
947
- unsigned Max = InfoCache. getMaxWavesPerEU (*F);
948
- return emitAttributeIfNotDefault ( A, 1 , Max );
960
+ return emitAttributeIfNotDefaultAfterClamp (
961
+ A, { 1U , InfoCache. getMaxWavesPerEU (*F)} );
949
962
}
950
963
951
964
// / See AbstractAttribute::getName()
0 commit comments