Skip to content

Commit 23bb343

Browse files
AMDGPU/GlobalISel: RegBankLegalize rules for load
Add IDs for bit width that cover multiple LLTs: B32 B64 etc. "Predicate" wrapper class for bool predicate functions used to write pretty rules. Predicates can be combined using &&, || and !. Lowering for splitting and widening loads. Write rules for loads to not change existing mir tests from old regbankselect.
1 parent 5aed391 commit 23bb343

File tree

6 files changed

+900
-65
lines changed

6 files changed

+900
-65
lines changed

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp

Lines changed: 284 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,83 @@ void RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
3939
lower(MI, Mapping, WaterfallSgprs);
4040
}
4141

42+
void RegBankLegalizeHelper::splitLoad(MachineInstr &MI,
43+
ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
44+
MachineFunction &MF = B.getMF();
45+
assert(MI.getNumMemOperands() == 1);
46+
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
47+
Register Dst = MI.getOperand(0).getReg();
48+
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
49+
Register Base = MI.getOperand(1).getReg();
50+
LLT PtrTy = MRI.getType(Base);
51+
const RegisterBank *PtrRB = MRI.getRegBankOrNull(Base);
52+
LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
53+
SmallVector<Register, 4> LoadPartRegs;
54+
55+
unsigned ByteOffset = 0;
56+
for (LLT PartTy : LLTBreakdown) {
57+
Register BasePlusOffset;
58+
if (ByteOffset == 0) {
59+
BasePlusOffset = Base;
60+
} else {
61+
auto Offset = B.buildConstant({PtrRB, OffsetTy}, ByteOffset);
62+
BasePlusOffset = B.buildPtrAdd({PtrRB, PtrTy}, Base, Offset).getReg(0);
63+
}
64+
auto *OffsetMMO = MF.getMachineMemOperand(&BaseMMO, ByteOffset, PartTy);
65+
auto LoadPart = B.buildLoad({DstRB, PartTy}, BasePlusOffset, *OffsetMMO);
66+
LoadPartRegs.push_back(LoadPart.getReg(0));
67+
ByteOffset += PartTy.getSizeInBytes();
68+
}
69+
70+
if (!MergeTy.isValid()) {
71+
// Loads are of same size, concat or merge them together.
72+
B.buildMergeLikeInstr(Dst, LoadPartRegs);
73+
} else {
74+
// Loads are not all of same size, need to unmerge them to smaller pieces
75+
// of MergeTy type, then merge pieces to Dst.
76+
SmallVector<Register, 4> MergeTyParts;
77+
for (Register Reg : LoadPartRegs) {
78+
if (MRI.getType(Reg) == MergeTy) {
79+
MergeTyParts.push_back(Reg);
80+
} else {
81+
auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, Reg);
82+
for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i)
83+
MergeTyParts.push_back(Unmerge.getReg(i));
84+
}
85+
}
86+
B.buildMergeLikeInstr(Dst, MergeTyParts);
87+
}
88+
MI.eraseFromParent();
89+
}
90+
91+
void RegBankLegalizeHelper::widenLoad(MachineInstr &MI, LLT WideTy,
92+
LLT MergeTy) {
93+
MachineFunction &MF = B.getMF();
94+
assert(MI.getNumMemOperands() == 1);
95+
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
96+
Register Dst = MI.getOperand(0).getReg();
97+
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
98+
Register Base = MI.getOperand(1).getReg();
99+
100+
MachineMemOperand *WideMMO = MF.getMachineMemOperand(&BaseMMO, 0, WideTy);
101+
auto WideLoad = B.buildLoad({DstRB, WideTy}, Base, *WideMMO);
102+
103+
if (WideTy.isScalar()) {
104+
B.buildTrunc(Dst, WideLoad);
105+
} else {
106+
SmallVector<Register, 4> MergeTyParts;
107+
auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, WideLoad);
108+
109+
LLT DstTy = MRI.getType(Dst);
110+
unsigned NumElts = DstTy.getSizeInBits() / MergeTy.getSizeInBits();
111+
for (unsigned i = 0; i < NumElts; ++i) {
112+
MergeTyParts.push_back(Unmerge.getReg(i));
113+
}
114+
B.buildMergeLikeInstr(Dst, MergeTyParts);
115+
}
116+
MI.eraseFromParent();
117+
}
118+
42119
void RegBankLegalizeHelper::lower(MachineInstr &MI,
43120
const RegBankLLTMapping &Mapping,
44121
SmallSet<Register, 4> &WaterfallSgprs) {
@@ -117,6 +194,54 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
117194
MI.eraseFromParent();
118195
break;
119196
}
197+
case SplitLoad: {
198+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
199+
unsigned Size = DstTy.getSizeInBits();
200+
// Even split to 128-bit loads
201+
if (Size > 128) {
202+
LLT B128;
203+
if (DstTy.isVector()) {
204+
LLT EltTy = DstTy.getElementType();
205+
B128 = LLT::fixed_vector(128 / EltTy.getSizeInBits(), EltTy);
206+
} else {
207+
B128 = LLT::scalar(128);
208+
}
209+
if (Size / 128 == 2)
210+
splitLoad(MI, {B128, B128});
211+
else if (Size / 128 == 4)
212+
splitLoad(MI, {B128, B128, B128, B128});
213+
else {
214+
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
215+
llvm_unreachable("SplitLoad type not supported for MI");
216+
}
217+
}
218+
// 64 and 32 bit load
219+
else if (DstTy == S96)
220+
splitLoad(MI, {S64, S32}, S32);
221+
else if (DstTy == V3S32)
222+
splitLoad(MI, {V2S32, S32}, S32);
223+
else if (DstTy == V6S16)
224+
splitLoad(MI, {V4S16, V2S16}, V2S16);
225+
else {
226+
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
227+
llvm_unreachable("SplitLoad type not supported for MI");
228+
}
229+
break;
230+
}
231+
case WidenLoad: {
232+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
233+
if (DstTy == S96)
234+
widenLoad(MI, S128);
235+
else if (DstTy == V3S32)
236+
widenLoad(MI, V4S32, S32);
237+
else if (DstTy == V6S16)
238+
widenLoad(MI, V8S16, V2S16);
239+
else {
240+
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
241+
llvm_unreachable("WidenLoad type not supported for MI");
242+
}
243+
break;
244+
}
120245
}
121246

