Skip to content

Commit c66dee4

Browse files
authored
[AMDGPU] Refactor several functions for merging with downstream work. (#110562)
For setScore, the root function is setScoreByInterval with RegInterval input For determineWait, the root function is determineWait with RegInterval input
1 parent 9f6f6af commit c66dee4

File tree

1 file changed

+133
-111
lines changed

1 file changed

+133
-111
lines changed

llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp

Lines changed: 133 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,14 @@ class WaitcntBrackets {
310310
bool counterOutOfOrder(InstCounterType T) const;
311311
void simplifyWaitcnt(AMDGPU::Waitcnt &Wait) const;
312312
void simplifyWaitcnt(InstCounterType T, unsigned &Count) const;
313-
void determineWait(InstCounterType T, int RegNo, AMDGPU::Waitcnt &Wait) const;
313+
314+
void determineWait(InstCounterType T, RegInterval Interval,
315+
AMDGPU::Waitcnt &Wait) const;
316+
void determineWait(InstCounterType T, int RegNo,
317+
AMDGPU::Waitcnt &Wait) const {
318+
determineWait(T, {RegNo, RegNo + 1}, Wait);
319+
}
320+
314321
void applyWaitcnt(const AMDGPU::Waitcnt &Wait);
315322
void applyWaitcnt(InstCounterType T, unsigned Count);
316323
void updateByEvent(const SIInstrInfo *TII, const SIRegisterInfo *TRI,
@@ -345,16 +352,22 @@ class WaitcntBrackets {
345352
LastFlat[DS_CNT] = ScoreUBs[DS_CNT];
346353
}
347354

348-
// Return true if there might be pending writes to the specified vgpr by VMEM
355+
// Return true if there might be pending writes to the vgpr-interval by VMEM
349356
// instructions with types different from V.
350-
bool hasOtherPendingVmemTypes(int GprNo, VmemType V) const {
351-
assert(GprNo < NUM_ALL_VGPRS);
352-
return VgprVmemTypes[GprNo] & ~(1 << V);
357+
bool hasOtherPendingVmemTypes(RegInterval Interval, VmemType V) const {
358+
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
359+
assert(RegNo < NUM_ALL_VGPRS);
360+
if (VgprVmemTypes[RegNo] & ~(1 << V))
361+
return true;
362+
}
363+
return false;
353364
}
354365

355-
void clearVgprVmemTypes(int GprNo) {
356-
assert(GprNo < NUM_ALL_VGPRS);
357-
VgprVmemTypes[GprNo] = 0;
366+
void clearVgprVmemTypes(RegInterval Interval) {
367+
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
368+
assert(RegNo < NUM_ALL_VGPRS);
369+
VgprVmemTypes[RegNo] = 0;
370+
}
358371
}
359372

360373
void setStateOnFunctionEntryOrReturn() {
@@ -396,19 +409,16 @@ class WaitcntBrackets {
396409
}
397410

398411
void setRegScore(int GprNo, InstCounterType T, unsigned Val) {
399-
if (GprNo < NUM_ALL_VGPRS) {
400-
VgprUB = std::max(VgprUB, GprNo);
401-
VgprScores[T][GprNo] = Val;
402-
} else {
403-
assert(T == SmemAccessCounter);
404-
SgprUB = std::max(SgprUB, GprNo - NUM_ALL_VGPRS);
405-
SgprScores[GprNo - NUM_ALL_VGPRS] = Val;
406-
}
412+
setScoreByInterval({GprNo, GprNo + 1}, T, Val);
407413
}
408414

409-
void setExpScore(const MachineInstr *MI, const SIRegisterInfo *TRI,
410-
const MachineRegisterInfo *MRI, const MachineOperand &Op,
411-
unsigned Val);
415+
void setScoreByInterval(RegInterval Interval, InstCounterType CntTy,
416+
unsigned Score);
417+
418+
void setScoreByOperand(const MachineInstr *MI, const SIRegisterInfo *TRI,
419+
const MachineRegisterInfo *MRI,
420+
const MachineOperand &Op, InstCounterType CntTy,
421+
unsigned Val);
412422

413423
const GCNSubtarget *ST = nullptr;
414424
InstCounterType MaxCounter = NUM_EXTENDED_INST_CNTS;
@@ -772,17 +782,30 @@ RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI,
772782
return Result;
773783
}
774784

775-
void WaitcntBrackets::setExpScore(const MachineInstr *MI,
776-
const SIRegisterInfo *TRI,
777-
const MachineRegisterInfo *MRI,
778-
const MachineOperand &Op, unsigned Val) {
779-
RegInterval Interval = getRegInterval(MI, MRI, TRI, Op);
780-
assert(TRI->isVectorRegister(*MRI, Op.getReg()));
785+
void WaitcntBrackets::setScoreByInterval(RegInterval Interval,
786+
InstCounterType CntTy,
787+
unsigned Score) {
781788
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
782-
setRegScore(RegNo, EXP_CNT, Val);
789+
if (RegNo < NUM_ALL_VGPRS) {
790+
VgprUB = std::max(VgprUB, RegNo);
791+
VgprScores[CntTy][RegNo] = Score;
792+
} else {
793+
assert(CntTy == SmemAccessCounter);
794+
SgprUB = std::max(SgprUB, RegNo - NUM_ALL_VGPRS);
795+
SgprScores[RegNo - NUM_ALL_VGPRS] = Score;
796+
}
783797
}
784798
}
785799

800+
void WaitcntBrackets::setScoreByOperand(const MachineInstr *MI,
801+
const SIRegisterInfo *TRI,
802+
const MachineRegisterInfo *MRI,
803+
const MachineOperand &Op,
804+
InstCounterType CntTy, unsigned Score) {
805+
RegInterval Interval = getRegInterval(MI, MRI, TRI, Op);
806+
setScoreByInterval(Interval, CntTy, Score);
807+
}
808+
786809
void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
787810
const SIRegisterInfo *TRI,
788811
const MachineRegisterInfo *MRI,
@@ -806,57 +829,61 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
806829
// All GDS operations must protect their address register (same as
807830
// export.)
808831
if (const auto *AddrOp = TII->getNamedOperand(Inst, AMDGPU::OpName::addr))
809-
setExpScore(&Inst, TRI, MRI, *AddrOp, CurrScore);
832+
setScoreByOperand(&Inst, TRI, MRI, *AddrOp, EXP_CNT, CurrScore);
810833

811834
if (Inst.mayStore()) {
812835
if (const auto *Data0 =
813836
TII->getNamedOperand(Inst, AMDGPU::OpName::data0))
814-
setExpScore(&Inst, TRI, MRI, *Data0, CurrScore);
837+
setScoreByOperand(&Inst, TRI, MRI, *Data0, EXP_CNT, CurrScore);
815838
if (const auto *Data1 =
816839
TII->getNamedOperand(Inst, AMDGPU::OpName::data1))
817-
setExpScore(&Inst, TRI, MRI, *Data1, CurrScore);
840+
setScoreByOperand(&Inst, TRI, MRI, *Data1, EXP_CNT, CurrScore);
818841
} else if (SIInstrInfo::isAtomicRet(Inst) && !SIInstrInfo::isGWS(Inst) &&
819842
Inst.getOpcode() != AMDGPU::DS_APPEND &&
820843
Inst.getOpcode() != AMDGPU::DS_CONSUME &&
821844
Inst.getOpcode() != AMDGPU::DS_ORDERED_COUNT) {
822845
for (const MachineOperand &Op : Inst.all_uses()) {
823846
if (TRI->isVectorRegister(*MRI, Op.getReg()))
824-
setExpScore(&Inst, TRI, MRI, Op, CurrScore);
847+
setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore);
825848
}
826849
}
827850
} else if (TII->isFLAT(Inst)) {
828851
if (Inst.mayStore()) {
829-
setExpScore(&Inst, TRI, MRI,
830-
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
831-
CurrScore);
852+
setScoreByOperand(&Inst, TRI, MRI,
853+
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
854+
EXP_CNT, CurrScore);
832855
} else if (SIInstrInfo::isAtomicRet(Inst)) {
833-
setExpScore(&Inst, TRI, MRI,
834-
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
835-
CurrScore);
856+
setScoreByOperand(&Inst, TRI, MRI,
857+
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
858+
EXP_CNT, CurrScore);
836859
}
837860
} else if (TII->isMIMG(Inst)) {
838861
if (Inst.mayStore()) {
839-
setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore);
862+
setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT,
863+
CurrScore);
840864
} else if (SIInstrInfo::isAtomicRet(Inst)) {
841-
setExpScore(&Inst, TRI, MRI,
842-
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
843-
CurrScore);
865+
setScoreByOperand(&Inst, TRI, MRI,
866+
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
867+
EXP_CNT, CurrScore);
844868
}
845869
} else if (TII->isMTBUF(Inst)) {
846870
if (Inst.mayStore())
847-
setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore);
871+
setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT,
872+
CurrScore);
848873
} else if (TII->isMUBUF(Inst)) {
849874
if (Inst.mayStore()) {
850-
setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore);
875+
setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT,
876+
CurrScore);
851877
} else if (SIInstrInfo::isAtomicRet(Inst)) {
852-
setExpScore(&Inst, TRI, MRI,
853-
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
854-
CurrScore);
878+
setScoreByOperand(&Inst, TRI, MRI,
879+
*TII->getNamedOperand(Inst, AMDGPU::OpName::data),
880+
EXP_CNT, CurrScore);
855881
}
856882
} else if (TII->isLDSDIR(Inst)) {
857883
// LDSDIR instructions attach the score to the destination.
858-
setExpScore(&Inst, TRI, MRI,
859-
*TII->getNamedOperand(Inst, AMDGPU::OpName::vdst), CurrScore);
884+
setScoreByOperand(&Inst, TRI, MRI,
885+
*TII->getNamedOperand(Inst, AMDGPU::OpName::vdst),
886+
EXP_CNT, CurrScore);
860887
} else {
861888
if (TII->isEXP(Inst)) {
862889
// For export the destination registers are really temps that
@@ -865,15 +892,13 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
865892
// score.
866893
for (MachineOperand &DefMO : Inst.all_defs()) {
867894
if (TRI->isVGPR(*MRI, DefMO.getReg())) {
868-
setRegScore(
869-
TRI->getEncodingValue(AMDGPU::getMCReg(DefMO.getReg(), *ST)),
870-
EXP_CNT, CurrScore);
895+
setScoreByOperand(&Inst, TRI, MRI, DefMO, EXP_CNT, CurrScore);
871896
}
872897
}
873898
}
874899
for (const MachineOperand &Op : Inst.all_uses()) {
875900
if (TRI->isVectorRegister(*MRI, Op.getReg()))
876-
setExpScore(&Inst, TRI, MRI, Op, CurrScore);
901+
setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore);
877902
}
878903
}
879904
} else /* LGKM_CNT || EXP_CNT || VS_CNT || NUM_INST_CNTS */ {
@@ -901,9 +926,7 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
901926
VgprVmemTypes[RegNo] |= 1 << V;
902927
}
903928
}
904-
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
905-
setRegScore(RegNo, T, CurrScore);
906-
}
929+
setScoreByInterval(Interval, T, CurrScore);
907930
}
908931
if (Inst.mayStore() &&
909932
(TII->isDS(Inst) || TII->mayWriteLDSThroughDMA(Inst))) {
@@ -1034,31 +1057,34 @@ void WaitcntBrackets::simplifyWaitcnt(InstCounterType T,
10341057
Count = ~0u;
10351058
}
10361059

1037-
void WaitcntBrackets::determineWait(InstCounterType T, int RegNo,
1060+
void WaitcntBrackets::determineWait(InstCounterType T, RegInterval Interval,
10381061
AMDGPU::Waitcnt &Wait) const {
1039-
unsigned ScoreToWait = getRegScore(RegNo, T);
1040-
1041-
// If the score of src_operand falls within the bracket, we need an
1042-
// s_waitcnt instruction.
10431062
const unsigned LB = getScoreLB(T);
10441063
const unsigned UB = getScoreUB(T);
1045-
if ((UB >= ScoreToWait) && (ScoreToWait > LB)) {
1046-
if ((T == LOAD_CNT || T == DS_CNT) && hasPendingFlat() &&
1047-
!ST->hasFlatLgkmVMemCountInOrder()) {
1048-
// If there is a pending FLAT operation, and this is a VMem or LGKM
1049-
// waitcnt and the target can report early completion, then we need
1050-
// to force a waitcnt 0.
1051-
addWait(Wait, T, 0);
1052-
} else if (counterOutOfOrder(T)) {
1053-
// Counter can get decremented out-of-order when there
1054-
// are multiple types event in the bracket. Also emit an s_wait counter
1055-
// with a conservative value of 0 for the counter.
1056-
addWait(Wait, T, 0);
1057-
} else {
1058-
// If a counter has been maxed out avoid overflow by waiting for
1059-
// MAX(CounterType) - 1 instead.
1060-
unsigned NeededWait = std::min(UB - ScoreToWait, getWaitCountMax(T) - 1);
1061-
addWait(Wait, T, NeededWait);
1064+
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
1065+
unsigned ScoreToWait = getRegScore(RegNo, T);
1066+
1067+
// If the score of src_operand falls within the bracket, we need an
1068+
// s_waitcnt instruction.
1069+
if ((UB >= ScoreToWait) && (ScoreToWait > LB)) {
1070+
if ((T == LOAD_CNT || T == DS_CNT) && hasPendingFlat() &&
1071+
!ST->hasFlatLgkmVMemCountInOrder()) {
1072+
// If there is a pending FLAT operation, and this is a VMem or LGKM
1073+
// waitcnt and the target can report early completion, then we need
1074+
// to force a waitcnt 0.
1075+
addWait(Wait, T, 0);
1076+
} else if (counterOutOfOrder(T)) {
1077+
// Counter can get decremented out-of-order when there
1078+
// are multiple types event in the bracket. Also emit an s_wait counter
1079+
// with a conservative value of 0 for the counter.
1080+
addWait(Wait, T, 0);
1081+
} else {
1082+
// If a counter has been maxed out avoid overflow by waiting for
1083+
// MAX(CounterType) - 1 instead.
1084+
unsigned NeededWait =
1085+
std::min(UB - ScoreToWait, getWaitCountMax(T) - 1);
1086+
addWait(Wait, T, NeededWait);
1087+
}
10621088
}
10631089
}
10641090
}
@@ -1670,18 +1696,16 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
16701696
RegInterval CallAddrOpInterval =
16711697
ScoreBrackets.getRegInterval(&MI, MRI, TRI, CallAddrOp);
16721698

1673-
for (int RegNo = CallAddrOpInterval.first;
1674-
RegNo < CallAddrOpInterval.second; ++RegNo)
1675-
ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait);
1699+
ScoreBrackets.determineWait(SmemAccessCounter, CallAddrOpInterval,
1700+
Wait);
16761701

16771702
if (const auto *RtnAddrOp =
16781703
TII->getNamedOperand(MI, AMDGPU::OpName::dst)) {
16791704
RegInterval RtnAddrOpInterval =
16801705
ScoreBrackets.getRegInterval(&MI, MRI, TRI, *RtnAddrOp);
16811706

1682-
for (int RegNo = RtnAddrOpInterval.first;
1683-
RegNo < RtnAddrOpInterval.second; ++RegNo)
1684-
ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait);
1707+
ScoreBrackets.determineWait(SmemAccessCounter, RtnAddrOpInterval,
1708+
Wait);
16851709
}
16861710
}
16871711
} else {
@@ -1750,36 +1774,34 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
17501774
RegInterval Interval = ScoreBrackets.getRegInterval(&MI, MRI, TRI, Op);
17511775

17521776
const bool IsVGPR = TRI->isVectorRegister(*MRI, Op.getReg());
1753-
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
1754-
if (IsVGPR) {
1755-
// Implicit VGPR defs and uses are never a part of the memory
1756-
// instructions description and usually present to account for
1757-
// super-register liveness.
1758-
// TODO: Most of the other instructions also have implicit uses
1759-
// for the liveness accounting only.
1760-
if (Op.isImplicit() && MI.mayLoadOrStore())
1761-
continue;
1762-
1763-
// RAW always needs an s_waitcnt. WAW needs an s_waitcnt unless the
1764-
// previous write and this write are the same type of VMEM
1765-
// instruction, in which case they are (in some architectures)
1766-
// guaranteed to write their results in order anyway.
1767-
if (Op.isUse() || !updateVMCntOnly(MI) ||
1768-
ScoreBrackets.hasOtherPendingVmemTypes(RegNo,
1769-
getVmemType(MI)) ||
1770-
!ST->hasVmemWriteVgprInOrder()) {
1771-
ScoreBrackets.determineWait(LOAD_CNT, RegNo, Wait);
1772-
ScoreBrackets.determineWait(SAMPLE_CNT, RegNo, Wait);
1773-
ScoreBrackets.determineWait(BVH_CNT, RegNo, Wait);
1774-
ScoreBrackets.clearVgprVmemTypes(RegNo);
1775-
}
1776-
if (Op.isDef() || ScoreBrackets.hasPendingEvent(EXP_LDS_ACCESS)) {
1777-
ScoreBrackets.determineWait(EXP_CNT, RegNo, Wait);
1778-
}
1779-
ScoreBrackets.determineWait(DS_CNT, RegNo, Wait);
1780-
} else {
1781-
ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait);
1777+
if (IsVGPR) {
1778+
// Implicit VGPR defs and uses are never a part of the memory
1779+
// instructions description and usually present to account for
1780+
// super-register liveness.
1781+
// TODO: Most of the other instructions also have implicit uses
1782+
// for the liveness accounting only.
1783+
if (Op.isImplicit() && MI.mayLoadOrStore())
1784+
continue;
1785+
1786+
// RAW always needs an s_waitcnt. WAW needs an s_waitcnt unless the
1787+
// previous write and this write are the same type of VMEM
1788+
// instruction, in which case they are (in some architectures)
1789+
// guaranteed to write their results in order anyway.
1790+
if (Op.isUse() || !updateVMCntOnly(MI) ||
1791+
ScoreBrackets.hasOtherPendingVmemTypes(Interval,
1792+
getVmemType(MI)) ||
1793+
!ST->hasVmemWriteVgprInOrder()) {
1794+
ScoreBrackets.determineWait(LOAD_CNT, Interval, Wait);
1795+
ScoreBrackets.determineWait(SAMPLE_CNT, Interval, Wait);
1796+
ScoreBrackets.determineWait(BVH_CNT, Interval, Wait);
1797+
ScoreBrackets.clearVgprVmemTypes(Interval);
1798+
}
1799+
if (Op.isDef() || ScoreBrackets.hasPendingEvent(EXP_LDS_ACCESS)) {
1800+
ScoreBrackets.determineWait(EXP_CNT, Interval, Wait);
17821801
}
1802+
ScoreBrackets.determineWait(DS_CNT, Interval, Wait);
1803+
} else {
1804+
ScoreBrackets.determineWait(SmemAccessCounter, Interval, Wait);
17831805
}
17841806
}
17851807
}

0 commit comments

Comments
 (0)