Skip to content

Commit b277bf5

Browse files
[LLVM][CodeGen][SVE] Clean up lowering of VECTOR_SPLICE operations. (llvm#91330)
Remove DAG combine that is performing type legalisation and instead add isel patterns for all legal types.
1 parent 1aca8ed commit b277bf5

File tree

5 files changed

+136
-49
lines changed

5 files changed

+136
-49
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12240,9 +12240,8 @@ void SelectionDAGBuilder::visitVectorSplice(const CallInst &I) {
1224012240

1224112241
// VECTOR_SHUFFLE doesn't support a scalable mask so use a dedicated node.
1224212242
if (VT.isScalableVector()) {
12243-
MVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout());
1224412243
setValue(&I, DAG.getNode(ISD::VECTOR_SPLICE, DL, VT, V1, V2,
12245-
DAG.getConstant(Imm, DL, IdxVT)));
12244+
DAG.getVectorIdxConstant(Imm, DL)));
1224612245
return;
1224712246
}
1224812247

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,9 +1048,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
10481048
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
10491049

10501050
setTargetDAGCombine({ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND,
1051-
ISD::VECTOR_SPLICE, ISD::SIGN_EXTEND_INREG,
1052-
ISD::CONCAT_VECTORS, ISD::EXTRACT_SUBVECTOR,
1053-
ISD::INSERT_SUBVECTOR, ISD::STORE, ISD::BUILD_VECTOR});
1051+
ISD::SIGN_EXTEND_INREG, ISD::CONCAT_VECTORS,
1052+
ISD::EXTRACT_SUBVECTOR, ISD::INSERT_SUBVECTOR,
1053+
ISD::STORE, ISD::BUILD_VECTOR});
10541054
setTargetDAGCombine(ISD::TRUNCATE);
10551055
setTargetDAGCombine(ISD::LOAD);
10561056

@@ -1580,6 +1580,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
15801580
setOperationAction(ISD::MLOAD, VT, Custom);
15811581
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
15821582
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
1583+
setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
15831584

15841585
if (!Subtarget->isLittleEndian())
15851586
setOperationAction(ISD::BITCAST, VT, Expand);
@@ -10102,10 +10103,9 @@ SDValue AArch64TargetLowering::LowerVECTOR_SPLICE(SDValue Op,
1010210103
Op.getOperand(1));
1010310104
}
1010410105

10105-
// This will select to an EXT instruction, which has a maximum immediate
10106-
// value of 255, hence 2048-bits is the maximum value we can lower.
10107-
if (IdxVal >= 0 &&
10108-
IdxVal < int64_t(2048 / Ty.getVectorElementType().getSizeInBits()))
10106+
// We can select to an EXT instruction when indexing the first 256 bytes.
10107+
unsigned BlockSize = AArch64::SVEBitsPerBlock / Ty.getVectorMinNumElements();
10108+
if (IdxVal >= 0 && (IdxVal * BlockSize / 8) < 256)
1010910109
return Op;
1011010110

1011110111
return SDValue();
@@ -24255,28 +24255,6 @@ performInsertVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
2425524255
return performPostLD1Combine(N, DCI, true);
2425624256
}
2425724257

