@@ -4446,34 +4446,9 @@ static SDValue lowerScalarInsert(SDValue Scalar, SDValue VL, MVT VT,
4446
4446
VL);
4447
4447
}
4448
4448
4449
- // Is this a shuffle extracts either the even or odd elements of a vector?
4450
- // That is, specifically, either (a) or (b) in the options below.
4451
- // Single operand shuffle is easy:
4452
- // a) t35: v8i8 = vector_shuffle<0,2,4,6,u,u,u,u> t34, undef
4453
- // b) t35: v8i8 = vector_shuffle<1,3,5,7,u,u,u,u> t34, undef
4454
- // Double operand shuffle:
4455
- // t34: v8i8 = extract_subvector t11, Constant:i64<0>
4456
- // t33: v8i8 = extract_subvector t11, Constant:i64<8>
4457
- // a) t35: v8i8 = vector_shuffle<0,2,4,6,8,10,12,14> t34, t33
4458
- // b) t35: v8i8 = vector_shuffle<1,3,5,7,9,11,13,15> t34, t33
4459
- static SDValue isDeinterleaveShuffle(MVT VT, MVT ContainerVT, SDValue V1,
4460
- SDValue V2, ArrayRef<int> Mask,
4461
- const RISCVSubtarget &Subtarget) {
4462
- // Need to be able to widen the vector.
4463
- if (VT.getScalarSizeInBits() >= Subtarget.getELen())
4464
- return SDValue();
4465
-
4466
- // First index must be the first even or odd element from V1.
4467
- if (Mask[0] != 0 && Mask[0] != 1)
4468
- return SDValue();
4469
-
4470
- // The others must increase by 2 each time.
4471
- for (unsigned i = 1; i != Mask.size(); ++i)
4472
- if (Mask[i] != -1 && Mask[i] != Mask[0] + (int)i * 2)
4473
- return SDValue();
4474
-
4475
- if (1 == count_if(Mask, [](int Idx) { return Idx != -1; }))
4476
- return SDValue();
4449
+ // Can this shuffle be performed on exactly one (possibly larger) input?
4450
+ static SDValue getSingleShuffleSrc(MVT VT, MVT ContainerVT, SDValue V1,
4451
+ SDValue V2) {
4477
4452
4478
4453
if (V2.isUndef() &&
4479
4454
RISCVTargetLowering::getLMUL(ContainerVT) != RISCVII::VLMUL::LMUL_8)
@@ -4490,12 +4465,13 @@ static SDValue isDeinterleaveShuffle(MVT VT, MVT ContainerVT, SDValue V1,
4490
4465
return SDValue();
4491
4466
4492
4467
// Src needs to have twice the number of elements.
4493
- if (Src.getValueType().getVectorNumElements() != (Mask.size() * 2))
4468
+ unsigned NumElts = VT.getVectorNumElements();
4469
+ if (Src.getValueType().getVectorNumElements() != (NumElts * 2))
4494
4470
return SDValue();
4495
4471
4496
4472
// The extracts must extract the two halves of the source.
4497
4473
if (V1.getConstantOperandVal(1) != 0 ||
4498
- V2.getConstantOperandVal(1) != Mask.size() )
4474
+ V2.getConstantOperandVal(1) != NumElts )
4499
4475
return SDValue();
4500
4476
4501
4477
return Src;
@@ -4612,36 +4588,29 @@ static int isElementRotate(int &LoSrc, int &HiSrc, ArrayRef<int> Mask) {
4612
4588
return Rotation;
4613
4589
}
4614
4590
4615
- // Lower a deinterleave shuffle to vnsrl.
4616
- // [a, p, b, q, c, r, d, s] -> [a, b, c, d] (EvenElts == true)
4617
- // -> [p, q, r, s] (EvenElts == false)
4618
- // VT is the type of the vector to return, <[vscale x ]n x ty>
4619
- // Src is the vector to deinterleave of type <[vscale x ]n*2 x ty>
4620
- static SDValue getDeinterleaveViaVNSRL(const SDLoc &DL, MVT VT, SDValue Src,
4621
- bool EvenElts, SelectionDAG &DAG) {
4622
- // The result is a vector of type <m x n x ty>. The source is a vector of
4623
- // type <m x n*2 x ty> (For the single source case, the high half is undef)
4624
- if (Src.getValueType() == VT) {
4625
- EVT WideVT = VT.getDoubleNumVectorElementsVT();
4626
- Src = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideVT, DAG.getUNDEF(WideVT),
4627
- Src, DAG.getVectorIdxConstant(0, DL));
4628
- }
4629
-
4630
- // Bitcast the source vector from <m x n*2 x ty> -> <m x n x ty*2>
4631
- // This also converts FP to int.
4591
+ // Lower a deinterleave shuffle to SRL and TRUNC. Factor must be
4592
+ // 2, 4, 8 and the integer type Factor-times larger than VT's
4593
+ // element type must be a legal element type.
4594
+ // [a, p, b, q, c, r, d, s] -> [a, b, c, d] (Factor=2, Index=0)
4595
+ // -> [p, q, r, s] (Factor=2, Index=1)
4596
+ static SDValue getDeinterleaveShiftAndTrunc(const SDLoc &DL, MVT VT,
4597
+ SDValue Src, unsigned Factor,
4598
+ unsigned Index, SelectionDAG &DAG) {
4632
4599
unsigned EltBits = VT.getScalarSizeInBits();
4633
- MVT WideSrcVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits * 2),
4634
- VT.getVectorElementCount());
4600
+ ElementCount SrcEC = Src.getValueType().getVectorElementCount();
4601
+ MVT WideSrcVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits * Factor),
4602
+ SrcEC.divideCoefficientBy(Factor));
4603
+ MVT ResVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits),
4604
+ SrcEC.divideCoefficientBy(Factor));
4635
4605
Src = DAG.getBitcast(WideSrcVT, Src);
4636
4606
4637
- MVT IntVT = VT.changeVectorElementTypeToInteger();
4638
-
4639
- // If we want even elements, then the shift amount is 0. Otherwise, shift by
4640
- // the original element size.
4641
- unsigned Shift = EvenElts ? 0 : EltBits;
4607
+ unsigned Shift = Index * EltBits;
4642
4608
SDValue Res = DAG.getNode(ISD::SRL, DL, WideSrcVT, Src,
4643
4609
DAG.getConstant(Shift, DL, WideSrcVT));
4644
- Res = DAG.getNode(ISD::TRUNCATE, DL, IntVT, Res);
4610
+ Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT, Res);
4611
+ MVT IntVT = VT.changeVectorElementTypeToInteger();
4612
+ Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, IntVT, DAG.getUNDEF(IntVT), Res,
4613
+ DAG.getVectorIdxConstant(0, DL));
4645
4614
return DAG.getBitcast(VT, Res);
4646
4615
}
4647
4616
@@ -5332,11 +5301,24 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
5332
5301
if (ShuffleVectorInst::isReverseMask(Mask, NumElts) && V2.isUndef())
5333
5302
return DAG.getNode(ISD::VECTOR_REVERSE, DL, VT, V1);
5334
5303
5335
- // If this is a deinterleave and we can widen the vector, then we can use
5336
- // vnsrl to deinterleave.
5337
- if (SDValue Src =
5338
- isDeinterleaveShuffle(VT, ContainerVT, V1, V2, Mask, Subtarget))
5339
- return getDeinterleaveViaVNSRL(DL, VT, Src, Mask[0] == 0, DAG);
5304
+ // If this is a deinterleave(2,4,8) and we can widen the vector, then we can
5305
+ // use shift and truncate to perform the shuffle.
5306
+ // TODO: For Factor=6, we can perform the first step of the deinterleave via
5307
+ // shift-and-trunc reducing total cost for everything except an mf8 result.
5308
+ // TODO: For Factor=4,8, we can do the same when the ratio isn't high enough
5309
+ // to do the entire operation.
5310
+ if (VT.getScalarSizeInBits() < Subtarget.getELen()) {
5311
+ const unsigned MaxFactor = Subtarget.getELen() / VT.getScalarSizeInBits();
5312
+ assert(MaxFactor == 2 || MaxFactor == 4 || MaxFactor == 8);
5313
+ for (unsigned Factor = 2; Factor <= MaxFactor; Factor <<= 1) {
5314
+ unsigned Index = 0;
5315
+ if (ShuffleVectorInst::isDeInterleaveMaskOfFactor(Mask, Factor, Index) &&
5316
+ 1 < count_if(Mask, [](int Idx) { return Idx != -1; })) {
5317
+ if (SDValue Src = getSingleShuffleSrc(VT, ContainerVT, V1, V2))
5318
+ return getDeinterleaveShiftAndTrunc(DL, VT, Src, Factor, Index, DAG);
5319
+ }
5320
+ }
5321
+ }
5340
5322
5341
5323
if (SDValue V =
5342
5324
lowerVECTOR_SHUFFLEAsVSlideup(DL, VT, V1, V2, Mask, Subtarget, DAG))
@@ -10739,8 +10721,8 @@ SDValue RISCVTargetLowering::lowerVECTOR_DEINTERLEAVE(SDValue Op,
10739
10721
// We can deinterleave through vnsrl.wi if the element type is smaller than
10740
10722
// ELEN
10741
10723
if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) {
10742
- SDValue Even = getDeinterleaveViaVNSRL (DL, VecVT, Concat, true , DAG);
10743
- SDValue Odd = getDeinterleaveViaVNSRL (DL, VecVT, Concat, false , DAG);
10724
+ SDValue Even = getDeinterleaveShiftAndTrunc (DL, VecVT, Concat, 2, 0 , DAG);
10725
+ SDValue Odd = getDeinterleaveShiftAndTrunc (DL, VecVT, Concat, 2, 1 , DAG);
10744
10726
return DAG.getMergeValues({Even, Odd}, DL);
10745
10727
}
10746
10728
0 commit comments