Skip to content

Commit 6d26799

Browse files
committed
[AArch64] Don't rely on (zext (trunc x)) pattern to detect zext_inreg MULL patterns - use value tracking directly
As explained on D159533, I'm trying to generalize the "(zext (trunc x)) -> x iff the upper bits are known zero" fold in getNode() and I was seeing assertions in the aarch64 mull matching code as it was assuming these 'zero-extend-inreg' patterns will remain from earlier in LowerMUL. Instead I've updated selectUmullSmull/skipExtensionForVectorMULL to just use value tracking to detect when the upper bits are known zero, and to insert the truncation nodes later if necessary. Differential Revision: https://reviews.llvm.org/D159537
1 parent bba83e2 commit 6d26799

File tree

1 file changed

+15
-31
lines changed

1 file changed

+15
-31
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4426,18 +4426,25 @@ static bool isExtendedBUILD_VECTOR(SDValue N, SelectionDAG &DAG,
44264426
}
44274427

44284428
static SDValue skipExtensionForVectorMULL(SDValue N, SelectionDAG &DAG) {
4429+
EVT VT = N.getValueType();
4430+
assert(VT.is128BitVector() && "Unexpected vector MULL size");
4431+
4432+
unsigned NumElts = VT.getVectorNumElements();
4433+
unsigned OrigEltSize = VT.getScalarSizeInBits();
4434+
unsigned EltSize = OrigEltSize / 2;
4435+
MVT TruncVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts);
4436+
4437+
APInt HiBits = APInt::getHighBitsSet(OrigEltSize, EltSize);
4438+
if (DAG.MaskedValueIsZero(N, HiBits))
4439+
return DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N);
4440+
44294441
if (ISD::isExtOpcode(N.getOpcode()))
44304442
return addRequiredExtensionForVectorMULL(N.getOperand(0), DAG,
4431-
N.getOperand(0).getValueType(),
4432-
N.getValueType(),
4443+
N.getOperand(0).getValueType(), VT,
44334444
N.getOpcode());
44344445

44354446
assert(N.getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR");
4436-
EVT VT = N.getValueType();
44374447
SDLoc dl(N);
4438-
unsigned EltSize = VT.getScalarSizeInBits() / 2;
4439-
unsigned NumElts = VT.getVectorNumElements();
4440-
MVT TruncVT = MVT::getIntegerVT(EltSize);
44414448
SmallVector<SDValue, 8> Ops;
44424449
for (unsigned i = 0; i != NumElts; ++i) {
44434450
ConstantSDNode *C = cast<ConstantSDNode>(N.getOperand(i));
@@ -4446,7 +4453,7 @@ static SDValue skipExtensionForVectorMULL(SDValue N, SelectionDAG &DAG) {
44464453
// The values are implicitly truncated so sext vs. zext doesn't matter.
44474454
Ops.push_back(DAG.getConstant(CInt.zextOrTrunc(32), dl, MVT::i32));
44484455
}
4449-
return DAG.getBuildVector(MVT::getVectorVT(TruncVT, NumElts), dl, Ops);
4456+
return DAG.getBuildVector(TruncVT, dl, Ops);
44504457
}
44514458

44524459
static bool isSignExtended(SDValue N, SelectionDAG &DAG) {
@@ -4588,31 +4595,8 @@ static unsigned selectUmullSmull(SDValue &N0, SDValue &N1, SelectionDAG &DAG,
45884595
EVT VT = N0.getValueType();
45894596
APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(),
45904597
VT.getScalarSizeInBits() / 2);
4591-
if (DAG.MaskedValueIsZero(IsN0ZExt ? N1 : N0, Mask)) {
4592-
EVT HalfVT;
4593-
switch (VT.getSimpleVT().SimpleTy) {
4594-
case MVT::v2i64:
4595-
HalfVT = MVT::v2i32;
4596-
break;
4597-
case MVT::v4i32:
4598-
HalfVT = MVT::v4i16;
4599-
break;
4600-
case MVT::v8i16:
4601-
HalfVT = MVT::v8i8;
4602-
break;
4603-
default:
4604-
return 0;
4605-
}
4606-
// Truncate and then extend the result.
4607-
SDValue NewExt =
4608-
DAG.getNode(ISD::TRUNCATE, DL, HalfVT, IsN0ZExt ? N1 : N0);
4609-
NewExt = DAG.getZExtOrTrunc(NewExt, DL, VT);
4610-
if (IsN0ZExt)
4611-
N1 = NewExt;
4612-
else
4613-
N0 = NewExt;
4598+
if (DAG.MaskedValueIsZero(IsN0ZExt ? N1 : N0, Mask))
46144599
return AArch64ISD::UMULL;
4615-
}
46164600
}
46174601

46184602
if (!IsN1SExt && !IsN1ZExt)

0 commit comments

Comments
 (0)