Skip to content

[AMDGPU] Refactor several functions for merging with downstream work. #110562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 1, 2024

Conversation

cmc-rep
Copy link
Contributor

@cmc-rep cmc-rep commented Sep 30, 2024

For setScore, the root function is setScoreByInterval with RegInterval input
For determineWait, the root function is determineWait with RegInterval input

@llvmbot
Copy link
Member

llvmbot commented Sep 30, 2024

@llvm/pr-subscribers-backend-amdgpu

Author: Gang Chen (cmc-rep)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/110562.diff

1 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp (+130-111)
diff --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
index 80a7529002ac90..ba567f863a7e75 100644
--- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
+++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
@@ -310,7 +310,14 @@ class WaitcntBrackets {
   bool counterOutOfOrder(InstCounterType T) const;
   void simplifyWaitcnt(AMDGPU::Waitcnt &Wait) const;
   void simplifyWaitcnt(InstCounterType T, unsigned &Count) const;
-  void determineWait(InstCounterType T, int RegNo, AMDGPU::Waitcnt &Wait) const;
+
+  void determineWait(InstCounterType T, RegInterval Interval,
+                     AMDGPU::Waitcnt &Wait) const;
+  void determineWait(InstCounterType T, int RegNo,
+                     AMDGPU::Waitcnt &Wait) const {
+    determineWait(T, {RegNo, RegNo + 1}, Wait);
+  }
+
   void applyWaitcnt(const AMDGPU::Waitcnt &Wait);
   void applyWaitcnt(InstCounterType T, unsigned Count);
   void updateByEvent(const SIInstrInfo *TII, const SIRegisterInfo *TRI,
@@ -345,16 +352,22 @@ class WaitcntBrackets {
     LastFlat[DS_CNT] = ScoreUBs[DS_CNT];
   }
 
-  // Return true if there might be pending writes to the specified vgpr by VMEM
+  // Return true if there might be pending writes to the vgpr-interval by VMEM
   // instructions with types different from V.
-  bool hasOtherPendingVmemTypes(int GprNo, VmemType V) const {
-    assert(GprNo < NUM_ALL_VGPRS);
-    return VgprVmemTypes[GprNo] & ~(1 << V);
+  bool hasOtherPendingVmemTypes(RegInterval Interval, VmemType V) const {
+    for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
+      assert(RegNo < NUM_ALL_VGPRS);
+      if (VgprVmemTypes[RegNo] & ~(1 << V))
+        return true;
+    }
+    return false;
   }
 
-  void clearVgprVmemTypes(int GprNo) {
-    assert(GprNo < NUM_ALL_VGPRS);
-    VgprVmemTypes[GprNo] = 0;
+  void clearVgprVmemTypes(RegInterval Interval) {
+    for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
+      assert(RegNo < NUM_ALL_VGPRS);
+      VgprVmemTypes[RegNo] = 0;
+    }
   }
 
   void setStateOnFunctionEntryOrReturn() {
@@ -396,19 +409,16 @@ class WaitcntBrackets {
   }
 
   void setRegScore(int GprNo, InstCounterType T, unsigned Val) {
-    if (GprNo < NUM_ALL_VGPRS) {
-      VgprUB = std::max(VgprUB, GprNo);
-      VgprScores[T][GprNo] = Val;
-    } else {
-      assert(T == SmemAccessCounter);
-      SgprUB = std::max(SgprUB, GprNo - NUM_ALL_VGPRS);
-      SgprScores[GprNo - NUM_ALL_VGPRS] = Val;
-    }
+    setScoreByInterval({GprNo, GprNo + 1}, T, Val);
   }
 
-  void setExpScore(const MachineInstr *MI, const SIRegisterInfo *TRI,
-                   const MachineRegisterInfo *MRI, const MachineOperand &Op,
-                   unsigned Val);
+  void setScoreByInterval(RegInterval Interval, InstCounterType CntTy,
+                          unsigned Score);
+
+  void setScoreByOperand(const MachineInstr *MI, const SIRegisterInfo *TRI,
+                         const MachineRegisterInfo *MRI,
+                         const MachineOperand &Op, InstCounterType CntTy,
+                         unsigned Val);
 
   const GCNSubtarget *ST = nullptr;
   InstCounterType MaxCounter = NUM_EXTENDED_INST_CNTS;
@@ -772,17 +782,30 @@ RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI,
   return Result;
 }
 
-void WaitcntBrackets::setExpScore(const MachineInstr *MI,
-                                  const SIRegisterInfo *TRI,
-                                  const MachineRegisterInfo *MRI,
-                                  const MachineOperand &Op, unsigned Val) {
-  RegInterval Interval = getRegInterval(MI, MRI, TRI, Op);
-  assert(TRI->isVectorRegister(*MRI, Op.getReg()));
+void WaitcntBrackets::setScoreByInterval(RegInterval Interval,
+                                         InstCounterType CntTy,
+                                         unsigned Score) {
   for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
-    setRegScore(RegNo, EXP_CNT, Val);
+    if (RegNo < NUM_ALL_VGPRS) {
+      VgprUB = std::max(VgprUB, RegNo);
+      VgprScores[CntTy][RegNo] = Score;
+    } else {
+      assert(CntTy == SmemAccessCounter);
+      SgprUB = std::max(SgprUB, RegNo - NUM_ALL_VGPRS);
+      SgprScores[RegNo - NUM_ALL_VGPRS] = Score;
+    }
   }
 }
 
+void WaitcntBrackets::setScoreByOperand(const MachineInstr *MI,
+                                        const SIRegisterInfo *TRI,
+                                        const MachineRegisterInfo *MRI,
+                                        const MachineOperand &Op,
+                                        InstCounterType CntTy, unsigned Score) {
+  RegInterval Interval = getRegInterval(MI, MRI, TRI, Op);
+  setScoreByInterval(Interval, CntTy, Score);
+}
+
 void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
                                     const SIRegisterInfo *TRI,
                                     const MachineRegisterInfo *MRI,
@@ -806,57 +829,58 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
       // All GDS operations must protect their address register (same as
       // export.)
       if (const auto *AddrOp = TII->getNamedOperand(Inst, AMDGPU::OpName::addr))
-        setExpScore(&Inst, TRI, MRI, *AddrOp, CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI, *AddrOp, EXP_CNT, CurrScore);
 
       if (Inst.mayStore()) {
         if (const auto *Data0 =
                 TII->getNamedOperand(Inst, AMDGPU::OpName::data0))
-          setExpScore(&Inst, TRI, MRI, *Data0, CurrScore);
+          setScoreByOperand(&Inst, TRI, MRI, *Data0, EXP_CNT, CurrScore);
         if (const auto *Data1 =
                 TII->getNamedOperand(Inst, AMDGPU::OpName::data1))
-          setExpScore(&Inst, TRI, MRI, *Data1, CurrScore);
+          setScoreByOperand(&Inst, TRI, MRI, *Data1, EXP_CNT, CurrScore);
       } else if (SIInstrInfo::isAtomicRet(Inst) && !SIInstrInfo::isGWS(Inst) &&
                  Inst.getOpcode() != AMDGPU::DS_APPEND &&
                  Inst.getOpcode() != AMDGPU::DS_CONSUME &&
                  Inst.getOpcode() != AMDGPU::DS_ORDERED_COUNT) {
         for (const MachineOperand &Op : Inst.all_uses()) {
           if (TRI->isVectorRegister(*MRI, Op.getReg()))
-            setExpScore(&Inst, TRI, MRI, Op, CurrScore);
+            setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore);
         }
       }
     } else if (TII->isFLAT(Inst)) {
       if (Inst.mayStore()) {
-        setExpScore(&Inst, TRI, MRI,
-                    *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
-                    CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI,
+                          *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
+                          EXP_CNT, CurrScore);
       } else if (SIInstrInfo::isAtomicRet(Inst)) {
-        setExpScore(&Inst, TRI, MRI,
-                    *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
-                    CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI,
+                          *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
+                          EXP_CNT, CurrScore);
       }
     } else if (TII->isMIMG(Inst)) {
       if (Inst.mayStore()) {
-        setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT, CurrScore);
       } else if (SIInstrInfo::isAtomicRet(Inst)) {
-        setExpScore(&Inst, TRI, MRI,
-                    *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
-                    CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI,
+                          *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
+                          EXP_CNT, CurrScore);
       }
     } else if (TII->isMTBUF(Inst)) {
       if (Inst.mayStore())
-        setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT, CurrScore);
     } else if (TII->isMUBUF(Inst)) {
       if (Inst.mayStore()) {
-        setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT, CurrScore);
       } else if (SIInstrInfo::isAtomicRet(Inst)) {
-        setExpScore(&Inst, TRI, MRI,
-                    *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
-                    CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI,
+                          *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
+                          EXP_CNT, CurrScore);
       }
     } else if (TII->isLDSDIR(Inst)) {
       // LDSDIR instructions attach the score to the destination.
-      setExpScore(&Inst, TRI, MRI,
-                  *TII->getNamedOperand(Inst, AMDGPU::OpName::vdst), CurrScore);
+      setScoreByOperand(&Inst, TRI, MRI,
+                        *TII->getNamedOperand(Inst, AMDGPU::OpName::vdst),
+                        EXP_CNT, CurrScore);
     } else {
       if (TII->isEXP(Inst)) {
         // For export the destination registers are really temps that
@@ -865,15 +889,13 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
         // score.
         for (MachineOperand &DefMO : Inst.all_defs()) {
           if (TRI->isVGPR(*MRI, DefMO.getReg())) {
-            setRegScore(
-                TRI->getEncodingValue(AMDGPU::getMCReg(DefMO.getReg(), *ST)),
-                EXP_CNT, CurrScore);
+            setScoreByOperand(&Inst, TRI, MRI, DefMO, EXP_CNT, CurrScore);
           }
         }
       }
       for (const MachineOperand &Op : Inst.all_uses()) {
         if (TRI->isVectorRegister(*MRI, Op.getReg()))
-          setExpScore(&Inst, TRI, MRI, Op, CurrScore);
+          setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore);
       }
     }
   } else /* LGKM_CNT || EXP_CNT || VS_CNT || NUM_INST_CNTS */ {
@@ -901,9 +923,7 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
             VgprVmemTypes[RegNo] |= 1 << V;
         }
       }
