Skip to content

Commit 0bc8320

Browse files
AMDGPU/GlobalISel: add RegBankLegalize rules for bit shifts and sext-inreg
Uniform S16 shifts have to be extended to S32 using appropriate Extend before lowering to S32 instruction. Uniform packed V2S16 are lowered to SGPR S32 instructions, other option is to use VALU packed V2S16 and ReadAnyLane. For uniform S32 and S64 and divergent S16, S32, S64 and V2S16 there are instructions available.
1 parent a2bb5ea commit 0bc8320

13 files changed

+375
-165
lines changed

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ bool AMDGPURegBankLegalize::runOnMachineFunction(MachineFunction &MF) {
310310
// Opcodes that support pretty much all combinations of reg banks and LLTs
311311
// (except S1). There is no point in writing rules for them.
312312
if (Opc == AMDGPU::G_BUILD_VECTOR || Opc == AMDGPU::G_UNMERGE_VALUES ||
313-
Opc == AMDGPU::G_MERGE_VALUES) {
313+
Opc == AMDGPU::G_MERGE_VALUES || Opc == AMDGPU::G_BITCAST) {
314314
RBLHelper.applyMappingTrivial(*MI);
315315
continue;
316316
}

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,62 @@ void RegBankLegalizeHelper::lowerVccExtToSel(MachineInstr &MI) {
171171
MI.eraseFromParent();
172172
}
173173

174+
const std::pair<Register, Register>
175+
RegBankLegalizeHelper::unpackZExt(Register Reg) {
176+
auto PackedS32 = B.buildBitcast(SgprRB_S32, Reg);
177+
auto Mask = B.buildConstant(SgprRB_S32, 0x0000ffff);
178+
auto Lo = B.buildAnd(SgprRB_S32, PackedS32, Mask);
179+
auto Hi = B.buildLShr(SgprRB_S32, PackedS32, B.buildConstant(SgprRB_S32, 16));
180+
return {Lo.getReg(0), Hi.getReg(0)};
181+
}
182+
183+
const std::pair<Register, Register>
184+
RegBankLegalizeHelper::unpackSExt(Register Reg) {
185+
auto PackedS32 = B.buildBitcast(SgprRB_S32, Reg);
186+
auto Lo = B.buildSExtInReg(SgprRB_S32, PackedS32, 16);
187+
auto Hi = B.buildAShr(SgprRB_S32, PackedS32, B.buildConstant(SgprRB_S32, 16));
188+
return {Lo.getReg(0), Hi.getReg(0)};
189+
}
190+
191+
const std::pair<Register, Register>
192+
RegBankLegalizeHelper::unpackAExt(Register Reg) {
193+
auto PackedS32 = B.buildBitcast(SgprRB_S32, Reg);
194+
auto Lo = PackedS32;
195+
auto Hi = B.buildLShr(SgprRB_S32, PackedS32, B.buildConstant(SgprRB_S32, 16));
196+
return {Lo.getReg(0), Hi.getReg(0)};
197+
}
198+
199+
void RegBankLegalizeHelper::lowerUnpackBitShift(MachineInstr &MI) {
200+
Register Lo, Hi;
201+
switch (MI.getOpcode()) {
202+
case AMDGPU::G_SHL: {
203+
auto [Val0, Val1] = unpackAExt(MI.getOperand(1).getReg());
204+
auto [Amt0, Amt1] = unpackAExt(MI.getOperand(2).getReg());
205+
Lo = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Val0, Amt0}).getReg(0);
206+
Hi = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Val1, Amt1}).getReg(0);
207+
break;
208+
}
209+
case AMDGPU::G_LSHR: {
210+
auto [Val0, Val1] = unpackZExt(MI.getOperand(1).getReg());
211+
auto [Amt0, Amt1] = unpackZExt(MI.getOperand(2).getReg());
212+
Lo = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Val0, Amt0}).getReg(0);
213+
Hi = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Val1, Amt1}).getReg(0);
214+
break;
215+
}
216+
case AMDGPU::G_ASHR: {
217+
auto [Val0, Val1] = unpackSExt(MI.getOperand(1).getReg());
218+
auto [Amt0, Amt1] = unpackSExt(MI.getOperand(2).getReg());
219+
Lo = B.buildAShr(SgprRB_S32, Val0, Amt0).getReg(0);
220+
Hi = B.buildAShr(SgprRB_S32, Val1, Amt1).getReg(0);
221+
break;
222+
}
223+
default:
224+
llvm_unreachable("Unpack lowering not implemented");
225+
}
226+
B.buildBuildVectorTrunc(MI.getOperand(0).getReg(), {Lo, Hi});
227+
MI.eraseFromParent();
228+
}
229+
174230
static bool isSignedBFE(MachineInstr &MI) {
175231
if (GIntrinsic *GI = dyn_cast<GIntrinsic>(&MI))
176232
return (GI->is(Intrinsic::amdgcn_sbfe));
@@ -306,6 +362,33 @@ void RegBankLegalizeHelper::lowerSplitTo32Select(MachineInstr &MI) {
306362
MI.eraseFromParent();
307363
}
308364

365+
void RegBankLegalizeHelper::lowerSplitTo32SExtInReg(MachineInstr &MI) {
366+
auto Op1 = B.buildUnmerge(VgprRB_S32, MI.getOperand(1).getReg());
367+
int Amt = MI.getOperand(2).getImm();
368+
Register Lo, Hi;
369+
// Hi|Lo: s sign bit, ?/x bits changed/not changed by sign-extend
370+
if (Amt <= 32) {
371+
auto Freeze = B.buildFreeze(VgprRB_S32, Op1.getReg(0));
372+
if (Amt == 32) {
373+
// Hi|Lo: ????????|sxxxxxxx -> ssssssss|sxxxxxxx
374+
Lo = Freeze.getReg(0);
375+
} else {
376+
// Hi|Lo: ????????|???sxxxx -> ssssssss|ssssxxxx
377+
Lo = B.buildSExtInReg(VgprRB_S32, Freeze, Amt).getReg(0);
378+
}
379+
380+
auto SignExtCst = B.buildConstant(SgprRB_S32, 31);
381+
Hi = B.buildAShr(VgprRB_S32, Lo, SignExtCst).getReg(0);
382+
} else {
383+
// Hi|Lo: ?????sxx|xxxxxxxx -> ssssssxx|xxxxxxxx
384+
Lo = Op1.getReg(0);
385+
Hi = B.buildSExtInReg(VgprRB_S32, Op1.getReg(1), Amt - 32).getReg(0);
386+
}
387+
388+
B.buildMergeLikeInstr(MI.getOperand(0).getReg(), {Lo, Hi});
389+
MI.eraseFromParent();
390+
}
391+
309392
void RegBankLegalizeHelper::lower(MachineInstr &MI,
310393
const RegBankLLTMapping &Mapping,
311394
SmallSet<Register, 4> &WaterfallSgprs) {
@@ -328,6 +411,8 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
328411
MI.eraseFromParent();
329412
return;
330413
}
414+
case UnpackBitShift:
415+
return lowerUnpackBitShift(MI);
331416
case Ext32To64: {
332417
const RegisterBank *RB = MRI.getRegBank(MI.getOperand(0).getReg());
333418
MachineInstrBuilder Hi;
@@ -394,6 +479,8 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
394479
return lowerSplitTo32(MI);
395480
case SplitTo32Select:
396481
return lowerSplitTo32Select(MI);
482+
case SplitTo32SExtInReg:
483+
return lowerSplitTo32SExtInReg(MI);
397484
case SplitLoad: {
398485
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
399486
unsigned Size = DstTy.getSizeInBits();
@@ -483,6 +570,13 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {
483570
case SgprP5:
484571
case VgprP5:
485572
return LLT::pointer(5, 32);
573+
case SgprV2S16:
574+
case VgprV2S16:
575+
case UniInVgprV2S16:
576+
return LLT::fixed_vector(2, 16);
577+
case SgprV2S32:
578+
case VgprV2S32:
579+
return LLT::fixed_vector(2, 32);
486580
case SgprV4S32:
487581
case VgprV4S32:
488582
case UniInVgprV4S32:
@@ -556,6 +650,8 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
556650
case SgprP3:
557651
case SgprP4:
558652
case SgprP5:
653+
case SgprV2S16:
654+
case SgprV2S32:
559655
case SgprV4S32:
560656
case SgprB32:
561657
case SgprB64:
@@ -565,6 +661,7 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
565661
case SgprB512:
566662
case UniInVcc:
567663
case UniInVgprS32:
664+
case UniInVgprV2S16:
568665
case UniInVgprV4S32:
569666
case UniInVgprB32:
570667
case UniInVgprB64:
@@ -586,6 +683,8 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
586683
case VgprP3:
587684
case VgprP4:
588685
case VgprP5:
686+
case VgprV2S16:
687+
case VgprV2S32:
589688
case VgprV4S32:
590689
case VgprB32:
591690
case VgprB64:
@@ -623,6 +722,8 @@ void RegBankLegalizeHelper::applyMappingDst(
623722
case SgprP3:
624723
case SgprP4:
625724
case SgprP5:
725+
case SgprV2S16:
726+
case SgprV2S32:
626727
case SgprV4S32:
627728
case Vgpr16:
628729
case Vgpr32:
@@ -632,6 +733,8 @@ void RegBankLegalizeHelper::applyMappingDst(
632733
case VgprP3:
633734
case VgprP4:
634735
case VgprP5:
736+
case VgprV2S16:
737+
case VgprV2S32:
635738
case VgprV4S32: {
636739
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
637740
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
@@ -666,6 +769,7 @@ void RegBankLegalizeHelper::applyMappingDst(
666769
break;
667770
}
668771
case UniInVgprS32:
772+
case UniInVgprV2S16:
669773
case UniInVgprV4S32: {
670774
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
671775
assert(RB == SgprRB);
@@ -739,6 +843,8 @@ void RegBankLegalizeHelper::applyMappingSrc(
739843
case SgprP3:
740844
case SgprP4:
741845
case SgprP5:
846+
case SgprV2S16:
847+
case SgprV2S32:
742848
case SgprV4S32: {
743849
assert(Ty == getTyFromID(MethodIDs[i]));
744850
assert(RB == getRegBankFromID(MethodIDs[i]));
@@ -764,6 +870,8 @@ void RegBankLegalizeHelper::applyMappingSrc(
764870
case VgprP3:
765871
case VgprP4:
766872
case VgprP5:
873+
case VgprV2S16:
874+
case VgprV2S32:
767875
case VgprV4S32: {
768876
assert(Ty == getTyFromID(MethodIDs[i]));
769877
if (RB != VgprRB) {

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,15 @@ class RegBankLegalizeHelper {
111111
SmallSet<Register, 4> &SgprWaterfallOperandRegs);
112112

113113
void lowerVccExtToSel(MachineInstr &MI);
114+
const std::pair<Register, Register> unpackZExt(Register Reg);
115+
const std::pair<Register, Register> unpackSExt(Register Reg);
116+
const std::pair<Register, Register> unpackAExt(Register Reg);
117+
void lowerUnpackBitShift(MachineInstr &MI);
114118
void lowerV_BFE(MachineInstr &MI);
115119
void lowerS_BFE(MachineInstr &MI);
116120
void lowerSplitTo32(MachineInstr &MI);
117121
void lowerSplitTo32Select(MachineInstr &MI);
122+
void lowerSplitTo32SExtInReg(MachineInstr &MI);
118123
};
119124

120125
} // end namespace AMDGPU

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
6060
return MRI.getType(Reg) == LLT::pointer(4, 64);
6161
case P5:
6262
return MRI.getType(Reg) == LLT::pointer(5, 32);
63+
case V2S32:
64+
return MRI.getType(Reg) == LLT::fixed_vector(2, 32);
6365
case V4S32:
6466
return MRI.getType(Reg) == LLT::fixed_vector(4, 32);
6567
case B32:
@@ -92,6 +94,8 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
9294
return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isUniform(Reg);
9395
case UniP5:
9496
return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isUniform(Reg);
97+
case UniV2S16:
98+
return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isUniform(Reg);
9599
case UniB32:
96100
return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isUniform(Reg);
97101
case UniB64:
@@ -122,6 +126,8 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
122126
return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isDivergent(Reg);
123127
case DivP5:
124128
return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isDivergent(Reg);
129+
case DivV2S16:
130+
return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isDivergent(Reg);
125131
case DivB32:
126132
return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isDivergent(Reg);
127133
case DivB64:
@@ -435,7 +441,7 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,
435441
MachineRegisterInfo &_MRI)
436442
: ST(&_ST), MRI(&_MRI) {
437443

438-
addRulesForGOpcs({G_ADD}, Standard)
444+
addRulesForGOpcs({G_ADD, G_SUB}, Standard)
439445
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
440446
.Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}});
441447

@@ -452,11 +458,36 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,
452458
.Div(B64, {{VgprB64}, {VgprB64, VgprB64}, SplitTo32});
453459

454460
addRulesForGOpcs({G_SHL}, Standard)
461+
.Uni(S16, {{Sgpr32Trunc}, {Sgpr32AExt, Sgpr32ZExt}})
462+
.Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})
463+
.Uni(V2S16, {{SgprV2S16}, {SgprV2S16, SgprV2S16}, UnpackBitShift})
464+
.Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}})
465+
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
466+
.Uni(S64, {{Sgpr64}, {Sgpr64, Sgpr32}})
455467
.Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}})
468+
.Div(S64, {{Vgpr64}, {Vgpr64, Vgpr32}});
469+
470+
addRulesForGOpcs({G_LSHR}, Standard)
471+
.Uni(S16, {{Sgpr32Trunc}, {Sgpr32ZExt, Sgpr32ZExt}})
472+
.Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})
473+
.Uni(V2S16, {{SgprV2S16}, {SgprV2S16, SgprV2S16}, UnpackBitShift})
474+
.Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}})
475+
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
456476
.Uni(S64, {{Sgpr64}, {Sgpr64, Sgpr32}})
477+
.Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}})
457478
.Div(S64, {{Vgpr64}, {Vgpr64, Vgpr32}});
458479