24258-
static SDValue performSVESpliceCombine(SDNode *N, SelectionDAG &DAG) {
24259-
EVT Ty = N->getValueType(0);
24260-
if (Ty.isInteger())
24261-
return SDValue();
24262-
24263-
EVT IntTy = Ty.changeVectorElementTypeToInteger();
24264-
EVT ExtIntTy = getPackedSVEVectorVT(IntTy.getVectorElementCount());
24265-
if (ExtIntTy.getVectorElementType().getScalarSizeInBits() <
24266-
IntTy.getVectorElementType().getScalarSizeInBits())
24267-
return SDValue();
24268-
24269-
SDLoc DL(N);
24270-
SDValue LHS = DAG.getAnyExtOrTrunc(DAG.getBitcast(IntTy, N->getOperand(0)),
24271-
DL, ExtIntTy);
24272-
SDValue RHS = DAG.getAnyExtOrTrunc(DAG.getBitcast(IntTy, N->getOperand(1)),
24273-
DL, ExtIntTy);
24274-
SDValue Idx = N->getOperand(2);
24275-
SDValue Splice = DAG.getNode(ISD::VECTOR_SPLICE, DL, ExtIntTy, LHS, RHS, Idx);
24276-
SDValue Trunc = DAG.getAnyExtOrTrunc(Splice, DL, IntTy);
24277-
return DAG.getBitcast(Ty, Trunc);
24278-
}
24279-
2428024258
static SDValue performFPExtendCombine(SDNode *N, SelectionDAG &DAG,
2428124259
TargetLowering::DAGCombinerInfo &DCI,
2428224260
const AArch64Subtarget *Subtarget) {
@@ -24661,8 +24639,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2466124639
case ISD::MGATHER:
2466224640
case ISD::MSCATTER:
2466324641
return performMaskedGatherScatterCombine(N, DCI, DAG);
24664-
case ISD::VECTOR_SPLICE:
24665-
return performSVESpliceCombine(N, DAG);
2466624642
case ISD::FP_EXTEND:
2466724643
return performFPExtendCombine(N, DAG, DCI, Subtarget);
2466824644
case AArch64ISD::BRCOND:

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,14 +1994,21 @@ let Predicates = [HasSVEorSME] in {
19941994
(LASTB_VPZ_D (PTRUE_D 31), ZPR:$Z1), dsub))>;
19951995

19961996
// Splice with lane bigger or equal to 0
1997-
def : Pat<(nxv16i8 (vector_splice (nxv16i8 ZPR:$Z1), (nxv16i8 ZPR:$Z2), (i64 (sve_ext_imm_0_255 i32:$index)))),
1998-
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
1999-
def : Pat<(nxv8i16 (vector_splice (nxv8i16 ZPR:$Z1), (nxv8i16 ZPR:$Z2), (i64 (sve_ext_imm_0_127 i32:$index)))),
2000-
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
2001-
def : Pat<(nxv4i32 (vector_splice (nxv4i32 ZPR:$Z1), (nxv4i32 ZPR:$Z2), (i64 (sve_ext_imm_0_63 i32:$index)))),
2002-
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
2003-
def : Pat<(nxv2i64 (vector_splice (nxv2i64 ZPR:$Z1), (nxv2i64 ZPR:$Z2), (i64 (sve_ext_imm_0_31 i32:$index)))),
2004-
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
1997+
foreach VT = [nxv16i8] in
1998+
def : Pat<(VT (vector_splice (VT ZPR:$Z1), (VT ZPR:$Z2), (i64 (sve_ext_imm_0_255 i32:$index)))),
1999+
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
2000+
2001+
foreach VT = [nxv8i16, nxv8f16, nxv8bf16] in
2002+
def : Pat<(VT (vector_splice (VT ZPR:$Z1), (VT ZPR:$Z2), (i64 (sve_ext_imm_0_127 i32:$index)))),
2003+
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
2004+
2005+
foreach VT = [nxv4i32, nxv4f16, nxv4f32, nxv4bf16] in
2006+
def : Pat<(VT (vector_splice (VT ZPR:$Z1), (VT ZPR:$Z2), (i64 (sve_ext_imm_0_63 i32:$index)))),
2007+
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
2008+
2009+
foreach VT = [nxv2i64, nxv2f16, nxv2f32, nxv2f64, nxv2bf16] in
2010+
def : Pat<(VT (vector_splice (VT ZPR:$Z1), (VT ZPR:$Z2), (i64 (sve_ext_imm_0_31 i32:$index)))),
2011+
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
20052012

