Skip to content

Commit cb28f82

Browse files
AMDGPU/GlobalISel: add RegBankLegalize rules for bit shifts and sext-inreg
Uniform S16 shifts have to be extended to S32 using appropriate Extend before lowering to S32 instruction. Uniform packed V2S16 are lowered to SGPR S32 instructions, other option is to use VALU packed V2S16 and ReadAnyLane. For uniform S32 and S64 and divergent S16, S32, S64 and V2S16 there are instructions available.
1 parent 2c7c01d commit cb28f82

13 files changed

+372
-165
lines changed

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ bool AMDGPURegBankLegalize::runOnMachineFunction(MachineFunction &MF) {
310310
// Opcodes that support pretty much all combinations of reg banks and LLTs
311311
// (except S1). There is no point in writing rules for them.
312312
if (Opc == AMDGPU::G_BUILD_VECTOR || Opc == AMDGPU::G_UNMERGE_VALUES ||
313-
Opc == AMDGPU::G_MERGE_VALUES) {
313+
Opc == AMDGPU::G_MERGE_VALUES || Opc == G_BITCAST) {
314314
RBLHelper.applyMappingTrivial(*MI);
315315
continue;
316316
}

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,59 @@ void RegBankLegalizeHelper::lowerVccExtToSel(MachineInstr &MI) {
171171
MI.eraseFromParent();
172172
}
173173

174+
std::pair<Register, Register> RegBankLegalizeHelper::unpackZExt(Register Reg) {
175+
auto PackedS32 = B.buildBitcast(SgprRB_S32, Reg);
176+
auto Mask = B.buildConstant(SgprRB_S32, 0x0000ffff);
177+
auto Lo = B.buildAnd(SgprRB_S32, PackedS32, Mask);
178+
auto Hi = B.buildLShr(SgprRB_S32, PackedS32, B.buildConstant(SgprRB_S32, 16));
179+
return {Lo.getReg(0), Hi.getReg(0)};
180+
}
181+
182+
std::pair<Register, Register> RegBankLegalizeHelper::unpackSExt(Register Reg) {
183+
auto PackedS32 = B.buildBitcast(SgprRB_S32, Reg);
184+
auto Lo = B.buildSExtInReg(SgprRB_S32, PackedS32, 16);
185+
auto Hi = B.buildAShr(SgprRB_S32, PackedS32, B.buildConstant(SgprRB_S32, 16));
186+
return {Lo.getReg(0), Hi.getReg(0)};
187+
}
188+
189+
std::pair<Register, Register> RegBankLegalizeHelper::unpackAExt(Register Reg) {
190+
auto PackedS32 = B.buildBitcast(SgprRB_S32, Reg);
191+
auto Lo = PackedS32;
192+
auto Hi = B.buildLShr(SgprRB_S32, PackedS32, B.buildConstant(SgprRB_S32, 16));
193+
return {Lo.getReg(0), Hi.getReg(0)};
194+
}
195+
196+
void RegBankLegalizeHelper::lowerUnpack(MachineInstr &MI) {
197+
Register Lo, Hi;
198+
switch (MI.getOpcode()) {
199+
case AMDGPU::G_SHL: {
200+
auto [Val0, Val1] = unpackAExt(MI.getOperand(1).getReg());
201+
auto [Amt0, Amt1] = unpackAExt(MI.getOperand(2).getReg());
202+
Lo = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Val0, Amt0}).getReg(0);
203+
Hi = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Val1, Amt1}).getReg(0);
204+
break;
205+
}
206+
case AMDGPU::G_LSHR: {
207+
auto [Val0, Val1] = unpackZExt(MI.getOperand(1).getReg());
208+
auto [Amt0, Amt1] = unpackZExt(MI.getOperand(2).getReg());
209+
Lo = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Val0, Amt0}).getReg(0);
210+
Hi = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Val1, Amt1}).getReg(0);
211+
break;
212+
}
213+
case AMDGPU::G_ASHR: {
214+
auto [Val0, Val1] = unpackSExt(MI.getOperand(1).getReg());
215+
auto [Amt0, Amt1] = unpackSExt(MI.getOperand(2).getReg());
216+
Lo = B.buildAShr(SgprRB_S32, Val0, Amt0).getReg(0);
217+
Hi = B.buildAShr(SgprRB_S32, Val1, Amt1).getReg(0);
218+
break;
219+
}
220+
default:
221+
llvm_unreachable("Unpack lowering not implemented");
222+
}
223+
B.buildBuildVectorTrunc(MI.getOperand(0).getReg(), {Lo, Hi});
224+
MI.eraseFromParent();
225+
}
226+
174227
static bool isSignedBFE(MachineInstr &MI) {
175228
if (GIntrinsic *GI = dyn_cast<GIntrinsic>(&MI))
176229
return (GI->is(Intrinsic::amdgcn_sbfe));
@@ -306,6 +359,33 @@ void RegBankLegalizeHelper::lowerSplitTo32Sel(MachineInstr &MI) {
306359
MI.eraseFromParent();
307360
}
308361

362+
void RegBankLegalizeHelper::lowerSplitTo32SExtInReg(MachineInstr &MI) {
363+
auto Op1 = B.buildUnmerge(VgprRB_S32, MI.getOperand(1).getReg());
364+
int Amt = MI.getOperand(2).getImm();
365+
Register Lo, Hi;
366+
// Hi|Lo: s sign bit, ?/x bits changed/not changed by sign-extend
367+
if (Amt <= 32) {
368+
auto Freeze = B.buildFreeze(VgprRB_S32, Op1.getReg(0));
369+
if (Amt == 32) {
370+
// Hi|Lo: ????????|sxxxxxxx -> ssssssss|sxxxxxxx
371+
Lo = Freeze.getReg(0);
372+
} else {
373+
// Hi|Lo: ????????|???sxxxx -> ssssssss|ssssxxxx
374+
Lo = B.buildSExtInReg(VgprRB_S32, Freeze, Amt).getReg(0);
375+
}
376+
377+
auto SignExtCst = B.buildConstant(SgprRB_S32, 31);
378+
Hi = B.buildAShr(VgprRB_S32, Lo, SignExtCst).getReg(0);
379+
} else {
380+
// Hi|Lo: ?????sxx|xxxxxxxx -> ssssssxx|xxxxxxxx
381+
Lo = Op1.getReg(0);
382+
Hi = B.buildSExtInReg(VgprRB_S32, Op1.getReg(1), Amt - 32).getReg(0);
383+
}
384+
385+
B.buildMergeLikeInstr(MI.getOperand(0).getReg(), {Lo, Hi});
386+
MI.eraseFromParent();
387+
}
388+
309389
void RegBankLegalizeHelper::lower(MachineInstr &MI,
310390
const RegBankLLTMapping &Mapping,
311391
SmallSet<Register, 4> &WaterfallSgprs) {
@@ -328,6 +408,8 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
328408
MI.eraseFromParent();
329409
return;
330410
}
411+
case Unpack:
412+
return lowerUnpack(MI);
331413
case Ext32To64: {
332414
const RegisterBank *RB = MRI.getRegBank(MI.getOperand(0).getReg());
333415
MachineInstrBuilder Hi;
@@ -394,6 +476,8 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
394476
return lowerSplitTo32(MI);
395477
case SplitTo32Sel:
396478
return lowerSplitTo32Sel(MI);
479+
case SplitTo32SExtInReg:
480+
return lowerSplitTo32SExtInReg(MI);
397481
case SplitLoad: {
398482
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
399483
unsigned Size = DstTy.getSizeInBits();
@@ -483,6 +567,13 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {
483567
case SgprP5:
484568
case VgprP5:
485569
return LLT::pointer(5, 32);
570+
case SgprV2S16:
571+
case VgprV2S16:
572+
case UniInVgprV2S16:
573+
return LLT::fixed_vector(2, 16);
574+
case SgprV2S32:
575+
case VgprV2S32:
576+
return LLT::fixed_vector(2, 32);
486577
case SgprV4S32:
487578
case VgprV4S32:
488579
case UniInVgprV4S32:
@@ -556,6 +647,8 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
556647
case SgprP3:
557648
case SgprP4:
558649
case SgprP5:
650+
case SgprV2S16:
651+
case SgprV2S32:
559652
case SgprV4S32:
560653
case SgprB32:
561654
case SgprB64:
@@ -565,6 +658,7 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
565658
case SgprB512:
566659
case UniInVcc:
567660
case UniInVgprS32:
661+
case UniInVgprV2S16:
568662
case UniInVgprV4S32:
569663
case UniInVgprB32:
570664
case UniInVgprB64:
@@ -586,6 +680,8 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
586680
case VgprP3:
587681
case VgprP4:
588682
case VgprP5:
683+
case VgprV2S16:
684+
case VgprV2S32:
589685
case VgprV4S32:
590686
case VgprB32:
591687
case VgprB64:
@@ -623,6 +719,8 @@ void RegBankLegalizeHelper::applyMappingDst(
623719
case SgprP3:
624720
case SgprP4:
625721
case SgprP5:
722+
case SgprV2S16:
723+
case SgprV2S32:
626724
case SgprV4S32:
627725
case Vgpr16:
628726
case Vgpr32:
@@ -632,6 +730,8 @@ void RegBankLegalizeHelper::applyMappingDst(
632730
case VgprP3:
633731
case VgprP4:
634732
case VgprP5:
733+
case VgprV2S16:
734+
case VgprV2S32:
635735
case VgprV4S32: {
636736
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
637737
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
@@ -666,6 +766,7 @@ void RegBankLegalizeHelper::applyMappingDst(
666766
break;
667767
}
668768
case UniInVgprS32:
769+
case UniInVgprV2S16:
669770
case UniInVgprV4S32: {
670771
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
671772
assert(RB == SgprRB);
@@ -739,6 +840,8 @@ void RegBankLegalizeHelper::applyMappingSrc(
739840
case SgprP3:
740841
case SgprP4:
741842
case SgprP5:
843+
case SgprV2S16:
844+
case SgprV2S32:
742845
case SgprV4S32: {
743846
assert(Ty == getTyFromID(MethodIDs[i]));
744847
assert(RB == getRegBankFromID(MethodIDs[i]));
@@ -764,6 +867,8 @@ void RegBankLegalizeHelper::applyMappingSrc(
764867
case VgprP3:
765868
case VgprP4:
766869
case VgprP5:
870+
case VgprV2S16:
871+
case VgprV2S32:
767872
case VgprV4S32: {
768873
assert(Ty == getTyFromID(MethodIDs[i]));
769874
if (RB != VgprRB) {

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,15 @@ class RegBankLegalizeHelper {
111111
SmallSet<Register, 4> &SgprWaterfallOperandRegs);
112112

113113
void lowerVccExtToSel(MachineInstr &MI);
114+
std::pair<Register, Register> unpackZExt(Register Reg);
115+
std::pair<Register, Register> unpackSExt(Register Reg);
116+
std::pair<Register, Register> unpackAExt(Register Reg);
117+
void lowerUnpack(MachineInstr &MI);
114118
void lowerV_BFE(MachineInstr &MI);
115119
void lowerS_BFE(MachineInstr &MI);
116120
void lowerSplitTo32(MachineInstr &MI);
117121
void lowerSplitTo32Sel(MachineInstr &MI);
122+
void lowerSplitTo32SExtInReg(MachineInstr &MI);
118123
};
119124

120125
} // end namespace AMDGPU

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
6060
return MRI.getType(Reg) == LLT::pointer(4, 64);
6161
case P5:
6262
return MRI.getType(Reg) == LLT::pointer(5, 32);
63+
case V2S32:
64+
return MRI.getType(Reg) == LLT::fixed_vector(2, 32);
6365
case V4S32:
6466
return MRI.getType(Reg) == LLT::fixed_vector(4, 32);
6567
case B32:
@@ -92,6 +94,8 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
9294
return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isUniform(Reg);
9395
case UniP5:
9496
return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isUniform(Reg);
97+
case UniV2S16:
98+
return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isUniform(Reg);
9599
case UniB32:
96100
return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isUniform(Reg);
97101
case UniB64:
@@ -122,6 +126,8 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
122126
return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isDivergent(Reg);
123127
case DivP5:
124128
return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isDivergent(Reg);
129+
case DivV2S16:
130+
return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isDivergent(Reg);
125131
case DivB32:
126132
return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isDivergent(Reg);
127133
case DivB64:
@@ -435,7 +441,7 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,
435441
MachineRegisterInfo &_MRI)
436442
: ST(&_ST), MRI(&_MRI) {
437443

438-
addRulesForGOpcs({G_ADD}, Standard)
444+
addRulesForGOpcs({G_ADD, G_SUB}, Standard)
439445
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
440446
.Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}});
441447

@@ -452,11 +458,36 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,
452458
.Div(B64, {{VgprB64}, {VgprB64, VgprB64}, SplitTo32});
453459

454460
addRulesForGOpcs({G_SHL}, Standard)
461+
.Uni(S16, {{Sgpr32Trunc}, {Sgpr32AExt, Sgpr32ZExt}})
462+
.Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})
463+
.Uni(V2S16, {{SgprV2S16}, {SgprV2S16, SgprV2S16}, Unpack})
464+
.Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}})
465+
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
466+
.Uni(S64, {{Sgpr64}, {Sgpr64, Sgpr32}})
455467
.Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}})
468+
.Div(S64, {{Vgpr64}, {Vgpr64, Vgpr32}});
469+
470+
addRulesForGOpcs({G_LSHR}, Standard)
471+
.Uni(S16, {{Sgpr32Trunc}, {Sgpr32ZExt, Sgpr32ZExt}})
472+
.Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})
473+
.Uni(V2S16, {{SgprV2S16}, {SgprV2S16, SgprV2S16}, Unpack})
474+
.Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}})
475+
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
456476
.Uni(S64, {{Sgpr64}, {Sgpr64, Sgpr32}})
477+
.Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}})
457478
.Div(S64, {{Vgpr64}, {Vgpr64, Vgpr32}});
458479

