Skip to content

[LLVM][CodeGen][SVE] Clean up lowering of VECTOR_SPLICE operations. #91330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12249,9 +12249,8 @@ void SelectionDAGBuilder::visitVectorSplice(const CallInst &I) {

// VECTOR_SHUFFLE doesn't support a scalable mask so use a dedicated node.
if (VT.isScalableVector()) {
MVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout());
setValue(&I, DAG.getNode(ISD::VECTOR_SPLICE, DL, VT, V1, V2,
DAG.getConstant(Imm, DL, IdxVT)));
DAG.getVectorIdxConstant(Imm, DL)));
return;
}

Expand Down
38 changes: 7 additions & 31 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1048,9 +1048,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);

setTargetDAGCombine({ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND,
ISD::VECTOR_SPLICE, ISD::SIGN_EXTEND_INREG,
ISD::CONCAT_VECTORS, ISD::EXTRACT_SUBVECTOR,
ISD::INSERT_SUBVECTOR, ISD::STORE, ISD::BUILD_VECTOR});
ISD::SIGN_EXTEND_INREG, ISD::CONCAT_VECTORS,
ISD::EXTRACT_SUBVECTOR, ISD::INSERT_SUBVECTOR,
ISD::STORE, ISD::BUILD_VECTOR});
setTargetDAGCombine(ISD::TRUNCATE);
setTargetDAGCombine(ISD::LOAD);

Expand Down Expand Up @@ -1580,6 +1580,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);

if (!Subtarget->isLittleEndian())
setOperationAction(ISD::BITCAST, VT, Expand);
Expand Down Expand Up @@ -10102,10 +10103,9 @@ SDValue AArch64TargetLowering::LowerVECTOR_SPLICE(SDValue Op,
Op.getOperand(1));
}

// This will select to an EXT instruction, which has a maximum immediate
// value of 255, hence 2048-bits is the maximum value we can lower.
if (IdxVal >= 0 &&
IdxVal < int64_t(2048 / Ty.getVectorElementType().getSizeInBits()))
// We can select to an EXT instruction when indexing the first 256 bytes.
unsigned BlockSize = AArch64::SVEBitsPerBlock / Ty.getVectorMinNumElements();
if (IdxVal >= 0 && (IdxVal * BlockSize / 8) < 256)
return Op;

return SDValue();
Expand Down Expand Up @@ -24237,28 +24237,6 @@ performInsertVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return performPostLD1Combine(N, DCI, true);
}

static SDValue performSVESpliceCombine(SDNode *N, SelectionDAG &DAG) {
EVT Ty = N->getValueType(0);
if (Ty.isInteger())
return SDValue();

EVT IntTy = Ty.changeVectorElementTypeToInteger();
EVT ExtIntTy = getPackedSVEVectorVT(IntTy.getVectorElementCount());
if (ExtIntTy.getVectorElementType().getScalarSizeInBits() <
IntTy.getVectorElementType().getScalarSizeInBits())
return SDValue();

SDLoc DL(N);
SDValue LHS = DAG.getAnyExtOrTrunc(DAG.getBitcast(IntTy, N->getOperand(0)),
DL, ExtIntTy);
SDValue RHS = DAG.getAnyExtOrTrunc(DAG.getBitcast(IntTy, N->getOperand(1)),
DL, ExtIntTy);
SDValue Idx = N->getOperand(2);
SDValue Splice = DAG.getNode(ISD::VECTOR_SPLICE, DL, ExtIntTy, LHS, RHS, Idx);
SDValue Trunc = DAG.getAnyExtOrTrunc(Splice, DL, IntTy);
return DAG.getBitcast(Ty, Trunc);
}

static SDValue performFPExtendCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
Expand Down Expand Up @@ -24643,8 +24621,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::MGATHER:
case ISD::MSCATTER:
return performMaskedGatherScatterCombine(N, DCI, DAG);
case ISD::VECTOR_SPLICE:
return performSVESpliceCombine(N, DAG);
case ISD::FP_EXTEND:
return performFPExtendCombine(N, DAG, DCI, Subtarget);
case AArch64ISD::BRCOND:
Expand Down
23 changes: 15 additions & 8 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1994,14 +1994,21 @@ let Predicates = [HasSVEorSME] in {
(LASTB_VPZ_D (PTRUE_D 31), ZPR:$Z1), dsub))>;

