Skip to content

Commit 94f6b6d

Browse files
authored
[SelectionDAG][RISCV] Promote VECREDUCE_{FMAX,FMIN,FMAXIMUM,FMINIMUM} (#128800)
This patch also adds the tests for VP_REDUCE_{FMAX,FMIN,FMAXIMUM,FMINIMUM}, which have been supported for a while.
1 parent 24abf2c commit 94f6b6d

File tree

7 files changed

+838
-15
lines changed

7 files changed

+838
-15
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2913,31 +2913,34 @@ SDValue SelectionDAGLegalize::ExpandPARITY(SDValue Op, const SDLoc &dl) {
29132913
}
29142914

29152915
SDValue SelectionDAGLegalize::PromoteReduction(SDNode *Node) {
2916-
MVT VecVT = Node->getOperand(1).getSimpleValueType();
2916+
bool IsVPOpcode = ISD::isVPOpcode(Node->getOpcode());
2917+
MVT VecVT = IsVPOpcode ? Node->getOperand(1).getSimpleValueType()
2918+
: Node->getOperand(0).getSimpleValueType();
29172919
MVT NewVecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), VecVT);
29182920
MVT ScalarVT = Node->getSimpleValueType(0);
29192921
MVT NewScalarVT = NewVecVT.getVectorElementType();
29202922

29212923
SDLoc DL(Node);
29222924
SmallVector<SDValue, 4> Operands(Node->getNumOperands());
29232925

2924-
// promote the initial value.
29252926
// FIXME: Support integer.
29262927
assert(Node->getOperand(0).getValueType().isFloatingPoint() &&
29272928
"Only FP promotion is supported");
2928-
Operands[0] =
2929-
DAG.getNode(ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand(0));
29302929

2931-
for (unsigned j = 1; j != Node->getNumOperands(); ++j)
2930+
for (unsigned j = 0; j != Node->getNumOperands(); ++j)
29322931
if (Node->getOperand(j).getValueType().isVector() &&
2933-
!(ISD::isVPOpcode(Node->getOpcode()) &&
2932+
!(IsVPOpcode &&
29342933
ISD::getVPMaskIdx(Node->getOpcode()) == j)) { // Skip mask operand.
29352934
// promote the vector operand.
29362935
// FIXME: Support integer.
29372936
assert(Node->getOperand(j).getValueType().isFloatingPoint() &&
29382937
"Only FP promotion is supported");
29392938
Operands[j] =
29402939
DAG.getNode(ISD::FP_EXTEND, DL, NewVecVT, Node->getOperand(j));
2940+
} else if (Node->getOperand(j).getValueType().isFloatingPoint()) {
2941+
// promote the initial value.
2942+
Operands[j] =
2943+
DAG.getNode(ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand(j));
29412944
} else {
29422945
Operands[j] = Node->getOperand(j); // Skip VL operand.
29432946
}
@@ -5049,7 +5052,11 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
50495052
Node->getOpcode() == ISD::SINT_TO_FP ||
50505053
Node->getOpcode() == ISD::SETCC ||
50515054
Node->getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
5052-
Node->getOpcode() == ISD::INSERT_VECTOR_ELT) {
5055+
Node->getOpcode() == ISD::INSERT_VECTOR_ELT ||
5056+
Node->getOpcode() == ISD::VECREDUCE_FMAX ||
5057+
Node->getOpcode() == ISD::VECREDUCE_FMIN ||
5058+
Node->getOpcode() == ISD::VECREDUCE_FMAXIMUM ||
5059+
Node->getOpcode() == ISD::VECREDUCE_FMINIMUM) {
50535060
OVT = Node->getOperand(0).getSimpleValueType();
50545061
}
50555062
if (Node->getOpcode() == ISD::ATOMIC_STORE ||
@@ -5796,6 +5803,10 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
57965803
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)));
57975804
break;
57985805
}
5806+
case ISD::VECREDUCE_FMAX:
5807+
case ISD::VECREDUCE_FMIN:
5808+
case ISD::VECREDUCE_FMAXIMUM:
5809+
case ISD::VECREDUCE_FMINIMUM:
57995810
case ISD::VP_REDUCE_FMAX:
58005811
case ISD::VP_REDUCE_FMIN:
58015812
case ISD::VP_REDUCE_FMAXIMUM:

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,13 +503,19 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
503503
case ISD::VECREDUCE_UMIN:
504504
case ISD::VECREDUCE_FADD:
505505
case ISD::VECREDUCE_FMUL:
506+
case ISD::VECTOR_FIND_LAST_ACTIVE:
507+
Action = TLI.getOperationAction(Node->getOpcode(),
508+
Node->getOperand(0).getValueType());
509+
break;
506510
case ISD::VECREDUCE_FMAX:
507511
case ISD::VECREDUCE_FMIN:
508512
case ISD::VECREDUCE_FMAXIMUM:
509513
case ISD::VECREDUCE_FMINIMUM:
510-
case ISD::VECTOR_FIND_LAST_ACTIVE:
511514
Action = TLI.getOperationAction(Node->getOpcode(),
512515
Node->getOperand(0).getValueType());
516+
// Defer non-vector results to LegalizeDAG.
517+
if (Action == TargetLowering::Promote)
518+
Action = TargetLowering::Legal;
513519
break;
514520
case ISD::VECREDUCE_SEQ_FADD:
515521
case ISD::VECREDUCE_SEQ_FMUL:

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -959,13 +959,35 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
959959