459-
addRulesForGOpcs({G_LSHR}, Standard).Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}});
480+
addRulesForGOpcs({G_ASHR}, Standard)
481+
.Uni(S16, {{Sgpr32Trunc}, {Sgpr32SExt, Sgpr32ZExt}})
482+
.Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})
483+
.Uni(V2S16, {{SgprV2S16}, {SgprV2S16, SgprV2S16}, UnpackBitShift})
484+
.Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}})
485+
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
486+
.Uni(S64, {{Sgpr64}, {Sgpr64, Sgpr32}})
487+
.Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}})
488+
.Div(S64, {{Vgpr64}, {Vgpr64, Vgpr32}});
489+
490+
addRulesForGOpcs({G_FRAME_INDEX}).Any({{UniP5, _}, {{SgprP5}, {None}}});
460491

461492
addRulesForGOpcs({G_UBFX, G_SBFX}, Standard)
462493
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32, Sgpr32}, S_BFE})
@@ -515,6 +546,8 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,
515546
.Any({{DivS16, S32}, {{Vgpr16}, {Vgpr32}}})
516547
.Any({{UniS32, S64}, {{Sgpr32}, {Sgpr64}}})
517548
.Any({{DivS32, S64}, {{Vgpr32}, {Vgpr64}}})
549+
.Any({{UniV2S16, V2S32}, {{SgprV2S16}, {SgprV2S32}}})
550+
.Any({{DivV2S16, V2S32}, {{VgprV2S16}, {VgprV2S32}}})
518551
// This is non-trivial. VgprToVccCopy is done using compare instruction.
519552
.Any({{DivS1, DivS16}, {{Vcc}, {Vgpr16}, VgprToVccCopy}})
520553
.Any({{DivS1, DivS32}, {{Vcc}, {Vgpr32}, VgprToVccCopy}})
@@ -550,6 +583,12 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,
550583
.Any({{UniS32, S16}, {{Sgpr32}, {Sgpr16}}})
551584
.Any({{DivS32, S16}, {{Vgpr32}, {Vgpr16}}});
552585

