@@ -95,6 +95,11 @@ static cl::opt<int> BrMergingCcmpBias(
95
95
"supports conditional compare instructions."),
96
96
cl::Hidden);
97
97
98
+ static cl::opt<bool>
99
+ WidenShift("x86-widen-shift", cl::init(true),
100
+ cl::desc("Replacte narrow shifts with wider shifts."),
101
+ cl::Hidden);
102
+
98
103
static cl::opt<int> BrMergingLikelyBias(
99
104
"x86-br-merging-likely-bias", cl::init(0),
100
105
cl::desc("Increases 'x86-br-merging-base-cost' in cases that it is likely "
@@ -29851,119 +29856,143 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
29851
29856
}
29852
29857
}
29853
29858
29854
- // Constant ISD::SRA/SRL/SHL can be performed efficiently on vXi8 vectors by
29855
- // using vXi16 vector operations.
29859
+ // Constant ISD::SRA/SRL/SHL can be performed efficiently on vXiN vectors by
29860
+ // using vYiM vector operations where X*N == Y*M and M > N .
29856
29861
if (ConstantAmt &&
29857
- (VT == MVT::v16i8 || ( VT == MVT::v32i8 && Subtarget.hasInt256()) ||
29858
- ( VT == MVT::v64i8 && Subtarget.hasBWI()) ) &&
29862
+ (VT == MVT::v16i8 || VT == MVT::v32i8 || VT == MVT::v64i8 ||
29863
+ VT == MVT::v8i16 || VT == MVT::v16i16 || VT == MVT::v32i16 ) &&
29859
29864
!Subtarget.hasXOP()) {
29865
+ MVT NarrowScalarVT = VT.getScalarType();
29860
29866
int NumElts = VT.getVectorNumElements();
29861
- MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2);
29862
- // We can do this extra fast if each pair of i8 elements is shifted by the
29863
- // same amount by doing this SWAR style: use a shift to move the valid bits
29864
- // to the right position, mask out any bits which crossed from one element
29865
- // to the other.
29866
- APInt UndefElts;
29867
- SmallVector<APInt, 64> AmtBits;
29867
+ // We can do this extra fast if each pair of narrow elements is shifted by
29868
+ // the same amount by doing this SWAR style: use a shift to move the valid
29869
+ // bits to the right position, mask out any bits which crossed from one
29870
+ // element to the other.
29868
29871
// This optimized lowering is only valid if the elements in a pair can
29869
29872
// be treated identically.
29870
- bool SameShifts = true ;
29871
- SmallVector<APInt, 32> AmtBits16 (NumElts / 2 );
29872
- APInt UndefElts16 = APInt::getZero(AmtBits16.size());
29873
- if (getTargetConstantBitsFromNode( Amt, /*EltSizeInBits=*/8, UndefElts,
29874
- AmtBits, /*AllowWholeUndefs=*/true,
29875
- /*AllowPartialUndefs=*/false)) {
29876
- // Collect information to construct the BUILD_VECTOR for the i16 version
29877
- // of the shift. Conceptually, this is equivalent to:
29878
- // 1. Making sure the shift amounts are the same for both the low i8 and
29879
- // high i8 corresponding to the i16 lane .
29880
- // 2. Extending that shift amount to i16 for a build vector operation.
29881
- //
29882
- // We want to handle undef shift amounts which requires a little more
29883
- // logic (e.g. if one is undef and the other is not, grab the other shift
29884
- // amount).
29885
- for (unsigned SrcI = 0, E = AmtBits .size(); SrcI != E; SrcI += 2) {
29873
+ SmallVector<SDValue, 32> AmtWideElts ;
29874
+ AmtWideElts.reserve (NumElts);
29875
+ for (int I = 0; I != NumElts; ++I) {
29876
+ AmtWideElts.push_back( Amt.getOperand(I));
29877
+ }
29878
+ SmallVector<SDValue, 32> TmpAmtWideElts;
29879
+ int WideEltSizeInBits = EltSizeInBits;
29880
+ while (WideEltSizeInBits < 32) {
29881
+ // AVX1 does not have psrlvd, etc. which makes interesting 32-bit shifts
29882
+ // unprofitable .
29883
+ if (WideEltSizeInBits >= 16 && !Subtarget.hasAVX2()) {
29884
+ break;
29885
+ }
29886
+ TmpAmtWideElts.resize(AmtWideElts.size() / 2);
29887
+ bool SameShifts = true;
29888
+ for (unsigned SrcI = 0, E = AmtWideElts .size(); SrcI != E; SrcI += 2) {
29886
29889
unsigned DstI = SrcI / 2;
29887
29890
// Both elements are undef? Make a note and keep going.
29888
- if (UndefElts[SrcI] && UndefElts[SrcI + 1]) {
29889
- AmtBits16[DstI] = APInt::getZero(16);
29890
- UndefElts16.setBit(DstI);
29891
+ if (AmtWideElts[SrcI].isUndef() && AmtWideElts[SrcI + 1].isUndef()) {
29892
+ TmpAmtWideElts[DstI] = AmtWideElts[SrcI];
29891
29893
continue;
29892
29894
}
29893
29895
// Even element is undef? We will shift it by the same shift amount as
29894
29896
// the odd element.
29895
- if (UndefElts [SrcI]) {
29896
- AmtBits16 [DstI] = AmtBits [SrcI + 1].zext(16) ;
29897
+ if (AmtWideElts [SrcI].isUndef() ) {
29898
+ TmpAmtWideElts [DstI] = AmtWideElts [SrcI + 1];
29897
29899
continue;
29898
29900
}
29899
29901
// Odd element is undef? We will shift it by the same shift amount as
29900
29902
// the even element.
29901
- if (UndefElts [SrcI + 1]) {
29902
- AmtBits16 [DstI] = AmtBits [SrcI].zext(16) ;
29903
+ if (AmtWideElts [SrcI + 1].isUndef() ) {
29904
+ TmpAmtWideElts [DstI] = AmtWideElts [SrcI];
29903
29905
continue;
29904
29906
}
29905
29907
// Both elements are equal.
29906
- if (AmtBits[SrcI] == AmtBits[SrcI + 1]) {
29907
- AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29908
+ if (AmtWideElts[SrcI].getNode()->getAsAPIntVal() ==
29909
+ AmtWideElts[SrcI + 1].getNode()->getAsAPIntVal()) {
29910
+ TmpAmtWideElts[DstI] = AmtWideElts[SrcI];
29908
29911
continue;
29909
29912
}
29910
- // One of the provisional i16 elements will not have the same shift
29913
+ // One of the provisional wide elements will not have the same shift
29911
29914
// amount. Let's bail.
29912
29915
SameShifts = false;
29913
29916
break;
29914
29917
}
29918
+ if (!SameShifts) {
29919
+ break;
29920
+ }
29921
+ WideEltSizeInBits *= 2;
29922
+ std::swap(TmpAmtWideElts, AmtWideElts);
29915
29923
}
29924
+ APInt APIntShiftAmt;
29925
+ bool IsConstantSplat = X86::isConstantSplat(Amt, APIntShiftAmt);
29926
+ bool Profitable = WidenShift;
29927
+ // AVX512BW brings support for vpsllvw.
29928
+ if (WideEltSizeInBits * AmtWideElts.size() >= 512 &&
29929
+ WideEltSizeInBits < 32 && !Subtarget.hasBWI()) {
29930
+ Profitable = false;
29931
+ }
29932
+ // Leave AVX512 uniform arithmetic shifts alone, they can be implemented
29933
+ // fairly cheaply in other ways.
29934
+ if (WideEltSizeInBits * AmtWideElts.size() >= 512 && IsConstantSplat) {
29935
+ Profitable = false;
29936
+ }
29937
+ // Leave it up to GFNI if we have it around.
29938
+ // TODO: gf2p8affine is usually higher latency and more port restricted. It
29939
+ // is probably a win to use other strategies in some cases.
29940
+ if (EltSizeInBits == 8 && Subtarget.hasGFNI()) {
29941
+ Profitable = false;
29942
+ }
29943
+
29944
+ // AVX1 does not have vpand which makes our masking impractical. It does
29945
+ // have vandps but that is an FP instruction and crossing FP<->int typically
29946
+ // has some cost.
29947
+ if (WideEltSizeInBits * AmtWideElts.size() >= 256 &&
29948
+ (WideEltSizeInBits < 32 || IsConstantSplat) && !Subtarget.hasAVX2()) {
29949
+ Profitable = false;
29950
+ }
29951
+ int WideNumElts = AmtWideElts.size();
29916
29952
// We are only dealing with identical pairs.
29917
- if (SameShifts) {
29918
- // Cast the operand to vXi16.
29919
- SDValue R16 = DAG.getBitcast(VT16, R);
29953
+ if (Profitable && WideNumElts != NumElts) {
29954
+ MVT WideScalarVT = MVT::getIntegerVT(WideEltSizeInBits);
29955
+ MVT WideVT = MVT::getVectorVT(WideScalarVT, WideNumElts);
29956
+ // Cast the operand to vXiM.
29957
+ SDValue RWide = DAG.getBitcast(WideVT, R);
29920
29958
// Create our new vector of shift amounts.
29921
- SDValue Amt16 = getConstVector(AmtBits16, UndefElts16, VT16, DAG, dl);
29959
+ SDValue AmtWide = DAG.getBuildVector(
29960
+ MVT::getVectorVT(NarrowScalarVT, WideNumElts), dl, AmtWideElts);
29961
+ AmtWide = DAG.getZExtOrTrunc(AmtWide, dl, WideVT);
29922
29962
// Perform the actual shift.
29923
29963
unsigned LogicalOpc = Opc == ISD::SRA ? ISD::SRL : Opc;
29924
- SDValue ShiftedR = DAG.getNode(LogicalOpc, dl, VT16, R16, Amt16 );
29964
+ SDValue ShiftedR = DAG.getNode(LogicalOpc, dl, WideVT, RWide, AmtWide );
29925
29965
// Now we need to construct a mask which will "drop" bits that get
29926
29966
// shifted past the LSB/MSB. For a logical shift left, it will look
29927
29967
// like:
29928
- // MaskLowBits = (0xff << Amt16) & 0xff;
29929
- // MaskHighBits = MaskLowBits << 8;
29930
- // Mask = MaskLowBits | MaskHighBits;
29968
+ // FullMask = (1 << EltSizeInBits) - 1
29969
+ // Mask = FullMask << Amt
29931
29970
//
29932
- // This masking ensures that bits cannot migrate from one i8 to
29971
+ // This masking ensures that bits cannot migrate from one narrow lane to
29933
29972
// another. The construction of this mask will be constant folded.
29934
29973
// The mask for a logical right shift is nearly identical, the only
29935
- // difference is that 0xff is shifted right instead of left.
29936
- SDValue Cst255 = DAG.getConstant(0xff, dl, MVT::i16);
29937
- SDValue Splat255 = DAG.getSplat(VT16, dl, Cst255);
29938
- // The mask for the low bits is most simply expressed as an 8-bit
29939
- // field of all ones which is shifted in the exact same way the data
29940
- // is shifted but masked with 0xff.
29941
- SDValue MaskLowBits = DAG.getNode(LogicalOpc, dl, VT16, Splat255, Amt16);
29942
- MaskLowBits = DAG.getNode(ISD::AND, dl, VT16, MaskLowBits, Splat255);
29943
- SDValue Cst8 = DAG.getConstant(8, dl, MVT::i16);
29944
- SDValue Splat8 = DAG.getSplat(VT16, dl, Cst8);
29945
- // The mask for the high bits is the same as the mask for the low bits but
29946
- // shifted up by 8.
29947
- SDValue MaskHighBits =
29948
- DAG.getNode(ISD::SHL, dl, VT16, MaskLowBits, Splat8);
29949
- SDValue Mask = DAG.getNode(ISD::OR, dl, VT16, MaskLowBits, MaskHighBits);
29974
+ // difference is that the all ones mask is shifted right instead of left.
29975
+ SDValue CstFullMask = DAG.getAllOnesConstant(dl, NarrowScalarVT);
29976
+ SDValue SplatFullMask = DAG.getSplat(VT, dl, CstFullMask);
29977
+ SDValue Mask = DAG.getNode(LogicalOpc, dl, VT, SplatFullMask, Amt);
29978
+ Mask = DAG.getBitcast(WideVT, Mask);
29950
29979
// Finally, we mask the shifted vector with the SWAR mask.
29951
- SDValue Masked = DAG.getNode(ISD::AND, dl, VT16 , ShiftedR, Mask);
29980
+ SDValue Masked = DAG.getNode(ISD::AND, dl, WideVT , ShiftedR, Mask);
29952
29981
Masked = DAG.getBitcast(VT, Masked);
29953
29982
if (Opc != ISD::SRA) {
29954
29983
// Logical shifts are complete at this point.
29955
29984
return Masked;
29956
29985
}
29957
29986
// At this point, we have done a *logical* shift right. We now need to
29958
29987
// sign extend the result so that we get behavior equivalent to an
29959
- // arithmetic shift right. Post-shifting by Amt16 , our i8 elements are
29960
- // `8-Amt16 ` bits wide.
29988
+ // arithmetic shift right. Post-shifting by AmtWide , our narrow elements
29989
+ // are `EltSizeInBits-AmtWide ` bits wide.
29961
29990
//
29962
- // To convert our `8-Amt16 ` bit unsigned numbers to 8-bit signed numbers,
29963
- // we need to replicate the bit at position `7-Amt16` into the MSBs of
29964
- // each i8.
29965
- // We can use the following trick to accomplish this:
29966
- // SignBitMask = 1 << (7-Amt16 )
29991
+ // To convert our `EltSizeInBits-AmtWide ` bit unsigned numbers to signed
29992
+ // numbers as wide as `EltSizeInBits`, we need to replicate the bit at
29993
+ // position `EltSizeInBits-AmtWide` into the MSBs of each narrow lane. We
29994
+ // can use the following trick to accomplish this:
29995
+ // SignBitMask = 1 << (EltSizeInBits-AmtWide-1 )
29967
29996
// (Masked ^ SignBitMask) - SignBitMask
29968
29997
//
29969
29998
// When the sign bit is already clear, this will compute:
@@ -29977,7 +30006,8 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
29977
30006
//
29978
30007
// This is equal to Masked - 2*SignBitMask which will correctly sign
29979
30008
// extend our result.
29980
- SDValue CstHighBit = DAG.getConstant(0x80, dl, MVT::i8);
30009
+ SDValue CstHighBit =
30010
+ DAG.getConstant(1 << (EltSizeInBits - 1), dl, NarrowScalarVT);
29981
30011
SDValue SplatHighBit = DAG.getSplat(VT, dl, CstHighBit);
29982
30012
// This does not induce recursion, all operands are constants.
29983
30013
SDValue SignBitMask = DAG.getNode(LogicalOpc, dl, VT, SplatHighBit, Amt);
0 commit comments