960960
// TODO: support more ops.
961961
static const unsigned ZvfhminZvfbfminPromoteOps[] = {
962-
ISD::FMINNUM, ISD::FMAXNUM, ISD::FADD, ISD::FSUB,
963-
ISD::FMUL, ISD::FMA, ISD::FDIV, ISD::FSQRT,
964-
ISD::FCEIL, ISD::FTRUNC, ISD::FFLOOR, ISD::FROUND,
965-
ISD::FROUNDEVEN, ISD::FRINT, ISD::FNEARBYINT, ISD::IS_FPCLASS,
966-
ISD::SETCC, ISD::FMAXIMUM, ISD::FMINIMUM, ISD::STRICT_FADD,
967-
ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV, ISD::STRICT_FSQRT,
968-
ISD::STRICT_FMA};
962+
ISD::FMINNUM,
963+
ISD::FMAXNUM,
964+
ISD::FADD,
965+
ISD::FSUB,
966+
ISD::FMUL,
967+
ISD::FMA,
968+
ISD::FDIV,
969+
ISD::FSQRT,
970+
ISD::FCEIL,
971+
ISD::FTRUNC,
972+
ISD::FFLOOR,
973+
ISD::FROUND,
974+
ISD::FROUNDEVEN,
975+
ISD::FRINT,
976+
ISD::FNEARBYINT,
977+
ISD::IS_FPCLASS,
978+
ISD::SETCC,
979+
ISD::FMAXIMUM,
980+
ISD::FMINIMUM,
981+
ISD::STRICT_FADD,
982+
ISD::STRICT_FSUB,
983+
ISD::STRICT_FMUL,
984+
ISD::STRICT_FDIV,
985+
ISD::STRICT_FSQRT,
986+
ISD::STRICT_FMA,
987+
ISD::VECREDUCE_FMIN,
988+
ISD::VECREDUCE_FMAX,
989+
ISD::VECREDUCE_FMINIMUM,
990+
ISD::VECREDUCE_FMAXIMUM};
969991

