Skip to content

Commit 2fef449

Browse files
[LLVM][AArch64] Enable verifyTargetSDNode for scalable vectors and fix the fallout. (#104820)
Fix incorrect use of AArch64ISD::UZP1/UUNPK{HI,LO} in: AArch64TargetLowering::LowerDIV AArch64TargetLowering::LowerINSERT_SUBVECTOR The latter highlighted DAG combines that relied on broken behaviour, which this patch also fixes.
1 parent 126d6f2 commit 2fef449

File tree

1 file changed

+64
-23
lines changed

1 file changed

+64
-23
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14908,10 +14908,11 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
1490814908
// NOP cast operands to the largest legal vector of the same element count.
1490914909
if (VT.isFloatingPoint()) {
1491014910
Vec0 = getSVESafeBitCast(NarrowVT, Vec0, DAG);
14911-
Vec1 = getSVESafeBitCast(WideVT, Vec1, DAG);
14911+
Vec1 = getSVESafeBitCast(NarrowVT, Vec1, DAG);
1491214912
} else {
1491314913
// Legal integer vectors are already their largest so Vec0 is fine as is.
1491414914
Vec1 = DAG.getNode(ISD::ANY_EXTEND, DL, WideVT, Vec1);
14915+
Vec1 = DAG.getNode(AArch64ISD::NVCAST, DL, NarrowVT, Vec1);
1491514916
}
1491614917

1491714918
// To replace the top/bottom half of vector V with vector SubV we widen the
@@ -14920,11 +14921,13 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
1492014921
SDValue Narrow;
1492114922
if (Idx == 0) {
1492214923
SDValue HiVec0 = DAG.getNode(AArch64ISD::UUNPKHI, DL, WideVT, Vec0);
14924+
HiVec0 = DAG.getNode(AArch64ISD::NVCAST, DL, NarrowVT, HiVec0);
1492314925
Narrow = DAG.getNode(AArch64ISD::UZP1, DL, NarrowVT, Vec1, HiVec0);
1492414926
} else {
1492514927
assert(Idx == InVT.getVectorMinNumElements() &&
1492614928
"Invalid subvector index!");
1492714929
SDValue LoVec0 = DAG.getNode(AArch64ISD::UUNPKLO, DL, WideVT, Vec0);
14930+
LoVec0 = DAG.getNode(AArch64ISD::NVCAST, DL, NarrowVT, LoVec0);
1492814931
Narrow = DAG.getNode(AArch64ISD::UZP1, DL, NarrowVT, LoVec0, Vec1);
1492914932
}
1493014933

@@ -15024,7 +15027,9 @@ SDValue AArch64TargetLowering::LowerDIV(SDValue Op, SelectionDAG &DAG) const {
1502415027
SDValue Op1Hi = DAG.getNode(UnpkHi, dl, WidenedVT, Op.getOperand(1));
1502515028
SDValue ResultLo = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Lo, Op1Lo);
1502615029
SDValue ResultHi = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Hi, Op1Hi);
15027-
return DAG.getNode(AArch64ISD::UZP1, dl, VT, ResultLo, ResultHi);
15030+
SDValue ResultLoCast = DAG.getNode(AArch64ISD::NVCAST, dl, VT, ResultLo);
15031+
SDValue ResultHiCast = DAG.getNode(AArch64ISD::NVCAST, dl, VT, ResultHi);
15032+
return DAG.getNode(AArch64ISD::UZP1, dl, VT, ResultLoCast, ResultHiCast);
1502815033
}
1502915034

1503015035
bool AArch64TargetLowering::shouldExpandBuildVectorWithShuffles(
@@ -22739,7 +22744,19 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
2273922744
SDValue Rshrnb = DAG.getNode(
2274022745
AArch64ISD::RSHRNB_I, DL, ResVT,
2274122746
{RShOperand, DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
22742-
return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
22747+
return DAG.getNode(AArch64ISD::NVCAST, DL, VT, Rshrnb);
22748+
}
22749+
22750+
static SDValue isNVCastToHalfWidthElements(SDValue V) {
22751+
if (V.getOpcode() != AArch64ISD::NVCAST)
22752+
return SDValue();
22753+
22754+
SDValue Op = V.getOperand(0);
22755+
if (V.getValueType().getVectorElementCount() !=
22756+
Op.getValueType().getVectorElementCount() * 2)
22757+
return SDValue();
22758+
22759+
return Op;
2274322760
}
2274422761

2274522762
static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
@@ -22802,25 +22819,37 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
2280222819
if (SDValue Urshr = tryCombineExtendRShTrunc(N, DAG))
2280322820
return Urshr;
2280422821

22805-
if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0, DAG, Subtarget))
22806-
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
22822+
if (SDValue PreCast = isNVCastToHalfWidthElements(Op0)) {
22823+
if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(PreCast, DAG, Subtarget)) {
22824+
Rshrnb = DAG.getNode(AArch64ISD::NVCAST, DL, ResVT, Rshrnb);
22825+
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
22826+
}
22827+
}
2280722828

