Skip to content

Commit 0fa258c

Browse files
committed
[X86] Implement certain 16-bit vector shifts via 32-bit shifts
x86 vector ISAs are non-orthogonal in a number of ways. For example, AVX2 has vpsravd but it does not have vpsravw. However, we can simulate it via vpsrlvd and some SWAR-style masking. Another example is 8-bit shifts: we can use vpsllvd to simulate the missing "vpsllvb" if shift amounts can be shared for a single lane. Existing code generation would use a variety of techniques including vpmulhuw which is higher latency and often has more rigid port requirements than simple bitwise operations.
1 parent e584278 commit 0fa258c

10 files changed

+1013
-74
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 100 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ static cl::opt<int> BrMergingCcmpBias(
9595
"supports conditional compare instructions."),
9696
cl::Hidden);
9797

98+
static cl::opt<bool>
99+
WidenShift("x86-widen-shift", cl::init(true),
100+
cl::desc("Replacte narrow shifts with wider shifts."),
101+
cl::Hidden);
102+
98103
static cl::opt<int> BrMergingLikelyBias(
99104
"x86-br-merging-likely-bias", cl::init(0),
100105
cl::desc("Increases 'x86-br-merging-base-cost' in cases that it is likely "
@@ -29851,119 +29856,143 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
2985129856
}
2985229857
}
2985329858

