@@ -115,6 +115,7 @@ class VectorCombine {
115
115
bool foldExtractedCmps (Instruction &I);
116
116
bool foldSingleElementStore (Instruction &I);
117
117
bool scalarizeLoadExtract (Instruction &I);
118
+ bool foldConcatOfBoolMasks (Instruction &I);
118
119
bool foldPermuteOfBinops (Instruction &I);
119
120
bool foldShuffleOfBinops (Instruction &I);
120
121
bool foldShuffleOfCastops (Instruction &I);
@@ -1423,6 +1424,113 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
1423
1424
return true ;
1424
1425
}
1425
1426
1427
+ // / Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
1428
+ // / to "(bitcast (concat X, Y))"
1429
+ // / where X/Y are bitcasted from i1 mask vectors.
1430
+ bool VectorCombine::foldConcatOfBoolMasks (Instruction &I) {
1431
+ Type *Ty = I.getType ();
1432
+ if (!Ty->isIntegerTy ())
1433
+ return false ;
1434
+
1435
+ // TODO: Add big endian test coverage
1436
+ if (DL->isBigEndian ())
1437
+ return false ;
1438
+
1439
+ // Restrict to disjoint cases so the mask vectors aren't overlapping.
1440
+ Instruction *X, *Y;
1441
+ if (!match (&I, m_DisjointOr (m_Instruction (X), m_Instruction (Y))))
1442
+ return false ;
1443
+
1444
+ // Allow both sources to contain shl, to handle more generic pattern:
1445
+ // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
1446
+ Value *SrcX;
1447
+ uint64_t ShAmtX = 0 ;
1448
+ if (!match (X, m_OneUse (m_ZExt (m_OneUse (m_BitCast (m_Value (SrcX)))))) &&
1449
+ !match (X, m_OneUse (
1450
+ m_Shl (m_OneUse (m_ZExt (m_OneUse (m_BitCast (m_Value (SrcX))))),
1451
+ m_ConstantInt (ShAmtX)))))
1452
+ return false ;
1453
+
1454
+ Value *SrcY;
1455
+ uint64_t ShAmtY = 0 ;
1456
+ if (!match (Y, m_OneUse (m_ZExt (m_OneUse (m_BitCast (m_Value (SrcY)))))) &&
1457
+ !match (Y, m_OneUse (
1458
+ m_Shl (m_OneUse (m_ZExt (m_OneUse (m_BitCast (m_Value (SrcY))))),
1459
+ m_ConstantInt (ShAmtY)))))
1460
+ return false ;
1461
+
1462
+ // Canonicalize larger shift to the RHS.
1463
+ if (ShAmtX > ShAmtY) {
1464
+ std::swap (X, Y);
1465
+ std::swap (SrcX, SrcY);
1466
+ std::swap (ShAmtX, ShAmtY);
1467
+ }
1468
+
1469
+ // Ensure both sources are matching vXi1 bool mask types, and that the shift
1470
+ // difference is the mask width so they can be easily concatenated together.
1471
+ uint64_t ShAmtDiff = ShAmtY - ShAmtX;
1472
+ unsigned NumSHL = (ShAmtX > 0 ) + (ShAmtY > 0 );
1473
+ unsigned BitWidth = Ty->getPrimitiveSizeInBits ();
1474
+ auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType ());
1475
+ if (!MaskTy || SrcX->getType () != SrcY->getType () ||
1476
+ !MaskTy->getElementType ()->isIntegerTy (1 ) ||
1477
+ MaskTy->getNumElements () != ShAmtDiff ||
1478
+ MaskTy->getNumElements () > (BitWidth / 2 ))
1479
+ return false ;
1480
+
1481
+ auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType (MaskTy);
1482
+ auto *ConcatIntTy =
1483
+ Type::getIntNTy (Ty->getContext (), ConcatTy->getNumElements ());
1484
+ auto *MaskIntTy = Type::getIntNTy (Ty->getContext (), ShAmtDiff);
1485
+
1486
+ SmallVector<int , 32 > ConcatMask (ConcatTy->getNumElements ());
1487
+ std::iota (ConcatMask.begin (), ConcatMask.end (), 0 );
1488
+
1489
+ // TODO: Is it worth supporting multi use cases?
1490
+ InstructionCost OldCost = 0 ;
1491
+ OldCost += TTI.getArithmeticInstrCost (Instruction::Or, Ty, CostKind);
1492
+ OldCost +=
1493
+ NumSHL * TTI.getArithmeticInstrCost (Instruction::Shl, Ty, CostKind);
1494
+ OldCost += 2 * TTI.getCastInstrCost (Instruction::ZExt, Ty, MaskIntTy,
1495
+ TTI::CastContextHint::None, CostKind);
1496
+ OldCost += 2 * TTI.getCastInstrCost (Instruction::BitCast, MaskIntTy, MaskTy,
1497
+ TTI::CastContextHint::None, CostKind);
1498
+
1499
+ InstructionCost NewCost = 0 ;
1500
+ NewCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, MaskTy,
1501
+ ConcatMask, CostKind);
1502
+ NewCost += TTI.getCastInstrCost (Instruction::BitCast, ConcatIntTy, ConcatTy,
1503
+ TTI::CastContextHint::None, CostKind);
1504
+ if (Ty != ConcatIntTy)
1505
+ NewCost += TTI.getCastInstrCost (Instruction::ZExt, Ty, ConcatIntTy,
1506
+ TTI::CastContextHint::None, CostKind);
1507
+ if (ShAmtX > 0 )
1508
+ NewCost += TTI.getArithmeticInstrCost (Instruction::Shl, Ty, CostKind);
1509
+
1510
+ if (NewCost > OldCost)
1511
+ return false ;
1512
+
1513
+ // Build bool mask concatenation, bitcast back to scalar integer, and perform
1514
+ // any residual zero-extension or shifting.
1515
+ Value *Concat = Builder.CreateShuffleVector (SrcX, SrcY, ConcatMask);
1516
+ Worklist.pushValue (Concat);
1517
+
1518
+ Value *Result = Builder.CreateBitCast (Concat, ConcatIntTy);
1519
+
1520
+ if (Ty != ConcatIntTy) {
1521
+ Worklist.pushValue (Result);
1522
+ Result = Builder.CreateZExt (Result, Ty);
1523
+ }
1524
+
1525
+ if (ShAmtX > 0 ) {
1526
+ Worklist.pushValue (Result);
1527
+ Result = Builder.CreateShl (Result, ShAmtX);
1528
+ }
1529
+
1530
+ replaceValue (I, *Result);
1531
+ return true ;
1532
+ }
1533
+
1426
1534
// / Try to convert "shuffle (binop (shuffle, shuffle)), undef"
1427
1535
// / --> "binop (shuffle), (shuffle)".
1428
1536
bool VectorCombine::foldPermuteOfBinops (Instruction &I) {
@@ -2908,6 +3016,9 @@ bool VectorCombine::run() {
2908
3016
if (TryEarlyFoldsOnly)
2909
3017
return ;
2910
3018
3019
+ if (I.getType ()->isIntegerTy ())
3020
+ MadeChange |= foldConcatOfBoolMasks (I);
3021
+
2911
3022
// Otherwise, try folds that improve codegen but may interfere with
2912
3023
// early IR canonicalizations.
2913
3024
// The type checking is for run-time efficiency. We can avoid wasting time
0 commit comments