@@ -6551,17 +6551,17 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
6551
6551
6552
6552
ElementCount NumElts = VT.getVectorElementCount();
6553
6553
6554
- // See if we can fold through bitcasted integer ops.
6554
+ // See if we can fold through any bitcasted integer ops.
6555
6555
if (NumOps == 2 && VT.isFixedLengthVector() && VT.isInteger() &&
6556
6556
Ops[0].getValueType() == VT && Ops[1].getValueType() == VT &&
6557
- Ops[0].getOpcode() == ISD::BITCAST &&
6558
- Ops[1].getOpcode() == ISD::BITCAST) {
6557
+ ( Ops[0].getOpcode() == ISD::BITCAST ||
6558
+ Ops[1].getOpcode() == ISD::BITCAST) ) {
6559
6559
SDValue N1 = peekThroughBitcasts(Ops[0]);
6560
6560
SDValue N2 = peekThroughBitcasts(Ops[1]);
6561
6561
auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
6562
6562
auto *BV2 = dyn_cast<BuildVectorSDNode>(N2);
6563
- EVT BVVT = N1.getValueType();
6564
- if (BV1 && BV2 && BVVT.isInteger() && BVVT == N2.getValueType ()) {
6563
+ if (BV1 && BV2 && N1.getValueType().isInteger() &&
6564
+ N2.getValueType().isInteger ()) {
6565
6565
bool IsLE = getDataLayout().isLittleEndian();
6566
6566
unsigned EltBits = VT.getScalarSizeInBits();
6567
6567
SmallVector<APInt> RawBits1, RawBits2;
@@ -6577,15 +6577,22 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
6577
6577
RawBits.push_back(*Fold);
6578
6578
}
6579
6579
if (RawBits.size() == NumElts.getFixedValue()) {
6580
- // We have constant folded, but we need to cast this again back to
6581
- // the original (possibly legalized) type.
6580
+ // We have constant folded, but we might need to cast this again back
6581
+ // to the original (possibly legalized) type.
6582
+ EVT BVVT, BVEltVT;
6583
+ if (N1.getValueType() == VT) {
6584
+ BVVT = N1.getValueType();
6585
+ BVEltVT = BV1->getOperand(0).getValueType();
6586
+ } else {
6587
+ BVVT = N2.getValueType();
6588
+ BVEltVT = BV2->getOperand(0).getValueType();
6589
+ }
6590
+ unsigned BVEltBits = BVEltVT.getSizeInBits();
6582
6591
SmallVector<APInt> DstBits;
6583
6592
BitVector DstUndefs;
6584
6593
BuildVectorSDNode::recastRawBits(IsLE, BVVT.getScalarSizeInBits(),
6585
6594
DstBits, RawBits, DstUndefs,
6586
6595
BitVector(RawBits.size(), false));
6587
- EVT BVEltVT = BV1->getOperand(0).getValueType();
6588
- unsigned BVEltBits = BVEltVT.getSizeInBits();
6589
6596
SmallVector<SDValue> Ops(DstBits.size(), getUNDEF(BVEltVT));
6590
6597
for (unsigned I = 0, E = DstBits.size(); I != E; ++I) {
6591
6598
if (DstUndefs[I])
0 commit comments