Skip to content

Commit 08f9040

Browse files
authored
[VectorCombine] Fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))" -> "(bitcast (concat X, Y))" MOVMSK bool mask style patterns (#119559)
Mask/Bool vectors are often bitcast to/from scalar integers, in particular when concatenating mask results, often this is due to the difficulties of working with vector of bools on C/C++. On x86 this typically involves the MOVMSK/KMOV instructions. To concatenate bool masks, these are typically cast to scalars, which are then zero-extended, shifted and OR'd together. This patch attempts to match these scalar concatenation patterns and convert them to vector shuffles instead. This in turn often assists with further vector combines, depending on the cost model. Fixes #111431
1 parent 8b63bfb commit 08f9040

File tree

2 files changed

+266
-114
lines changed

2 files changed

+266
-114
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class VectorCombine {
115115
bool foldExtractedCmps(Instruction &I);
116116
bool foldSingleElementStore(Instruction &I);
117117
bool scalarizeLoadExtract(Instruction &I);
118+
bool foldConcatOfBoolMasks(Instruction &I);
118119
bool foldPermuteOfBinops(Instruction &I);
119120
bool foldShuffleOfBinops(Instruction &I);
120121
bool foldShuffleOfCastops(Instruction &I);
@@ -1423,6 +1424,113 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
14231424
return true;
14241425
}
14251426

1427+
/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
1428+
/// to "(bitcast (concat X, Y))"
1429+
/// where X/Y are bitcasted from i1 mask vectors.
1430+
bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
1431+
Type *Ty = I.getType();
1432+
if (!Ty->isIntegerTy())
1433+
return false;
1434+
1435+
// TODO: Add big endian test coverage
1436+
if (DL->isBigEndian())
1437+
return false;
1438+
1439+
// Restrict to disjoint cases so the mask vectors aren't overlapping.
1440+
Instruction *X, *Y;
1441+
if (!match(&I, m_DisjointOr(m_Instruction(X), m_Instruction(Y))))
1442+
return false;
1443+
1444+
// Allow both sources to contain shl, to handle more generic pattern:
1445+
// "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
1446+
Value *SrcX;
1447+
uint64_t ShAmtX = 0;
1448+
if (!match(X, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX)))))) &&
1449+
!match(X, m_OneUse(
1450+
m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX))))),
1451+
m_ConstantInt(ShAmtX)))))
1452+
return false;
1453+
1454+
Value *SrcY;
1455+
uint64_t ShAmtY = 0;
1456+
if (!match(Y, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY)))))) &&
1457+
!match(Y, m_OneUse(
1458+
m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY))))),
1459+
m_ConstantInt(ShAmtY)))))
1460+
return false;
1461+
1462+
// Canonicalize larger shift to the RHS.
1463+
if (ShAmtX > ShAmtY) {
1464+
std::swap(X, Y);
1465+
std::swap(SrcX, SrcY);
1466+
std::swap(ShAmtX, ShAmtY);
1467+
}
1468+
1469+
// Ensure both sources are matching vXi1 bool mask types, and that the shift
1470+
// difference is the mask width so they can be easily concatenated together.
1471+
uint64_t ShAmtDiff = ShAmtY - ShAmtX;
1472+
unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
1473+
unsigned BitWidth = Ty->getPrimitiveSizeInBits();
1474+
auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType());
1475+
if (!MaskTy || SrcX->getType() != SrcY->getType() ||
1476+
!MaskTy->getElementType()->isIntegerTy(1) ||
1477+
MaskTy->getNumElements() != ShAmtDiff ||
1478+
MaskTy->getNumElements() > (BitWidth / 2))
1479+
return false;
1480+
1481+
auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(MaskTy);
1482+
auto *ConcatIntTy =
1483+
Type::getIntNTy(Ty->getContext(), ConcatTy->getNumElements());
1484+
auto *MaskIntTy = Type::getIntNTy(Ty->getContext(), ShAmtDiff);
1485+
1486+
SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
1487+
std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
1488+
1489+
// TODO: Is it worth supporting multi use cases?
1490+
InstructionCost OldCost = 0;
1491+
OldCost += TTI.getArithmeticInstrCost(Instruction::Or, Ty, CostKind);
1492+
OldCost +=
1493+
NumSHL * TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
1494+
OldCost += 2 * TTI.getCastInstrCost(Instruction::ZExt, Ty, MaskIntTy,
1495+
TTI::CastContextHint::None, CostKind);
1496+
OldCost += 2 * TTI.getCastInstrCost(Instruction::BitCast, MaskIntTy, MaskTy,
1497+
TTI::CastContextHint::None, CostKind);
1498+
1499+
InstructionCost NewCost = 0;
1500+
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, MaskTy,
1501+
ConcatMask, CostKind);
1502+
NewCost += TTI.getCastInstrCost(Instruction::BitCast, ConcatIntTy, ConcatTy,
1503+
TTI::CastContextHint::None, CostKind);
1504+
if (Ty != ConcatIntTy)
1505+
NewCost += TTI.getCastInstrCost(Instruction::ZExt, Ty, ConcatIntTy,
1506+
TTI::CastContextHint::None, CostKind);
1507+
if (ShAmtX > 0)
1508+
NewCost += TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
1509+
1510+
if (NewCost > OldCost)
1511+
return false;
1512+
1513+
// Build bool mask concatenation, bitcast back to scalar integer, and perform
1514+
// any residual zero-extension or shifting.
1515+
Value *Concat = Builder.CreateShuffleVector(SrcX, SrcY, ConcatMask);
1516+
Worklist.pushValue(Concat);
1517+
1518+
Value *Result = Builder.CreateBitCast(Concat, ConcatIntTy);
1519+
1520+
if (Ty != ConcatIntTy) {
1521+
Worklist.pushValue(Result);
1522+
Result = Builder.CreateZExt(Result, Ty);
1523+
}
1524+
1525+
if (ShAmtX > 0) {
1526+
Worklist.pushValue(Result);
1527+
Result = Builder.CreateShl(Result, ShAmtX);
1528+
}
1529+
1530+
replaceValue(I, *Result);
1531+
return true;
1532+
}
1533+
14261534
/// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
14271535
/// --> "binop (shuffle), (shuffle)".
14281536
bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
@@ -2908,6 +3016,9 @@ bool VectorCombine::run() {
29083016
if (TryEarlyFoldsOnly)
29093017
return;
29103018

3019+
if (I.getType()->isIntegerTy())
3020+
MadeChange |= foldConcatOfBoolMasks(I);
3021+
29113022
// Otherwise, try folds that improve codegen but may interfere with
29123023
// early IR canonicalizations.
29133024
// The type checking is for run-time efficiency. We can avoid wasting time

0 commit comments

Comments
 (0)