Skip to content

Commit 86779da

Browse files
authored
[VectorCombine] Fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))" -> "(bitcast (concat X, Y))" MOVMSK bool mask style patterns (#119695)
Mask/Bool vectors are often bitcast to/from scalar integers, in particular when concatenating mask results, often this is due to the difficulties of working with vector of bools on C/C++. On x86 this typically involves the MOVMSK/KMOV instructions. To concatenate bool masks, these are typically cast to scalars, which are then zero-extended, shifted and OR'd together. This patch attempts to match these scalar concatenation patterns and convert them to vector shuffles instead. This in turn often assists with further vector combines, depending on the cost model. Reapplied patch from #119559 - fixed use after free issue. Fixes #111431
1 parent f9734b9 commit 86779da

File tree

2 files changed

+266
-114
lines changed

2 files changed

+266
-114
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class VectorCombine {
115115
bool foldExtractedCmps(Instruction &I);
116116
bool foldSingleElementStore(Instruction &I);
117117
bool scalarizeLoadExtract(Instruction &I);
118+
bool foldConcatOfBoolMasks(Instruction &I);
118119
bool foldPermuteOfBinops(Instruction &I);
119120
bool foldShuffleOfBinops(Instruction &I);
120121
bool foldShuffleOfCastops(Instruction &I);
@@ -1423,6 +1424,113 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
14231424
return true;
14241425
}
14251426

1427+
/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
1428+
/// to "(bitcast (concat X, Y))"
1429+
/// where X/Y are bitcasted from i1 mask vectors.
1430+
bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
1431+
Type *Ty = I.getType();
1432+
if (!Ty->isIntegerTy())
1433+
return false;
1434+
1435+
// TODO: Add big endian test coverage
1436+
if (DL->isBigEndian())
1437+
return false;
1438+
1439+
// Restrict to disjoint cases so the mask vectors aren't overlapping.
1440+
Instruction *X, *Y;
1441+
if (!match(&I, m_DisjointOr(m_Instruction(X), m_Instruction(Y))))
1442+
return false;
1443+
1444+
// Allow both sources to contain shl, to handle more generic pattern:
1445+
// "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
1446+
Value *SrcX;
1447+
uint64_t ShAmtX = 0;
1448+
if (!match(X, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX)))))) &&
1449+
!match(X, m_OneUse(
1450+
m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX))))),
1451+
m_ConstantInt(ShAmtX)))))
1452+
return false;
1453+
1454+
Value *SrcY;
1455+
uint64_t ShAmtY = 0;
1456+
if (!match(Y, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY)))))) &&
1457+
!match(Y, m_OneUse(
1458+
m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY))))),
1459+
m_ConstantInt(ShAmtY)))))
1460+
return false;
1461+
1462+
// Canonicalize larger shift to the RHS.
1463+
if (ShAmtX > ShAmtY) {
1464+
std::swap(X, Y);
1465+
std::swap(SrcX, SrcY);
1466+
std::swap(ShAmtX, ShAmtY);
1467+
}
1468+
1469+
// Ensure both sources are matching vXi1 bool mask types, and that the shift
1470+
// difference is the mask width so they can be easily concatenated together.
1471+
uint64_t ShAmtDiff = ShAmtY - ShAmtX;
1472+
unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
1473+
unsigned BitWidth = Ty->getPrimitiveSizeInBits();
1474+
auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType());
1475+
if (!MaskTy || SrcX->getType() != SrcY->getType() ||
1476+
!MaskTy->getElementType()->isIntegerTy(1) ||
1477+
MaskTy->getNumElements() != ShAmtDiff ||
1478+
MaskTy->getNumElements() > (BitWidth / 2))
1479+
return false;
1480+
1481+
auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(MaskTy);
1482+
auto *ConcatIntTy =
1483+
Type::getIntNTy(Ty->getContext(), ConcatTy->getNumElements());
1484+
auto *MaskIntTy = Type::getIntNTy(Ty->getContext(), ShAmtDiff);
1485+
1486+
SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
1487+
std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
1488+
1489+
// TODO: Is it worth supporting multi use cases?
1490+
InstructionCost OldCost = 0;
1491+
OldCost += TTI.getArithmeticInstrCost(Instruction::Or, Ty, CostKind);
1492+
OldCost +=
1493+
NumSHL * TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
1494+
OldCost += 2 * TTI.getCastInstrCost(Instruction::ZExt, Ty, MaskIntTy,
1495+
TTI::CastContextHint::None, CostKind);
1496+
OldCost += 2 * TTI.getCastInstrCost(Instruction::BitCast, MaskIntTy, MaskTy,
1497+
TTI::CastContextHint::None, CostKind);
1498+
1499+
InstructionCost NewCost = 0;
1500+
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, MaskTy,
1501+
ConcatMask, CostKind);
1502+
NewCost += TTI.getCastInstrCost(Instruction::BitCast, ConcatIntTy, ConcatTy,
1503+
TTI::CastContextHint::None, CostKind);
1504+
if (Ty != ConcatIntTy)
1505+
NewCost += TTI.getCastInstrCost(Instruction::ZExt, Ty, ConcatIntTy,
1506+
TTI::CastContextHint::None, CostKind);
1507+
if (ShAmtX > 0)
1508+
NewCost += TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
1509+
1510+
if (NewCost > OldCost)
1511+
return false;
1512+
1513+
// Build bool mask concatenation, bitcast back to scalar integer, and perform
1514+
// any residual zero-extension or shifting.
1515+
Value *Concat = Builder.CreateShuffleVector(SrcX, SrcY, ConcatMask);
1516+
Worklist.pushValue(Concat);
1517+
1518+
Value *Result = Builder.CreateBitCast(Concat, ConcatIntTy);
1519+
1520+
if (Ty != ConcatIntTy) {
1521+
Worklist.pushValue(Result);
1522+
Result = Builder.CreateZExt(Result, Ty);
1523+
}
1524+
1525+
if (ShAmtX > 0) {
1526+
Worklist.pushValue(Result);
1527+
Result = Builder.CreateShl(Result, ShAmtX);
1528+
}
1529+
1530+
replaceValue(I, *Result);
1531+
return true;
1532+
}
1533+
14261534
/// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
14271535
/// --> "binop (shuffle), (shuffle)".
14281536
bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
@@ -2945,6 +3053,9 @@ bool VectorCombine::run() {
29453053
case Instruction::FCmp:
29463054
MadeChange |= foldExtractExtract(I);
29473055
break;
3056+
case Instruction::Or:
3057+
MadeChange |= foldConcatOfBoolMasks(I);
3058+
[[fallthrough]];
29483059
default:
29493060
if (Instruction::isBinaryOp(Opcode)) {
29503061
MadeChange |= foldExtractExtract(I);

0 commit comments

Comments
 (0)