Skip to content

Commit 4831fa8

Browse files
AMDGPU/GlobalISel: RegBankLegalize rules for load (#112882)
Add IDs for bit width that cover multiple LLTs: B32 B64 etc. "Predicate" wrapper class for bool predicate functions used to write pretty rules. Predicates can be combined using &&, || and !. Lowering for splitting and widening loads. Write rules for loads to not change existing mir tests from old regbankselect.
1 parent 8e6d6a5 commit 4831fa8

6 files changed

+900
-65
lines changed

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp

Lines changed: 284 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,83 @@ void RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
5050
lower(MI, Mapping, WaterfallSgprs);
5151
}
5252

53+
void RegBankLegalizeHelper::splitLoad(MachineInstr &MI,
54+
ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
55+
MachineFunction &MF = B.getMF();
56+
assert(MI.getNumMemOperands() == 1);
57+
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
58+
Register Dst = MI.getOperand(0).getReg();
59+
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
60+
Register Base = MI.getOperand(1).getReg();
61+
LLT PtrTy = MRI.getType(Base);
62+
const RegisterBank *PtrRB = MRI.getRegBankOrNull(Base);
63+
LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
64+
SmallVector<Register, 4> LoadPartRegs;
65+
66+
unsigned ByteOffset = 0;
67+
for (LLT PartTy : LLTBreakdown) {
68+
Register BasePlusOffset;
69+
if (ByteOffset == 0) {
70+
BasePlusOffset = Base;
71+
} else {
72+
auto Offset = B.buildConstant({PtrRB, OffsetTy}, ByteOffset);
73+
BasePlusOffset = B.buildPtrAdd({PtrRB, PtrTy}, Base, Offset).getReg(0);
74+
}
75+
auto *OffsetMMO = MF.getMachineMemOperand(&BaseMMO, ByteOffset, PartTy);
76+
auto LoadPart = B.buildLoad({DstRB, PartTy}, BasePlusOffset, *OffsetMMO);
77+
LoadPartRegs.push_back(LoadPart.getReg(0));
78+
ByteOffset += PartTy.getSizeInBytes();
79+
}
80+
81+
if (!MergeTy.isValid()) {
82+
// Loads are of same size, concat or merge them together.
83+
B.buildMergeLikeInstr(Dst, LoadPartRegs);
84+
} else {
85+
// Loads are not all of same size, need to unmerge them to smaller pieces
86+
// of MergeTy type, then merge pieces to Dst.
87+
SmallVector<Register, 4> MergeTyParts;
88+
for (Register Reg : LoadPartRegs) {
89+
if (MRI.getType(Reg) == MergeTy) {
90+
MergeTyParts.push_back(Reg);
91+
} else {
92+
auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, Reg);
93+
for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i)
94+
MergeTyParts.push_back(Unmerge.getReg(i));
95+
}
96+
}
97+
B.buildMergeLikeInstr(Dst, MergeTyParts);
98+
}
99+
MI.eraseFromParent();
100+
}
101+
102+
void RegBankLegalizeHelper::widenLoad(MachineInstr &MI, LLT WideTy,
103+
LLT MergeTy) {
104+
MachineFunction &MF = B.getMF();
105+
assert(MI.getNumMemOperands() == 1);
106+
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
107+
Register Dst = MI.getOperand(0).getReg();
108+
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
109+
Register Base = MI.getOperand(1).getReg();
110+
111+
MachineMemOperand *WideMMO = MF.getMachineMemOperand(&BaseMMO, 0, WideTy);
112+
auto WideLoad = B.buildLoad({DstRB, WideTy}, Base, *WideMMO);
113+
114+
if (WideTy.isScalar()) {
115+
B.buildTrunc(Dst, WideLoad);
116+
} else {
117+
SmallVector<Register, 4> MergeTyParts;
118+
auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, WideLoad);
119+
120+
LLT DstTy = MRI.getType(Dst);
121+
unsigned NumElts = DstTy.getSizeInBits() / MergeTy.getSizeInBits();
122+
for (unsigned i = 0; i < NumElts; ++i) {
123+
MergeTyParts.push_back(Unmerge.getReg(i));
124+
}
125+
B.buildMergeLikeInstr(Dst, MergeTyParts);
126+
}
127+
MI.eraseFromParent();
128+
}
129+
53130
void RegBankLegalizeHelper::lower(MachineInstr &MI,
54131
const RegBankLLTMapping &Mapping,
55132
SmallSet<Register, 4> &WaterfallSgprs) {
@@ -128,6 +205,54 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
128205
MI.eraseFromParent();
129206
break;
130207
}
208+
case SplitLoad: {
209+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
210+
unsigned Size = DstTy.getSizeInBits();
211+
// Even split to 128-bit loads
212+
if (Size > 128) {
213+
LLT B128;
214+
if (DstTy.isVector()) {
215+
LLT EltTy = DstTy.getElementType();
216+
B128 = LLT::fixed_vector(128 / EltTy.getSizeInBits(), EltTy);
217+
} else {
218+
B128 = LLT::scalar(128);
219+
}
220+
if (Size / 128 == 2)
221+
splitLoad(MI, {B128, B128});
222+
else if (Size / 128 == 4)
223+
splitLoad(MI, {B128, B128, B128, B128});
224+
else {
225+
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
226+
llvm_unreachable("SplitLoad type not supported for MI");
227+
}
228+
}
229+
// 64 and 32 bit load
230+
else if (DstTy == S96)
231+
splitLoad(MI, {S64, S32}, S32);
232+
else if (DstTy == V3S32)
233+
splitLoad(MI, {V2S32, S32}, S32);
234+
else if (DstTy == V6S16)
235+
splitLoad(MI, {V4S16, V2S16}, V2S16);
236+
else {
237+
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
238+
llvm_unreachable("SplitLoad type not supported for MI");
239+
}
240+
break;
241+
}
242+
case WidenLoad: {
243+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
244+
if (DstTy == S96)
245+
widenLoad(MI, S128);
246+
else if (DstTy == V3S32)
247+
widenLoad(MI, V4S32, S32);
248+
else if (DstTy == V6S16)
249+
widenLoad(MI, V8S16, V2S16);
250+
else {
251+
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
252+
llvm_unreachable("WidenLoad type not supported for MI");
253+
}
254+
break;
255+
}
131256
}
132257