20062013
defm CMPHS_PPzZZ : sve_int_cmp_0<0b000, "cmphs", SETUGE, SETULE>;
20072014
defm CMPHI_PPzZZ : sve_int_cmp_0<0b001, "cmphi", SETUGT, SETULT>;

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7060,16 +7060,17 @@ multiclass sve_int_perm_splice<string asm, SDPatternOperator op> {
70607060
def _S : sve_int_perm_splice<0b10, asm, ZPR32>;
70617061
def _D : sve_int_perm_splice<0b11, asm, ZPR64>;
70627062

7063-
def : SVE_3_Op_Pat<nxv16i8, op, nxv16i1, nxv16i8, nxv16i8, !cast<Instruction>(NAME # _B)>;
7064-
def : SVE_3_Op_Pat<nxv8i16, op, nxv8i1, nxv8i16, nxv8i16, !cast<Instruction>(NAME # _H)>;
7065-
def : SVE_3_Op_Pat<nxv4i32, op, nxv4i1, nxv4i32, nxv4i32, !cast<Instruction>(NAME # _S)>;
7066-
def : SVE_3_Op_Pat<nxv2i64, op, nxv2i1, nxv2i64, nxv2i64, !cast<Instruction>(NAME # _D)>;
7063+
foreach VT = [nxv16i8] in
7064+
def : SVE_3_Op_Pat<VT, op, nxv16i1, VT, VT, !cast<Instruction>(NAME # _B)>;
70677065

7068-
def : SVE_3_Op_Pat<nxv8f16, op, nxv8i1, nxv8f16, nxv8f16, !cast<Instruction>(NAME # _H)>;
7069-
def : SVE_3_Op_Pat<nxv4f32, op, nxv4i1, nxv4f32, nxv4f32, !cast<Instruction>(NAME # _S)>;
7070-
def : SVE_3_Op_Pat<nxv2f64, op, nxv2i1, nxv2f64, nxv2f64, !cast<Instruction>(NAME # _D)>;
7066+
foreach VT = [nxv8i16, nxv8f16, nxv8bf16] in
7067+
def : SVE_3_Op_Pat<VT, op, nxv8i1, VT, VT, !cast<Instruction>(NAME # _H)>;
70717068

7072-
def : SVE_3_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME # _H)>;
7069+
foreach VT = [nxv4i32, nxv4f16, nxv4f32, nxv4bf16] in
7070+
def : SVE_3_Op_Pat<VT, op, nxv4i1, VT, VT, !cast<Instruction>(NAME # _S)>;
7071+
7072+
foreach VT = [nxv2i64, nxv2f16, nxv2f32, nxv2f64, nxv2bf16] in
7073+
def : SVE_3_Op_Pat<VT, op, nxv2i1, VT, VT, !cast<Instruction>(NAME # _D)>;
70737074
}
70747075

70757076
class sve2_int_perm_splice_cons<bits<2> sz8_64, string asm,

llvm/test/CodeGen/AArch64/named-vector-shuffles-sve.ll

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,104 @@ define <vscale x 2 x double> @splice_nxv2f64_neg3(<vscale x 2 x double> %a, <vsc
692692
ret <vscale x 2 x double> %res
693693
}
694694

695+
define <vscale x 2 x bfloat> @splice_nxv2bf16_neg_idx(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) #0 {
696+
; CHECK-LABEL: splice_nxv2bf16_neg_idx:
697+
; CHECK: // %bb.0:
698+
; CHECK-NEXT: ptrue p0.d, vl1
699+
; CHECK-NEXT: rev p0.d, p0.d
700+
; CHECK-NEXT: splice z0.d, p0, z0.d, z1.d
701+
; CHECK-NEXT: ret
702+
%res = call <vscale x 2 x bfloat> @llvm.vector.splice.nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b, i32 -1)
703+
ret <vscale x 2 x bfloat> %res
704+
}
705+
706+
define <vscale x 2 x bfloat> @splice_nxv2bf16_neg2_idx(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) #0 {
707+
; CHECK-LABEL: splice_nxv2bf16_neg2_idx:
708+
; CHECK: // %bb.0:
709+
; CHECK-NEXT: ptrue p0.d, vl2
710+
; CHECK-NEXT: rev p0.d, p0.d
711+
; CHECK-NEXT: splice z0.d, p0, z0.d, z1.d
712+
; CHECK-NEXT: ret
713+
%res = call <vscale x 2 x bfloat> @llvm.vector.splice.nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b, i32 -2)
714+
ret <vscale x 2 x bfloat> %res
715+
}
716+
717+
define <vscale x 2 x bfloat> @splice_nxv2bf16_first_idx(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) #0 {
718+
; CHECK-LABEL: splice_nxv2bf16_first_idx:
719+
; CHECK: // %bb.0:
720+
; CHECK-NEXT: ext z0.b, z0.b, z1.b, #8
721+
; CHECK-NEXT: ret
722+
%res = call <vscale x 2 x bfloat> @llvm.vector.splice.nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b, i32 1)
723+
ret <vscale x 2 x bfloat> %res
724+
}
725+
726+
define <vscale x 2 x bfloat> @splice_nxv2bf16_last_idx(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) vscale_range(16,16) #0 {
727+
; CHECK-LABEL: splice_nxv2bf16_last_idx:
728+
; CHECK: // %bb.0:
729+
; CHECK-NEXT: ext z0.b, z0.b, z1.b, #248
730+
; CHECK-NEXT: ret
731+
%res = call <vscale x 2 x bfloat> @llvm.vector.splice.nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b, i32 31)
732+
ret <vscale x 2 x bfloat> %res
733+
}
734+
735+
define <vscale x 4 x bfloat> @splice_nxv4bf16_neg_idx(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) #0 {
736+
; CHECK-LABEL: splice_nxv4bf16_neg_idx:
737+
; CHECK: // %bb.0:
738+
; CHECK-NEXT: ptrue p0.s, vl1
739+
; CHECK-NEXT: rev p0.s, p0.s
740+
; CHECK-NEXT: splice z0.s, p0, z0.s, z1.s
741+
; CHECK-NEXT: ret
742+
%res = call <vscale x 4 x bfloat> @llvm.vector.splice.nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b, i32 -1)
743+
ret <vscale x 4 x bfloat> %res
744+
}
745+
746+
define <vscale x 4 x bfloat> @splice_nxv4bf16_neg3_idx(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) #0 {
747+
; CHECK-LABEL: splice_nxv4bf16_neg3_idx:
748+
; CHECK: // %bb.0:
749+
; CHECK-NEXT: ptrue p0.s, vl3
750+
; CHECK-NEXT: rev p0.s, p0.s
751+
; CHECK-NEXT: splice z0.s, p0, z0.s, z1.s
752+
; CHECK-NEXT: ret
753+
%res = call <vscale x 4 x bfloat> @llvm.vector.splice.nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b, i32 -3)
754+
ret <vscale x 4 x bfloat> %res
755+
}
756+
757+
define <vscale x 4 x bfloat> @splice_nxv4bf16_first_idx(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) #0 {
758+
; CHECK-LABEL: splice_nxv4bf16_first_idx:
759+
; CHECK: // %bb.0:
760+
; CHECK-NEXT: ext z0.b, z0.b, z1.b, #4
761+
; CHECK-NEXT: ret
762+
%res = call <vscale x 4 x bfloat> @llvm.vector.splice.nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b, i32 1)
763+
ret <vscale x 4 x bfloat> %res
764+
}
765+
766+
define <vscale x 4 x bfloat> @splice_nxv4bf16_last_idx(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) vscale_range(16,16) #0 {
767+
; CHECK-LABEL: splice_nxv4bf16_last_idx:
768+
; CHECK: // %bb.0:
769+
; CHECK-NEXT: ext z0.b, z0.b, z1.b, #252
770+
; CHECK-NEXT: ret
771+
%res = call <vscale x 4 x bfloat> @llvm.vector.splice.nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b, i32 63)
772+
ret <vscale x 4 x bfloat> %res
773+
}
774+
775+
define <vscale x 8 x bfloat> @splice_nxv8bf16_first_idx(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) #0 {
776+
; CHECK-LABEL: splice_nxv8bf16_first_idx:
777+
; CHECK: // %bb.0:
778+
; CHECK-NEXT: ext z0.b, z0.b, z1.b, #2
779+
; CHECK-NEXT: ret
780+
%res = call <vscale x 8 x bfloat> @llvm.vector.splice.nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b, i32 1)
781+
ret <vscale x 8 x bfloat> %res
782+
}
783+
784+
define <vscale x 8 x bfloat> @splice_nxv8bf16_last_idx(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) vscale_range(16,16) #0 {
785+
; CHECK-LABEL: splice_nxv8bf16_last_idx:
786+
; CHECK: // %bb.0:
787+
; CHECK-NEXT: ext z0.b, z0.b, z1.b, #254
788+
; CHECK-NEXT: ret
789+
%res = call <vscale x 8 x bfloat> @llvm.vector.splice.nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b, i32 127)
790+
ret <vscale x 8 x bfloat> %res
791+
}
792+
695793
; Ensure predicate based splice is promoted to use ZPRs.
696794
define <vscale x 2 x i1> @splice_nxv2i1(<vscale x 2 x i1> %a, <vscale x 2 x i1> %b) #0 {
697795
; CHECK-LABEL: splice_nxv2i1:
@@ -834,12 +932,14 @@ declare <vscale x 2 x i1> @llvm.vector.splice.nxv2i1(<vscale x 2 x i1>, <vscale
834932
declare <vscale x 4 x i1> @llvm.vector.splice.nxv4i1(<vscale x 4 x i1>, <vscale x 4 x i1>, i32)
835933
declare <vscale x 8 x i1> @llvm.vector.splice.nxv8i1(<vscale x 8 x i1>, <vscale x 8 x i1>, i32)
836934
declare <vscale x 16 x i1> @llvm.vector.splice.nxv16i1(<vscale x 16 x i1>, <vscale x 16 x i1>, i32)
935+
837936
declare <vscale x 2 x i8> @llvm.vector.splice.nxv2i8(<vscale x 2 x i8>, <vscale x 2 x i8>, i32)
838937
declare <vscale x 16 x i8> @llvm.vector.splice.nxv16i8(<vscale x 16 x i8>, <vscale x 16 x i8>, i32)
839938
declare <vscale x 8 x i16> @llvm.vector.splice.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16>, i32)
840939
declare <vscale x 4 x i32> @llvm.vector.splice.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, i32)
841940
declare <vscale x 8 x i32> @llvm.vector.splice.nxv8i32(<vscale x 8 x i32>, <vscale x 8 x i32>, i32)
842941
declare <vscale x 2 x i64> @llvm.vector.splice.nxv2i64(<vscale x 2 x i64>, <vscale x 2 x i64>, i32)
942+
843943
declare <vscale x 2 x half> @llvm.vector.splice.nxv2f16(<vscale x 2 x half>, <vscale x 2 x half>, i32)
844944
declare <vscale x 4 x half> @llvm.vector.splice.nxv4f16(<vscale x 4 x half>, <vscale x 4 x half>, i32)
845945
declare <vscale x 8 x half> @llvm.vector.splice.nxv8f16(<vscale x 8 x half>, <vscale x 8 x half>, i32)
@@ -848,4 +948,8 @@ declare <vscale x 4 x float> @llvm.vector.splice.nxv4f32(<vscale x 4 x float>, <
848948
declare <vscale x 16 x float> @llvm.vector.splice.nxv16f32(<vscale x 16 x float>, <vscale x 16 x float>, i32)
849949
declare <vscale x 2 x double> @llvm.vector.splice.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>, i32)
850950

951+
declare <vscale x 2 x bfloat> @llvm.vector.splice.nxv2bf16(<vscale x 2 x bfloat>, <vscale x 2 x bfloat>, i32)
952+
declare <vscale x 4 x bfloat> @llvm.vector.splice.nxv4bf16(<vscale x 4 x bfloat>, <vscale x 4 x bfloat>, i32)
953+
declare <vscale x 8 x bfloat> @llvm.vector.splice.nxv8bf16(<vscale x 8 x bfloat>, <vscale x 8 x bfloat>, i32)
954+
851955
attributes #0 = { nounwind "target-features"="+sve" }

0 commit comments

Comments
 (0)