// Splice with lane bigger or equal to 0
def : Pat<(nxv16i8 (vector_splice (nxv16i8 ZPR:$Z1), (nxv16i8 ZPR:$Z2), (i64 (sve_ext_imm_0_255 i32:$index)))),
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
def : Pat<(nxv8i16 (vector_splice (nxv8i16 ZPR:$Z1), (nxv8i16 ZPR:$Z2), (i64 (sve_ext_imm_0_127 i32:$index)))),
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
def : Pat<(nxv4i32 (vector_splice (nxv4i32 ZPR:$Z1), (nxv4i32 ZPR:$Z2), (i64 (sve_ext_imm_0_63 i32:$index)))),
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
def : Pat<(nxv2i64 (vector_splice (nxv2i64 ZPR:$Z1), (nxv2i64 ZPR:$Z2), (i64 (sve_ext_imm_0_31 i32:$index)))),
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
foreach VT = [nxv16i8] in
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this is overly verbose but I do prefer the symmetry with the other patterns, plus I don't think 8-bit fp types are that far away. Either way I'm happy to remove if reviewers disagree.

def : Pat<(VT (vector_splice (VT ZPR:$Z1), (VT ZPR:$Z2), (i64 (sve_ext_imm_0_255 i32:$index)))),
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;

foreach VT = [nxv8i16, nxv8f16, nxv8bf16] in
def : Pat<(VT (vector_splice (VT ZPR:$Z1), (VT ZPR:$Z2), (i64 (sve_ext_imm_0_127 i32:$index)))),
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;

foreach VT = [nxv4i32, nxv4f16, nxv4f32, nxv4bf16] in
def : Pat<(VT (vector_splice (VT ZPR:$Z1), (VT ZPR:$Z2), (i64 (sve_ext_imm_0_63 i32:$index)))),
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;

foreach VT = [nxv2i64, nxv2f16, nxv2f32, nxv2f64, nxv2bf16] in
def : Pat<(VT (vector_splice (VT ZPR:$Z1), (VT ZPR:$Z2), (i64 (sve_ext_imm_0_31 i32:$index)))),
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;

