Skip to content

Commit 08b20d8

Browse files
authored
[RISCV] Generaize reduction tree matching to fp sum reductions (#68599)
This builds on the transform introduced in f0505c3, and generalizes to all integer operations in 45a334d. This change adds support for floating point sumation. A couple of notes: * I chose to leave fmaxnum and fminnum unhandled for the moment. They have a slightly different set of legality rules. * We could form strictly sequenced FADD reductions for FADDs without fast math flags. As the ordered reductions are more expensive, I left thinking about this as a future exercise. * This can't yet match the full vector reduce + start value idiom. That will be an upcoming set of changes.
1 parent 6668d14 commit 08b20d8

File tree

2 files changed

+174
-6
lines changed

2 files changed

+174
-6
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11299,7 +11299,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
1129911299
}
1130011300
}
1130111301

11302-
/// Given an integer binary operator, return the generic ISD::VECREDUCE_OP
11302+
/// Given a binary operator, return the *associative* generic ISD::VECREDUCE_OP
1130311303
/// which corresponds to it.
1130411304
static unsigned getVecReduceOpcode(unsigned Opc) {
1130511305
switch (Opc) {
@@ -11321,6 +11321,9 @@ static unsigned getVecReduceOpcode(unsigned Opc) {
1132111321
return ISD::VECREDUCE_OR;
1132211322
case ISD::XOR:
1132311323
return ISD::VECREDUCE_XOR;
11324+
case ISD::FADD:
11325+
// Note: This is the associative form of the generic reduction opcode.
11326+
return ISD::VECREDUCE_FADD;
1132411327
}
1132511328
}
1132611329

@@ -11347,12 +11350,16 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
1134711350

1134811351
const SDLoc DL(N);
1134911352
const EVT VT = N->getValueType(0);
11353+
const unsigned Opc = N->getOpcode();
1135011354

11351-
// TODO: Handle floating point here.
11352-
if (!VT.isInteger())
11355+
// For FADD, we only handle the case with reassociation allowed. We
11356+
// could handle strict reduction order, but at the moment, there's no
11357+
// known reason to, and the complexity isn't worth it.
11358+
// TODO: Handle fminnum and fmaxnum here
11359+
if (!VT.isInteger() &&
11360+
(Opc != ISD::FADD || !N->getFlags().hasAllowReassociation()))
1135311361
return SDValue();
1135411362

11355-
const unsigned Opc = N->getOpcode();
1135611363
const unsigned ReduceOpc = getVecReduceOpcode(Opc);
1135711364
assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
1135811365
"Inconsistent mappings");
@@ -11385,7 +11392,7 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
1138511392
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
1138611393
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
1138711394
DAG.getVectorIdxConstant(0, DL));
11388-
return DAG.getNode(ReduceOpc, DL, VT, Vec);
11395+
return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
1138911396
}
1139011397

