@@ -850,7 +850,8 @@ class IGLPStrategy {
850
850
// Add SchedGroups to \p Pipeline to implement this Strategy.
851
851
virtual void applyIGLPStrategy (
852
852
DenseMap<int , SUnitsToCandidateSGsMap> &SyncedInstrs,
853
- DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups) = 0;
853
+ DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups,
854
+ bool IsPostRA) = 0;
854
855
855
856
// Returns true if this strategy should be applied to a ScheduleDAG.
856
857
virtual bool shouldApplyStrategy (ScheduleDAGInstrs *DAG) = 0;
@@ -868,7 +869,8 @@ class MFMASmallGemmOpt final : public IGLPStrategy {
868
869
public:
869
870
void applyIGLPStrategy (
870
871
DenseMap<int , SUnitsToCandidateSGsMap> &SyncedInstrs,
871
- DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups) override ;
872
+ DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups,
873
+ bool IsPostRA) override ;
872
874
873
875
bool shouldApplyStrategy (ScheduleDAGInstrs *DAG) override { return true ; }
874
876
@@ -880,7 +882,8 @@ class MFMASmallGemmOpt final : public IGLPStrategy {
880
882
881
883
void MFMASmallGemmOpt::applyIGLPStrategy (
882
884
DenseMap<int , SUnitsToCandidateSGsMap> &SyncedInstrs,
883
- DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups) {
885
+ DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups,
886
+ bool IsPostRA) {
884
887
// Count the number of MFMA instructions.
885
888
unsigned MFMACount = 0 ;
886
889
for (const MachineInstr &I : *DAG)
@@ -1076,9 +1079,12 @@ class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
1076
1079
Cache->push_back (Pred.getSUnit ());
1077
1080
}
1078
1081
}
1082
+
1083
+ // If the other group has no PERM preds, then this group won't share any
1084
+ if (!Cache->size ())
1085
+ return false ;
1079
1086
}
1080
1087
1081
- assert (Cache->size ());
1082
1088
auto DAG = SyncPipe[0 ].DAG ;
1083
1089
// Does the previous DS_WRITE share a V_PERM predecessor with this
1084
1090
// VMEM_READ
@@ -1095,7 +1101,8 @@ class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
1095
1101
public:
1096
1102
void applyIGLPStrategy (
1097
1103
DenseMap<int , SUnitsToCandidateSGsMap> &SyncedInstrs,
1098
- DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups) override ;
1104
+ DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups,
1105
+ bool IsPostRA) override ;
1099
1106
1100
1107
bool shouldApplyStrategy (ScheduleDAGInstrs *DAG) override { return true ; }
1101
1108
@@ -1105,14 +1112,20 @@ class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
1105
1112
}
1106
1113
};
1107
1114
1115
+ static unsigned DSWCount = 0 ;
1116
+ static unsigned DSWWithPermCount = 0 ;
1117
+ static unsigned DSWWithSharedVMEMCount = 0 ;
1118
+
1108
1119
void MFMASmallGemmSingleWaveOpt::applyIGLPStrategy (
1109
1120
DenseMap<int , SUnitsToCandidateSGsMap> &SyncedInstrs,
1110
- DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups) {
1121
+ DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups,
1122
+ bool IsPostRA) {
1111
1123
unsigned MFMACount = 0 ;
1112
- unsigned DSWCount = 0 ;
1113
- unsigned DSWWithPermCount = 0 ;
1114
- unsigned DSWWithSharedVMEMCount = 0 ;
1115
1124
unsigned DSRCount = 0 ;
1125
+
1126
+ assert ((IsPostRA ||
1127
+ DSWCount == DSWWithPermCount == DSWWithSharedVMEMCount == 0 ) &&
1128
+ " DSWCounters should be zero in pre-RA scheduling!" );
1116
1129
SmallVector<SUnit *, 6 > DSWithPerms;
1117
1130
for (auto &SU : DAG->SUnits ) {
1118
1131
auto I = SU.getInstr ();
@@ -1121,7 +1134,7 @@ void MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
1121
1134
else if (TII->isDS (*I)) {
1122
1135
if (I->mayLoad ())
1123
1136
++DSRCount;
1124
- else if (I->mayStore ()) {
1137
+ else if (I->mayStore () && !IsPostRA ) {
1125
1138
++DSWCount;
1126
1139
for (auto Pred : SU.Preds ) {
1127
1140
if (Pred.getSUnit ()->getInstr ()->getOpcode () ==
@@ -1133,56 +1146,59 @@ void MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
1133
1146
}
1134
1147
}
1135
1148
}
1136
- DSWWithPermCount = DSWithPerms.size ();
1137
- auto I = DSWithPerms.begin ();
1138
- auto E = DSWithPerms.end ();
1139
-
1140
- // Get the count of DS_WRITES with V_PERM predecessors which
1141
- // have loop carried dependencies (WAR) on the same VMEM_READs.
1142
- // We consider partial overlap as a miss -- in other words,
1143
- // for a given DS_W, we only consider another DS_W as matching
1144
- // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
1145
- // for every V_PERM pred of this DS_W.
1146
- DenseMap<MachineInstr *, SUnit *> VMEMLookup;
1147
- SmallVector<SUnit *, 6 > Counted;
1148
- for (; I != E; I++) {
1149
- SUnit *Cand = nullptr ;
1150
- bool MissedAny = false ;
1151
- for (auto &Pred : (*I)->Preds ) {
1152
- if (Pred.getSUnit ()->getInstr ()->getOpcode () != AMDGPU::V_PERM_B32_e64)
1153
- continue ;
1154
1149
1155
- if (Cand && llvm::is_contained (Counted, Cand))
1156
- break ;
1157
-
1158
- for (auto &Succ : Pred.getSUnit ()->Succs ) {
1159
- auto MI = Succ.getSUnit ()->getInstr ();
1160
- if (!TII->isVMEM (*MI) || !MI->mayLoad ())
1150
+ if (!IsPostRA) {
1151
+ DSWWithPermCount = DSWithPerms.size ();
1152
+ auto I = DSWithPerms.begin ();
1153
+ auto E = DSWithPerms.end ();
1154
+
1155
+ // Get the count of DS_WRITES with V_PERM predecessors which
1156
+ // have loop carried dependencies (WAR) on the same VMEM_READs.
1157
+ // We consider partial overlap as a miss -- in other words,
1158
+ // for a given DS_W, we only consider another DS_W as matching
1159
+ // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
1160
+ // for every V_PERM pred of this DS_W.
1161
+ DenseMap<MachineInstr *, SUnit *> VMEMLookup;
1162
+ SmallVector<SUnit *, 6 > Counted;
1163
+ for (; I != E; I++) {
1164
+ SUnit *Cand = nullptr ;
1165
+ bool MissedAny = false ;
1166
+ for (auto &Pred : (*I)->Preds ) {
1167
+ if (Pred.getSUnit ()->getInstr ()->getOpcode () != AMDGPU::V_PERM_B32_e64)
1161
1168
continue ;
1162
1169
1163
- if (MissedAny || !VMEMLookup.size ()) {
1164
- MissedAny = true ;
1165
- VMEMLookup[MI] = *I;
1166
- continue ;
1167
- }
1170
+ if (Cand && llvm::is_contained (Counted, Cand))
1171
+ break ;
1168
1172
1169
- if (!VMEMLookup.contains (MI)) {
1170
- MissedAny = true ;
1171
- VMEMLookup[MI] = *I;
1172
- continue ;
1173
- }
1173
+ for (auto &Succ : Pred.getSUnit ()->Succs ) {
1174
+ auto MI = Succ.getSUnit ()->getInstr ();
1175
+ if (!TII->isVMEM (*MI) || !MI->mayLoad ())
1176
+ continue ;
1174
1177
1175
- Cand = VMEMLookup[MI];
1176
- if (llvm::is_contained (Counted, Cand)) {
1177
- MissedAny = true ;
1178
- break ;
1178
+ if (MissedAny || !VMEMLookup.size ()) {
1179
+ MissedAny = true ;
1180
+ VMEMLookup[MI] = *I;
1181
+ continue ;
1182
+ }
1183
+
1184
+ if (!VMEMLookup.contains (MI)) {
1185
+ MissedAny = true ;
1186
+ VMEMLookup[MI] = *I;
1187
+ continue ;
1188
+ }
1189
+
1190
+ Cand = VMEMLookup[MI];
1191
+ if (llvm::is_contained (Counted, Cand)) {
1192
+ MissedAny = true ;
1193
+ break ;
1194
+ }
1179
1195
}
1180
1196
}
1181
- }
1182
- if (!MissedAny && Cand) {
1183
- DSWWithSharedVMEMCount += 2 ;
1184
- Counted.push_back (Cand );
1185
- Counted. push_back (*I);
1197
+ if (!MissedAny && Cand) {
1198
+ DSWWithSharedVMEMCount += 2 ;
1199
+ Counted. push_back (Cand) ;
1200
+ Counted.push_back (*I );
1201
+ }
1186
1202
}
1187
1203
}
1188
1204
@@ -1398,7 +1414,11 @@ class IGroupLPDAGMutation : public ScheduleDAGMutation {
1398
1414
// first created SchedGroup first.
1399
1415
bool IsBottomUp = 1 ;
1400
1416
1417
+ // Whether the mutation is being applied to post RA scheduling
1418
+ bool IsPostRA = false ;
1419
+
1401
1420
IGroupLPDAGMutation () = default ;
1421
+ IGroupLPDAGMutation (bool IsPostRA) : IsPostRA(IsPostRA) {}
1402
1422
};
1403
1423
1404
1424
unsigned SchedGroup::NumSchedGroups = 0 ;
@@ -1686,16 +1706,16 @@ void IGroupLPDAGMutation::initIGLPOpt(SUnit &SU) {
1686
1706
auto S = createIGLPStrategy (StrategyID, DAG, TII);
1687
1707
if (S->shouldApplyStrategy (DAG)) {
1688
1708
IsBottomUp = S->IsBottomUp ;
1689
- S->applyIGLPStrategy (SyncedInstrs, SyncedSchedGroups);
1709
+ S->applyIGLPStrategy (SyncedInstrs, SyncedSchedGroups, IsPostRA );
1690
1710
}
1691
1711
}
1692
1712
1693
1713
} // namespace
1694
1714
1695
1715
namespace llvm {
1696
1716
1697
- std::unique_ptr<ScheduleDAGMutation> createIGroupLPDAGMutation () {
1698
- return std::make_unique<IGroupLPDAGMutation>();
1717
+ std::unique_ptr<ScheduleDAGMutation> createIGroupLPDAGMutation (bool IsPostRA ) {
1718
+ return std::make_unique<IGroupLPDAGMutation>(IsPostRA );
1699
1719
}
1700
1720
1701
1721
} // end namespace llvm
0 commit comments