122247
// TODO: executeInWaterfallLoop(... WaterfallSgprs)
@@ -140,12 +265,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {
140265
case Sgpr64:
141266
case Vgpr64:
142267
return LLT::scalar(64);
268+
case SgprP1:
269+
case VgprP1:
270+
return LLT::pointer(1, 64);
271+
case SgprP3:
272+
case VgprP3:
273+
return LLT::pointer(3, 32);
274+
case SgprP4:
275+
case VgprP4:
276+
return LLT::pointer(4, 64);
277+
case SgprP5:
278+
case VgprP5:
279+
return LLT::pointer(5, 32);
143280
case SgprV4S32:
144281
case VgprV4S32:
145282
case UniInVgprV4S32:
146283
return LLT::fixed_vector(4, 32);
147-
case VgprP1:
148-
return LLT::pointer(1, 64);
284+
default:
285+
return LLT();
286+
}
287+
}
288+
289+
LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
290+
switch (ID) {
291+
case SgprB32:
292+
case VgprB32:
293+
case UniInVgprB32:
294+
if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
295+
Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
296+
Ty == LLT::pointer(6, 32))
297+
return Ty;
298+
return LLT();
299+
case SgprB64:
300+
case VgprB64:
301+
case UniInVgprB64:
302+
if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
303+
Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
304+
Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64))
305+
return Ty;
306+
return LLT();
307+
case SgprB96:
308+
case VgprB96:
309+
case UniInVgprB96:
310+
if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) ||
311+
Ty == LLT::fixed_vector(6, 16))
312+
return Ty;
313+
return LLT();
314+
case SgprB128:
315+
case VgprB128:
316+
case UniInVgprB128:
317+
if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
318+
Ty == LLT::fixed_vector(2, 64))
319+
return Ty;
320+
return LLT();
321+
case SgprB256:
322+
case VgprB256:
323+
case UniInVgprB256:
324+
if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) ||
325+
Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16))
326+
return Ty;
327+
return LLT();
328+
case SgprB512:
329+
case VgprB512:
330+
case UniInVgprB512:
331+
if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) ||
332+
Ty == LLT::fixed_vector(8, 64))
333+
return Ty;
334+
return LLT();
149335
default:
150336
return LLT();
151337
}
@@ -159,10 +345,26 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
159345
case Sgpr16:
160346
case Sgpr32:
161347
case Sgpr64:
348+
case SgprP1:
349+
case SgprP3:
350+
case SgprP4:
351+
case SgprP5:
162352
case SgprV4S32:
353+
case SgprB32:
354+
case SgprB64:
355+
case SgprB96:
356+
case SgprB128:
357+
case SgprB256:
358+
case SgprB512:
163359
case UniInVcc:
164360
case UniInVgprS32:
165361
case UniInVgprV4S32:
362+
case UniInVgprB32:
363+
case UniInVgprB64:
364+
case UniInVgprB96:
365+
case UniInVgprB128:
366+
case UniInVgprB256:
367+
case UniInVgprB512:
166368
case Sgpr32Trunc:
167369
case Sgpr32AExt:
168370
case Sgpr32AExtBoolInReg:
@@ -171,7 +373,16 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
171373
case Vgpr32:
172374
case Vgpr64:
173375
case VgprP1:
376+
case VgprP3:
377+
case VgprP4:
378+
case VgprP5:
174379
case VgprV4S32:
380+
case VgprB32:
381+
case VgprB64:
382+
case VgprB96:
383+
case VgprB128:
384+
case VgprB256:
385+
case VgprB512:
175386
return VgprRB;
176387
default:
177388
return nullptr;
@@ -196,16 +407,40 @@ void RegBankLegalizeHelper::applyMappingDst(
196407
case Sgpr16:
197408
case Sgpr32:
198409
case Sgpr64:
410+
case SgprP1:
411+
case SgprP3:
412+
case SgprP4:
413+
case SgprP5:
199414
case SgprV4S32:
200415
case Vgpr32:
201416
case Vgpr64:
202417
case VgprP1:
418+
case VgprP3:
419+
case VgprP4:
420+
case VgprP5:
203421
case VgprV4S32: {
204422
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
205423
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
206424
break;
207425
}
208-
// uniform in vcc/vgpr: scalars and vectors
426+
// sgpr and vgpr B-types
427+
case SgprB32:
428+
case SgprB64:
429+
case SgprB96:
430+
case SgprB128:
431+
case SgprB256:
432+
case SgprB512:
433+
case VgprB32:
434+
case VgprB64:
435+
case VgprB96:
436+
case VgprB128:
437+
case VgprB256:
438+
case VgprB512: {
439+
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
440+
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
441+
break;
442+
}
443+
// uniform in vcc/vgpr: scalars, vectors and B-types
209444
case UniInVcc: {
210445
assert(Ty == S1);
211446
assert(RB == SgprRB);
@@ -225,6 +460,19 @@ void RegBankLegalizeHelper::applyMappingDst(
225460
buildReadAnyLane(B, Reg, NewVgprDst, RBI);
226461
break;
227462
}
463+
case UniInVgprB32:
464+
case UniInVgprB64:
465+
case UniInVgprB96:
466+
case UniInVgprB128:
467+
case UniInVgprB256:
468+
case UniInVgprB512: {
469+
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
470+
assert(RB == SgprRB);
471+
Register NewVgprDst = MRI.createVirtualRegister({VgprRB, Ty});
472+
Op.setReg(NewVgprDst);
473+
AMDGPU::buildReadAnyLane(B, Reg, NewVgprDst, RBI);
474+
break;
475+
}
228476
// sgpr trunc
229477
case Sgpr32Trunc: {
230478
assert(Ty.getSizeInBits() < 32);
@@ -273,15 +521,33 @@ void RegBankLegalizeHelper::applyMappingSrc(
273521
case Sgpr16:
274522
case Sgpr32:
275523
case Sgpr64:
524+
case SgprP1:
525+
case SgprP3:
526+
case SgprP4:
527+
case SgprP5:
276528
case SgprV4S32: {
277529
assert(Ty == getTyFromID(MethodIDs[i]));
278530
assert(RB == getRegBankFromID(MethodIDs[i]));
279531
break;
280532
}
533+
// sgpr B-types
534+
case SgprB32:
535+
case SgprB64:
536+
case SgprB96:
537+
case SgprB128:
538+
case SgprB256:
539+
case SgprB512: {
540+
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
541+
assert(RB == getRegBankFromID(MethodIDs[i]));
542+
break;
543+
}
281544
// vgpr scalars, pointers and vectors
282545
case Vgpr32:
283546
case Vgpr64:
284547
case VgprP1:
548+
case VgprP3:
549+
case VgprP4:
550+
case VgprP5:
285551
case VgprV4S32: {
286552
assert(Ty == getTyFromID(MethodIDs[i]));
287553
if (RB != VgprRB) {
@@ -290,6 +556,20 @@ void RegBankLegalizeHelper::applyMappingSrc(
290556
}
291557
break;
292558
}
559+
// vgpr B-types
560+
case VgprB32:
561+
case VgprB64:
562+
case VgprB96:
563+
case VgprB128:
564+
case VgprB256:
565+
case VgprB512: {
566+
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
567+
if (RB != VgprRB) {
568+
auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
569+
Op.setReg(CopyToVgpr.getReg(0));
570+
}
571+
break;
572+
}
293573
// sgpr and vgpr scalars with extend
294574
case Sgpr32AExt: {
295575
// Note: this ext allows S1, and it is meant to be combined away.
@@ -362,7 +642,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
362642
// We accept all types that can fit in some register class.
363643
// Uniform G_PHIs have all sgpr registers.
364644
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
365-
if (Ty == LLT::scalar(32)) {
645+
if (Ty == LLT::scalar(32) || Ty == LLT::pointer(4, 64)) {
366646
return;
367647
}
368648

0 commit comments

Comments
 (0)