29854-
// Constant ISD::SRA/SRL/SHL can be performed efficiently on vXi8 vectors by
29855-
// using vXi16 vector operations.
29859+
// Constant ISD::SRA/SRL/SHL can be performed efficiently on vXiN vectors by
29860+
// using vYiM vector operations where X*N == Y*M and M > N.
2985629861
if (ConstantAmt &&
29857-
(VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
29858-
(VT == MVT::v64i8 && Subtarget.hasBWI())) &&
29862+
(VT == MVT::v16i8 || VT == MVT::v32i8 || VT == MVT::v64i8 ||
29863+
VT == MVT::v8i16 || VT == MVT::v16i16 || VT == MVT::v32i16) &&
2985929864
!Subtarget.hasXOP()) {
29865+
MVT NarrowScalarVT = VT.getScalarType();
2986029866
int NumElts = VT.getVectorNumElements();
29861-
MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2);
29862-
// We can do this extra fast if each pair of i8 elements is shifted by the
29863-
// same amount by doing this SWAR style: use a shift to move the valid bits
29864-
// to the right position, mask out any bits which crossed from one element
29865-
// to the other.
29866-
APInt UndefElts;
29867-
SmallVector<APInt, 64> AmtBits;
29867+
// We can do this extra fast if each pair of narrow elements is shifted by
29868+
// the same amount by doing this SWAR style: use a shift to move the valid
29869+
// bits to the right position, mask out any bits which crossed from one
29870+
// element to the other.
2986829871
// This optimized lowering is only valid if the elements in a pair can
2986929872
// be treated identically.
29870-
bool SameShifts = true;
29871-
SmallVector<APInt, 32> AmtBits16(NumElts / 2);
29872-
APInt UndefElts16 = APInt::getZero(AmtBits16.size());
29873-
if (getTargetConstantBitsFromNode(Amt, /*EltSizeInBits=*/8, UndefElts,
29874-
AmtBits, /*AllowWholeUndefs=*/true,
29875-
/*AllowPartialUndefs=*/false)) {
29876-
// Collect information to construct the BUILD_VECTOR for the i16 version
29877-
// of the shift. Conceptually, this is equivalent to:
29878-
// 1. Making sure the shift amounts are the same for both the low i8 and
29879-
// high i8 corresponding to the i16 lane.
29880-
// 2. Extending that shift amount to i16 for a build vector operation.
29881-
//
29882-
// We want to handle undef shift amounts which requires a little more
29883-
// logic (e.g. if one is undef and the other is not, grab the other shift
29884-
// amount).
29885-
for (unsigned SrcI = 0, E = AmtBits.size(); SrcI != E; SrcI += 2) {
29873+
SmallVector<SDValue, 32> AmtWideElts;
29874+
AmtWideElts.reserve(NumElts);
29875+
for (int I = 0; I != NumElts; ++I) {
29876+
AmtWideElts.push_back(Amt.getOperand(I));
29877+
}
29878+
SmallVector<SDValue, 32> TmpAmtWideElts;
29879+
int WideEltSizeInBits = EltSizeInBits;
29880+
while (WideEltSizeInBits < 32) {
29881+
// AVX1 does not have psrlvd, etc. which makes interesting 32-bit shifts
29882+
// unprofitable.
29883+
if (WideEltSizeInBits >= 16 && !Subtarget.hasAVX2()) {
29884+
break;
29885+
}
29886+
TmpAmtWideElts.resize(AmtWideElts.size() / 2);
29887+
bool SameShifts = true;
29888+
for (unsigned SrcI = 0, E = AmtWideElts.size(); SrcI != E; SrcI += 2) {
2988629889
unsigned DstI = SrcI / 2;
2988729890
// Both elements are undef? Make a note and keep going.
29888-
if (UndefElts[SrcI] && UndefElts[SrcI + 1]) {
29889-
AmtBits16[DstI] = APInt::getZero(16);
29890-
UndefElts16.setBit(DstI);
29891+
if (AmtWideElts[SrcI].isUndef() && AmtWideElts[SrcI + 1].isUndef()) {
29892+
TmpAmtWideElts[DstI] = AmtWideElts[SrcI];
2989129893
continue;
2989229894
}
2989329895
// Even element is undef? We will shift it by the same shift amount as
2989429896
// the odd element.
29895-
if (UndefElts[SrcI]) {
29896-
AmtBits16[DstI] = AmtBits[SrcI + 1].zext(16);
29897+
if (AmtWideElts[SrcI].isUndef()) {
29898+
TmpAmtWideElts[DstI] = AmtWideElts[SrcI + 1];
2989729899
continue;
2989829900
}
2989929901
// Odd element is undef? We will shift it by the same shift amount as
2990029902
// the even element.
29901-
if (UndefElts[SrcI + 1]) {
29902-
AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29903+
if (AmtWideElts[SrcI + 1].isUndef()) {
29904+
TmpAmtWideElts[DstI] = AmtWideElts[SrcI];
2990329905
continue;
2990429906
}
2990529907
// Both elements are equal.
29906-
if (AmtBits[SrcI] == AmtBits[SrcI + 1]) {
29907-
AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29908+
if (AmtWideElts[SrcI].getNode()->getAsAPIntVal() ==
29909+
AmtWideElts[SrcI + 1].getNode()->getAsAPIntVal()) {
29910+
TmpAmtWideElts[DstI] = AmtWideElts[SrcI];
2990829911
continue;
2990929912
}
29910-
// One of the provisional i16 elements will not have the same shift
29913+
// One of the provisional wide elements will not have the same shift
2991129914
// amount. Let's bail.
2991229915
SameShifts = false;
2991329916
break;
2991429917
}
29918+
if (!SameShifts) {
29919+
break;
29920+
}
29921+
WideEltSizeInBits *= 2;
29922+
std::swap(TmpAmtWideElts, AmtWideElts);
2991529923
}
29924+
APInt APIntShiftAmt;
29925+
bool IsConstantSplat = X86::isConstantSplat(Amt, APIntShiftAmt);
29926+
bool Profitable = WidenShift;
29927+
// AVX512BW brings support for vpsllvw.
29928+
if (WideEltSizeInBits * AmtWideElts.size() >= 512 &&
29929+
WideEltSizeInBits < 32 && !Subtarget.hasBWI()) {
29930+
Profitable = false;
29931+
}
29932+
// Leave AVX512 uniform arithmetic shifts alone, they can be implemented
29933+
// fairly cheaply in other ways.
29934+
if (WideEltSizeInBits * AmtWideElts.size() >= 512 && IsConstantSplat) {
29935+
Profitable = false;
29936+
}
29937+
// Leave it up to GFNI if we have it around.
29938+
// TODO: gf2p8affine is usually higher latency and more port restricted. It
29939+
// is probably a win to use other strategies in some cases.
29940+
if (EltSizeInBits == 8 && Subtarget.hasGFNI()) {
29941+
Profitable = false;
29942+
}
29943+
29944+
// AVX1 does not have vpand which makes our masking impractical. It does
29945+
// have vandps but that is an FP instruction and crossing FP<->int typically
29946+
// has some cost.
29947+
if (WideEltSizeInBits * AmtWideElts.size() >= 256 &&
29948+
(WideEltSizeInBits < 32 || IsConstantSplat) && !Subtarget.hasAVX2()) {
29949+
Profitable = false;
29950+
}
29951+
int WideNumElts = AmtWideElts.size();
2991629952
// We are only dealing with identical pairs.
29917-
if (SameShifts) {
29918-
// Cast the operand to vXi16.
29919-
SDValue R16 = DAG.getBitcast(VT16, R);
29953+
if (Profitable && WideNumElts != NumElts) {
29954+
MVT WideScalarVT = MVT::getIntegerVT(WideEltSizeInBits);
29955+
MVT WideVT = MVT::getVectorVT(WideScalarVT, WideNumElts);
29956+
// Cast the operand to vXiM.
29957+
SDValue RWide = DAG.getBitcast(WideVT, R);
2992029958
// Create our new vector of shift amounts.
29921-
SDValue Amt16 = getConstVector(AmtBits16, UndefElts16, VT16, DAG, dl);
29959+
SDValue AmtWide = DAG.getBuildVector(
29960+
MVT::getVectorVT(NarrowScalarVT, WideNumElts), dl, AmtWideElts);
29961+
AmtWide = DAG.getZExtOrTrunc(AmtWide, dl, WideVT);
2992229962
// Perform the actual shift.
2992329963
unsigned LogicalOpc = Opc == ISD::SRA ? ISD::SRL : Opc;
29924-
SDValue ShiftedR = DAG.getNode(LogicalOpc, dl, VT16, R16, Amt16);
29964+
SDValue ShiftedR = DAG.getNode(LogicalOpc, dl, WideVT, RWide, AmtWide);
2992529965
// Now we need to construct a mask which will "drop" bits that get
2992629966
// shifted past the LSB/MSB. For a logical shift left, it will look
2992729967
// like:
29928-
// MaskLowBits = (0xff << Amt16) & 0xff;
29929-
// MaskHighBits = MaskLowBits << 8;
29930-
// Mask = MaskLowBits | MaskHighBits;
29968+
// FullMask = (1 << EltSizeInBits) - 1
29969+
// Mask = FullMask << Amt
2993129970
//
29932-
// This masking ensures that bits cannot migrate from one i8 to
29971+
// This masking ensures that bits cannot migrate from one narrow lane to
2993329972
// another. The construction of this mask will be constant folded.
2993429973
// The mask for a logical right shift is nearly identical, the only
29935-
// difference is that 0xff is shifted right instead of left.
29936-
SDValue Cst255 = DAG.getConstant(0xff, dl, MVT::i16);
29937-
SDValue Splat255 = DAG.getSplat(VT16, dl, Cst255);
29938-
// The mask for the low bits is most simply expressed as an 8-bit
29939-
// field of all ones which is shifted in the exact same way the data
29940-
// is shifted but masked with 0xff.
29941-
SDValue MaskLowBits = DAG.getNode(LogicalOpc, dl, VT16, Splat255, Amt16);
29942-
MaskLowBits = DAG.getNode(ISD::AND, dl, VT16, MaskLowBits, Splat255);
29943-
SDValue Cst8 = DAG.getConstant(8, dl, MVT::i16);
29944-
SDValue Splat8 = DAG.getSplat(VT16, dl, Cst8);
29945-
// The mask for the high bits is the same as the mask for the low bits but
29946-
// shifted up by 8.
29947-
SDValue MaskHighBits =
29948-
DAG.getNode(ISD::SHL, dl, VT16, MaskLowBits, Splat8);
29949-
SDValue Mask = DAG.getNode(ISD::OR, dl, VT16, MaskLowBits, MaskHighBits);
29974+
// difference is that the all ones mask is shifted right instead of left.
29975+
SDValue CstFullMask = DAG.getAllOnesConstant(dl, NarrowScalarVT);
29976+
SDValue SplatFullMask = DAG.getSplat(VT, dl, CstFullMask);
29977+
SDValue Mask = DAG.getNode(LogicalOpc, dl, VT, SplatFullMask, Amt);
29978+
Mask = DAG.getBitcast(WideVT, Mask);
2995029979
// Finally, we mask the shifted vector with the SWAR mask.
29951-
SDValue Masked = DAG.getNode(ISD::AND, dl, VT16, ShiftedR, Mask);
29980+
SDValue Masked = DAG.getNode(ISD::AND, dl, WideVT, ShiftedR, Mask);
2995229981
Masked = DAG.getBitcast(VT, Masked);
2995329982
if (Opc != ISD::SRA) {
2995429983
// Logical shifts are complete at this point.
2995529984
return Masked;
2995629985
}
2995729986
// At this point, we have done a *logical* shift right. We now need to
2995829987
// sign extend the result so that we get behavior equivalent to an
29959-
// arithmetic shift right. Post-shifting by Amt16, our i8 elements are
29960-
// `8-Amt16` bits wide.
29988+
// arithmetic shift right. Post-shifting by AmtWide, our narrow elements
29989+
// are `EltSizeInBits-AmtWide` bits wide.
2996129990
//
29962-
// To convert our `8-Amt16` bit unsigned numbers to 8-bit signed numbers,
29963-
// we need to replicate the bit at position `7-Amt16` into the MSBs of
29964-
// each i8.
29965-
// We can use the following trick to accomplish this:
29966-
// SignBitMask = 1 << (7-Amt16)
29991+
// To convert our `EltSizeInBits-AmtWide` bit unsigned numbers to signed
29992+
// numbers as wide as `EltSizeInBits`, we need to replicate the bit at
29993+
// position `EltSizeInBits-AmtWide` into the MSBs of each narrow lane. We
29994+
// can use the following trick to accomplish this:
29995+
// SignBitMask = 1 << (EltSizeInBits-AmtWide-1)
2996729996
// (Masked ^ SignBitMask) - SignBitMask
2996829997
//
2996929998
// When the sign bit is already clear, this will compute:
@@ -29977,7 +30006,8 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
2997730006
//
2997830007
// This is equal to Masked - 2*SignBitMask which will correctly sign
2997930008
// extend our result.
29980-
SDValue CstHighBit = DAG.getConstant(0x80, dl, MVT::i8);
30009+
SDValue CstHighBit =
30010+
DAG.getConstant(1 << (EltSizeInBits - 1), dl, NarrowScalarVT);
2998130011
SDValue SplatHighBit = DAG.getSplat(VT, dl, CstHighBit);
2998230012
// This does not induce recursion, all operands are constants.
2998330013
SDValue SignBitMask = DAG.getNode(LogicalOpc, dl, VT, SplatHighBit, Amt);

0 commit comments

Comments
 (0)