970992
// TODO: support more vp ops.
971993
static const unsigned ZvfhminZvfbfminPromoteVPOps[] = {
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfbfmin,+zvfbfmin,+v -target-abi=ilp32d \
3+
; RUN: -verify-machineinstrs < %s | FileCheck %s
4+
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfbfmin,+zvfbfmin,+v -target-abi=lp64d \
5+
; RUN: -verify-machineinstrs < %s | FileCheck %s
6+
7+
define bfloat @vreduce_fmin_nxv4bf16(<vscale x 4 x bfloat> %val) {
8+
; CHECK-LABEL: vreduce_fmin_nxv4bf16:
9+
; CHECK: # %bb.0:
10+
; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, ma
11+
; CHECK-NEXT: vfwcvtbf16.f.f.v v10, v8
12+
; CHECK-NEXT: vsetvli zero, zero, e32, m2, ta, ma
13+
; CHECK-NEXT: vfredmin.vs v8, v10, v10
14+
; CHECK-NEXT: vfmv.f.s fa5, v8
15+
; CHECK-NEXT: fcvt.bf16.s fa0, fa5
16+
; CHECK-NEXT: ret
17+
%s = call bfloat @llvm.vector.reduce.fmin.nxv4bf16(<vscale x 4 x bfloat> %val)
18+
ret bfloat %s
19+
}
20+
21+
define bfloat @vreduce_fmax_nxv4bf16(<vscale x 4 x bfloat> %val) {
22+
; CHECK-LABEL: vreduce_fmax_nxv4bf16:
23+
; CHECK: # %bb.0:
24+
; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, ma
25+
; CHECK-NEXT: vfwcvtbf16.f.f.v v10, v8
26+
; CHECK-NEXT: vsetvli zero, zero, e32, m2, ta, ma
27+
; CHECK-NEXT: vfredmax.vs v8, v10, v10
28+
; CHECK-NEXT: vfmv.f.s fa5, v8
29+
; CHECK-NEXT: fcvt.bf16.s fa0, fa5
30+
; CHECK-NEXT: ret
31+
%s = call bfloat @llvm.vector.reduce.fmax.nxv4bf16(<vscale x 4 x bfloat> %val)
32+
ret bfloat %s
33+
}
34+
35+
define bfloat @vreduce_fmin_nnan_nxv4bf16(<vscale x 4 x bfloat> %val) {
36+
; CHECK-LABEL: vreduce_fmin_nnan_nxv4bf16:
37+
; CHECK: # %bb.0:
38+
; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, ma
39+
; CHECK-NEXT: vfwcvtbf16.f.f.v v10, v8
40+
; CHECK-NEXT: vsetvli zero, zero, e32, m2, ta, ma
41+
; CHECK-NEXT: vfredmin.vs v8, v10, v10
42+
; CHECK-NEXT: vfmv.f.s fa5, v8
43+
; CHECK-NEXT: fcvt.bf16.s fa0, fa5
44+
; CHECK-NEXT: ret
45+
%s = call nnan bfloat @llvm.vector.reduce.fmin.nxv4bf16(<vscale x 4 x bfloat> %val)
46+
ret bfloat %s
47+
}
48+
49+
define bfloat @vreduce_fmax_nnan_nxv4bf16(<vscale x 4 x bfloat> %val) {
50+
; CHECK-LABEL: vreduce_fmax_nnan_nxv4bf16:
51+
; CHECK: # %bb.0:
52+
; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, ma
53+
; CHECK-NEXT: vfwcvtbf16.f.f.v v10, v8
54+
; CHECK-NEXT: vsetvli zero, zero, e32, m2, ta, ma
55+
; CHECK-NEXT: vfredmax.vs v8, v10, v10
56+
; CHECK-NEXT: vfmv.f.s fa5, v8
57+
; CHECK-NEXT: fcvt.bf16.s fa0, fa5
58+
; CHECK-NEXT: ret
59+
%s = call nnan bfloat @llvm.vector.reduce.fmax.nxv4bf16(<vscale x 4 x bfloat> %val)
60+
ret bfloat %s
61+
}
62+
63+
define bfloat @vreduce_fminimum_nxv4bf16(<vscale x 4 x bfloat> %val) {
64+
; CHECK-LABEL: vreduce_fminimum_nxv4bf16:
65+
; CHECK: # %bb.0:
66+
; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, ma
67+
; CHECK-NEXT: vfwcvtbf16.f.f.v v10, v8
68+
; CHECK-NEXT: vsetvli zero, zero, e32, m2, ta, ma
69+
; CHECK-NEXT: vmfne.vv v8, v10, v10
70+
; CHECK-NEXT: vcpop.m a0, v8
71+
; CHECK-NEXT: beqz a0, .LBB4_2
72+
; CHECK-NEXT: # %bb.1:
73+
; CHECK-NEXT: lui a0, 523264
74+
; CHECK-NEXT: fmv.w.x fa5, a0
75+
; CHECK-NEXT: fcvt.bf16.s fa0, fa5
76+
; CHECK-NEXT: ret
77+
; CHECK-NEXT: .LBB4_2:
78+
; CHECK-NEXT: vfredmin.vs v8, v10, v10
79+
; CHECK-NEXT: vfmv.f.s fa5, v8
80+
; CHECK-NEXT: fcvt.bf16.s fa0, fa5
81+
; CHECK-NEXT: ret
82+
%s = call bfloat @llvm.vector.reduce.fminimum.nxv4bf16(<vscale x 4 x bfloat> %val)
83+
ret bfloat %s
84+
}
85+
86+
define bfloat @vreduce_fmaximum_nxv4bf16(<vscale x 4 x bfloat> %val) {
87+
; CHECK-LABEL: vreduce_fmaximum_nxv4bf16:
88+
; CHECK: # %bb.0:
89+
; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, ma
90+
; CHECK-NEXT: vfwcvtbf16.f.f.v v10, v8
91+
; CHECK-NEXT: vsetvli zero, zero, e32, m2, ta, ma
92+
; CHECK-NEXT: vmfne.vv v8, v10, v10
93+
; CHECK-NEXT: vcpop.m a0, v8
94+
; CHECK-NEXT: beqz a0, .LBB5_2
95+
; CHECK-NEXT: # %bb.1:
96+
; CHECK-NEXT: lui a0, 523264
97+
; CHECK-NEXT: fmv.w.x fa5, a0
98+
; CHECK-NEXT: fcvt.bf16.s fa0, fa5
99+
; CHECK-NEXT: ret
100+
; CHECK-NEXT: .LBB5_2:
101+
; CHECK-NEXT: vfredmax.vs v8, v10, v10
102+
; CHECK-NEXT: vfmv.f.s fa5, v8
103+
; CHECK-NEXT: fcvt.bf16.s fa0, fa5
104+
; CHECK-NEXT: ret
105+
%s = call bfloat @llvm.vector.reduce.fmaximum.nxv4bf16(<vscale x 4 x bfloat> %val)
106+
ret bfloat %s
107+
}
108+
109+
define bfloat @vreduce_fminimum_nnan_nxv4bf16(<vscale x 4 x bfloat> %val) {
110+
; CHECK-LABEL: vreduce_fminimum_nnan_nxv4bf16:
111+
; CHECK: # %bb.0:
112+
; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, ma
113+
; CHECK-NEXT: vfwcvtbf16.f.f.v v10, v8
114+
; CHECK-NEXT: vsetvli zero, zero, e32, m2, ta, ma
115+
; CHECK-NEXT: vfredmin.vs v8, v10, v10
116+
; CHECK-NEXT: vfmv.f.s fa5, v8
117+
; CHECK-NEXT: fcvt.bf16.s fa0, fa5
118+
; CHECK-NEXT: ret
119+
%s = call nnan bfloat @llvm.vector.reduce.fminimum.nxv4bf16(<vscale x 4 x bfloat> %val)
120+
ret bfloat %s
121+
}
122+
123+
define bfloat @vreduce_fmaximum_nnan_nxv4bf16(<vscale x 4 x bfloat> %val) {
124+
; CHECK-LABEL: vreduce_fmaximum_nnan_nxv4bf16:
125+
; CHECK: # %bb.0:
126+
; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, ma
127+
; CHECK-NEXT: vfwcvtbf16.f.f.v v10, v8
128+
; CHECK-NEXT: vsetvli zero, zero, e32, m2, ta, ma
129+
; CHECK-NEXT: vfredmax.vs v8, v10, v10
130+
; CHECK-NEXT: vfmv.f.s fa5, v8
131+
; CHECK-NEXT: fcvt.bf16.s fa0, fa5
132+
; CHECK-NEXT: ret
133+
%s = call nnan bfloat @llvm.vector.reduce.fmaximum.nxv4bf16(<vscale x 4 x bfloat> %val)
134+
ret bfloat %s
135+
}
136+

0 commit comments

Comments
 (0)