1139111398
// Match (binop (reduce (extract_subvector V, 0),
@@ -11407,7 +11414,9 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
1140711414
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
1140811415
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
1140911416
DAG.getVectorIdxConstant(0, DL));
11410-
return DAG.getNode(ReduceOpc, DL, VT, Vec);
11417+
auto Flags = ReduceVec->getFlags();
11418+
Flags.intersectWith(N->getFlags());
11419+
return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
1141111420
}
1141211421
}
1141311422

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,165 @@ define i32 @reduce_umin_16xi32_prefix5(ptr %p) {
764764
%umin3 = call i32 @llvm.umin.i32(i32 %umin2, i32 %e4)
765765
ret i32 %umin3
766766
}
767+
768+
define float @reduce_fadd_16xf32_prefix2(ptr %p) {
769+
; CHECK-LABEL: reduce_fadd_16xf32_prefix2:
770+
; CHECK: # %bb.0:
771+
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
772+
; CHECK-NEXT: vle32.v v8, (a0)
773+
; CHECK-NEXT: vmv.s.x v9, zero
774+
; CHECK-NEXT: vfredusum.vs v8, v8, v9
775+
; CHECK-NEXT: vfmv.f.s fa0, v8
776+
; CHECK-NEXT: ret
777+
%v = load <16 x float>, ptr %p, align 256
778+
%e0 = extractelement <16 x float> %v, i32 0
779+
%e1 = extractelement <16 x float> %v, i32 1
780+
%fadd0 = fadd fast float %e0, %e1
781+
ret float %fadd0
782+
}
783+
784+
define float @reduce_fadd_16xi32_prefix5(ptr %p) {
785+
; CHECK-LABEL: reduce_fadd_16xi32_prefix5:
786+
; CHECK: # %bb.0:
787+
; CHECK-NEXT: lui a1, 524288
788+
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
789+
; CHECK-NEXT: vle32.v v8, (a0)
790+
; CHECK-NEXT: vmv.s.x v10, a1
791+
; CHECK-NEXT: vsetivli zero, 6, e32, m2, tu, ma
792+
; CHECK-NEXT: vslideup.vi v8, v10, 5
793+
; CHECK-NEXT: vsetivli zero, 7, e32, m2, tu, ma
794+
; CHECK-NEXT: vslideup.vi v8, v10, 6
795+
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
796+
; CHECK-NEXT: vslideup.vi v8, v10, 7
797+
; CHECK-NEXT: vfredusum.vs v8, v8, v10
798+
; CHECK-NEXT: vfmv.f.s fa0, v8
799+
; CHECK-NEXT: ret
800+
%v = load <16 x float>, ptr %p, align 256
801+
%e0 = extractelement <16 x float> %v, i32 0
802+
%e1 = extractelement <16 x float> %v, i32 1
803+
%e2 = extractelement <16 x float> %v, i32 2
804+
%e3 = extractelement <16 x float> %v, i32 3
805+
%e4 = extractelement <16 x float> %v, i32 4
806+
%fadd0 = fadd fast float %e0, %e1
807+
%fadd1 = fadd fast float %fadd0, %e2
808+
%fadd2 = fadd fast float %fadd1, %e3
809+
%fadd3 = fadd fast float %fadd2, %e4
810+
ret float %fadd3
811+
}
812+
813+
;; Corner case tests for fadd associativity
814+
815+
; Negative test, not associative. Would need strict opcode.
816+
define float @reduce_fadd_2xf32_non_associative(ptr %p) {
817+
; CHECK-LABEL: reduce_fadd_2xf32_non_associative:
818+
; CHECK: # %bb.0:
819+
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
820+
; CHECK-NEXT: vle32.v v8, (a0)
821+
; CHECK-NEXT: vfmv.f.s fa5, v8
822+
; CHECK-NEXT: vslidedown.vi v8, v8, 1
823+
; CHECK-NEXT: vfmv.f.s fa4, v8
824+
; CHECK-NEXT: fadd.s fa0, fa5, fa4
825+
; CHECK-NEXT: ret
826+
%v = load <2 x float>, ptr %p, align 256
827+
%e0 = extractelement <2 x float> %v, i32 0
828+
%e1 = extractelement <2 x float> %v, i32 1
829+
%fadd0 = fadd float %e0, %e1
830+
ret float %fadd0
831+
}
832+
833+
; Positive test - minimal set of fast math flags
834+
define float @reduce_fadd_2xf32_reassoc_only(ptr %p) {
835+
; CHECK-LABEL: reduce_fadd_2xf32_reassoc_only:
836+
; CHECK: # %bb.0:
837+
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
838+
; CHECK-NEXT: vle32.v v8, (a0)
839+
; CHECK-NEXT: lui a0, 524288
840+
; CHECK-NEXT: vmv.s.x v9, a0
841+
; CHECK-NEXT: vfredusum.vs v8, v8, v9
842+
; CHECK-NEXT: vfmv.f.s fa0, v8
843+
; CHECK-NEXT: ret
844+
%v = load <2 x float>, ptr %p, align 256
845+
%e0 = extractelement <2 x float> %v, i32 0
846+
%e1 = extractelement <2 x float> %v, i32 1
847+
%fadd0 = fadd reassoc float %e0, %e1
848+
ret float %fadd0
849+
}
850+
851+
; Negative test - wrong fast math flag.
852+
define float @reduce_fadd_2xf32_ninf_only(ptr %p) {
853+
; CHECK-LABEL: reduce_fadd_2xf32_ninf_only:
854+
; CHECK: # %bb.0:
855+
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
856+
; CHECK-NEXT: vle32.v v8, (a0)
857+
; CHECK-NEXT: vfmv.f.s fa5, v8
858+
; CHECK-NEXT: vslidedown.vi v8, v8, 1
859+
; CHECK-NEXT: vfmv.f.s fa4, v8
860+
; CHECK-NEXT: fadd.s fa0, fa5, fa4
861+
; CHECK-NEXT: ret
862+
%v = load <2 x float>, ptr %p, align 256
863+
%e0 = extractelement <2 x float> %v, i32 0
864+
%e1 = extractelement <2 x float> %v, i32 1
865+
%fadd0 = fadd ninf float %e0, %e1
866+
ret float %fadd0
867+
}
868+
869+
870+
; Negative test - last fadd is not associative
871+
define float @reduce_fadd_4xi32_non_associative(ptr %p) {
872+
; CHECK-LABEL: reduce_fadd_4xi32_non_associative:
873+
; CHECK: # %bb.0:
874+
; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
875+
; CHECK-NEXT: vle32.v v8, (a0)
876+
; CHECK-NEXT: vslidedown.vi v9, v8, 3
877+
; CHECK-NEXT: vfmv.f.s fa5, v9
878+
; CHECK-NEXT: lui a0, 524288
879+
; CHECK-NEXT: vmv.s.x v9, a0
880+
; CHECK-NEXT: vslideup.vi v8, v9, 3
881+
; CHECK-NEXT: vfredusum.vs v8, v8, v9
882+
; CHECK-NEXT: vfmv.f.s fa4, v8
883+
; CHECK-NEXT: fadd.s fa0, fa4, fa5
884+
; CHECK-NEXT: ret
885+
%v = load <4 x float>, ptr %p, align 256
886+
%e0 = extractelement <4 x float> %v, i32 0
887+
%e1 = extractelement <4 x float> %v, i32 1
888+
%e2 = extractelement <4 x float> %v, i32 2
889+
%e3 = extractelement <4 x float> %v, i32 3
890+
%fadd0 = fadd fast float %e0, %e1
891+
%fadd1 = fadd fast float %fadd0, %e2
892+
%fadd2 = fadd float %fadd1, %e3
893+
ret float %fadd2
894+
}
895+
896+
; Negative test - first fadd is not associative
897+
; We could form a reduce for elements 2 and 3.
898+
define float @reduce_fadd_4xi32_non_associative2(ptr %p) {
899+
; CHECK-LABEL: reduce_fadd_4xi32_non_associative2:
900+
; CHECK: # %bb.0:
901+
; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
902+
; CHECK-NEXT: vle32.v v8, (a0)
903+
; CHECK-NEXT: vfmv.f.s fa5, v8
904+
; CHECK-NEXT: vslidedown.vi v9, v8, 1
905+
; CHECK-NEXT: vfmv.f.s fa4, v9
906+
; CHECK-NEXT: vslidedown.vi v9, v8, 2
907+
; CHECK-NEXT: vfmv.f.s fa3, v9
908+
; CHECK-NEXT: vslidedown.vi v8, v8, 3
909+
; CHECK-NEXT: vfmv.f.s fa2, v8
910+
; CHECK-NEXT: fadd.s fa5, fa5, fa4
911+
; CHECK-NEXT: fadd.s fa4, fa3, fa2
912+
; CHECK-NEXT: fadd.s fa0, fa5, fa4
913+
; CHECK-NEXT: ret
914+
%v = load <4 x float>, ptr %p, align 256
915+
%e0 = extractelement <4 x float> %v, i32 0
916+
%e1 = extractelement <4 x float> %v, i32 1
917+
%e2 = extractelement <4 x float> %v, i32 2
918+
%e3 = extractelement <4 x float> %v, i32 3
919+
%fadd0 = fadd float %e0, %e1
920+
%fadd1 = fadd fast float %fadd0, %e2
921+
%fadd2 = fadd fast float %fadd1, %e3
922+
ret float %fadd2
923+
}
924+
925+
767926
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
768927
; RV32: {{.*}}
769928
; RV64: {{.*}}

0 commit comments

Comments
 (0)