defm CMPHS_PPzZZ : sve_int_cmp_0<0b000, "cmphs", SETUGE, SETULE>;
defm CMPHI_PPzZZ : sve_int_cmp_0<0b001, "cmphi", SETUGT, SETULT>;
Expand Down
17 changes: 9 additions & 8 deletions llvm/lib/Target/AArch64/SVEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -7060,16 +7060,17 @@ multiclass sve_int_perm_splice<string asm, SDPatternOperator op> {
def _S : sve_int_perm_splice<0b10, asm, ZPR32>;
def _D : sve_int_perm_splice<0b11, asm, ZPR64>;

def : SVE_3_Op_Pat<nxv16i8, op, nxv16i1, nxv16i8, nxv16i8, !cast<Instruction>(NAME # _B)>;
def : SVE_3_Op_Pat<nxv8i16, op, nxv8i1, nxv8i16, nxv8i16, !cast<Instruction>(NAME # _H)>;
def : SVE_3_Op_Pat<nxv4i32, op, nxv4i1, nxv4i32, nxv4i32, !cast<Instruction>(NAME # _S)>;
def : SVE_3_Op_Pat<nxv2i64, op, nxv2i1, nxv2i64, nxv2i64, !cast<Instruction>(NAME # _D)>;
foreach VT = [nxv16i8] in
def : SVE_3_Op_Pat<VT, op, nxv16i1, VT, VT, !cast<Instruction>(NAME # _B)>;

def : SVE_3_Op_Pat<nxv8f16, op, nxv8i1, nxv8f16, nxv8f16, !cast<Instruction>(NAME # _H)>;
def : SVE_3_Op_Pat<nxv4f32, op, nxv4i1, nxv4f32, nxv4f32, !cast<Instruction>(NAME # _S)>;
def : SVE_3_Op_Pat<nxv2f64, op, nxv2i1, nxv2f64, nxv2f64, !cast<Instruction>(NAME # _D)>;
foreach VT = [nxv8i16, nxv8f16, nxv8bf16] in
def : SVE_3_Op_Pat<VT, op, nxv8i1, VT, VT, !cast<Instruction>(NAME # _H)>;

def : SVE_3_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME # _H)>;
foreach VT = [nxv4i32, nxv4f16, nxv4f32, nxv4bf16] in
def : SVE_3_Op_Pat<VT, op, nxv4i1, VT, VT, !cast<Instruction>(NAME # _S)>;

foreach VT = [nxv2i64, nxv2f16, nxv2f32, nxv2f64, nxv2bf16] in
def : SVE_3_Op_Pat<VT, op, nxv2i1, VT, VT, !cast<Instruction>(NAME # _D)>;
}

class sve2_int_perm_splice_cons<bits<2> sz8_64, string asm,
Expand Down
104 changes: 104 additions & 0 deletions llvm/test/CodeGen/AArch64/named-vector-shuffles-sve.ll
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,104 @@ define <vscale x 2 x double> @splice_nxv2f64_neg3(<vscale x 2 x double> %a, <vsc
ret <vscale x 2 x double> %res
}

define <vscale x 2 x bfloat> @splice_nxv2bf16_neg_idx(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) #0 {
; CHECK-LABEL: splice_nxv2bf16_neg_idx:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.d, vl1
; CHECK-NEXT: rev p0.d, p0.d
; CHECK-NEXT: splice z0.d, p0, z0.d, z1.d
; CHECK-NEXT: ret
%res = call <vscale x 2 x bfloat> @llvm.vector.splice.nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b, i32 -1)
ret <vscale x 2 x bfloat> %res
}

define <vscale x 2 x bfloat> @splice_nxv2bf16_neg2_idx(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) #0 {
; CHECK-LABEL: splice_nxv2bf16_neg2_idx:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.d, vl2
; CHECK-NEXT: rev p0.d, p0.d
; CHECK-NEXT: splice z0.d, p0, z0.d, z1.d
; CHECK-NEXT: ret
%res = call <vscale x 2 x bfloat> @llvm.vector.splice.nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b, i32 -2)
ret <vscale x 2 x bfloat> %res
}

define <vscale x 2 x bfloat> @splice_nxv2bf16_first_idx(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) #0 {
; CHECK-LABEL: splice_nxv2bf16_first_idx:
; CHECK: // %bb.0:
; CHECK-NEXT: ext z0.b, z0.b, z1.b, #8
; CHECK-NEXT: ret
%res = call <vscale x 2 x bfloat> @llvm.vector.splice.nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b, i32 1)
ret <vscale x 2 x bfloat> %res
}

define <vscale x 2 x bfloat> @splice_nxv2bf16_last_idx(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) vscale_range(16,16) #0 {
; CHECK-LABEL: splice_nxv2bf16_last_idx:
; CHECK: // %bb.0:
; CHECK-NEXT: ext z0.b, z0.b, z1.b, #248
; CHECK-NEXT: ret
%res = call <vscale x 2 x bfloat> @llvm.vector.splice.nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b, i32 31)
ret <vscale x 2 x bfloat> %res
}

define <vscale x 4 x bfloat> @splice_nxv4bf16_neg_idx(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) #0 {
; CHECK-LABEL: splice_nxv4bf16_neg_idx:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.s, vl1
; CHECK-NEXT: rev p0.s, p0.s
; CHECK-NEXT: splice z0.s, p0, z0.s, z1.s
; CHECK-NEXT: ret
%res = call <vscale x 4 x bfloat> @llvm.vector.splice.nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b, i32 -1)
ret <vscale x 4 x bfloat> %res
}

define <vscale x 4 x bfloat> @splice_nxv4bf16_neg3_idx(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) #0 {
; CHECK-LABEL: splice_nxv4bf16_neg3_idx:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.s, vl3
; CHECK-NEXT: rev p0.s, p0.s
; CHECK-NEXT: splice z0.s, p0, z0.s, z1.s
; CHECK-NEXT: ret
%res = call <vscale x 4 x bfloat> @llvm.vector.splice.nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b, i32 -3)
ret <vscale x 4 x bfloat> %res
}

define <vscale x 4 x bfloat> @splice_nxv4bf16_first_idx(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) #0 {
; CHECK-LABEL: splice_nxv4bf16_first_idx:
; CHECK: // %bb.0:
; CHECK-NEXT: ext z0.b, z0.b, z1.b, #4
; CHECK-NEXT: ret
%res = call <vscale x 4 x bfloat> @llvm.vector.splice.nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b, i32 1)
ret <vscale x 4 x bfloat> %res
}

