Skip to content

Commit 8f320df

Browse files
committed
[RISCV] Handle .vx/.vi pseudos in hasAllNBitUsers
Vector pseudos with scalar operands only use the lower SEW bits (or less in the case of shifts and clips). This patch accounts for this in hasAllNBitUsers for both SDNodes in RISCVISelDAGToDAG. We also need to handle this in RISCVOptWInstrs otherwise we introduce slliw instructions that are less compressible than their original slli counterpart.
1 parent 487c784 commit 8f320df

11 files changed

+1088
-784
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2752,6 +2752,175 @@ bool RISCVDAGToDAGISel::selectSHXADD_UWOp(SDValue N, unsigned ShAmt,
27522752
return false;
27532753
}
27542754

2755+
static bool vectorPseudoHasAllNBitUsers(SDNode *User, unsigned UserOpNo,
2756+
unsigned Bits,
2757+
const TargetInstrInfo *TII) {
2758+
const RISCVVPseudosTable::PseudoInfo *PseudoInfo =
2759+
RISCVVPseudosTable::getPseudoInfo(User->getMachineOpcode());
2760+
2761+
if (!PseudoInfo)
2762+
return false;
2763+
2764+
const MCInstrDesc &MCID = TII->get(User->getMachineOpcode());
2765+
const uint64_t TSFlags = MCID.TSFlags;
2766+
if (!RISCVII::hasSEWOp(TSFlags))
2767+
return false;
2768+
assert(RISCVII::hasVLOp(TSFlags));
2769+
2770+
bool HasGlueOp = User->getGluedNode() != nullptr;
2771+
unsigned ChainOpIdx = User->getNumOperands() - HasGlueOp - 1;
2772+
bool HasChainOp = User->getOperand(ChainOpIdx).getValueType() == MVT::Other;
2773+
bool HasVecPolicyOp = RISCVII::hasVecPolicyOp(TSFlags);
2774+
unsigned VLIdx =
2775+
User->getNumOperands() - HasVecPolicyOp - HasChainOp - HasGlueOp - 2;
2776+
const unsigned Log2SEW = User->getConstantOperandVal(VLIdx + 1);
2777+
2778+
// TODO: The Largest VL 65,536 occurs for LMUL=8 and SEW=8 with
2779+
// VLEN=65,536. We could check if Bits < 16 here.
2780+
if (UserOpNo == VLIdx)
2781+
return false;
2782+
2783+
// TODO: Handle Zvbb instructions
2784+
switch (PseudoInfo->BaseInstr) {
2785+
default:
2786+
return false;
2787+
2788+
// 11.6. Vector Single-Width Shift Instructions
2789+
case RISCV::VSLL_VX:
2790+
case RISCV::VSLL_VI:
2791+
case RISCV::VSRL_VX:
2792+
case RISCV::VSRL_VI:
2793+
case RISCV::VSRA_VX:
2794+
case RISCV::VSRA_VI:
2795+
// 12.4. Vector Single-Width Scaling Shift Instructions
2796+
case RISCV::VSSRL_VX:
2797+
case RISCV::VSSRL_VI:
2798+
case RISCV::VSSRA_VX:
2799+
case RISCV::VSSRA_VI:
2800+
// Only the low lg2(SEW) bits of the shift-amount value are used.
2801+
if (Bits < Log2SEW)
2802+
return false;
2803+
break;
2804+
2805+
// 11.7 Vector Narrowing Integer Right Shift Instructions
2806+
case RISCV::VNSRL_WX:
2807+
case RISCV::VNSRL_WI:
2808+
case RISCV::VNSRA_WX:
2809+
case RISCV::VNSRA_WI:
2810+
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
2811+
case RISCV::VNCLIPU_WX:
2812+
case RISCV::VNCLIPU_WI:
2813+
case RISCV::VNCLIP_WX:
2814+
case RISCV::VNCLIP_WI:
2815+
// Only the low lg2(2*SEW) bits of the shift-amount value are used.
2816+
if (Bits < Log2SEW + 1)
2817+
return false;
2818+
break;
2819+
2820+
// 11.1. Vector Single-Width Integer Add and Subtract
2821+
case RISCV::VADD_VX:
2822+
case RISCV::VADD_VI:
2823+
case RISCV::VSUB_VX:
2824+
case RISCV::VRSUB_VX:
2825+
case RISCV::VRSUB_VI:
2826+
// 11.2. Vector Widening Integer Add/Subtract
2827+
case RISCV::VWADDU_VX:
2828+
case RISCV::VWSUBU_VX:
2829+
case RISCV::VWADD_VX:
2830+
case RISCV::VWSUB_VX:
2831+
case RISCV::VWADDU_WX:
2832+
case RISCV::VWSUBU_WX:
2833+
case RISCV::VWADD_WX:
2834+
case RISCV::VWSUB_WX:
2835+
// 11.4. Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
2836+
case RISCV::VADC_VXM:
2837+
case RISCV::VADC_VIM:
2838+
case RISCV::VMADC_VXM:
2839+
case RISCV::VMADC_VIM:
2840+
case RISCV::VMADC_VX:
2841+
case RISCV::VMADC_VI:
2842+
case RISCV::VSBC_VXM:
2843+
case RISCV::VMSBC_VXM:
2844+
case RISCV::VMSBC_VX:
2845+
// 11.5 Vector Bitwise Logical Instructions
2846+
case RISCV::VAND_VX:
2847+
case RISCV::VAND_VI:
2848+
case RISCV::VOR_VX:
2849+
case RISCV::VOR_VI:
2850+
case RISCV::VXOR_VX:
2851+
case RISCV::VXOR_VI:
2852+
// 11.8. Vector Integer Compare Instructions
2853+
case RISCV::VMSEQ_VX:
2854+
case RISCV::VMSEQ_VI:
2855+
case RISCV::VMSNE_VX:
2856+
case RISCV::VMSNE_VI:
2857+
case RISCV::VMSLTU_VX:
2858+
case RISCV::VMSLT_VX:
2859+
case RISCV::VMSLEU_VX:
2860+
case RISCV::VMSLEU_VI:
2861+
case RISCV::VMSLE_VX:
2862+
case RISCV::VMSLE_VI:
2863+
case RISCV::VMSGTU_VX:
2864+
case RISCV::VMSGTU_VI:
2865+
case RISCV::VMSGT_VX:
2866+
case RISCV::VMSGT_VI:
2867+
// 11.9. Vector Integer Min/Max Instructions
2868+
case RISCV::VMINU_VX:
2869+
case RISCV::VMIN_VX:
2870+
case RISCV::VMAXU_VX:
2871+
case RISCV::VMAX_VX:
2872+
// 11.10. Vector Single-Width Integer Multiply Instructions
2873+
case RISCV::VMUL_VX:
2874+
case RISCV::VMULH_VX:
2875+
case RISCV::VMULHU_VX:
2876+
case RISCV::VMULHSU_VX:
2877+
// 11.11. Vector Integer Divide Instructions
2878+
case RISCV::VDIVU_VX:
2879+
case RISCV::VDIV_VX:
2880+
case RISCV::VREMU_VX:
2881+
case RISCV::VREM_VX:
2882+
// 11.12. Vector Widening Integer Multiply Instructions
2883+
case RISCV::VWMUL_VX:
2884+
case RISCV::VWMULU_VX:
2885+
case RISCV::VWMULSU_VX:
2886+
// 11.13. Vector Single-Width Integer Multiply-Add Instructions
2887+
case RISCV::VMACC_VX:
2888+
case RISCV::VNMSAC_VX:
2889+
case RISCV::VMADD_VX:
2890+
case RISCV::VNMSUB_VX:
2891+
// 11.14. Vector Widening Integer Multiply-Add Instructions
2892+
case RISCV::VWMACCU_VX:
2893+
case RISCV::VWMACC_VX:
2894+
case RISCV::VWMACCSU_VX:
2895+
case RISCV::VWMACCUS_VX:
2896+
// 11.15. Vector Integer Merge Instructions
2897+
case RISCV::VMERGE_VXM:
2898+
case RISCV::VMERGE_VIM:
2899+
// 11.16. Vector Integer Move Instructions
2900+
case RISCV::VMV_V_X:
2901+
case RISCV::VMV_V_I:
2902+
// 12.1. Vector Single-Width Saturating Add and Subtract
2903+
case RISCV::VSADDU_VX:
2904+
case RISCV::VSADDU_VI:
2905+
case RISCV::VSADD_VX:
2906+
case RISCV::VSADD_VI:
2907+
case RISCV::VSSUBU_VX:
2908+
case RISCV::VSSUB_VX:
2909+
// 12.2. Vector Single-Width Averaging Add and Subtract
2910+
case RISCV::VAADDU_VX:
2911+
case RISCV::VAADD_VX:
2912+
case RISCV::VASUBU_VX:
2913+
case RISCV::VASUB_VX:
2914+
// 12.3. Vector Single-Width Fractional Multiply with Rounding and Saturation
2915+
case RISCV::VSMUL_VX:
2916+
// 16.1. Integer Scalar Move Instructions
2917+
case RISCV::VMV_S_X:
2918+
if (Bits < (1 << Log2SEW))
2919+
return false;
2920+
}
2921+
return true;
2922+
}
2923+
27552924
// Return true if all users of this SDNode* only consume the lower \p Bits.
27562925
// This can be used to form W instructions for add/sub/mul/shl even when the
27572926
// root isn't a sext_inreg. This can allow the ADDW/SUBW/MULW/SLLIW to CSE if
@@ -2783,6 +2952,8 @@ bool RISCVDAGToDAGISel::hasAllNBitUsers(SDNode *Node, unsigned Bits,
27832952
// TODO: Add more opcodes?
27842953
switch (User->getMachineOpcode()) {
27852954
default:
2955+
if (vectorPseudoHasAllNBitUsers(User, UI.getOperandNo(), Bits, TII))
2956+
break;
27862957
return false;
27872958
case RISCV::ADDW:
27882959
case RISCV::ADDIW:

llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,168 @@ FunctionPass *llvm::createRISCVOptWInstrsPass() {
7777
return new RISCVOptWInstrs();
7878
}
7979

80+
static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp,
81+
unsigned Bits) {
82+
const MachineInstr &MI = *UserOp.getParent();
83+
const RISCVVPseudosTable::PseudoInfo *PseudoInfo =
84+
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
85+
86+
if (!PseudoInfo)
87+
return false;
88+
89+
const MCInstrDesc &MCID = MI.getDesc();
90+
const uint64_t TSFlags = MI.getDesc().TSFlags;
91+
if (!RISCVII::hasSEWOp(TSFlags))
92+
return false;
93+
assert(RISCVII::hasVLOp(TSFlags));
94+
const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm();
95+
96+
// TODO: The Largest VL 65,536 occurs for LMUL=8 and SEW=8 with
97+
// VLEN=65,536. We could check if Bits < 16 here.
98+
if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID))
99+
return false;
100+
101+
// TODO: Handle Zvbb instructions
102+
switch (PseudoInfo->BaseInstr) {
103+
default:
104+
return false;
105+
106+
// 11.6. Vector Single-Width Shift Instructions
107+
case RISCV::VSLL_VX:
108+
case RISCV::VSLL_VI:
109+
case RISCV::VSRL_VX:
110+
case RISCV::VSRL_VI:
111+
case RISCV::VSRA_VX:
112+
case RISCV::VSRA_VI:
113+
// 12.4. Vector Single-Width Scaling Shift Instructions
114+
case RISCV::VSSRL_VX:
115+
case RISCV::VSSRL_VI:
116+
case RISCV::VSSRA_VX:
117+
case RISCV::VSSRA_VI:
118+
// Only the low lg2(SEW) bits of the shift-amount value are used.
119+
if (Bits < Log2SEW)
120+
return false;
121+
break;
122+
123+
// 11.7 Vector Narrowing Integer Right Shift Instructions
124+
case RISCV::VNSRL_WX:
125+
case RISCV::VNSRL_WI:
126+
case RISCV::VNSRA_WX:
127+
case RISCV::VNSRA_WI:
128+
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
129+
case RISCV::VNCLIPU_WX:
130+
case RISCV::VNCLIPU_WI:
131+
case RISCV::VNCLIP_WX:
132+
case RISCV::VNCLIP_WI:
133+
// Only the low lg2(2*SEW) bits of the shift-amount value are used.
134+
if (Bits < Log2SEW + 1)
135+
return false;
136+
break;
137+
138+
// 11.1. Vector Single-Width Integer Add and Subtract
139+
case RISCV::VADD_VX:
140+
case RISCV::VADD_VI:
141+
case RISCV::VSUB_VX:
142+
case RISCV::VRSUB_VX:
143+
case RISCV::VRSUB_VI:
144+
// 11.2. Vector Widening Integer Add/Subtract
145+
case RISCV::VWADDU_VX:
146+
case RISCV::VWSUBU_VX:
147+
case RISCV::VWADD_VX:
148+
case RISCV::VWSUB_VX:
149+
case RISCV::VWADDU_WX:
150+
case RISCV::VWSUBU_WX:
151+
case RISCV::VWADD_WX:
152+
case RISCV::VWSUB_WX:
153+
// 11.4. Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
154+
case RISCV::VADC_VXM:
155+
case RISCV::VADC_VIM:
156+
case RISCV::VMADC_VXM:
157+
case RISCV::VMADC_VIM:
158+
case RISCV::VMADC_VX:
159+
case RISCV::VMADC_VI:
160+
case RISCV::VSBC_VXM:
161+
case RISCV::VMSBC_VXM:
162+
case RISCV::VMSBC_VX:
163+
// 11.5 Vector Bitwise Logical Instructions
164+
case RISCV::VAND_VX:
165+
case RISCV::VAND_VI:
166+
case RISCV::VOR_VX:
167+
case RISCV::VOR_VI:
168+
case RISCV::VXOR_VX:
169+
case RISCV::VXOR_VI:
170+
// 11.8. Vector Integer Compare Instructions
171+
case RISCV::VMSEQ_VX:
172+
case RISCV::VMSEQ_VI:
173+
case RISCV::VMSNE_VX:
174+
case RISCV::VMSNE_VI:
175+
case RISCV::VMSLTU_VX:
176+
case RISCV::VMSLT_VX:
177+
case RISCV::VMSLEU_VX:
178+
case RISCV::VMSLEU_VI:
179+
case RISCV::VMSLE_VX:
180+
case RISCV::VMSLE_VI:
181+
case RISCV::VMSGTU_VX:
182+
case RISCV::VMSGTU_VI:
183+
case RISCV::VMSGT_VX:
184+
case RISCV::VMSGT_VI:
185+
// 11.9. Vector Integer Min/Max Instructions
186+
case RISCV::VMINU_VX:
187+
case RISCV::VMIN_VX:
188+
case RISCV::VMAXU_VX:
189+
case RISCV::VMAX_VX:
190+
// 11.10. Vector Single-Width Integer Multiply Instructions
191+
case RISCV::VMUL_VX:
192+
case RISCV::VMULH_VX:
193+
case RISCV::VMULHU_VX:
194+
case RISCV::VMULHSU_VX:
195+
// 11.11. Vector Integer Divide Instructions
196+
case RISCV::VDIVU_VX:
197+
case RISCV::VDIV_VX:
198+
case RISCV::VREMU_VX:
199+
case RISCV::VREM_VX:
200+
// 11.12. Vector Widening Integer Multiply Instructions
201+
case RISCV::VWMUL_VX:
202+
case RISCV::VWMULU_VX:
203+
case RISCV::VWMULSU_VX:
204+
// 11.13. Vector Single-Width Integer Multiply-Add Instructions
205+
case RISCV::VMACC_VX:
206+
case RISCV::VNMSAC_VX:
207+
case RISCV::VMADD_VX:
208+
case RISCV::VNMSUB_VX:
209+
// 11.14. Vector Widening Integer Multiply-Add Instructions
210+
case RISCV::VWMACCU_VX:
211+
case RISCV::VWMACC_VX:
212+
case RISCV::VWMACCSU_VX:
213+
case RISCV::VWMACCUS_VX:
214+
// 11.15. Vector Integer Merge Instructions
215+
case RISCV::VMERGE_VXM:
216+
case RISCV::VMERGE_VIM:
217+
// 11.16. Vector Integer Move Instructions
218+
case RISCV::VMV_V_X:
219+
case RISCV::VMV_V_I:
220+
// 12.1. Vector Single-Width Saturating Add and Subtract
221+
case RISCV::VSADDU_VX:
222+
case RISCV::VSADDU_VI:
223+
case RISCV::VSADD_VX:
224+
case RISCV::VSADD_VI:
225+
case RISCV::VSSUBU_VX:
226+
case RISCV::VSSUB_VX:
227+
// 12.2. Vector Single-Width Averaging Add and Subtract
228+
case RISCV::VAADDU_VX:
229+
case RISCV::VAADD_VX:
230+
case RISCV::VASUBU_VX:
231+
case RISCV::VASUB_VX:
232+
// 12.3. Vector Single-Width Fractional Multiply with Rounding and Saturation
233+
case RISCV::VSMUL_VX:
234+
// 16.1. Integer Scalar Move Instructions
235+
case RISCV::VMV_S_X:
236+
if (Bits < (1 << Log2SEW))
237+
return false;
238+
}
239+
return true;
240+
}
241+
80242
// Checks if all users only demand the lower \p OrigBits of the original
81243
// instruction's result.
82244
// TODO: handle multiple interdependent transformations
@@ -107,6 +269,8 @@ static bool hasAllNBitUsers(const MachineInstr &OrigMI,
107269

108270
switch (UserMI->getOpcode()) {
109271
default:
272+
if (vectorPseudoHasAllNBitUsers(UserOp, Bits))
273+
break;
110274
return false;
111275

112276
case RISCV::ADDIW:

0 commit comments

Comments
 (0)