Skip to content

Commit b22fa90

Browse files
authored
[ValueTracking][X86] Compute KnownBits for phadd/phsub (#92429)
Add KnownBits computations to ValueTracking and X86 DAG lowering. These instructions add/subtract adjacent vector elements in their operands. Example: phadd [X1, X2] [Y1, Y2] = [X1 + X2, Y1 + Y2]. This means that, in this example, we can compute the KnownBits of the operation by computing the KnownBits of [X1, X2] + [X1, X2] and [Y1, Y2] + [Y1, Y2] and intersecting the results. This approach also generalizes to all x86 vector types. There are also the operations phadd.sw and phsub.sw, which perform saturating addition/subtraction. Use sadd_sat and ssub_sat to compute the KnownBits of these operations. Also adjust the existing test case pr53247.ll because it can be transformed to a constant using the new KnownBits computation. Fixes #82516.
1 parent 78dea4c commit b22fa90

File tree

8 files changed

+630
-31
lines changed

8 files changed

+630
-31
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,23 @@ void processShuffleMasks(
255255
function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
256256
function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction);
257257

258+
/// Compute the demanded elements mask of horizontal binary operations. A
259+
/// horizontal operation combines two adjacent elements in a vector operand.
260+
/// This function returns a mask for the elements that correspond to the first
261+
/// operand of this horizontal combination. For example, for two vectors
262+
/// [X1, X2, X3, X4] and [Y1, Y2, Y3, Y4], the resulting mask can include the
263+
/// elements X1, X3, Y1, and Y3. To get the other operands, simply shift the
264+
/// result of this function to the left by 1.
265+
///
266+
/// \param VectorBitWidth the total bit width of the vector
267+
/// \param DemandedElts the demanded elements mask for the operation
268+
/// \param DemandedLHS the demanded elements mask for the left operand
269+
/// \param DemandedRHS the demanded elements mask for the right operand
270+
void getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth,
271+
const APInt &DemandedElts,
272+
APInt &DemandedLHS,
273+
APInt &DemandedRHS);
274+
258275
/// Compute a map of integer instructions to their minimum legal type
259276
/// size.
260277
///

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,32 @@ getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts,
959959
return KnownOut;
960960
}
961961

962+
static KnownBits computeKnownBitsForHorizontalOperation(
963+
const Operator *I, const APInt &DemandedElts, unsigned Depth,
964+
const SimplifyQuery &Q,
965+
const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
966+
KnownBitsFunc) {
967+
APInt DemandedEltsLHS, DemandedEltsRHS;
968+
getHorizDemandedEltsForFirstOperand(Q.DL.getTypeSizeInBits(I->getType()),
969+
DemandedElts, DemandedEltsLHS,
970+
DemandedEltsRHS);
971+
972+
const auto ComputeForSingleOpFunc =
973+
[Depth, &Q, KnownBitsFunc](const Value *Op, APInt &DemandedEltsOp) {
974+
return KnownBitsFunc(
975+
computeKnownBits(Op, DemandedEltsOp, Depth + 1, Q),
976+
computeKnownBits(Op, DemandedEltsOp << 1, Depth + 1, Q));
977+
};
978+
979+
if (DemandedEltsRHS.isZero())
980+
return ComputeForSingleOpFunc(I->getOperand(0), DemandedEltsLHS);
981+
if (DemandedEltsLHS.isZero())
982+
return ComputeForSingleOpFunc(I->getOperand(1), DemandedEltsRHS);
983+
984+
return ComputeForSingleOpFunc(I->getOperand(0), DemandedEltsLHS)
985+
.intersectWith(ComputeForSingleOpFunc(I->getOperand(1), DemandedEltsRHS));
986+
}
987+
962988
// Public so this can be used in `SimplifyDemandedUseBits`.
963989
KnownBits llvm::analyzeKnownBitsFromAndXorOr(const Operator *I,
964990
const KnownBits &KnownLHS,
@@ -1756,6 +1782,44 @@ static void computeKnownBitsFromOperator(const Operator *I,
17561782
case Intrinsic::x86_sse42_crc32_64_64:
17571783
Known.Zero.setBitsFrom(32);
17581784
break;
1785+
case Intrinsic::x86_ssse3_phadd_d_128:
1786+
case Intrinsic::x86_ssse3_phadd_w_128:
1787+
case Intrinsic::x86_avx2_phadd_d:
1788+
case Intrinsic::x86_avx2_phadd_w: {
1789+
Known = computeKnownBitsForHorizontalOperation(
1790+
I, DemandedElts, Depth, Q,
1791+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
1792+
return KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
1793+
/*NUW=*/false, KnownLHS,
1794+
KnownRHS);
1795+
});
1796+
break;
1797+
}
1798+
case Intrinsic::x86_ssse3_phadd_sw_128:
1799+
case Intrinsic::x86_avx2_phadd_sw: {
1800+
Known = computeKnownBitsForHorizontalOperation(I, DemandedElts, Depth,
1801+
Q, KnownBits::sadd_sat);
1802+
break;
1803+
}
1804+
case Intrinsic::x86_ssse3_phsub_d_128:
1805+
case Intrinsic::x86_ssse3_phsub_w_128:
1806+
case Intrinsic::x86_avx2_phsub_d:
1807+
case Intrinsic::x86_avx2_phsub_w: {
1808+
Known = computeKnownBitsForHorizontalOperation(
1809+
I, DemandedElts, Depth, Q,
1810+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
1811+
return KnownBits::computeForAddSub(/*Add=*/false, /*NSW=*/false,
1812+
/*NUW=*/false, KnownLHS,
1813+
KnownRHS);
1814+
});
1815+
break;
1816+
}
1817+
case Intrinsic::x86_ssse3_phsub_sw_128:
1818+
case Intrinsic::x86_avx2_phsub_sw: {
1819+
Known = computeKnownBitsForHorizontalOperation(I, DemandedElts, Depth,
1820+
Q, KnownBits::ssub_sat);
1821+
break;
1822+
}
17591823
case Intrinsic::riscv_vsetvli:
17601824
case Intrinsic::riscv_vsetvlimax: {
17611825
bool HasAVL = II->getIntrinsicID() == Intrinsic::riscv_vsetvli;

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,34 @@ void llvm::processShuffleMasks(
567567
}
568568
}
569569