-      for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
-        setRegScore(RegNo, T, CurrScore);
-      }
+      setScoreByInterval(Interval, T, CurrScore);
     }
     if (Inst.mayStore() &&
         (TII->isDS(Inst) || TII->mayWriteLDSThroughDMA(Inst))) {
@@ -1034,31 +1054,34 @@ void WaitcntBrackets::simplifyWaitcnt(InstCounterType T,
     Count = ~0u;
 }
 
-void WaitcntBrackets::determineWait(InstCounterType T, int RegNo,
+void WaitcntBrackets::determineWait(InstCounterType T, RegInterval Interval,
                                     AMDGPU::Waitcnt &Wait) const {
-  unsigned ScoreToWait = getRegScore(RegNo, T);
-
-  // If the score of src_operand falls within the bracket, we need an
-  // s_waitcnt instruction.
   const unsigned LB = getScoreLB(T);
   const unsigned UB = getScoreUB(T);
-  if ((UB >= ScoreToWait) && (ScoreToWait > LB)) {
-    if ((T == LOAD_CNT || T == DS_CNT) && hasPendingFlat() &&
-        !ST->hasFlatLgkmVMemCountInOrder()) {
-      // If there is a pending FLAT operation, and this is a VMem or LGKM
-      // waitcnt and the target can report early completion, then we need
-      // to force a waitcnt 0.
-      addWait(Wait, T, 0);
-    } else if (counterOutOfOrder(T)) {
-      // Counter can get decremented out-of-order when there
-      // are multiple types event in the bracket. Also emit an s_wait counter
-      // with a conservative value of 0 for the counter.
-      addWait(Wait, T, 0);
-    } else {
-      // If a counter has been maxed out avoid overflow by waiting for
-      // MAX(CounterType) - 1 instead.
-      unsigned NeededWait = std::min(UB - ScoreToWait, getWaitCountMax(T) - 1);
-      addWait(Wait, T, NeededWait);
+  for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
+    unsigned ScoreToWait = getRegScore(RegNo, T);
+
+    // If the score of src_operand falls within the bracket, we need an
+    // s_waitcnt instruction.
+    if ((UB >= ScoreToWait) && (ScoreToWait > LB)) {
+      if ((T == LOAD_CNT || T == DS_CNT) && hasPendingFlat() &&
+          !ST->hasFlatLgkmVMemCountInOrder()) {
+        // If there is a pending FLAT operation, and this is a VMem or LGKM
+        // waitcnt and the target can report early completion, then we need
+        // to force a waitcnt 0.
+        addWait(Wait, T, 0);
+      } else if (counterOutOfOrder(T)) {
+        // Counter can get decremented out-of-order when there
+        // are multiple types event in the bracket. Also emit an s_wait counter
+        // with a conservative value of 0 for the counter.
+        addWait(Wait, T, 0);
+      } else {
+        // If a counter has been maxed out avoid overflow by waiting for
+        // MAX(CounterType) - 1 instead.
+        unsigned NeededWait =
+            std::min(UB - ScoreToWait, getWaitCountMax(T) - 1);
+        addWait(Wait, T, NeededWait);
+      }
     }
   }
 }
@@ -1670,18 +1693,16 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
         RegInterval CallAddrOpInterval =
             ScoreBrackets.getRegInterval(&MI, MRI, TRI, CallAddrOp);
 
-        for (int RegNo = CallAddrOpInterval.first;
-             RegNo < CallAddrOpInterval.second; ++RegNo)
-          ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait);
+        ScoreBrackets.determineWait(SmemAccessCounter, CallAddrOpInterval,
+                                    Wait);
 
         if (const auto *RtnAddrOp =
                 TII->getNamedOperand(MI, AMDGPU::OpName::dst)) {
           RegInterval RtnAddrOpInterval =
               ScoreBrackets.getRegInterval(&MI, MRI, TRI, *RtnAddrOp);
 
-          for (int RegNo = RtnAddrOpInterval.first;
-               RegNo < RtnAddrOpInterval.second; ++RegNo)
-            ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait);
+          ScoreBrackets.determineWait(SmemAccessCounter, RtnAddrOpInterval,
+                                      Wait);
         }
       }
     } else {
@@ -1750,36 +1771,34 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
         RegInterval Interval = ScoreBrackets.getRegInterval(&MI, MRI, TRI, Op);
 
         const bool IsVGPR = TRI->isVectorRegister(*MRI, Op.getReg());
-        for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
-          if (IsVGPR) {
-            // Implicit VGPR defs and uses are never a part of the memory
-            // instructions description and usually present to account for
-            // super-register liveness.
-            // TODO: Most of the other instructions also have implicit uses
-            // for the liveness accounting only.
-            if (Op.isImplicit() && MI.mayLoadOrStore())
-              continue;
-
-            // RAW always needs an s_waitcnt. WAW needs an s_waitcnt unless the
-            // previous write and this write are the same type of VMEM
-            // instruction, in which case they are (in some architectures)
-            // guaranteed to write their results in order anyway.
-            if (Op.isUse() || !updateVMCntOnly(MI) ||
-                ScoreBrackets.hasOtherPendingVmemTypes(RegNo,
-                                                       getVmemType(MI)) ||
-                !ST->hasVmemWriteVgprInOrder()) {
-              ScoreBrackets.determineWait(LOAD_CNT, RegNo, Wait);
-              ScoreBrackets.determineWait(SAMPLE_CNT, RegNo, Wait);
-              ScoreBrackets.determineWait(BVH_CNT, RegNo, Wait);
-              ScoreBrackets.clearVgprVmemTypes(RegNo);
-            }
-            if (Op.isDef() || ScoreBrackets.hasPendingEvent(EXP_LDS_ACCESS)) {
-              ScoreBrackets.determineWait(EXP_CNT, RegNo, Wait);
-            }
-            ScoreBrackets.determineWait(DS_CNT, RegNo, Wait);
-          } else {
-            ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait);
+        if (IsVGPR) {
+          // Implicit VGPR defs and uses are never a part of the memory
+          // instructions description and usually present to account for
+          // super-register liveness.
+          // TODO: Most of the other instructions also have implicit uses
+          // for the liveness accounting only.
+          if (Op.isImplicit() && MI.mayLoadOrStore())
+            continue;
+
+          // RAW always needs an s_waitcnt. WAW needs an s_waitcnt unless the
+          // previous write and this write are the same type of VMEM
+          // instruction, in which case they are (in some architectures)
+          // guaranteed to write their results in order anyway.
+          if (Op.isUse() || !updateVMCntOnly(MI) ||
+              ScoreBrackets.hasOtherPendingVmemTypes(Interval,
+                                                     getVmemType(MI)) ||
+              !ST->hasVmemWriteVgprInOrder()) {
+            ScoreBrackets.determineWait(LOAD_CNT, Interval, Wait);
+            ScoreBrackets.determineWait(SAMPLE_CNT, Interval, Wait);
+            ScoreBrackets.determineWait(BVH_CNT, Interval, Wait);
+            ScoreBrackets.clearVgprVmemTypes(Interval);
+          }
+          if (Op.isDef() || ScoreBrackets.hasPendingEvent(EXP_LDS_ACCESS)) {
+            ScoreBrackets.determineWait(EXP_CNT, Interval, Wait);
           }
+          ScoreBrackets.determineWait(DS_CNT, Interval, Wait);
+        } else {
+          ScoreBrackets.determineWait(SmemAccessCounter, Interval, Wait);
         }
       }
     }

Copy link

github-actions bot commented Sep 30, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

For setScore, the root function is setScoreByInterval with RegInterval
input.
For determineWait, the root function is determineWait with RegInterval
inputs.
@cmc-rep cmc-rep requested review from jayfoad and arsenm September 30, 2024 21:15
@cmc-rep cmc-rep merged commit c66dee4 into llvm:main Oct 1, 2024
8 checks passed
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Oct 3, 2024
…llvm#110562)

For setScore, the root function is setScoreByInterval with RegInterval
input
For determineWait, the root function is determineWait with RegInterval
input
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants