Skip to content

Commit 1386a93

Browse files
committed
[RISCV][Isel] Remove redundant vmerge for the vwadd.
1 parent faef68b commit 1386a93

File tree

3 files changed

+168
-1
lines changed

3 files changed

+168
-1
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13709,6 +13709,57 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
1370913709
return InputRootReplacement;
1371013710
}
1371113711

13712+
// Fold (vwadd.wv y, (vmerge cond, x, 0)) -> vwadd.wv y, x, y, cond
13713+
// y will be the Passthru and cond will be the Mask.
13714+
static SDValue combineVWADDWSelect(SDNode *N, SelectionDAG &DAG) {
13715+
unsigned Opc = N->getOpcode();
13716+
assert(Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL);
13717+
13718+
SDValue Y = N->getOperand(0);
13719+
SDValue MergeOp = N->getOperand(1);
13720+
if (MergeOp.getOpcode() != RISCVISD::VMERGE_VL)
13721+
return SDValue();
13722+
SDValue X = MergeOp->getOperand(1);
13723+
13724+
if (!MergeOp.hasOneUse())
13725+
return SDValue();
13726+
13727+
// Passthru should be undef
13728+
SDValue Passthru = N->getOperand(2);
13729+
if (!Passthru.isUndef())
13730+
return SDValue();
13731+
13732+
// Mask should be all ones
13733+
SDValue Mask = N->getOperand(3);
13734+
if (Mask.getOpcode() != RISCVISD::VMSET_VL)
13735+
return SDValue();
13736+
13737+
// False value of MergeOp should be all zeros
13738+
SDValue Z = MergeOp->getOperand(2);
13739+
if (Z.getOpcode() != ISD::INSERT_SUBVECTOR)
13740+
return SDValue();
13741+
if (!ISD::isBuildVectorAllZeros(Z.getOperand(1).getNode()))
13742+
return SDValue();
13743+
if (!isNullOrNullSplat(Z.getOperand(0)) && !Z.getOperand(0).isUndef())
13744+
return SDValue();
13745+
13746+
return DAG.getNode(Opc, SDLoc(N), N->getValueType(0),
13747+
{Y, X, Y, MergeOp->getOperand(0), N->getOperand(4)},
13748+
N->getFlags());
13749+
}
13750+
13751+
static SDValue performVWADDW_VLCombine(SDNode *N,
13752+
TargetLowering::DAGCombinerInfo &DCI,
13753+
const RISCVSubtarget &Subtarget) {
13754+
unsigned Opc = N->getOpcode();
13755+
assert(Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL);
13756+
13757+
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
13758+
return V;
13759+
13760+
return combineVWADDWSelect(N, DCI.DAG);
13761+
}
13762+
1371213763
// Helper function for performMemPairCombine.
1371313764
// Try to combine the memory loads/stores LSNode1 and LSNode2
1371413765
// into a single memory pair operation.
@@ -15777,9 +15828,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1577715828
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
1577815829
return V;
1577915830
return combineToVWMACC(N, DAG, Subtarget);
15780-
case RISCVISD::SUB_VL:
1578115831
case RISCVISD::VWADD_W_VL:
1578215832
case RISCVISD::VWADDU_W_VL:
15833+
return performVWADDW_VLCombine(N, DCI, Subtarget);
15834+
case RISCVISD::SUB_VL:
1578315835
case RISCVISD::VWSUB_W_VL:
1578415836
case RISCVISD::VWSUBU_W_VL:
1578515837
case RISCVISD::MUL_VL:

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,27 @@ multiclass VPatTiedBinaryNoMaskVL_V<SDNode vop,
691691
GPR:$vl, sew, TU_MU)>;
692692
}
693693