459-
addRulesForGOpcs({G_LSHR}, Standard).Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}});
480+
addRulesForGOpcs({G_ASHR}, Standard)
481+
.Uni(S16, {{Sgpr32Trunc}, {Sgpr32SExt, Sgpr32ZExt}})
482+
.Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})
483+
.Uni(V2S16, {{SgprV2S16}, {SgprV2S16, SgprV2S16}, Unpack})
484+
.Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}})
485+
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
486+
.Uni(S64, {{Sgpr64}, {Sgpr64, Sgpr32}})
487+
.Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}})
488+
.Div(S64, {{Vgpr64}, {Vgpr64, Vgpr32}});
489+
490+
addRulesForGOpcs({G_FRAME_INDEX}).Any({{UniP5, _}, {{SgprP5}, {None}}});
460491

461492
addRulesForGOpcs({G_UBFX, G_SBFX}, Standard)
462493
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32, Sgpr32}, S_BFE})
@@ -515,6 +546,8 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,
515546
.Any({{DivS16, S32}, {{Vgpr16}, {Vgpr32}}})
516547
.Any({{UniS32, S64}, {{Sgpr32}, {Sgpr64}}})
517548
.Any({{DivS32, S64}, {{Vgpr32}, {Vgpr64}}})
549+
.Any({{UniV2S16, V2S32}, {{SgprV2S16}, {SgprV2S32}}})
550+
.Any({{DivV2S16, V2S32}, {{VgprV2S16}, {VgprV2S32}}})
518551
// This is non-trivial. VgprToVccCopy is done using compare instruction.
519552
.Any({{DivS1, DivS16}, {{Vcc}, {Vgpr16}, VgprToVccCopy}})
520553
.Any({{DivS1, DivS32}, {{Vcc}, {Vgpr32}, VgprToVccCopy}})
@@ -550,6 +583,12 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,
550583
.Any({{UniS32, S16}, {{Sgpr32}, {Sgpr16}}})
551584
.Any({{DivS32, S16}, {{Vgpr32}, {Vgpr16}}});
552585