define <vscale x 4 x bfloat> @splice_nxv4bf16_last_idx(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) vscale_range(16,16) #0 {
; CHECK-LABEL: splice_nxv4bf16_last_idx:
; CHECK: // %bb.0:
; CHECK-NEXT: ext z0.b, z0.b, z1.b, #252
; CHECK-NEXT: ret
%res = call <vscale x 4 x bfloat> @llvm.vector.splice.nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b, i32 63)
ret <vscale x 4 x bfloat> %res
}

define <vscale x 8 x bfloat> @splice_nxv8bf16_first_idx(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) #0 {
; CHECK-LABEL: splice_nxv8bf16_first_idx:
; CHECK: // %bb.0:
; CHECK-NEXT: ext z0.b, z0.b, z1.b, #2
; CHECK-NEXT: ret
%res = call <vscale x 8 x bfloat> @llvm.vector.splice.nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b, i32 1)
ret <vscale x 8 x bfloat> %res
}

define <vscale x 8 x bfloat> @splice_nxv8bf16_last_idx(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) vscale_range(16,16) #0 {
; CHECK-LABEL: splice_nxv8bf16_last_idx:
; CHECK: // %bb.0:
; CHECK-NEXT: ext z0.b, z0.b, z1.b, #254
; CHECK-NEXT: ret
%res = call <vscale x 8 x bfloat> @llvm.vector.splice.nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b, i32 127)
ret <vscale x 8 x bfloat> %res
}

; Ensure predicate based splice is promoted to use ZPRs.
define <vscale x 2 x i1> @splice_nxv2i1(<vscale x 2 x i1> %a, <vscale x 2 x i1> %b) #0 {
; CHECK-LABEL: splice_nxv2i1:
Expand Down Expand Up @@ -834,12 +932,14 @@ declare <vscale x 2 x i1> @llvm.vector.splice.nxv2i1(<vscale x 2 x i1>, <vscale
declare <vscale x 4 x i1> @llvm.vector.splice.nxv4i1(<vscale x 4 x i1>, <vscale x 4 x i1>, i32)
declare <vscale x 8 x i1> @llvm.vector.splice.nxv8i1(<vscale x 8 x i1>, <vscale x 8 x i1>, i32)
declare <vscale x 16 x i1> @llvm.vector.splice.nxv16i1(<vscale x 16 x i1>, <vscale x 16 x i1>, i32)

declare <vscale x 2 x i8> @llvm.vector.splice.nxv2i8(<vscale x 2 x i8>, <vscale x 2 x i8>, i32)
declare <vscale x 16 x i8> @llvm.vector.splice.nxv16i8(<vscale x 16 x i8>, <vscale x 16 x i8>, i32)
declare <vscale x 8 x i16> @llvm.vector.splice.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16>, i32)
declare <vscale x 4 x i32> @llvm.vector.splice.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, i32)
declare <vscale x 8 x i32> @llvm.vector.splice.nxv8i32(<vscale x 8 x i32>, <vscale x 8 x i32>, i32)
declare <vscale x 2 x i64> @llvm.vector.splice.nxv2i64(<vscale x 2 x i64>, <vscale x 2 x i64>, i32)

declare <vscale x 2 x half> @llvm.vector.splice.nxv2f16(<vscale x 2 x half>, <vscale x 2 x half>, i32)
declare <vscale x 4 x half> @llvm.vector.splice.nxv4f16(<vscale x 4 x half>, <vscale x 4 x half>, i32)
declare <vscale x 8 x half> @llvm.vector.splice.nxv8f16(<vscale x 8 x half>, <vscale x 8 x half>, i32)
Expand All @@ -848,4 +948,8 @@ declare <vscale x 4 x float> @llvm.vector.splice.nxv4f32(<vscale x 4 x float>, <
declare <vscale x 16 x float> @llvm.vector.splice.nxv16f32(<vscale x 16 x float>, <vscale x 16 x float>, i32)
declare <vscale x 2 x double> @llvm.vector.splice.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>, i32)

declare <vscale x 2 x bfloat> @llvm.vector.splice.nxv2bf16(<vscale x 2 x bfloat>, <vscale x 2 x bfloat>, i32)
declare <vscale x 4 x bfloat> @llvm.vector.splice.nxv4bf16(<vscale x 4 x bfloat>, <vscale x 4 x bfloat>, i32)
declare <vscale x 8 x bfloat> @llvm.vector.splice.nxv8bf16(<vscale x 8 x bfloat>, <vscale x 8 x bfloat>, i32)

attributes #0 = { nounwind "target-features"="+sve" }
Loading