Skip to content

Commit 78ffc92

Browse files
AMDGPU/GlobalISel: RegBankLegalize rules for load
Add IDs for bit width that cover multiple LLTs: B32 B64 etc. "Predicate" wrapper class for bool predicate functions used to write pretty rules. Predicates can be combined using &&, || and !. Lowering for splitting and widening loads. Write rules for loads to not change existing mir tests from old regbankselect.
1 parent d7060d4 commit 78ffc92

File tree

6 files changed

+900
-65
lines changed

6 files changed

+900
-65
lines changed

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp

Lines changed: 284 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,83 @@ void RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
3838
lower(MI, Mapping, WaterfallSgprs);
3939
}
4040

41+
void RegBankLegalizeHelper::splitLoad(MachineInstr &MI,
42+
ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
43+
MachineFunction &MF = B.getMF();
44+
assert(MI.getNumMemOperands() == 1);
45+
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
46+
Register Dst = MI.getOperand(0).getReg();
47+
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
48+
Register Base = MI.getOperand(1).getReg();
49+
LLT PtrTy = MRI.getType(Base);
50+
const RegisterBank *PtrRB = MRI.getRegBankOrNull(Base);
51+
LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
52+
SmallVector<Register, 4> LoadPartRegs;
53+
54+
unsigned ByteOffset = 0;
55+
for (LLT PartTy : LLTBreakdown) {
56+
Register BasePlusOffset;
57+
if (ByteOffset == 0) {
58+
BasePlusOffset = Base;
59+
} else {
60+
auto Offset = B.buildConstant({PtrRB, OffsetTy}, ByteOffset);
61+
BasePlusOffset = B.buildPtrAdd({PtrRB, PtrTy}, Base, Offset).getReg(0);
62+
}
63+
auto *OffsetMMO = MF.getMachineMemOperand(&BaseMMO, ByteOffset, PartTy);
64+
auto LoadPart = B.buildLoad({DstRB, PartTy}, BasePlusOffset, *OffsetMMO);
65+
LoadPartRegs.push_back(LoadPart.getReg(0));
66+
ByteOffset += PartTy.getSizeInBytes();
67+
}
68+
69+
if (!MergeTy.isValid()) {
70+
// Loads are of same size, concat or merge them together.
71+
B.buildMergeLikeInstr(Dst, LoadPartRegs);
72+
} else {
73+
// Loads are not all of same size, need to unmerge them to smaller pieces
74+
// of MergeTy type, then merge pieces to Dst.
75+
SmallVector<Register, 4> MergeTyParts;
76+
for (Register Reg : LoadPartRegs) {
77+
if (MRI.getType(Reg) == MergeTy) {
78+
MergeTyParts.push_back(Reg);
79+
} else {
80+
auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, Reg);
81+
for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i)
82+
MergeTyParts.push_back(Unmerge.getReg(i));
83+
}
84+
}
85+
B.buildMergeLikeInstr(Dst, MergeTyParts);
86+
}
87+
MI.eraseFromParent();
88+
}
89+
90+
void RegBankLegalizeHelper::widenLoad(MachineInstr &MI, LLT WideTy,
91+
LLT MergeTy) {
92+
MachineFunction &MF = B.getMF();
93+
assert(MI.getNumMemOperands() == 1);
94+
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
95+
Register Dst = MI.getOperand(0).getReg();
96+
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
97+
Register Base = MI.getOperand(1).getReg();
98+
99+
MachineMemOperand *WideMMO = MF.getMachineMemOperand(&BaseMMO, 0, WideTy);
100+
auto WideLoad = B.buildLoad({DstRB, WideTy}, Base, *WideMMO);
101+
102+
if (WideTy.isScalar()) {
103+
B.buildTrunc(Dst, WideLoad);
104+
} else {
105+
SmallVector<Register, 4> MergeTyParts;
106+
auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, WideLoad);
107+
108+
LLT DstTy = MRI.getType(Dst);
109+
unsigned NumElts = DstTy.getSizeInBits() / MergeTy.getSizeInBits();
110+
for (unsigned i = 0; i < NumElts; ++i) {
111+
MergeTyParts.push_back(Unmerge.getReg(i));
112+
}
113+
B.buildMergeLikeInstr(Dst, MergeTyParts);
114+
}
115+
MI.eraseFromParent();
116+
}
117+
41118
void RegBankLegalizeHelper::lower(MachineInstr &MI,
42119
const RegBankLLTMapping &Mapping,
43120
SmallSet<Register, 4> &WaterfallSgprs) {
@@ -116,6 +193,54 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
116193
MI.eraseFromParent();
117194
break;
118195
}
196+
case SplitLoad: {
197+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
198+
unsigned Size = DstTy.getSizeInBits();
199+
// Even split to 128-bit loads
200+
if (Size > 128) {
201+
LLT B128;
202+
if (DstTy.isVector()) {
203+
LLT EltTy = DstTy.getElementType();
204+
B128 = LLT::fixed_vector(128 / EltTy.getSizeInBits(), EltTy);
205+
} else {
206+
B128 = LLT::scalar(128);
207+
}
208+
if (Size / 128 == 2)
209+
splitLoad(MI, {B128, B128});
210+
else if (Size / 128 == 4)
211+
splitLoad(MI, {B128, B128, B128, B128});
212+
else {
213+
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
214+
llvm_unreachable("SplitLoad type not supported for MI");
215+
}
216+
}
217+
// 64 and 32 bit load
218+
else if (DstTy == S96)
219+
splitLoad(MI, {S64, S32}, S32);
220+
else if (DstTy == V3S32)
221+
splitLoad(MI, {V2S32, S32}, S32);
222+
else if (DstTy == V6S16)
223+
splitLoad(MI, {V4S16, V2S16}, V2S16);
224+
else {
225+
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
226+
llvm_unreachable("SplitLoad type not supported for MI");
227+
}
228+
break;
229+
}
230+
case WidenLoad: {
231+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
232+
if (DstTy == S96)
233+
widenLoad(MI, S128);
234+
else if (DstTy == V3S32)
235+
widenLoad(MI, V4S32, S32);
236+
else if (DstTy == V6S16)
237+
widenLoad(MI, V8S16, V2S16);
238+
else {
239+
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
240+
llvm_unreachable("WidenLoad type not supported for MI");
241+
}
242+
break;
243+
}
119244
}
120245

