Skip to content

[LLVM][CodeGen][SVE] Improve custom lowering for EXTRACT_SUBVECTOR. #90963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13856,45 +13856,52 @@ AArch64TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,

SDValue AArch64TargetLowering::LowerEXTRACT_SUBVECTOR(SDValue Op,
SelectionDAG &DAG) const {
assert(Op.getValueType().isFixedLengthVector() &&
EVT VT = Op.getValueType();
assert(VT.isFixedLengthVector() &&
"Only cases that extract a fixed length vector are supported!");

EVT InVT = Op.getOperand(0).getValueType();
unsigned Idx = Op.getConstantOperandVal(1);
unsigned Size = Op.getValueSizeInBits();

// If we don't have legal types yet, do nothing
if (!DAG.getTargetLoweringInfo().isTypeLegal(InVT))
if (!isTypeLegal(InVT))
return SDValue();

if (InVT.isScalableVector()) {
// This will be matched by custom code during ISelDAGToDAG.
if (Idx == 0 && isPackedVectorType(InVT, DAG))
if (InVT.is128BitVector()) {
assert(VT.is64BitVector() && "Extracting unexpected vector type!");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the rationale here that at this point all types should be legal and therefore the only possible result VTs are 64-bit and 128-bit. I guess we are assuming that for 128-bit result types we're relying on the EXTRACT_SUBVECTOR being folded away before hand?

Copy link
Collaborator Author

@paulwalker-arm paulwalker-arm May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We can be sure that if the input is type legal then the result must also be type legal. That means for operation legalisation the only NEON sized combination that can happen is extracting a 64-bit vector from a 128-bit vector. The only exception is when the result type matches the input type and the index is zero, which is a NOP and optimised by SelectionDAG::getNode() and so is not worth considering at this point in the pipeline.

unsigned Idx = Op.getConstantOperandVal(1);

// This will get lowered to an appropriate EXTRACT_SUBREG in ISel.
if (Idx == 0)
return Op;

return SDValue();
// If this is extracting the upper 64-bits of a 128-bit vector, we match
// that directly.
if (Idx * InVT.getScalarSizeInBits() == 64 && Subtarget->isNeonAvailable())
return Op;
}

// This will get lowered to an appropriate EXTRACT_SUBREG in ISel.
if (Idx == 0 && InVT.getSizeInBits() <= 128)
return Op;

// If this is extracting the upper 64-bits of a 128-bit vector, we match
// that directly.
if (Size == 64 && Idx * InVT.getScalarSizeInBits() == 64 &&
InVT.getSizeInBits() == 128 && Subtarget->isNeonAvailable())
return Op;

if (useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable())) {
if (InVT.isScalableVector() ||
useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable())) {
SDLoc DL(Op);
SDValue Vec = Op.getOperand(0);
SDValue Idx = Op.getOperand(1);

EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
SDValue NewInVec =
convertToScalableVector(DAG, ContainerVT, Op.getOperand(0));
EVT PackedVT = getPackedSVEVectorVT(InVT.getVectorElementType());
if (PackedVT != InVT) {
// Pack input into the bottom part of an SVE register and try again.
SDValue Container = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, PackedVT,
DAG.getUNDEF(PackedVT), Vec,
DAG.getVectorIdxConstant(0, DL));
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Container, Idx);
}

// This will get matched by custom code during ISelDAGToDAG.
if (isNullConstant(Idx))
return Op;

SDValue Splice = DAG.getNode(ISD::VECTOR_SPLICE, DL, ContainerVT, NewInVec,
NewInVec, DAG.getConstant(Idx, DL, MVT::i64));
return convertFromScalableVector(DAG, Op.getValueType(), Splice);
assert(InVT.isScalableVector() && "Unexpected vector type!");
// Move requested subvector to the start of the vector and try again.
SDValue Splice = DAG.getNode(ISD::VECTOR_SPLICE, DL, InVT, Vec, Vec, Idx);
return convertFromScalableVector(DAG, VT, Splice);
}

return SDValue();
Expand Down
48 changes: 11 additions & 37 deletions llvm/test/CodeGen/AArch64/sve-extract-fixed-from-scalable-vector.ll
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,8 @@ define <4 x float> @extract_v4f32_nxv16f32_12(<vscale x 16 x float> %arg) {
define <2 x float> @extract_v2f32_nxv16f32_2(<vscale x 16 x float> %arg) {
; CHECK-LABEL: extract_v2f32_nxv16f32_2:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: st1w { z0.s }, p0, [sp]
; CHECK-NEXT: ldr d0, [sp, #8]
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ext z0.b, z0.b, z0.b, #8
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: ret
%ext = call <2 x float> @llvm.vector.extract.v2f32.nxv16f32(<vscale x 16 x float> %arg, i64 2)
ret <2 x float> %ext
Expand Down Expand Up @@ -274,15 +267,8 @@ define <4 x i3> @extract_v4i3_nxv32i3_16(<vscale x 32 x i3> %arg) {
define <2 x i32> @extract_v2i32_nxv16i32_2(<vscale x 16 x i32> %arg) {
; CHECK-LABEL: extract_v2i32_nxv16i32_2:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: st1w { z0.s }, p0, [sp]
; CHECK-NEXT: ldr d0, [sp, #8]
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ext z0.b, z0.b, z0.b, #8
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: ret
%ext = call <2 x i32> @llvm.vector.extract.v2i32.nxv16i32(<vscale x 16 x i32> %arg, i64 2)
ret <2 x i32> %ext
Expand Down Expand Up @@ -314,16 +300,9 @@ define <4 x half> @extract_v4f16_nxv2f16_0(<vscale x 2 x half> %arg) {
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: cntd x8
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: addpl x9, sp, #6
; CHECK-NEXT: subs x8, x8, #4
; CHECK-NEXT: csel x8, xzr, x8, lo
; CHECK-NEXT: st1h { z0.d }, p0, [sp, #3, mul vl]
; CHECK-NEXT: cmp x8, #0
; CHECK-NEXT: csel x8, x8, xzr, lo
; CHECK-NEXT: lsl x8, x8, #1
; CHECK-NEXT: ldr d0, [x9, x8]
; CHECK-NEXT: st1h { z0.d }, p0, [sp]
; CHECK-NEXT: ldr d0, [sp]
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
Expand All @@ -338,17 +317,12 @@ define <4 x half> @extract_v4f16_nxv2f16_4(<vscale x 2 x half> %arg) {
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: cntd x8
; CHECK-NEXT: mov w9, #4 // =0x4
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: subs x8, x8, #4
; CHECK-NEXT: csel x8, xzr, x8, lo
; CHECK-NEXT: st1h { z0.d }, p0, [sp, #3, mul vl]
; CHECK-NEXT: cmp x8, #4
; CHECK-NEXT: csel x8, x8, x9, lo
; CHECK-NEXT: addpl x9, sp, #6
; CHECK-NEXT: lsl x8, x8, #1
; CHECK-NEXT: ldr d0, [x9, x8]
; CHECK-NEXT: ptrue p1.h
; CHECK-NEXT: st1h { z0.d }, p0, [sp]
; CHECK-NEXT: ld1h { z0.h }, p1/z, [sp]
; CHECK-NEXT: ext z0.b, z0.b, z0.b, #8
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
Expand Down
Loading
Loading