694+
class VPatTiedBinaryMaskVL_V<SDNode vop,
695+
string instruction_name,
696+
string suffix,
697+
ValueType result_type,
698+
ValueType op2_type,
699+
ValueType mask_type,
700+
int sew,
701+
LMULInfo vlmul,
702+
VReg result_reg_class,
703+
VReg op2_reg_class> :
704+
Pat<(result_type (vop
705+
(result_type result_reg_class:$rs1),
706+
(op2_type op2_reg_class:$rs2),
707+
(result_type result_reg_class:$rs1),
708+
(mask_type V0),
709+
VLOpFrag)),
710+
(!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK_TIED")
711+
result_reg_class:$rs1,
712+
op2_reg_class:$rs2,
713+
(mask_type V0), GPR:$vl, sew, TU_MU)>;
714+
694715
multiclass VPatTiedBinaryNoMaskVL_V_RM<SDNode vop,
695716
string instruction_name,
696717
string suffix,
@@ -819,6 +840,10 @@ multiclass VPatBinaryWVL_VV_VX_WV_WX<SDPatternOperator vop, SDNode vop_w,
819840
defm : VPatTiedBinaryNoMaskVL_V<vop_w, instruction_name, "WV",
820841
wti.Vector, vti.Vector, vti.Log2SEW,
821842
vti.LMul, wti.RegClass, vti.RegClass>;
843+
def : VPatTiedBinaryMaskVL_V<vop_w, instruction_name, "WV",
844+
wti.Vector, vti.Vector, wti.Mask,
845+
vti.Log2SEW, vti.LMul, wti.RegClass,
846+
vti.RegClass>;
822847
def : VPatBinaryVL_V<vop_w, instruction_name, "WV",
823848
wti.Vector, wti.Vector, vti.Vector, vti.Mask,
824849
vti.Log2SEW, vti.LMul, wti.RegClass, wti.RegClass,
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
2+
; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
3+
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
4+
5+
define <8 x i64> @vwadd_wv_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
6+
; CHECK-LABEL: vwadd_wv_mask_v8i32:
7+
; CHECK: # %bb.0:
8+
; CHECK-NEXT: li a0, 42
9+
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
10+
; CHECK-NEXT: vmslt.vx v0, v8, a0
11+
; CHECK-NEXT: vsetvli zero, zero, e32, m2, tu, mu
12+
; CHECK-NEXT: vwadd.wv v12, v12, v8, v0.t
13+
; CHECK-NEXT: vmv4r.v v8, v12
14+
; CHECK-NEXT: ret
15+
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
16+
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
17+
%sa = sext <8 x i32> %a to <8 x i64>
18+
%ret = add <8 x i64> %sa, %y
19+
ret <8 x i64> %ret
20+
}
21+
22+
define <8 x i64> @vwaddu_wv_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
23+
; CHECK-LABEL: vwaddu_wv_mask_v8i32:
24+
; CHECK: # %bb.0:
25+
; CHECK-NEXT: li a0, 42
26+
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
27+
; CHECK-NEXT: vmslt.vx v0, v8, a0
28+
; CHECK-NEXT: vsetvli zero, zero, e32, m2, tu, mu
29+
; CHECK-NEXT: vwaddu.wv v12, v12, v8, v0.t
30+
; CHECK-NEXT: vmv4r.v v8, v12
31+
; CHECK-NEXT: ret
32+
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
33+
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
34+
%sa = zext <8 x i32> %a to <8 x i64>
35+
%ret = add <8 x i64> %sa, %y
36+
ret <8 x i64> %ret
37+
}
38+
39+
define <8 x i64> @vwaddu_vv_mask_v8i32(<8 x i32> %x, <8 x i32> %y) {
40+
; CHECK-LABEL: vwaddu_vv_mask_v8i32:
41+
; CHECK: # %bb.0:
42+
; CHECK-NEXT: li a0, 42
43+
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
44+
; CHECK-NEXT: vmslt.vx v0, v8, a0
45+
; CHECK-NEXT: vmv.v.i v12, 0
46+
; CHECK-NEXT: vmerge.vvm v8, v12, v8, v0
47+
; CHECK-NEXT: vwaddu.vv v12, v8, v10
48+
; CHECK-NEXT: vmv4r.v v8, v12
49+
; CHECK-NEXT: ret
50+
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
51+
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
52+
%sa = zext <8 x i32> %a to <8 x i64>
53+
%sy = zext <8 x i32> %y to <8 x i64>
54+
%ret = add <8 x i64> %sa, %sy
55+
ret <8 x i64> %ret
56+
}
57+
58+
define <8 x i64> @vwadd_wv_mask_v8i32_commutative(<8 x i32> %x, <8 x i64> %y) {
59+
; CHECK-LABEL: vwadd_wv_mask_v8i32_commutative:
60+
; CHECK: # %bb.0:
61+
; CHECK-NEXT: li a0, 42
62+
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
63+
; CHECK-NEXT: vmslt.vx v0, v8, a0
64+
; CHECK-NEXT: vsetvli zero, zero, e32, m2, tu, mu
65+
; CHECK-NEXT: vwadd.wv v12, v12, v8, v0.t
66+
; CHECK-NEXT: vmv4r.v v8, v12
67+
; CHECK-NEXT: ret
68+
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
69+
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
70+
%sa = sext <8 x i32> %a to <8 x i64>
71+
%ret = add <8 x i64> %y, %sa
72+
ret <8 x i64> %ret
73+
}
74+
75+
define <8 x i64> @vwadd_wv_mask_v8i32_nonzero(<8 x i32> %x, <8 x i64> %y) {
76+
; CHECK-LABEL: vwadd_wv_mask_v8i32_nonzero:
77+
; CHECK: # %bb.0:
78+
; CHECK-NEXT: li a0, 42
79+
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
80+
; CHECK-NEXT: vmslt.vx v0, v8, a0
81+
; CHECK-NEXT: vmv.v.i v10, 1
82+
; CHECK-NEXT: vmerge.vvm v16, v10, v8, v0
83+
; CHECK-NEXT: vwadd.wv v8, v12, v16
84+
; CHECK-NEXT: ret
85+
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
86+
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>
87+
%sa = sext <8 x i32> %a to <8 x i64>
88+
%ret = add <8 x i64> %y, %sa
89+
ret <8 x i64> %ret
90+
}

0 commit comments

Comments
 (0)