570+
void llvm::getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth,
571+
const APInt &DemandedElts,
572+
APInt &DemandedLHS,
573+
APInt &DemandedRHS) {
574+
assert(VectorBitWidth >= 128 && "Vectors smaller than 128 bit not supported");
575+
int NumLanes = VectorBitWidth / 128;
576+
int NumElts = DemandedElts.getBitWidth();
577+
int NumEltsPerLane = NumElts / NumLanes;
578+
int HalfEltsPerLane = NumEltsPerLane / 2;
579+
580+
DemandedLHS = APInt::getZero(NumElts);
581+
DemandedRHS = APInt::getZero(NumElts);
582+
583+
// Map DemandedElts to the horizontal operands.
584+
for (int Idx = 0; Idx != NumElts; ++Idx) {
585+
if (!DemandedElts[Idx])
586+
continue;
587+
int LaneIdx = (Idx / NumEltsPerLane) * NumEltsPerLane;
588+
int LocalIdx = Idx % NumEltsPerLane;
589+
if (LocalIdx < HalfEltsPerLane) {
590+
DemandedLHS.setBit(LaneIdx + 2 * LocalIdx);
591+
} else {
592+
LocalIdx -= HalfEltsPerLane;
593+
DemandedRHS.setBit(LaneIdx + 2 * LocalIdx);
594+
}
595+
}
596+
}
597+
570598
MapVector<Instruction *, uint64_t>
571599
llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
572600
const TargetTransformInfo *TTI) {

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5204,29 +5204,10 @@ static void getPackDemandedElts(EVT VT, const APInt &DemandedElts,
52045204
// Split the demanded elts of a HADD/HSUB node between its operands.
52055205
static void getHorizDemandedElts(EVT VT, const APInt &DemandedElts,
52065206
APInt &DemandedLHS, APInt &DemandedRHS) {
5207-
int NumLanes = VT.getSizeInBits() / 128;
5208-
int NumElts = DemandedElts.getBitWidth();
5209-
int NumEltsPerLane = NumElts / NumLanes;
5210-
int HalfEltsPerLane = NumEltsPerLane / 2;
5211-
5212-
DemandedLHS = APInt::getZero(NumElts);
5213-
DemandedRHS = APInt::getZero(NumElts);
5214-
5215-
// Map DemandedElts to the horizontal operands.
5216-
for (int Idx = 0; Idx != NumElts; ++Idx) {
5217-
if (!DemandedElts[Idx])
5218-
continue;
5219-
int LaneIdx = (Idx / NumEltsPerLane) * NumEltsPerLane;
5220-
int LocalIdx = Idx % NumEltsPerLane;
5221-
if (LocalIdx < HalfEltsPerLane) {
5222-
DemandedLHS.setBit(LaneIdx + 2 * LocalIdx + 0);
5223-
DemandedLHS.setBit(LaneIdx + 2 * LocalIdx + 1);
5224-
} else {
5225-
LocalIdx -= HalfEltsPerLane;
5226-
DemandedRHS.setBit(LaneIdx + 2 * LocalIdx + 0);
5227-
DemandedRHS.setBit(LaneIdx + 2 * LocalIdx + 1);
5228-
}
5229-
}
5207+
getHorizDemandedEltsForFirstOperand(VT.getSizeInBits(), DemandedElts,
5208+
DemandedLHS, DemandedRHS);
5209+
DemandedLHS |= DemandedLHS << 1;
5210+
DemandedRHS |= DemandedRHS << 1;
52305211
}
52315212

52325213
/// Calculates the shuffle mask corresponding to the target-specific opcode.
@@ -37174,6 +37155,32 @@ static void computeKnownBitsForPMADDUBSW(SDValue LHS, SDValue RHS,
3717437155
Known = KnownBits::sadd_sat(Lo, Hi);
3717537156
}
3717637157

37158+
static KnownBits computeKnownBitsForHorizontalOperation(
37159+
const SDValue Op, const APInt &DemandedElts, unsigned Depth,
37160+
const SelectionDAG &DAG,
37161+
const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
37162+
KnownBitsFunc) {
37163+
APInt DemandedEltsLHS, DemandedEltsRHS;
37164+
getHorizDemandedEltsForFirstOperand(Op.getValueType().getSizeInBits(),
37165+
DemandedElts, DemandedEltsLHS,
37166+
DemandedEltsRHS);
37167+
37168+
const auto ComputeForSingleOpFunc =
37169+
[&DAG, Depth, KnownBitsFunc](SDValue Op, APInt &DemandedEltsOp) {
37170+
return KnownBitsFunc(
37171+
DAG.computeKnownBits(Op, DemandedEltsOp, Depth + 1),
37172+
DAG.computeKnownBits(Op, DemandedEltsOp << 1, Depth + 1));
37173+
};
37174+
37175+
if (DemandedEltsRHS.isZero())
37176+
return ComputeForSingleOpFunc(Op.getOperand(0), DemandedEltsLHS);
37177+
if (DemandedEltsLHS.isZero())
37178+
return ComputeForSingleOpFunc(Op.getOperand(1), DemandedEltsRHS);
37179+
37180+
return ComputeForSingleOpFunc(Op.getOperand(0), DemandedEltsLHS)
37181+
.intersectWith(ComputeForSingleOpFunc(Op.getOperand(1), DemandedEltsRHS));
37182+
}
37183+
3717737184
void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3717837185
KnownBits &Known,
3717937186
const APInt &DemandedElts,
@@ -37503,6 +37510,17 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3750337510
}
3750437511
break;
3750537512
}
37513+
case X86ISD::HADD:
37514+
case X86ISD::HSUB: {
37515+
Known = computeKnownBitsForHorizontalOperation(
37516+
Op, DemandedElts, Depth, DAG,
37517+
[Opc](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
37518+
return KnownBits::computeForAddSub(
37519+
/*Add=*/Opc == X86ISD::HADD, /*NSW=*/false, /*NUW=*/false,
37520+
KnownLHS, KnownRHS);
37521+
});
37522+
break;
37523+
}
3750637524
case ISD::INTRINSIC_WO_CHAIN: {
3750737525
switch (Op->getConstantOperandVal(0)) {
3750837526
case Intrinsic::x86_sse2_pmadd_wd:

0 commit comments

Comments
 (0)