Skip to content

Commit 11c1827

Browse files
authored
[RISCV] Use masked pseudo peephole for reduction pseudos (#71508)
After #71483 we now have a way of marking masked pseudos as having an unmasked equivalent, but their mask shouldn't be folded unless it's all ones since it would affect the result. This patch uses it to mark the pseudos for vredsum and friends, which in turn allows us to remove the unmasked patterns, and catch some other forms of vmerge.
1 parent c79b544 commit 11c1827

File tree

3 files changed

+8
-75
lines changed

3 files changed

+8
-75
lines changed

llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3213,7 +3213,8 @@ multiclass VPseudoTernaryWithTailPolicy<VReg RetClass,
32133213
defvar mx = MInfo.MX;
32143214
let isCommutable = Commutable in
32153215
def "_" # mx # "_E" # sew : VPseudoTernaryNoMaskWithPolicy<RetClass, Op1Class, Op2Class, Constraint>;
3216-
def "_" # mx # "_E" # sew # "_MASK" : VPseudoTernaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>;
3216+
def "_" # mx # "_E" # sew # "_MASK" : VPseudoTernaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>,
3217+
RISCVMaskedPseudo<MaskIdx=3, MaskAffectsRes=true>;
32173218
}
32183219
}
32193220

@@ -3232,7 +3233,8 @@ multiclass VPseudoTernaryWithTailPolicyRoundingMode<VReg RetClass,
32323233
Op2Class, Constraint>;
32333234
def "_" # mx # "_E" # sew # "_MASK"
32343235
: VPseudoTernaryMaskPolicyRoundingMode<RetClass, Op1Class,
3235-
Op2Class, Constraint>;
3236+
Op2Class, Constraint>,
3237+
RISCVMaskedPseudo<MaskIdx=3, MaskAffectsRes=true>;
32363238
}
32373239
}
32383240

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,16 +1381,6 @@ multiclass VPatReductionVL<SDNode vop, string instruction_name, bit is_float> {
13811381
foreach vti = !if(is_float, AllFloatVectors, AllIntegerVectors) in {
13821382
defvar vti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # vti.SEW # "M1");
13831383
let Predicates = GetVTypePredicates<vti>.Predicates in {
1384-
def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
1385-
(vti.Vector vti.RegClass:$rs1), VR:$rs2,
1386-
(vti.Mask true_mask), VLOpFrag,
1387-
(XLenVT timm:$policy))),
1388-
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
1389-
(vti_m1.Vector VR:$merge),
1390-
(vti.Vector vti.RegClass:$rs1),
1391-
(vti_m1.Vector VR:$rs2),
1392-
GPR:$vl, vti.Log2SEW, (XLenVT timm:$policy))>;
1393-
13941384
def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
13951385
(vti.Vector vti.RegClass:$rs1), VR:$rs2,
13961386
(vti.Mask V0), VLOpFrag,
@@ -1408,19 +1398,6 @@ multiclass VPatReductionVL_RM<SDNode vop, string instruction_name, bit is_float>
14081398
foreach vti = !if(is_float, AllFloatVectors, AllIntegerVectors) in {
14091399
defvar vti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # vti.SEW # "M1");
14101400
let Predicates = GetVTypePredicates<vti>.Predicates in {
1411-
def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
1412-
(vti.Vector vti.RegClass:$rs1), VR:$rs2,
1413-
(vti.Mask true_mask), VLOpFrag,
1414-
(XLenVT timm:$policy))),
1415-
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
1416-
(vti_m1.Vector VR:$merge),
1417-
(vti.Vector vti.RegClass:$rs1),
1418-
(vti_m1.Vector VR:$rs2),
1419-
// Value to indicate no rounding mode change in
1420-
// RISCVInsertReadWriteCSR
1421-
FRM_DYN,
1422-
GPR:$vl, vti.Log2SEW, (XLenVT timm:$policy))>;
1423-
14241401
def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
14251402
(vti.Vector vti.RegClass:$rs1), VR:$rs2,
14261403
(vti.Mask V0), VLOpFrag,
@@ -1486,14 +1463,6 @@ multiclass VPatWidenReductionVL<SDNode vop, PatFrags extop, string instruction_n
14861463
defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
14871464
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
14881465
GetVTypePredicates<wti>.Predicates) in {
1489-
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
1490-
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
1491-
VR:$rs2, (vti.Mask true_mask), VLOpFrag,
1492-
(XLenVT timm:$policy))),
1493-
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
1494-
(wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
1495-
(wti_m1.Vector VR:$rs2), GPR:$vl, vti.Log2SEW,
1496-
(XLenVT timm:$policy))>;
14971466
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
14981467
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
14991468
VR:$rs2, (vti.Mask V0), VLOpFrag,
@@ -1513,18 +1482,6 @@ multiclass VPatWidenReductionVL_RM<SDNode vop, PatFrags extop, string instructio
15131482
defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
15141483
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
15151484
GetVTypePredicates<wti>.Predicates) in {
1516-
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
1517-
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
1518-
VR:$rs2, (vti.Mask true_mask), VLOpFrag,
1519-
(XLenVT timm:$policy))),
1520-
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
1521-
(wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
1522-
(wti_m1.Vector VR:$rs2),
1523-
// Value to indicate no rounding mode change in
1524-
// RISCVInsertReadWriteCSR
1525-
FRM_DYN,
1526-
GPR:$vl, vti.Log2SEW,
1527-
(XLenVT timm:$policy))>;
15281485
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
15291486
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
15301487
VR:$rs2, (vti.Mask V0), VLOpFrag,
@@ -1548,14 +1505,6 @@ multiclass VPatWidenReductionVL_Ext_VL<SDNode vop, PatFrags extop, string instru
15481505
defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
15491506
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
15501507
GetVTypePredicates<wti>.Predicates) in {
1551-
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
1552-
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
1553-
VR:$rs2, (vti.Mask true_mask), VLOpFrag,
1554-
(XLenVT timm:$policy))),
1555-
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
1556-
(wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
1557-
(wti_m1.Vector VR:$rs2), GPR:$vl, vti.Log2SEW,
1558-
(XLenVT timm:$policy))>;
15591508
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
15601509
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
15611510
VR:$rs2, (vti.Mask V0), VLOpFrag,
@@ -1575,18 +1524,6 @@ multiclass VPatWidenReductionVL_Ext_VL_RM<SDNode vop, PatFrags extop, string ins
15751524
defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
15761525
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
15771526
GetVTypePredicates<wti>.Predicates) in {
1578-
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
1579-
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
1580-
VR:$rs2, (vti.Mask true_mask), VLOpFrag,
1581-
(XLenVT timm:$policy))),
1582-
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
1583-
(wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
1584-
(wti_m1.Vector VR:$rs2),
1585-
// Value to indicate no rounding mode change in
1586-
// RISCVInsertReadWriteCSR
1587-
FRM_DYN,
1588-
GPR:$vl, vti.Log2SEW,
1589-
(XLenVT timm:$policy))>;
15901527
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
15911528
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
15921529
VR:$rs2, (vti.Mask V0), VLOpFrag,

llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,11 +1049,8 @@ define <vscale x 2 x float> @vfredusum(<vscale x 2 x float> %passthru, <vscale x
10491049
define <vscale x 2 x i32> @vredsum_allones_mask(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, i64 %vl) {
10501050
; CHECK-LABEL: vredsum_allones_mask:
10511051
; CHECK: # %bb.0:
1052-
; CHECK-NEXT: vsetvli zero, a0, e32, m1, ta, ma
1053-
; CHECK-NEXT: vmv1r.v v11, v8
1054-
; CHECK-NEXT: vredsum.vs v11, v9, v10
1055-
; CHECK-NEXT: vsetvli zero, zero, e32, m1, tu, ma
1056-
; CHECK-NEXT: vmv.v.v v8, v11
1052+
; CHECK-NEXT: vsetvli zero, a0, e32, m1, tu, ma
1053+
; CHECK-NEXT: vredsum.vs v8, v9, v10
10571054
; CHECK-NEXT: ret
10581055
%splat = insertelement <vscale x 2 x i1> poison, i1 -1, i32 0
10591056
%mask = shufflevector <vscale x 2 x i1> %splat, <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer
@@ -1070,12 +1067,9 @@ define <vscale x 2 x i32> @vredsum_allones_mask(<vscale x 2 x i32> %passthru, <v
10701067
define <vscale x 2 x float> @vfredusum_allones_mask(<vscale x 2 x float> %passthru, <vscale x 2 x float> %x, <vscale x 2 x float> %y, i64 %vl) {
10711068
; CHECK-LABEL: vfredusum_allones_mask:
10721069
; CHECK: # %bb.0:
1073-
; CHECK-NEXT: vsetvli zero, a0, e32, m1, ta, ma
1070+
; CHECK-NEXT: vsetvli zero, a0, e32, m1, tu, ma
10741071
; CHECK-NEXT: fsrmi a0, 0
1075-
; CHECK-NEXT: vmv1r.v v11, v8
1076-
; CHECK-NEXT: vfredusum.vs v11, v9, v10
1077-
; CHECK-NEXT: vsetvli zero, zero, e32, m1, tu, ma
1078-
; CHECK-NEXT: vmv.v.v v8, v11
1072+
; CHECK-NEXT: vfredusum.vs v8, v9, v10
10791073
; CHECK-NEXT: fsrm a0
10801074
; CHECK-NEXT: ret
10811075
%splat = insertelement <vscale x 2 x i1> poison, i1 -1, i32 0

0 commit comments

Comments
 (0)