Skip to content

[LLVM][AArch64] Enable verifyTargetSDNode for scalable vectors and fix the fallout. #104820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 64 additions & 23 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14897,10 +14897,11 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
// NOP cast operands to the largest legal vector of the same element count.
if (VT.isFloatingPoint()) {
Vec0 = getSVESafeBitCast(NarrowVT, Vec0, DAG);
Vec1 = getSVESafeBitCast(WideVT, Vec1, DAG);
Vec1 = getSVESafeBitCast(NarrowVT, Vec1, DAG);
} else {
// Legal integer vectors are already their largest so Vec0 is fine as is.
Vec1 = DAG.getNode(ISD::ANY_EXTEND, DL, WideVT, Vec1);
Vec1 = DAG.getNode(AArch64ISD::NVCAST, DL, NarrowVT, Vec1);
}

// To replace the top/bottom half of vector V with vector SubV we widen the
Expand All @@ -14909,11 +14910,13 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
SDValue Narrow;
if (Idx == 0) {
SDValue HiVec0 = DAG.getNode(AArch64ISD::UUNPKHI, DL, WideVT, Vec0);
HiVec0 = DAG.getNode(AArch64ISD::NVCAST, DL, NarrowVT, HiVec0);
Narrow = DAG.getNode(AArch64ISD::UZP1, DL, NarrowVT, Vec1, HiVec0);
} else {
assert(Idx == InVT.getVectorMinNumElements() &&
"Invalid subvector index!");
SDValue LoVec0 = DAG.getNode(AArch64ISD::UUNPKLO, DL, WideVT, Vec0);
LoVec0 = DAG.getNode(AArch64ISD::NVCAST, DL, NarrowVT, LoVec0);
Narrow = DAG.getNode(AArch64ISD::UZP1, DL, NarrowVT, LoVec0, Vec1);
}

Expand Down Expand Up @@ -15013,7 +15016,9 @@ SDValue AArch64TargetLowering::LowerDIV(SDValue Op, SelectionDAG &DAG) const {
SDValue Op1Hi = DAG.getNode(UnpkHi, dl, WidenedVT, Op.getOperand(1));
SDValue ResultLo = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Lo, Op1Lo);
SDValue ResultHi = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Hi, Op1Hi);
return DAG.getNode(AArch64ISD::UZP1, dl, VT, ResultLo, ResultHi);
SDValue ResultLoCast = DAG.getNode(AArch64ISD::NVCAST, dl, VT, ResultLo);
SDValue ResultHiCast = DAG.getNode(AArch64ISD::NVCAST, dl, VT, ResultHi);
return DAG.getNode(AArch64ISD::UZP1, dl, VT, ResultLoCast, ResultHiCast);
}

bool AArch64TargetLowering::shouldExpandBuildVectorWithShuffles(
Expand Down Expand Up @@ -22667,7 +22672,19 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
SDValue Rshrnb = DAG.getNode(
AArch64ISD::RSHRNB_I, DL, ResVT,
{RShOperand, DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
return DAG.getNode(AArch64ISD::NVCAST, DL, VT, Rshrnb);
}

static SDValue isNVCastToHalfWidthElements(SDValue V) {
if (V.getOpcode() != AArch64ISD::NVCAST)
return SDValue();

SDValue Op = V.getOperand(0);
if (V.getValueType().getVectorElementCount() !=
Op.getValueType().getVectorElementCount() * 2)
return SDValue();
Comment on lines +22682 to +22685
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pedantic question: If possible here, what happens if one type is a double tuple and one is a quad tuple here? Does getVectorElementCount() return the element count for each of the scalable vectors in the quad or just one scalable type within the tuple?

Copy link
Collaborator Author

@paulwalker-arm paulwalker-arm Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getValueType() returns an EVT, which has no concept of tuple types. The implementation of getVectorElementCount will itself assert that it is a plain vector type.


return Op;
}

static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
Expand Down Expand Up @@ -22730,25 +22747,37 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
if (SDValue Urshr = tryCombineExtendRShTrunc(N, DAG))
return Urshr;

if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0, DAG, Subtarget))
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
if (SDValue PreCast = isNVCastToHalfWidthElements(Op0)) {
if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(PreCast, DAG, Subtarget)) {
Rshrnb = DAG.getNode(AArch64ISD::NVCAST, DL, ResVT, Rshrnb);
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
}
}