586+
addRulesForGOpcs({G_SEXT_INREG})
587+
.Any({{UniS32, S32}, {{Sgpr32}, {Sgpr32}}})
588+
.Any({{DivS32, S32}, {{Vgpr32}, {Vgpr32}}})
589+
.Any({{UniS64, S64}, {{Sgpr64}, {Sgpr64}}})
590+
.Any({{DivS64, S64}, {{Vgpr64}, {Vgpr64}, SplitTo32SExtInReg}});
591+
553592
bool hasUnalignedLoads = ST->getGeneration() >= AMDGPUSubtarget::GFX12;
554593
bool hasSMRDSmall = ST->hasScalarSubwordLoads();
555594

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ enum UniformityLLTOpPredicateID {
7575
V3S32,
7676
V4S32,
7777

78+
UniV2S16,
79+
80+
DivV2S16,
81+
7882
// B types
7983
B32,
8084
B64,
@@ -117,7 +121,9 @@ enum RegBankLLTMappingApplyID {
117121
SgprP3,
118122
SgprP4,
119123
SgprP5,
124+
SgprV2S16,
120125
SgprV4S32,
126+
SgprV2S32,
121127
SgprB32,
122128
SgprB64,
123129
SgprB96,
@@ -134,6 +140,8 @@ enum RegBankLLTMappingApplyID {
134140
VgprP3,
135141
VgprP4,
136142
VgprP5,
143+
VgprV2S16,
144+
VgprV2S32,
137145
VgprB32,
138146
VgprB64,
139147
VgprB96,
@@ -145,6 +153,7 @@ enum RegBankLLTMappingApplyID {
145153
// Dst only modifiers: read-any-lane and truncs
146154
UniInVcc,
147155
UniInVgprS32,
156+
UniInVgprV2S16,
148157
UniInVgprV4S32,
149158
UniInVgprB32,
150159
UniInVgprB64,
@@ -173,11 +182,13 @@ enum LoweringMethodID {
173182
DoNotLower,
174183
VccExtToSel,
175184
UniExtToSel,
185+
UnpackBitShift,
176186
S_BFE,
177187
V_BFE,
178188
VgprToVccCopy,
179189
SplitTo32,
180190
SplitTo32Select,
191+
SplitTo32SExtInReg,
181192
Ext32To64,
182193
UniCstExt,
183194
SplitLoad,

0 commit comments

Comments
 (0)