133258
// TODO: executeInWaterfallLoop(... WaterfallSgprs)
@@ -151,12 +276,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {
151276
case Sgpr64:
152277
case Vgpr64:
153278
return LLT::scalar(64);
279+
case SgprP1:
280+
case VgprP1:
281+
return LLT::pointer(1, 64);
282+
case SgprP3:
283+
case VgprP3:
284+
return LLT::pointer(3, 32);
285+
case SgprP4:
286+
case VgprP4:
287+
return LLT::pointer(4, 64);
288+
case SgprP5:
289+
case VgprP5:
290+
return LLT::pointer(5, 32);
154291
case SgprV4S32:
155292
case VgprV4S32:
156293
case UniInVgprV4S32:
157294
return LLT::fixed_vector(4, 32);
158-
case VgprP1:
159-
return LLT::pointer(1, 64);
295+
default:
296+
return LLT();
297+
}
298+
}
299+
300+
LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
301+
switch (ID) {
302+
case SgprB32:
303+
case VgprB32:
304+
case UniInVgprB32:
305+
if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
306+
Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
307+
Ty == LLT::pointer(6, 32))
308+
return Ty;
309+
return LLT();
310+
case SgprB64:
311+
case VgprB64:
312+
case UniInVgprB64:
313+
if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
314+
Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
315+
Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64))
316+
return Ty;
317+
return LLT();
318+
case SgprB96:
319+
case VgprB96:
320+
case UniInVgprB96:
321+
if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) ||
322+
Ty == LLT::fixed_vector(6, 16))
323+
return Ty;
324+
return LLT();
325+
case SgprB128:
326+
case VgprB128:
327+
case UniInVgprB128:
328+
if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
329+
Ty == LLT::fixed_vector(2, 64))
330+
return Ty;
331+
return LLT();
332+
case SgprB256:
333+
case VgprB256:
334+
case UniInVgprB256:
335+
if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) ||
336+
Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16))
337+
return Ty;
338+
return LLT();
339+
case SgprB512:
340+
case VgprB512:
341+
case UniInVgprB512:
342+
if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) ||
343+
Ty == LLT::fixed_vector(8, 64))
344+
return Ty;
345+
return LLT();
160346
default:
161347
return LLT();
162348
}
@@ -170,10 +356,26 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
170356
case Sgpr16:
171357
case Sgpr32:
172358
case Sgpr64:
359+
case SgprP1:
360+
case SgprP3:
361+
case SgprP4:
362+
case SgprP5:
173363
case SgprV4S32:
364+
case SgprB32:
365+
case SgprB64:
366+
case SgprB96:
367+
case SgprB128:
368+
case SgprB256:
369+
case SgprB512:
174370
case UniInVcc:
175371
case UniInVgprS32:
176372
case UniInVgprV4S32:
373+
case UniInVgprB32:
374+
case UniInVgprB64:
375+
case UniInVgprB96:
376+
case UniInVgprB128:
377+
case UniInVgprB256:
378+
case UniInVgprB512:
177379
case Sgpr32Trunc:
178380
case Sgpr32AExt:
179381
case Sgpr32AExtBoolInReg:
@@ -182,7 +384,16 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
182384
case Vgpr32:
183385
case Vgpr64:
184386
case VgprP1:
387+
case VgprP3:
388+
case VgprP4:
389+
case VgprP5:
185390
case VgprV4S32:
391+
case VgprB32:
392+
case VgprB64:
393+
case VgprB96:
394+
case VgprB128:
395+
case VgprB256:
396+
case VgprB512:
186397
return VgprRB;
187398
default:
188399
return nullptr;
@@ -207,16 +418,40 @@ void RegBankLegalizeHelper::applyMappingDst(
207418
case Sgpr16:
208419
case Sgpr32:
209420
case Sgpr64:
421+
case SgprP1:
422+
case SgprP3:
423+
case SgprP4:
424+
case SgprP5:
210425
case SgprV4S32:
211426
case Vgpr32:
212427
case Vgpr64:
213428
case VgprP1:
429+
case VgprP3:
430+
case VgprP4:
431+
case VgprP5:
214432
case VgprV4S32: {
215433
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
216434
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
217435
break;
218436
}
219-
// uniform in vcc/vgpr: scalars and vectors
437+
// sgpr and vgpr B-types
438+
case SgprB32:
439+
case SgprB64:
440+
case SgprB96:
441+
case SgprB128:
442+
case SgprB256:
443+
case SgprB512:
444+
case VgprB32:
445+
case VgprB64:
446+
case VgprB96:
447+
case VgprB128:
448+
case VgprB256:
449+
case VgprB512: {
450+
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
451+
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
452+
break;
453+
}
454+
// uniform in vcc/vgpr: scalars, vectors and B-types
220455
case UniInVcc: {
221456
assert(Ty == S1);
222457
assert(RB == SgprRB);
@@ -236,6 +471,19 @@ void RegBankLegalizeHelper::applyMappingDst(
236471
buildReadAnyLane(B, Reg, NewVgprDst, RBI);
237472
break;
238473
}
474+
case UniInVgprB32:
475+
case UniInVgprB64:
476+
case UniInVgprB96:
477+
case UniInVgprB128:
478+
case UniInVgprB256:
479+
case UniInVgprB512: {
480+
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
481+
assert(RB == SgprRB);
482+
Register NewVgprDst = MRI.createVirtualRegister({VgprRB, Ty});
483+
Op.setReg(NewVgprDst);
484+
AMDGPU::buildReadAnyLane(B, Reg, NewVgprDst, RBI);
485+
break;
486+
}
239487
// sgpr trunc
240488
case Sgpr32Trunc: {
241489
assert(Ty.getSizeInBits() < 32);
@@ -284,15 +532,33 @@ void RegBankLegalizeHelper::applyMappingSrc(
284532
case Sgpr16:
285533
case Sgpr32:
286534
case Sgpr64:
535+
case SgprP1:
536+
case SgprP3:
537+
case SgprP4:
538+
case SgprP5:
287539
case SgprV4S32: {
288540
assert(Ty == getTyFromID(MethodIDs[i]));
289541
assert(RB == getRegBankFromID(MethodIDs[i]));
290542
break;
291543
}
544+
// sgpr B-types
545+
case SgprB32:
546+
case SgprB64:
547+
case SgprB96:
548+
case SgprB128:
549+
case SgprB256:
550+
case SgprB512: {
551+
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
552+
assert(RB == getRegBankFromID(MethodIDs[i]));
553+
break;
554+
}
292555
// vgpr scalars, pointers and vectors
293556
case Vgpr32:
294557
case Vgpr64:
295558
case VgprP1:
559+
case VgprP3:
560+
case VgprP4:
561+
case VgprP5:
296562
case VgprV4S32: {
297563
assert(Ty == getTyFromID(MethodIDs[i]));
298564
if (RB != VgprRB) {
@@ -301,6 +567,20 @@ void RegBankLegalizeHelper::applyMappingSrc(
301567
}
302568
break;
303569
}
570+
// vgpr B-types
571+
case VgprB32:
572+
case VgprB64:
573+
case VgprB96:
574+
case VgprB128:
575+
case VgprB256:
576+
case VgprB512: {
577+
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
578+
if (RB != VgprRB) {
579+
auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
580+
Op.setReg(CopyToVgpr.getReg(0));
581+
}
582+
break;
583+
}
304584
// sgpr and vgpr scalars with extend
305585
case Sgpr32AExt: {
306586
// Note: this ext allows S1, and it is meant to be combined away.
@@ -373,7 +653,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
373653
// We accept all types that can fit in some register class.
374654
// Uniform G_PHIs have all sgpr registers.
375655
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
376-
if (Ty == LLT::scalar(32)) {
656+
if (Ty == LLT::scalar(32) || Ty == LLT::pointer(4, 64)) {
377657
return;
378658
}
379659

0 commit comments

Comments
 (0)