121246
// TODO: executeInWaterfallLoop(... WaterfallSgprs)
@@ -139,12 +264,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {
139264
case Sgpr64:
140265
case Vgpr64:
141266
return LLT::scalar(64);
267+
case SgprP1:
268+
case VgprP1:
269+
return LLT::pointer(1, 64);
270+
case SgprP3:
271+
case VgprP3:
272+
return LLT::pointer(3, 32);
273+
case SgprP4:
274+
case VgprP4:
275+
return LLT::pointer(4, 64);
276+
case SgprP5:
277+
case VgprP5:
278+
return LLT::pointer(5, 32);
142279
case SgprV4S32:
143280
case VgprV4S32:
144281
case UniInVgprV4S32:
145282
return LLT::fixed_vector(4, 32);
146-
case VgprP1:
147-
return LLT::pointer(1, 64);
283+
default:
284+
return LLT();
285+
}
286+
}
287+
288+
LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
289+
switch (ID) {
290+
case SgprB32:
291+
case VgprB32:
292+
case UniInVgprB32:
293+
if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
294+
Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
295+
Ty == LLT::pointer(6, 32))
296+
return Ty;
297+
return LLT();
298+
case SgprB64:
299+
case VgprB64:
300+
case UniInVgprB64:
301+
if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
302+
Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
303+
Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64))
304+
return Ty;
305+
return LLT();
306+
case SgprB96:
307+
case VgprB96:
308+
case UniInVgprB96:
309+
if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) ||
310+
Ty == LLT::fixed_vector(6, 16))
311+
return Ty;
312+
return LLT();
313+
case SgprB128:
314+
case VgprB128:
315+
case UniInVgprB128:
316+
if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
317+
Ty == LLT::fixed_vector(2, 64))
318+
return Ty;
319+
return LLT();
320+
case SgprB256:
321+
case VgprB256:
322+
case UniInVgprB256:
323+
if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) ||
324+
Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16))
325+
return Ty;
326+
return LLT();
327+
case SgprB512:
328+
case VgprB512:
329+
case UniInVgprB512:
330+
if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) ||
331+
Ty == LLT::fixed_vector(8, 64))
332+
return Ty;
333+
return LLT();
148334
default:
149335
return LLT();
150336
}
@@ -158,10 +344,26 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
158344
case Sgpr16:
159345
case Sgpr32:
160346
case Sgpr64:
347+
case SgprP1:
348+
case SgprP3:
349+
case SgprP4:
350+
case SgprP5:
161351
case SgprV4S32:
352+
case SgprB32:
353+
case SgprB64:
354+
case SgprB96:
355+
case SgprB128:
356+
case SgprB256:
357+
case SgprB512:
162358
case UniInVcc:
163359
case UniInVgprS32:
164360
case UniInVgprV4S32:
361+
case UniInVgprB32:
362+
case UniInVgprB64:
363+
case UniInVgprB96:
364+
case UniInVgprB128:
365+
case UniInVgprB256:
366+
case UniInVgprB512:
165367
case Sgpr32Trunc:
166368
case Sgpr32AExt:
167369
case Sgpr32AExtBoolInReg:
@@ -170,7 +372,16 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
170372
case Vgpr32:
171373
case Vgpr64:
172374
case VgprP1:
375+
case VgprP3:
376+
case VgprP4:
377+
case VgprP5:
173378
case VgprV4S32:
379+
case VgprB32:
380+
case VgprB64:
381+
case VgprB96:
382+
case VgprB128:
383+
case VgprB256:
384+
case VgprB512:
174385
return VgprRB;
175386
default:
176387
return nullptr;
@@ -195,16 +406,40 @@ void RegBankLegalizeHelper::applyMappingDst(
195406
case Sgpr16:
196407
case Sgpr32:
197408
case Sgpr64:
409+
case SgprP1:
410+
case SgprP3:
411+
case SgprP4:
412+
case SgprP5:
198413
case SgprV4S32:
199414
case Vgpr32:
200415
case Vgpr64:
201416
case VgprP1:
417+
case VgprP3:
418+
case VgprP4:
419+
case VgprP5:
202420
case VgprV4S32: {
203421
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
204422
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
205423
break;
206424
}
207-
// uniform in vcc/vgpr: scalars and vectors
425+
// sgpr and vgpr B-types
426+
case SgprB32:
427+
case SgprB64:
428+
case SgprB96:
429+
case SgprB128:
430+
case SgprB256:
431+
case SgprB512:
432+
case VgprB32:
433+
case VgprB64:
434+
case VgprB96:
435+
case VgprB128:
436+
case VgprB256:
437+
case VgprB512: {
438+
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
439+
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
440+
break;
441+
}
442+
// uniform in vcc/vgpr: scalars, vectors and B-types
208443
case UniInVcc: {
209444
assert(Ty == S1);
210445
assert(RB == SgprRB);
@@ -224,6 +459,19 @@ void RegBankLegalizeHelper::applyMappingDst(
224459
buildReadAnyLane(B, Reg, NewVgprDst, RBI);
225460
break;
226461
}
462+
case UniInVgprB32:
463+
case UniInVgprB64:
464+
case UniInVgprB96:
465+
case UniInVgprB128:
466+
case UniInVgprB256:
467+
case UniInVgprB512: {
468+
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
469+
assert(RB == SgprRB);
470+
Register NewVgprDst = MRI.createVirtualRegister({VgprRB, Ty});
471+
Op.setReg(NewVgprDst);
472+
AMDGPU::buildReadAnyLane(B, Reg, NewVgprDst, RBI);
473+
break;
474+
}
227475
// sgpr trunc
228476
case Sgpr32Trunc: {
229477
assert(Ty.getSizeInBits() < 32);
@@ -272,15 +520,33 @@ void RegBankLegalizeHelper::applyMappingSrc(
272520
case Sgpr16:
273521
case Sgpr32:
274522
case Sgpr64:
523+
case SgprP1:
524+
case SgprP3:
525+
case SgprP4:
526+
case SgprP5:
275527
case SgprV4S32: {
276528
assert(Ty == getTyFromID(MethodIDs[i]));
277529
assert(RB == getRegBankFromID(MethodIDs[i]));
278530
break;
279531
}
532+
// sgpr B-types
533+
case SgprB32:
534+
case SgprB64:
535+
case SgprB96:
536+
case SgprB128:
537+
case SgprB256:
538+
case SgprB512: {
539+
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
540+
assert(RB == getRegBankFromID(MethodIDs[i]));
541+
break;
542+
}
280543
// vgpr scalars, pointers and vectors
281544
case Vgpr32:
282545
case Vgpr64:
283546
case VgprP1:
547+
case VgprP3:
548+
case VgprP4:
549+
case VgprP5:
284550
case VgprV4S32: {
285551
assert(Ty == getTyFromID(MethodIDs[i]));
286552
if (RB != VgprRB) {
@@ -289,6 +555,20 @@ void RegBankLegalizeHelper::applyMappingSrc(
289555
}
290556
break;
291557
}
558+
// vgpr B-types
559+
case VgprB32:
560+
case VgprB64:
561+
case VgprB96:
562+
case VgprB128:
563+
case VgprB256:
564+
case VgprB512: {
565+
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
566+
if (RB != VgprRB) {
567+
auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
568+
Op.setReg(CopyToVgpr.getReg(0));
569+
}
570+
break;
571+
}
292572
// sgpr and vgpr scalars with extend
293573
case Sgpr32AExt: {
294574
// Note: this ext allows S1, and it is meant to be combined away.
@@ -361,7 +641,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
361641
// We accept all types that can fit in some register class.
362642
// Uniform G_PHIs have all sgpr registers.
363643
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
364-
if (Ty == LLT::scalar(32)) {
644+
if (Ty == LLT::scalar(32) || Ty == LLT::pointer(4, 64)) {
365645
return;
366646
}
367647

0 commit comments

Comments
 (0)