Skip to content

Commit b30e8f6

Browse files
committed
[RISCV] Use masked pseudo peephole for reduction pseudos
After llvm#71483 we now have a way of marking masked pseudos as having an unmasked equivalent, but their mask shouldn't be folded unless it's all ones since it would affect the result. This patch uses it to mark the pseudos for vredsum and friends, which in turn allows us to remove the unmasked patterns and remove vmerges entirely if it's known to have an all ones mask.
1 parent fd48044 commit b30e8f6

File tree

3 files changed

+8
-75
lines changed

3 files changed

+8
-75
lines changed

llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3213,7 +3213,8 @@ multiclass VPseudoTernaryWithTailPolicy<VReg RetClass,
32133213
defvar mx = MInfo.MX;
32143214
let isCommutable = Commutable in
32153215
def "_" # mx # "_E" # sew : VPseudoTernaryNoMaskWithPolicy<RetClass, Op1Class, Op2Class, Constraint>;
3216-
def "_" # mx # "_E" # sew # "_MASK" : VPseudoTernaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>;
3216+
def "_" # mx # "_E" # sew # "_MASK" : VPseudoTernaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>,
3217+
RISCVMaskedPseudo<MaskIdx=3, MaskAffectsRes=true>;
32173218
}
32183219
}
32193220

@@ -3232,7 +3233,8 @@ multiclass VPseudoTernaryWithTailPolicyRoundingMode<VReg RetClass,
32323233
Op2Class, Constraint>;
32333234
def "_" # mx # "_E" # sew # "_MASK"
32343235
: VPseudoTernaryMaskPolicyRoundingMode<RetClass, Op1Class,
3235-
Op2Class, Constraint>;
3236+
Op2Class, Constraint>,
3237+
RISCVMaskedPseudo<MaskIdx=3, MaskAffectsRes=true>;
32363238
}
32373239
}
32383240

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,16 +1381,6 @@ multiclass VPatReductionVL<SDNode vop, string instruction_name, bit is_float> {
13811381
foreach vti = !if(is_float, AllFloatVectors, AllIntegerVectors) in {
13821382
defvar vti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # vti.SEW # "M1");
13831383
let Predicates = GetVTypePredicates<vti>.Predicates in {
1384-
def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
1385-
(vti.Vector vti.RegClass:$rs1), VR:$rs2,
1386-
(vti.Mask true_mask), VLOpFrag,
1387-
(XLenVT timm:$policy))),
1388-
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
1389-
(vti_m1.Vector VR:$merge),
1390-
(vti.Vector vti.RegClass:$rs1),
1391-
(vti_m1.Vector VR:$rs2),
1392-
GPR:$vl, vti.Log2SEW, (XLenVT timm:$policy))>;
1393-
13941384
def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
13951385
(vti.Vector vti.RegClass:$rs1), VR:$rs2,
13961386
(vti.Mask V0), VLOpFrag,
@@ -1408,19 +1398,6 @@ multiclass VPatReductionVL_RM<SDNode vop, string instruction_name, bit is_float>
14081398
foreach vti = !if(is_float, AllFloatVectors, AllIntegerVectors) in {
14091399
defvar vti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # vti.SEW # "M1");
14101400
let Predicates = GetVTypePredicates<vti>.Predicates in {
1411-
def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
1412-
(vti.Vector vti.RegClass:$rs1), VR:$rs2,
1413-
(vti.Mask true_mask), VLOpFrag,
1414-
(XLenVT timm:$policy))),
1415-
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
1416-
(vti_m1.Vector VR:$merge),
1417-
(vti.Vector vti.RegClass:$rs1),
1418-
(vti_m1.Vector VR:$rs2),
1419-
// Value to indicate no rounding mode change in
1420-
// RISCVInsertReadWriteCSR
1421-
FRM_DYN,
1422-
GPR:$vl, vti.Log2SEW, (XLenVT timm:$policy))>;
1423-
14241401
def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
14251402
(vti.Vector vti.RegClass:$rs1), VR:$rs2,
14261403
(vti.Mask V0), VLOpFrag,
@@ -1486,14 +1463,6 @@ multiclass VPatWidenReductionVL<SDNode vop, PatFrags extop, string instruction_n
14861463
defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
14871464
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
14881465
GetVTypePredicates<wti>.Predicates) in {
1489-
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
1490-
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
1491-
VR:$rs2, (vti.Mask true_mask), VLOpFrag,
1492-
(XLenVT timm:$policy))),
1493-
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
1494-
(wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
1495-
(wti_m1.Vector VR:$rs2), GPR:$vl, vti.Log2SEW,
1496-
(XLenVT timm:$policy))>;
14971466
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
14981467
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
14991468
VR:$rs2, (vti.Mask V0), VLOpFrag,
@@ -1513,18 +1482,6 @@ multiclass VPatWidenReductionVL_RM<SDNode vop, PatFrags extop, string instructio
15131482
defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
15141483
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
15151484
GetVTypePredicates<wti>.Predicates) in {
1516-
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
1517-
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
1518-
VR:$rs2, (vti.Mask true_mask), VLOpFrag,
1519-
(XLenVT timm:$policy))),
1520-
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
1521-
(wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
1522-
(wti_m1.Vector VR:$rs2),
1523-
// Value to indicate no rounding mode change in
1524-
// RISCVInsertReadWriteCSR
1525-
FRM_DYN,
1526-
GPR:$vl, vti.Log2SEW,
1527-
(XLenVT timm:$policy))>;
15281485
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
15291486
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
15301487
VR:$rs2, (vti.Mask V0), VLOpFrag,
@@ -1548,14 +1505,6 @@ multiclass VPatWidenReductionVL_Ext_VL<SDNode vop, PatFrags extop, string instru
15481505
defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
15491506
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
15501507
GetVTypePredicates<wti>.Predicates) in {
1551-
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
1552-
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
1553-
VR:$rs2, (vti.Mask true_mask), VLOpFrag,
1554-
(XLenVT timm:$policy))),
1555-
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
1556-
(wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
1557-
(wti_m1.Vector VR:$rs2), GPR:$vl, vti.Log2SEW,
1558-
(XLenVT timm:$policy))>;
15591508
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
15601509
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
15611510
VR:$rs2, (vti.Mask V0), VLOpFrag,
@@ -1575,18 +1524,6 @@ multiclass VPatWidenReductionVL_Ext_VL_RM<SDNode vop, PatFrags extop, string ins
15751524
defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
15761525
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
15771526
GetVTypePredicates<wti>.Predicates) in {
1578-
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
1579-
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
1580-
VR:$rs2, (vti.Mask true_mask), VLOpFrag,
1581-
(XLenVT timm:$policy))),
1582-
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
1583-
(wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
1584-
(wti_m1.Vector VR:$rs2),
1585-
// Value to indicate no rounding mode change in
1586-
// RISCVInsertReadWriteCSR
1587-
FRM_DYN,
1588-
GPR:$vl, vti.Log2SEW,
1589-
(XLenVT timm:$policy))>;
15901527
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
15911528
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
15921529
VR:$rs2, (vti.Mask V0), VLOpFrag,

llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,11 +1049,8 @@ define <vscale x 2 x float> @vfredusum(<vscale x 2 x float> %passthru, <vscale x
10491049
define <vscale x 2 x i32> @vredsum_allones_mask(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, i64 %vl) {
10501050
; CHECK-LABEL: vredsum_allones_mask:
10511051
; CHECK: # %bb.0:
1052-
; CHECK-NEXT: vsetvli zero, a0, e32, m1, ta, ma
1053-
; CHECK-NEXT: vmv1r.v v11, v8
1054-
; CHECK-NEXT: vredsum.vs v11, v9, v10
1055-
; CHECK-NEXT: vsetvli zero, zero, e32, m1, tu, ma
1056-
; CHECK-NEXT: vmv.v.v v8, v11
1052+
; CHECK-NEXT: vsetvli zero, a0, e32, m1, tu, ma
1053+
; CHECK-NEXT: vredsum.vs v8, v9, v10
10571054
; CHECK-NEXT: ret
10581055
%splat = insertelement <vscale x 2 x i1> poison, i1 -1, i32 0
10591056
%mask = shufflevector <vscale x 2 x i1> %splat, <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer
@@ -1070,12 +1067,9 @@ define <vscale x 2 x i32> @vredsum_allones_mask(<vscale x 2 x i32> %passthru, <v
10701067
define <vscale x 2 x float> @vfredusum_allones_mask(<vscale x 2 x float> %passthru, <vscale x 2 x float> %x, <vscale x 2 x float> %y, i64 %vl) {
10711068
; CHECK-LABEL: vfredusum_allones_mask:
10721069
; CHECK: # %bb.0:
1073-
; CHECK-NEXT: vsetvli zero, a0, e32, m1, ta, ma
1070+
; CHECK-NEXT: vsetvli zero, a0, e32, m1, tu, ma
10741071
; CHECK-NEXT: fsrmi a0, 0
1075-
; CHECK-NEXT: vmv1r.v v11, v8
1076-
; CHECK-NEXT: vfredusum.vs v11, v9, v10
1077-
; CHECK-NEXT: vsetvli zero, zero, e32, m1, tu, ma
1078-
; CHECK-NEXT: vmv.v.v v8, v11
1072+
; CHECK-NEXT: vfredusum.vs v8, v9, v10
10791073
; CHECK-NEXT: fsrm a0
10801074
; CHECK-NEXT: ret
10811075
%splat = insertelement <vscale x 2 x i1> poison, i1 -1, i32 0

0 commit comments

Comments
 (0)