Skip to content

[AMDGPU] Refactor several functions for merging with downstream work. #110562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 133 additions & 111 deletions llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,14 @@ class WaitcntBrackets {
bool counterOutOfOrder(InstCounterType T) const;
void simplifyWaitcnt(AMDGPU::Waitcnt &Wait) const;
void simplifyWaitcnt(InstCounterType T, unsigned &Count) const;
void determineWait(InstCounterType T, int RegNo, AMDGPU::Waitcnt &Wait) const;

void determineWait(InstCounterType T, RegInterval Interval,
AMDGPU::Waitcnt &Wait) const;
void determineWait(InstCounterType T, int RegNo,
AMDGPU::Waitcnt &Wait) const {
determineWait(T, {RegNo, RegNo + 1}, Wait);
}

void applyWaitcnt(const AMDGPU::Waitcnt &Wait);
void applyWaitcnt(InstCounterType T, unsigned Count);
void updateByEvent(const SIInstrInfo *TII, const SIRegisterInfo *TRI,
Expand Down Expand Up @@ -345,16 +352,22 @@ class WaitcntBrackets {
LastFlat[DS_CNT] = ScoreUBs[DS_CNT];
}

// Return true if there might be pending writes to the specified vgpr by VMEM
// Return true if there might be pending writes to the vgpr-interval by VMEM
// instructions with types different from V.
bool hasOtherPendingVmemTypes(int GprNo, VmemType V) const {
assert(GprNo < NUM_ALL_VGPRS);
return VgprVmemTypes[GprNo] & ~(1 << V);
bool hasOtherPendingVmemTypes(RegInterval Interval, VmemType V) const {
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
assert(RegNo < NUM_ALL_VGPRS);
if (VgprVmemTypes[RegNo] & ~(1 << V))
return true;
}
return false;
}

void clearVgprVmemTypes(int GprNo) {
assert(GprNo < NUM_ALL_VGPRS);
VgprVmemTypes[GprNo] = 0;
void clearVgprVmemTypes(RegInterval Interval) {
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
assert(RegNo < NUM_ALL_VGPRS);
VgprVmemTypes[RegNo] = 0;
}
}

void setStateOnFunctionEntryOrReturn() {
Expand Down Expand Up @@ -396,19 +409,16 @@ class WaitcntBrackets {
}

void setRegScore(int GprNo, InstCounterType T, unsigned Val) {
if (GprNo < NUM_ALL_VGPRS) {
VgprUB = std::max(VgprUB, GprNo);
VgprScores[T][GprNo] = Val;
} else {
assert(T == SmemAccessCounter);
SgprUB = std::max(SgprUB, GprNo - NUM_ALL_VGPRS);
SgprScores[GprNo - NUM_ALL_VGPRS] = Val;
}
setScoreByInterval({GprNo, GprNo + 1}, T, Val);
}

void setExpScore(const MachineInstr *MI, const SIRegisterInfo *TRI,
const MachineRegisterInfo *MRI, const MachineOperand &Op,
unsigned Val);
void setScoreByInterval(RegInterval Interval, InstCounterType CntTy,
unsigned Score);

void setScoreByOperand(const MachineInstr *MI, const SIRegisterInfo *TRI,
const MachineRegisterInfo *MRI,
const MachineOperand &Op, InstCounterType CntTy,
unsigned Val);

const GCNSubtarget *ST = nullptr;
InstCounterType MaxCounter = NUM_EXTENDED_INST_CNTS;
Expand Down Expand Up @@ -772,17 +782,30 @@ RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI,
return Result;
}

void WaitcntBrackets::setExpScore(const MachineInstr *MI,
const SIRegisterInfo *TRI,
const MachineRegisterInfo *MRI,
const MachineOperand &Op, unsigned Val) {
RegInterval Interval = getRegInterval(MI, MRI, TRI, Op);
assert(TRI->isVectorRegister(*MRI, Op.getReg()));
void WaitcntBrackets::setScoreByInterval(RegInterval Interval,
InstCounterType CntTy,
unsigned Score) {
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
setRegScore(RegNo, EXP_CNT, Val);
if (RegNo < NUM_ALL_VGPRS) {
VgprUB = std::max(VgprUB, RegNo);
VgprScores[CntTy][RegNo] = Score;
} else {
assert(CntTy == SmemAccessCounter);
SgprUB = std::max(SgprUB, RegNo - NUM_ALL_VGPRS);
SgprScores[RegNo - NUM_ALL_VGPRS] = Score;
}
}
}

void WaitcntBrackets::setScoreByOperand(const MachineInstr *MI,
const SIRegisterInfo *TRI,
const MachineRegisterInfo *MRI,
const MachineOperand &Op,
InstCounterType CntTy, unsigned Score) {
RegInterval Interval = getRegInterval(MI, MRI, TRI, Op);
setScoreByInterval(Interval, CntTy, Score);
}

void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
const SIRegisterInfo *TRI,
const MachineRegisterInfo *MRI,
Expand All @@ -806,57 +829,61 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
// All GDS operations must protect their address register (same as
// export.)
if (const auto *AddrOp = TII->getNamedOperand(Inst, AMDGPU::OpName::addr))
setExpScore(&Inst, TRI, MRI, *AddrOp, CurrScore);
setScoreByOperand(&Inst, TRI, MRI, *AddrOp, EXP_CNT, CurrScore);

if (Inst.mayStore()) {
if (const auto *Data0 =
TII->getNamedOperand(Inst, AMDGPU::OpName::data0))
setExpScore(&Inst, TRI, MRI, *Data0, CurrScore);
setScoreByOperand(&Inst, TRI, MRI, *Data0, EXP_CNT, CurrScore);
if (const auto *Data1 =
TII->getNamedOperand(Inst, AMDGPU::OpName::data1))
setExpScore(&Inst, TRI, MRI, *Data1, CurrScore);
setScoreByOperand(&Inst, TRI, MRI, *Data1, EXP_CNT, CurrScore);
} else if (SIInstrInfo::isAtomicRet(Inst) && !SIInstrInfo::isGWS(Inst) &&
Inst.getOpcode() != AMDGPU::DS_APPEND &&
Inst.getOpcode() != AMDGPU::DS_CONSUME &&
Inst.getOpcode() != AMDGPU::DS_ORDERED_COUNT) {
for (const MachineOperand &Op : Inst.all_uses()) {
if (TRI->isVectorRegister(*MRI, Op.getReg()))
setExpScore(&Inst, TRI, MRI, Op, CurrScore);
setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore);
}
}
} else if (TII->isFLAT(Inst)) {
if (Inst.mayStore()) {
setExpScore(&Inst, TRI, MRI,
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
CurrScore);
setScoreByOperand(&Inst, TRI, MRI,
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
EXP_CNT, CurrScore);
} else if (SIInstrInfo::isAtomicRet(Inst)) {
setExpScore(&Inst, TRI, MRI,
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
CurrScore);
setScoreByOperand(&Inst, TRI, MRI,
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
EXP_CNT, CurrScore);
}
} else if (TII->isMIMG(Inst)) {
if (Inst.mayStore()) {
setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore);
setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT,
CurrScore);
} else if (SIInstrInfo::isAtomicRet(Inst)) {
setExpScore(&Inst, TRI, MRI,
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
CurrScore);
setScoreByOperand(&Inst, TRI, MRI,
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
EXP_CNT, CurrScore);
}
} else if (TII->isMTBUF(Inst)) {
if (Inst.mayStore())
setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore);
setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT,
CurrScore);
} else if (TII->isMUBUF(Inst)) {
if (Inst.mayStore()) {
setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore);
setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT,
CurrScore);
} else if (SIInstrInfo::isAtomicRet(Inst)) {
setExpScore(&Inst, TRI, MRI,
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
CurrScore);
setScoreByOperand(&Inst, TRI, MRI,
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
EXP_CNT, CurrScore);
}
} else if (TII->isLDSDIR(Inst)) {
// LDSDIR instructions attach the score to the destination.
setExpScore(&Inst, TRI, MRI,
*TII->getNamedOperand(Inst, AMDGPU::OpName::vdst), CurrScore);
setScoreByOperand(&Inst, TRI, MRI,
*TII->getNamedOperand(Inst, AMDGPU::OpName::vdst),
EXP_CNT, CurrScore);
} else {
if (TII->isEXP(Inst)) {
// For export the destination registers are really temps that
Expand All @@ -865,15 +892,13 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
// score.
for (MachineOperand &DefMO : Inst.all_defs()) {
if (TRI->isVGPR(*MRI, DefMO.getReg())) {
setRegScore(
TRI->getEncodingValue(AMDGPU::getMCReg(DefMO.getReg(), *ST)),
EXP_CNT, CurrScore);
setScoreByOperand(&Inst, TRI, MRI, DefMO, EXP_CNT, CurrScore);
}
}
}
for (const MachineOperand &Op : Inst.all_uses()) {
if (TRI->isVectorRegister(*MRI, Op.getReg()))
setExpScore(&Inst, TRI, MRI, Op, CurrScore);
setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore);
}
}
} else /* LGKM_CNT || EXP_CNT || VS_CNT || NUM_INST_CNTS */ {
Expand Down Expand Up @@ -901,9 +926,7 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
VgprVmemTypes[RegNo] |= 1 << V;
}
}
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
setRegScore(RegNo, T, CurrScore);
}
setScoreByInterval(Interval, T, CurrScore);
}
if (Inst.mayStore() &&
(TII->isDS(Inst) || TII->mayWriteLDSThroughDMA(Inst))) {
Expand Down Expand Up @@ -1034,31 +1057,34 @@ void WaitcntBrackets::simplifyWaitcnt(InstCounterType T,
Count = ~0u;
}

void WaitcntBrackets::determineWait(InstCounterType T, int RegNo,
void WaitcntBrackets::determineWait(InstCounterType T, RegInterval Interval,
AMDGPU::Waitcnt &Wait) const {
unsigned ScoreToWait = getRegScore(RegNo, T);

// If the score of src_operand falls within the bracket, we need an
// s_waitcnt instruction.
const unsigned LB = getScoreLB(T);
const unsigned UB = getScoreUB(T);
if ((UB >= ScoreToWait) && (ScoreToWait > LB)) {
if ((T == LOAD_CNT || T == DS_CNT) && hasPendingFlat() &&
!ST->hasFlatLgkmVMemCountInOrder()) {
// If there is a pending FLAT operation, and this is a VMem or LGKM
// waitcnt and the target can report early completion, then we need
// to force a waitcnt 0.
addWait(Wait, T, 0);
} else if (counterOutOfOrder(T)) {
// Counter can get decremented out-of-order when there
// are multiple types event in the bracket. Also emit an s_wait counter
// with a conservative value of 0 for the counter.
addWait(Wait, T, 0);
} else {
// If a counter has been maxed out avoid overflow by waiting for
// MAX(CounterType) - 1 instead.
unsigned NeededWait = std::min(UB - ScoreToWait, getWaitCountMax(T) - 1);
addWait(Wait, T, NeededWait);
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
unsigned ScoreToWait = getRegScore(RegNo, T);

// If the score of src_operand falls within the bracket, we need an
// s_waitcnt instruction.
if ((UB >= ScoreToWait) && (ScoreToWait > LB)) {
if ((T == LOAD_CNT || T == DS_CNT) && hasPendingFlat() &&
!ST->hasFlatLgkmVMemCountInOrder()) {
// If there is a pending FLAT operation, and this is a VMem or LGKM
// waitcnt and the target can report early completion, then we need
// to force a waitcnt 0.
addWait(Wait, T, 0);
} else if (counterOutOfOrder(T)) {
// Counter can get decremented out-of-order when there
// are multiple types event in the bracket. Also emit an s_wait counter
// with a conservative value of 0 for the counter.
addWait(Wait, T, 0);
} else {
// If a counter has been maxed out avoid overflow by waiting for
// MAX(CounterType) - 1 instead.
unsigned NeededWait =
std::min(UB - ScoreToWait, getWaitCountMax(T) - 1);
addWait(Wait, T, NeededWait);
}
}
}
}
Expand Down Expand Up @@ -1670,18 +1696,16 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
RegInterval CallAddrOpInterval =
ScoreBrackets.getRegInterval(&MI, MRI, TRI, CallAddrOp);

for (int RegNo = CallAddrOpInterval.first;
RegNo < CallAddrOpInterval.second; ++RegNo)
ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait);
ScoreBrackets.determineWait(SmemAccessCounter, CallAddrOpInterval,
Wait);

if (const auto *RtnAddrOp =
TII->getNamedOperand(MI, AMDGPU::OpName::dst)) {
RegInterval RtnAddrOpInterval =
ScoreBrackets.getRegInterval(&MI, MRI, TRI, *RtnAddrOp);

for (int RegNo = RtnAddrOpInterval.first;
RegNo < RtnAddrOpInterval.second; ++RegNo)
ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait);
ScoreBrackets.determineWait(SmemAccessCounter, RtnAddrOpInterval,
Wait);
}
}
} else {
Expand Down Expand Up @@ -1750,36 +1774,34 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
RegInterval Interval = ScoreBrackets.getRegInterval(&MI, MRI, TRI, Op);

const bool IsVGPR = TRI->isVectorRegister(*MRI, Op.getReg());
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
if (IsVGPR) {
// Implicit VGPR defs and uses are never a part of the memory
// instructions description and usually present to account for
// super-register liveness.
// TODO: Most of the other instructions also have implicit uses
// for the liveness accounting only.
if (Op.isImplicit() && MI.mayLoadOrStore())
continue;

// RAW always needs an s_waitcnt. WAW needs an s_waitcnt unless the
// previous write and this write are the same type of VMEM
// instruction, in which case they are (in some architectures)
// guaranteed to write their results in order anyway.
if (Op.isUse() || !updateVMCntOnly(MI) ||
ScoreBrackets.hasOtherPendingVmemTypes(RegNo,
getVmemType(MI)) ||
!ST->hasVmemWriteVgprInOrder()) {
ScoreBrackets.determineWait(LOAD_CNT, RegNo, Wait);
ScoreBrackets.determineWait(SAMPLE_CNT, RegNo, Wait);
ScoreBrackets.determineWait(BVH_CNT, RegNo, Wait);
ScoreBrackets.clearVgprVmemTypes(RegNo);
}
if (Op.isDef() || ScoreBrackets.hasPendingEvent(EXP_LDS_ACCESS)) {
ScoreBrackets.determineWait(EXP_CNT, RegNo, Wait);
}
ScoreBrackets.determineWait(DS_CNT, RegNo, Wait);
} else {
ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait);
if (IsVGPR) {
// Implicit VGPR defs and uses are never a part of the memory
// instructions description and usually present to account for
// super-register liveness.
// TODO: Most of the other instructions also have implicit uses
// for the liveness accounting only.
if (Op.isImplicit() && MI.mayLoadOrStore())
continue;

// RAW always needs an s_waitcnt. WAW needs an s_waitcnt unless the
// previous write and this write are the same type of VMEM
// instruction, in which case they are (in some architectures)
// guaranteed to write their results in order anyway.
if (Op.isUse() || !updateVMCntOnly(MI) ||
ScoreBrackets.hasOtherPendingVmemTypes(Interval,
getVmemType(MI)) ||
!ST->hasVmemWriteVgprInOrder()) {
ScoreBrackets.determineWait(LOAD_CNT, Interval, Wait);
ScoreBrackets.determineWait(SAMPLE_CNT, Interval, Wait);
ScoreBrackets.determineWait(BVH_CNT, Interval, Wait);
ScoreBrackets.clearVgprVmemTypes(Interval);
}
if (Op.isDef() || ScoreBrackets.hasPendingEvent(EXP_LDS_ACCESS)) {
ScoreBrackets.determineWait(EXP_CNT, Interval, Wait);
}
ScoreBrackets.determineWait(DS_CNT, Interval, Wait);
} else {
ScoreBrackets.determineWait(SmemAccessCounter, Interval, Wait);
}
}
}
Expand Down
Loading