22808-
if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op1, DAG, Subtarget))
22809-
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb);
22829+
if (SDValue PreCast = isNVCastToHalfWidthElements(Op1)) {
22830+
if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(PreCast, DAG, Subtarget)) {
22831+
Rshrnb = DAG.getNode(AArch64ISD::NVCAST, DL, ResVT, Rshrnb);
22832+
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb);
22833+
}
22834+
}
2281022835

22811-
// uzp1(unpklo(uzp1(x, y)), z) => uzp1(x, z)
22812-
if (Op0.getOpcode() == AArch64ISD::UUNPKLO) {
22813-
if (Op0.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
22814-
SDValue X = Op0.getOperand(0).getOperand(0);
22815-
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, X, Op1);
22836+
// uzp1<ty>(nvcast(unpklo(uzp1<ty>(x, y))), z) => uzp1<ty>(x, z)
22837+
if (SDValue PreCast = isNVCastToHalfWidthElements(Op0)) {
22838+
if (PreCast.getOpcode() == AArch64ISD::UUNPKLO) {
22839+
if (PreCast.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
22840+
SDValue X = PreCast.getOperand(0).getOperand(0);
22841+
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, X, Op1);
22842+
}
2281622843
}
2281722844
}
2281822845

22819-
// uzp1(x, unpkhi(uzp1(y, z))) => uzp1(x, z)
22820-
if (Op1.getOpcode() == AArch64ISD::UUNPKHI) {
22821-
if (Op1.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
22822-
SDValue Z = Op1.getOperand(0).getOperand(1);
22823-
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Z);
22846+
// uzp1<ty>(x, nvcast(unpkhi(uzp1<ty>(y, z)))) => uzp1<ty>(x, z)
22847+
if (SDValue PreCast = isNVCastToHalfWidthElements(Op1)) {
22848+
if (PreCast.getOpcode() == AArch64ISD::UUNPKHI) {
22849+
if (PreCast.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
22850+
SDValue Z = PreCast.getOperand(0).getOperand(1);
22851+
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Z);
22852+
}
2282422853
}
2282522854
}
2282622855

@@ -29415,9 +29444,6 @@ void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
2941529444
VT.isInteger() && "Expected integer vectors!");
2941629445
assert(OpVT.getSizeInBits() == VT.getSizeInBits() &&
2941729446
"Expected vectors of equal size!");
29418-
// TODO: Enable assert once bogus creations have been fixed.
29419-
if (VT.isScalableVector())
29420-
break;
2942129447
assert(OpVT.getVectorElementCount() == VT.getVectorElementCount() * 2 &&
2942229448
"Expected result vector with half the lanes of its input!");
2942329449
break;
@@ -29435,12 +29461,27 @@ void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
2943529461
EVT Op1VT = N->getOperand(1).getValueType();
2943629462
assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() &&
2943729463
"Expected vectors!");
29438-
// TODO: Enable assert once bogus creations have been fixed.
29439-
if (VT.isScalableVector())
29440-
break;
2944129464
assert(VT == Op0VT && VT == Op1VT && "Expected matching vectors!");
2944229465
break;
2944329466
}
29467+
case AArch64ISD::RSHRNB_I: {
29468+
assert(N->getNumValues() == 1 && "Expected one result!");
29469+
assert(N->getNumOperands() == 2 && "Expected two operands!");
29470+
EVT VT = N->getValueType(0);
29471+
EVT Op0VT = N->getOperand(0).getValueType();
29472+
EVT Op1VT = N->getOperand(1).getValueType();
29473+
assert(VT.isVector() && VT.isInteger() &&
29474+
"Expected integer vector result type!");
29475+
assert(Op0VT.isVector() && Op0VT.isInteger() &&
29476+
"Expected first operand to be an integer vector!");
29477+
assert(VT.getSizeInBits() == Op0VT.getSizeInBits() &&
29478+
"Expected vectors of equal size!");
29479+
assert(VT.getVectorElementCount() == Op0VT.getVectorElementCount() * 2 &&
29480+
"Expected input vector with half the lanes of its result!");
29481+
assert(Op1VT == MVT::i32 && isa<ConstantSDNode>(N->getOperand(1)) &&
29482+
"Expected second operand to be a constant i32!");
29483+
break;
29484+
}
2944429485
}
2944529486
}
2944629487
#endif

0 commit comments

Comments
 (0)