if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op1, DAG, Subtarget))
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb);
if (SDValue PreCast = isNVCastToHalfWidthElements(Op1)) {
if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(PreCast, DAG, Subtarget)) {
Rshrnb = DAG.getNode(AArch64ISD::NVCAST, DL, ResVT, Rshrnb);
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb);
}
}

// uzp1(unpklo(uzp1(x, y)), z) => uzp1(x, z)
if (Op0.getOpcode() == AArch64ISD::UUNPKLO) {
if (Op0.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
SDValue X = Op0.getOperand(0).getOperand(0);
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, X, Op1);
// uzp1<ty>(nvcast(unpklo(uzp1<ty>(x, y))), z) => uzp1<ty>(x, z)
if (SDValue PreCast = isNVCastToHalfWidthElements(Op0)) {
if (PreCast.getOpcode() == AArch64ISD::UUNPKLO) {
if (PreCast.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
SDValue X = PreCast.getOperand(0).getOperand(0);
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, X, Op1);
}
}
}

// uzp1(x, unpkhi(uzp1(y, z))) => uzp1(x, z)
if (Op1.getOpcode() == AArch64ISD::UUNPKHI) {
if (Op1.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
SDValue Z = Op1.getOperand(0).getOperand(1);
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Z);
// uzp1<ty>(x, nvcast(unpkhi(uzp1<ty>(y, z)))) => uzp1<ty>(x, z)
if (SDValue PreCast = isNVCastToHalfWidthElements(Op1)) {
if (PreCast.getOpcode() == AArch64ISD::UUNPKHI) {
if (PreCast.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
SDValue Z = PreCast.getOperand(0).getOperand(1);
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Z);
}
}
}

Expand Down Expand Up @@ -29343,9 +29372,6 @@ void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
VT.isInteger() && "Expected integer vectors!");
assert(OpVT.getSizeInBits() == VT.getSizeInBits() &&
"Expected vectors of equal size!");
// TODO: Enable assert once bogus creations have been fixed.
if (VT.isScalableVector())
break;
assert(OpVT.getVectorElementCount() == VT.getVectorElementCount() * 2 &&
"Expected result vector with half the lanes of its input!");
break;
Expand All @@ -29363,12 +29389,27 @@ void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
EVT Op1VT = N->getOperand(1).getValueType();
assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() &&
"Expected vectors!");
// TODO: Enable assert once bogus creations have been fixed.
if (VT.isScalableVector())
break;
assert(VT == Op0VT && VT == Op1VT && "Expected matching vectors!");
break;
}
case AArch64ISD::RSHRNB_I: {
assert(N->getNumValues() == 1 && "Expected one result!");
assert(N->getNumOperands() == 2 && "Expected two operands!");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(N->getNumOperands() == 2 && "Expected two operand!");
assert(N->getNumOperands() == 2 && "Expected two operands!");

EVT VT = N->getValueType(0);
EVT Op0VT = N->getOperand(0).getValueType();
EVT Op1VT = N->getOperand(1).getValueType();
assert(VT.isVector() && VT.isInteger() &&
"Expected integer vector result type!");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the message for the second operand is very specific, this message could be more specific about which operand numbers this refers to

assert(Op0VT.isVector() && Op0VT.isInteger() &&
"Expected first operand to be an integer vector!");
assert(VT.getSizeInBits() == Op0VT.getSizeInBits() &&
"Expected vectors of equal size!");
assert(VT.getVectorElementCount() == Op0VT.getVectorElementCount() * 2 &&
"Expected input vector with half the lanes of its result!");
assert(Op1VT == MVT::i32 && isa<ConstantSDNode>(N->getOperand(1)) &&
"Expected second operand to be a constant i32!");
break;
}
}
}
#endif
Loading