586+
addRulesForGOpcs({G_SEXT_INREG})
587+
.Any({{UniS32, S32}, {{Sgpr32}, {Sgpr32}}})
588+
.Any({{DivS32, S32}, {{Vgpr32}, {Vgpr32}}})
589+
.Any({{UniS64, S64}, {{Sgpr64}, {Sgpr64}}})
590+
.Any({{DivS64, S64}, {{Vgpr64}, {Vgpr64}, SplitTo32SExtInReg}});
591+
553592
bool hasUnalignedLoads = ST->getGeneration() >= AMDGPUSubtarget::GFX12;
554593
bool hasSMRDSmall = ST->hasScalarSubwordLoads();
555594

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ enum UniformityLLTOpPredicateID {
7575
V3S32,
7676
V4S32,
7777

78+
UniV2S16,
79+
80+
DivV2S16,
81+
7882
// B types
7983
B32,
8084
B64,
@@ -117,7 +121,9 @@ enum RegBankLLTMappingApplyID {
117121
SgprP3,
118122
SgprP4,
119123
SgprP5,
124+
SgprV2S16,
120125
SgprV4S32,
126+
SgprV2S32,
121127
SgprB32,
122128
SgprB64,
123129
SgprB96,
@@ -134,6 +140,8 @@ enum RegBankLLTMappingApplyID {
134140
VgprP3,
135141
VgprP4,
136142
VgprP5,
143+
VgprV2S16,
144+
VgprV2S32,
137145
VgprB32,
138146
VgprB64,
139147
VgprB96,
@@ -145,6 +153,7 @@ enum RegBankLLTMappingApplyID {
145153
// Dst only modifiers: read-any-lane and truncs
146154
UniInVcc,
147155
UniInVgprS32,
156+
UniInVgprV2S16,
148157
UniInVgprV4S32,
149158
UniInVgprB32,
150159
UniInVgprB64,
@@ -173,11 +182,13 @@ enum LoweringMethodID {
173182
DoNotLower,
174183
VccExtToSel,
175184
UniExtToSel,
185+
Unpack,
176186
S_BFE,
177187
V_BFE,
178188
VgprToVccCopy,
179189
SplitTo32,
180190
SplitTo32Sel,
191+
SplitTo32SExtInReg,
181192
Ext32To64,
182193
UniCstExt,
183194
SplitLoad,

0 commit comments

Comments
 (0)