-
Notifications
You must be signed in to change notification settings - Fork 13.6k
AMDGPU: Propagate amdgpu-max-num-workgroups attribute #113018
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fd3ae0b
41bb72b
dfe5c82
7994c83
eaf4a26
2ad2f35
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -179,6 +179,11 @@ class AMDGPUInformationCache : public InformationCache { | |
return {ST.getMinFlatWorkGroupSize(), ST.getMaxFlatWorkGroupSize()}; | ||
} | ||
|
||
SmallVector<unsigned> getMaxNumWorkGroups(const Function &F) { | ||
const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F); | ||
return ST.getMaxNumWorkGroups(F); | ||
} | ||
|
||
/// Get code object version. | ||
unsigned getCodeObjectVersion() const { return CodeObjectVersion; } | ||
|
||
|
@@ -821,6 +826,145 @@ AAAMDFlatWorkGroupSize::createForPosition(const IRPosition &IRP, | |
"AAAMDFlatWorkGroupSize is only valid for function position"); | ||
} | ||
|
||
struct TupleDecIntegerRangeState : public AbstractState { | ||
DecIntegerState<uint32_t> X, Y, Z; | ||
|
||
bool isValidState() const override { | ||
return X.isValidState() && Y.isValidState() && Z.isValidState(); | ||
} | ||
|
||
bool isAtFixpoint() const override { | ||
return X.isAtFixpoint() && Y.isAtFixpoint() && Z.isAtFixpoint(); | ||
} | ||
|
||
ChangeStatus indicateOptimisticFixpoint() override { | ||
return X.indicateOptimisticFixpoint() | Y.indicateOptimisticFixpoint() | | ||
Z.indicateOptimisticFixpoint(); | ||
} | ||
|
||
ChangeStatus indicatePessimisticFixpoint() override { | ||
return X.indicatePessimisticFixpoint() | Y.indicatePessimisticFixpoint() | | ||
Z.indicatePessimisticFixpoint(); | ||
} | ||
|
||
TupleDecIntegerRangeState operator^=(const TupleDecIntegerRangeState &Other) { | ||
X ^= Other.X; | ||
Y ^= Other.Y; | ||
Z ^= Other.Z; | ||
return *this; | ||
} | ||
|
||
bool operator==(const TupleDecIntegerRangeState &Other) const { | ||
return X == Other.X && Y == Other.Y && Z == Other.Z; | ||
} | ||
|
||
TupleDecIntegerRangeState &getAssumed() { return *this; } | ||
const TupleDecIntegerRangeState &getAssumed() const { return *this; } | ||
}; | ||
|
||
using AAAMDMaxNumWorkgroupsState = | ||
StateWrapper<TupleDecIntegerRangeState, AbstractAttribute, uint32_t>; | ||
|
||
/// Propagate amdgpu-max-num-workgroups attribute. | ||
struct AAAMDMaxNumWorkgroups | ||
: public StateWrapper<TupleDecIntegerRangeState, AbstractAttribute> { | ||
using Base = StateWrapper<TupleDecIntegerRangeState, AbstractAttribute>; | ||
|
||
AAAMDMaxNumWorkgroups(const IRPosition &IRP, Attributor &A) : Base(IRP) {} | ||
|
||
void initialize(Attributor &A) override { | ||
Function *F = getAssociatedFunction(); | ||
auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's all this parse is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but here it doesn't stop if an attribute is not a default one There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This case isn't actually like the others. It's only useful for known-bits style optimizations, and isn't a first class restriction like amdgpu-waves-per-eu There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure how this attribute is gonna be used in the future, but it can get convoluted and go south if the moving direction of this AA is different from its depending AA (if there is any in the future). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #113019 is about the extent of what it can do |
||
|
||
SmallVector<unsigned> MaxNumWorkgroups = InfoCache.getMaxNumWorkGroups(*F); | ||
|
||
X.takeKnownMinimum(MaxNumWorkgroups[0]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update assumed here instead of known, only if they are not default. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's effectively what happens? The default is the maximum There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, takeKnownMinimum will only update known, not assumed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still think this should be takeKnownMinimum. This has more in common with AAAlign than amdgpu-waves-per-eu. It is an informative, not a prescriptive, attribute. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the key part here missing (if you update assumed value) is, Apparently There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. getMaxNumWorkGroups does check the attribute, otherwise nothing would work. I see no change by special case not adding the worst case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For some reason, I made a typo in the last paragraph, which changed its meaning entirely. Lol. What I meant to say is that |
||
Y.takeKnownMinimum(MaxNumWorkgroups[1]); | ||
Z.takeKnownMinimum(MaxNumWorkgroups[2]); | ||
|
||
if (AMDGPU::isEntryFunctionCC(F->getCallingConv())) | ||
indicatePessimisticFixpoint(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be an optimistic fix point We want the pessimistic state to be also the invalid state, such that its manifest will not be called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optimistic fix point is 0,0,0 which isn't even valid. It doesn't make sense to ever set that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's why we want to update assumed instead of known, as I mentioned in another comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can't touch the entry point, how could it ever be optimistic There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. entry point is the starting point of the propagation, so whatever value it has, no matter user specified value or default value, it is the best it has. |
||
} | ||
|
||
ChangeStatus updateImpl(Attributor &A) override { | ||
ChangeStatus Change = ChangeStatus::UNCHANGED; | ||
|
||
auto CheckCallSite = [&](AbstractCallSite CS) { | ||
Function *Caller = CS.getInstruction()->getFunction(); | ||
LLVM_DEBUG(dbgs() << "[AAAMDMaxNumWorkgroups] Call " << Caller->getName() | ||
<< "->" << getAssociatedFunction()->getName() << '\n'); | ||
|
||
const auto *CallerInfo = A.getAAFor<AAAMDMaxNumWorkgroups>( | ||
*this, IRPosition::function(*Caller), DepClassTy::REQUIRED); | ||
if (!CallerInfo || !CallerInfo->isValidState()) | ||
return false; | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. skip the update if it at initial state, like what we did in #114726 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've tried this and it seems to never fire |
||
Change |= | ||
clampStateAndIndicateChange(this->getState(), CallerInfo->getState()); | ||
return true; | ||
}; | ||
|
||
bool AllCallSitesKnown = true; | ||
if (!A.checkForAllCallSites(CheckCallSite, *this, | ||
/*RequireAllCallSites=*/true, | ||
AllCallSitesKnown)) | ||
return indicatePessimisticFixpoint(); | ||
|
||
return Change; | ||
} | ||
|
||
/// Create an abstract attribute view for the position \p IRP. | ||
static AAAMDMaxNumWorkgroups &createForPosition(const IRPosition &IRP, | ||
Attributor &A); | ||
|
||
ChangeStatus manifest(Attributor &A) override { | ||
Function *F = getAssociatedFunction(); | ||
LLVMContext &Ctx = F->getContext(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. skil manifest if it is still in initial state There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would expect manifest to not get called in the first place if the worst state is considered invalid |
||
SmallString<32> Buffer; | ||
raw_svector_ostream OS(Buffer); | ||
OS << X.getAssumed() << ',' << Y.getAssumed() << ',' << Z.getAssumed(); | ||
|
||
// TODO: Should annotate loads of the group size for this to do anything | ||
// useful. | ||
return A.manifestAttrs( | ||
getIRPosition(), | ||
{Attribute::get(Ctx, "amdgpu-max-num-workgroups", OS.str())}, | ||
/* ForceReplace= */ true); | ||
} | ||
|
||
const std::string getName() const override { return "AAAMDMaxNumWorkgroups"; } | ||
|
||
const std::string getAsStr(Attributor *) const override { | ||
std::string Buffer = "AAAMDMaxNumWorkgroupsState["; | ||
raw_string_ostream OS(Buffer); | ||
OS << X.getAssumed() << ',' << Y.getAssumed() << ',' << Z.getAssumed() | ||
<< ']'; | ||
return OS.str(); | ||
} | ||
|
||
const char *getIdAddr() const override { return &ID; } | ||
|
||
/// This function should return true if the type of the \p AA is | ||
/// AAAMDMaxNumWorkgroups | ||
static bool classof(const AbstractAttribute *AA) { | ||
return (AA->getIdAddr() == &ID); | ||
} | ||
|
||
void trackStatistics() const override {} | ||
|
||
/// Unique ID (due to the unique address) | ||
static const char ID; | ||
}; | ||
|
||
const char AAAMDMaxNumWorkgroups::ID = 0; | ||
|
||
AAAMDMaxNumWorkgroups & | ||
AAAMDMaxNumWorkgroups::createForPosition(const IRPosition &IRP, Attributor &A) { | ||
if (IRP.getPositionKind() == IRPosition::IRP_FUNCTION) | ||
return *new (A.Allocator) AAAMDMaxNumWorkgroups(IRP, A); | ||
llvm_unreachable("AAAMDMaxNumWorkgroups is only valid for function position"); | ||
} | ||
|
||
/// Propagate amdgpu-waves-per-eu attribute. | ||
struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute { | ||
AAAMDWavesPerEU(const IRPosition &IRP, Attributor &A) | ||
|
@@ -1046,8 +1190,8 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM, | |
DenseSet<const char *> Allowed( | ||
{&AAAMDAttributes::ID, &AAUniformWorkGroupSize::ID, | ||
&AAPotentialValues::ID, &AAAMDFlatWorkGroupSize::ID, | ||
&AAAMDWavesPerEU::ID, &AAAMDGPUNoAGPR::ID, &AACallEdges::ID, | ||
&AAPointerInfo::ID, &AAPotentialConstantValues::ID, | ||
&AAAMDMaxNumWorkgroups::ID, &AAAMDWavesPerEU::ID, &AAAMDGPUNoAGPR::ID, | ||
&AACallEdges::ID, &AAPointerInfo::ID, &AAPotentialConstantValues::ID, | ||
&AAUnderlyingObjects::ID, &AAAddressSpace::ID, &AAIndirectCallInfo::ID, | ||
&AAInstanceInfo::ID}); | ||
|
||
|
@@ -1071,6 +1215,7 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM, | |
for (auto *F : Functions) { | ||
A.getOrCreateAAFor<AAAMDAttributes>(IRPosition::function(*F)); | ||
A.getOrCreateAAFor<AAUniformWorkGroupSize>(IRPosition::function(*F)); | ||
A.getOrCreateAAFor<AAAMDMaxNumWorkgroups>(IRPosition::function(*F)); | ||
A.getOrCreateAAFor<AAAMDGPUNoAGPR>(IRPosition::function(*F)); | ||
CallingConv::ID CC = F->getCallingConv(); | ||
if (!AMDGPU::isEntryFunctionCC(CC)) { | ||
|
Uh oh!
There was an error while loading. Please reload this page.