Skip to content

Commit 7818e5a

Browse files
authored
[LoongArch] lower vector shuffle to shift if possible (#132866)
1 parent 807cc37 commit 7818e5a

9 files changed

+354
-3039
lines changed

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 136 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,121 @@ SDValue LoongArchTargetLowering::lowerBITREVERSE(SDValue Op,
525525
}
526526
}
527527

528+
/// Attempts to match a shuffle mask against the VBSLL, VBSRL, VSLLI and VSRLI
529+
/// instruction.
530+
// The funciton matches elements from one of the input vector shuffled to the
531+
// left or right with zeroable elements 'shifted in'. It handles both the
532+
// strictly bit-wise element shifts and the byte shfit across an entire 128-bit
533+
// lane.
534+
// Mostly copied from X86.
535+
static int matchShuffleAsShift(MVT &ShiftVT, unsigned &Opcode,
536+
unsigned ScalarSizeInBits, ArrayRef<int> Mask,
537+
int MaskOffset, const APInt &Zeroable) {
538+
int Size = Mask.size();
539+
unsigned SizeInBits = Size * ScalarSizeInBits;
540+
541+
auto CheckZeros = [&](int Shift, int Scale, bool Left) {
542+
for (int i = 0; i < Size; i += Scale)
543+
for (int j = 0; j < Shift; ++j)
544+
if (!Zeroable[i + j + (Left ? 0 : (Scale - Shift))])
545+
return false;
546+
547+
return true;
548+
};
549+
550+
auto isSequentialOrUndefInRange = [&](unsigned Pos, unsigned Size, int Low,
551+
int Step = 1) {
552+
for (unsigned i = Pos, e = Pos + Size; i != e; ++i, Low += Step)
553+
if (!(Mask[i] == -1 || Mask[i] == Low))
554+
return false;
555+
return true;
556+
};
557+
558+
auto MatchShift = [&](int Shift, int Scale, bool Left) {
559+
for (int i = 0; i != Size; i += Scale) {
560+
unsigned Pos = Left ? i + Shift : i;
561+
unsigned Low = Left ? i : i + Shift;
562+
unsigned Len = Scale - Shift;
563+
if (!isSequentialOrUndefInRange(Pos, Len, Low + MaskOffset))
564+
return -1;
565+
}
566+
567+
int ShiftEltBits = ScalarSizeInBits * Scale;
568+
bool ByteShift = ShiftEltBits > 64;
569+
Opcode = Left ? (ByteShift ? LoongArchISD::VBSLL : LoongArchISD::VSLLI)
570+
: (ByteShift ? LoongArchISD::VBSRL : LoongArchISD::VSRLI);
571+
int ShiftAmt = Shift * ScalarSizeInBits / (ByteShift ? 8 : 1);
572+
573+
// Normalize the scale for byte shifts to still produce an i64 element
574+
// type.
575+
Scale = ByteShift ? Scale / 2 : Scale;
576+
577+
// We need to round trip through the appropriate type for the shift.
578+
MVT ShiftSVT = MVT::getIntegerVT(ScalarSizeInBits * Scale);
579+
ShiftVT = ByteShift ? MVT::getVectorVT(MVT::i8, SizeInBits / 8)
580+
: MVT::getVectorVT(ShiftSVT, Size / Scale);
581+
return (int)ShiftAmt;
582+
};
583+
584+
unsigned MaxWidth = 128;
585+
for (int Scale = 2; Scale * ScalarSizeInBits <= MaxWidth; Scale *= 2)
586+
for (int Shift = 1; Shift != Scale; ++Shift)
587+
for (bool Left : {true, false})
588+
if (CheckZeros(Shift, Scale, Left)) {
589+
int ShiftAmt = MatchShift(Shift, Scale, Left);
590+
if (0 < ShiftAmt)
591+
return ShiftAmt;
592+
}
593+
594+
// no match
595+
return -1;
596+
}
597+
598+
/// Lower VECTOR_SHUFFLE as shift (if possible).
599+
///
600+
/// For example:
601+
/// %2 = shufflevector <4 x i32> %0, <4 x i32> zeroinitializer,
602+
/// <4 x i32> <i32 4, i32 0, i32 1, i32 2>
603+
/// is lowered to:
604+
/// (VBSLL_V $v0, $v0, 4)
605+
///
606+
/// %2 = shufflevector <4 x i32> %0, <4 x i32> zeroinitializer,
607+
/// <4 x i32> <i32 4, i32 0, i32 4, i32 2>
608+
/// is lowered to:
609+
/// (VSLLI_D $v0, $v0, 32)
610+
static SDValue lowerVECTOR_SHUFFLEAsShift(const SDLoc &DL, ArrayRef<int> Mask,
611+
MVT VT, SDValue V1, SDValue V2,
612+
SelectionDAG &DAG,
613+
const APInt &Zeroable) {
614+
int Size = Mask.size();
615+
assert(Size == (int)VT.getVectorNumElements() && "Unexpected mask size");
616+
617+
MVT ShiftVT;
618+
SDValue V = V1;
619+
unsigned Opcode;
620+
621+
// Try to match shuffle against V1 shift.
622+
int ShiftAmt = matchShuffleAsShift(ShiftVT, Opcode, VT.getScalarSizeInBits(),
623+
Mask, 0, Zeroable);
624+
625+
// If V1 failed, try to match shuffle against V2 shift.
626+
if (ShiftAmt < 0) {
627+
ShiftAmt = matchShuffleAsShift(ShiftVT, Opcode, VT.getScalarSizeInBits(),
628+
Mask, Size, Zeroable);
629+
V = V2;
630+
}
631+
632+
if (ShiftAmt < 0)
633+
return SDValue();
634+
635+
assert(DAG.getTargetLoweringInfo().isTypeLegal(ShiftVT) &&
636+
"Illegal integer vector type");
637+
V = DAG.getBitcast(ShiftVT, V);
638+
V = DAG.getNode(Opcode, DL, ShiftVT, V,
639+
DAG.getConstant(ShiftAmt, DL, MVT::i64));
640+
return DAG.getBitcast(VT, V);
641+
}
642+
528643
/// Determine whether a range fits a regular pattern of values.
529644
/// This function accounts for the possibility of jumping over the End iterator.
530645
template <typename ValType>
@@ -593,14 +708,12 @@ static void computeZeroableShuffleElements(ArrayRef<int> Mask, SDValue V1,
593708
static SDValue lowerVECTOR_SHUFFLEAsZeroOrAnyExtend(const SDLoc &DL,
594709
ArrayRef<int> Mask, MVT VT,
595710
SDValue V1, SDValue V2,
596-
SelectionDAG &DAG) {
711+
SelectionDAG &DAG,
712+
const APInt &Zeroable) {
597713
int Bits = VT.getSizeInBits();
598714
int EltBits = VT.getScalarSizeInBits();
599715
int NumElements = VT.getVectorNumElements();
600716

601-
APInt KnownUndef, KnownZero;
602-
computeZeroableShuffleElements(Mask, V1, V2, KnownUndef, KnownZero);
603-
APInt Zeroable = KnownUndef | KnownZero;
604717
if (Zeroable.isAllOnes())
605718
return DAG.getConstant(0, DL, VT);
606719

@@ -1062,6 +1175,10 @@ static SDValue lower128BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
10621175
"Unexpected mask size for shuffle!");
10631176
assert(Mask.size() % 2 == 0 && "Expected even mask size.");
10641177

1178+
APInt KnownUndef, KnownZero;
1179+
computeZeroableShuffleElements(Mask, V1, V2, KnownUndef, KnownZero);
1180+
APInt Zeroable = KnownUndef | KnownZero;
1181+
10651182
SDValue Result;
10661183
// TODO: Add more comparison patterns.
10671184
if (V2.isUndef()) {
@@ -1089,12 +1206,14 @@ static SDValue lower128BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
10891206
return Result;
10901207
if ((Result = lowerVECTOR_SHUFFLE_VPICKOD(DL, Mask, VT, V1, V2, DAG)))
10911208
return Result;
1209+
if ((Result = lowerVECTOR_SHUFFLEAsZeroOrAnyExtend(DL, Mask, VT, V1, V2, DAG,
1210+
Zeroable)))
1211+
return Result;
10921212
if ((Result =
1093-
lowerVECTOR_SHUFFLEAsZeroOrAnyExtend(DL, Mask, VT, V1, V2, DAG)))
1213+
lowerVECTOR_SHUFFLEAsShift(DL, Mask, VT, V1, V2, DAG, Zeroable)))
10941214
return Result;
10951215
if ((Result = lowerVECTOR_SHUFFLE_VSHUF(DL, Mask, VT, V1, V2, DAG)))
10961216
return Result;
1097-
10981217
return SDValue();
10991218
}
11001219

@@ -1495,6 +1614,10 @@ static SDValue lower256BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
14951614
SmallVector<int> NewMask(Mask);
14961615
canonicalizeShuffleVectorByLane(DL, NewMask, VT, V1, V2, DAG);
14971616

1617+
APInt KnownUndef, KnownZero;
1618+
computeZeroableShuffleElements(NewMask, V1, V2, KnownUndef, KnownZero);
1619+
APInt Zeroable = KnownUndef | KnownZero;
1620+
14981621
SDValue Result;
14991622
// TODO: Add more comparison patterns.
15001623
if (V2.isUndef()) {
@@ -1522,6 +1645,9 @@ static SDValue lower256BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
15221645
return Result;
15231646
if ((Result = lowerVECTOR_SHUFFLE_XVPICKOD(DL, NewMask, VT, V1, V2, DAG)))
15241647
return Result;
1648+
if ((Result =
1649+
lowerVECTOR_SHUFFLEAsShift(DL, NewMask, VT, V1, V2, DAG, Zeroable)))
1650+
return Result;
15251651
if ((Result = lowerVECTOR_SHUFFLE_XVSHUF(DL, NewMask, VT, V1, V2, DAG)))
15261652
return Result;
15271653

@@ -5041,6 +5167,10 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
50415167
NODE_NAME_CASE(VANY_NONZERO)
50425168
NODE_NAME_CASE(FRECIPE)
50435169
NODE_NAME_CASE(FRSQRTE)
5170+
NODE_NAME_CASE(VSLLI)
5171+
NODE_NAME_CASE(VSRLI)
5172+
NODE_NAME_CASE(VBSLL)
5173+
NODE_NAME_CASE(VBSRL)
50445174
}
50455175
#undef NODE_NAME_CASE
50465176
return nullptr;

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,15 @@ enum NodeType : unsigned {
147147

148148
// Floating point approximate reciprocal operation
149149
FRECIPE,
150-
FRSQRTE
150+
FRSQRTE,
151+
152+
// Vector logicial left / right shift by immediate
153+
VSLLI,
154+
VSRLI,
155+
156+
// Vector byte logicial left / right shift
157+
VBSLL,
158+
VBSRL
151159

152160
// Intrinsic operations end =============================================
153161
};

llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,7 +1187,7 @@ multiclass PatShiftXrXr<SDPatternOperator OpNode, string Inst> {
11871187
(!cast<LAInst>(Inst#"_D") LASX256:$xj, LASX256:$xk)>;
11881188
}
11891189

1190-
multiclass PatShiftXrUimm<SDPatternOperator OpNode, string Inst> {
1190+
multiclass PatShiftXrSplatUimm<SDPatternOperator OpNode, string Inst> {
11911191
def : Pat<(OpNode (v32i8 LASX256:$xj), (v32i8 (SplatPat_uimm3 uimm3:$imm))),
11921192
(!cast<LAInst>(Inst#"_B") LASX256:$xj, uimm3:$imm)>;
11931193
def : Pat<(OpNode (v16i16 LASX256:$xj), (v16i16 (SplatPat_uimm4 uimm4:$imm))),
@@ -1198,6 +1198,17 @@ multiclass PatShiftXrUimm<SDPatternOperator OpNode, string Inst> {
11981198
(!cast<LAInst>(Inst#"_D") LASX256:$xj, uimm6:$imm)>;
11991199
}
12001200

1201+
multiclass PatShiftXrUimm<SDPatternOperator OpNode, string Inst> {
1202+
def : Pat<(OpNode(v32i8 LASX256:$vj), uimm3:$imm),
1203+
(!cast<LAInst>(Inst#"_B") LASX256:$vj, uimm3:$imm)>;
1204+
def : Pat<(OpNode(v16i16 LASX256:$vj), uimm4:$imm),
1205+
(!cast<LAInst>(Inst#"_H") LASX256:$vj, uimm4:$imm)>;
1206+
def : Pat<(OpNode(v8i32 LASX256:$vj), uimm5:$imm),
1207+
(!cast<LAInst>(Inst#"_W") LASX256:$vj, uimm5:$imm)>;
1208+
def : Pat<(OpNode(v4i64 LASX256:$vj), uimm6:$imm),
1209+
(!cast<LAInst>(Inst#"_D") LASX256:$vj, uimm6:$imm)>;
1210+
}
1211+
12011212
multiclass PatCCXrSimm5<CondCode CC, string Inst> {
12021213
def : Pat<(v32i8 (setcc (v32i8 LASX256:$xj),
12031214
(v32i8 (SplatPat_simm5 simm5:$imm)), CC)),
@@ -1335,20 +1346,32 @@ def : Pat<(or (v32i8 LASX256:$xj), (v32i8 (SplatPat_uimm8 uimm8:$imm))),
13351346
def : Pat<(xor (v32i8 LASX256:$xj), (v32i8 (SplatPat_uimm8 uimm8:$imm))),
13361347
(XVXORI_B LASX256:$xj, uimm8:$imm)>;
13371348

1349+
// XVBSLL_V
1350+
foreach vt = [v32i8, v16i16, v8i32, v4i64, v8f32,
1351+
v4f64] in def : Pat<(loongarch_vbsll(vt LASX256:$xj), uimm5:$imm),
1352+
(XVBSLL_V LASX256:$xj, uimm5:$imm)>;
1353+
1354+
// XVBSRL_V
1355+
foreach vt = [v32i8, v16i16, v8i32, v4i64, v8f32,
1356+
v4f64] in def : Pat<(loongarch_vbsrl(vt LASX256:$xj), uimm5:$imm),
1357+
(XVBSRL_V LASX256:$xj, uimm5:$imm)>;
1358+
13381359
// XVSLL[I]_{B/H/W/D}
13391360
defm : PatXrXr<shl, "XVSLL">;
13401361
defm : PatShiftXrXr<shl, "XVSLL">;
1341-
defm : PatShiftXrUimm<shl, "XVSLLI">;
1362+
defm : PatShiftXrSplatUimm<shl, "XVSLLI">;
1363+
defm : PatShiftXrUimm<loongarch_vslli, "XVSLLI">;
13421364

13431365
// XVSRL[I]_{B/H/W/D}
13441366
defm : PatXrXr<srl, "XVSRL">;
13451367
defm : PatShiftXrXr<srl, "XVSRL">;
1346-
defm : PatShiftXrUimm<srl, "XVSRLI">;
1368+
defm : PatShiftXrSplatUimm<srl, "XVSRLI">;
1369+
defm : PatShiftXrUimm<loongarch_vsrli, "XVSRLI">;
13471370

13481371
// XVSRA[I]_{B/H/W/D}
13491372
defm : PatXrXr<sra, "XVSRA">;
13501373
defm : PatShiftXrXr<sra, "XVSRA">;
1351-
defm : PatShiftXrUimm<sra, "XVSRAI">;
1374+
defm : PatShiftXrSplatUimm<sra, "XVSRAI">;
13521375

13531376
// XVCLZ_{B/H/W/D}
13541377
defm : PatXr<ctlz, "XVCLZ">;

llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ def loongarch_vreplgr2vr: SDNode<"LoongArchISD::VREPLGR2VR", SDT_LoongArchVreplg
5858
def loongarch_vfrecipe: SDNode<"LoongArchISD::FRECIPE", SDT_LoongArchVFRECIPE>;
5959
def loongarch_vfrsqrte: SDNode<"LoongArchISD::FRSQRTE", SDT_LoongArchVFRSQRTE>;
6060

61+
def loongarch_vslli : SDNode<"LoongArchISD::VSLLI", SDT_LoongArchV1RUimm>;
62+
def loongarch_vsrli : SDNode<"LoongArchISD::VSRLI", SDT_LoongArchV1RUimm>;
63+
64+
def loongarch_vbsll : SDNode<"LoongArchISD::VBSLL", SDT_LoongArchV1RUimm>;
65+
def loongarch_vbsrl : SDNode<"LoongArchISD::VBSRL", SDT_LoongArchV1RUimm>;
66+
6167
def immZExt1 : ImmLeaf<i64, [{return isUInt<1>(Imm);}]>;
6268
def immZExt2 : ImmLeaf<i64, [{return isUInt<2>(Imm);}]>;
6369
def immZExt3 : ImmLeaf<i64, [{return isUInt<3>(Imm);}]>;
@@ -1346,7 +1352,7 @@ multiclass PatShiftVrVr<SDPatternOperator OpNode, string Inst> {
13461352
(!cast<LAInst>(Inst#"_D") LSX128:$vj, LSX128:$vk)>;
13471353
}
13481354

1349-
multiclass PatShiftVrUimm<SDPatternOperator OpNode, string Inst> {
1355+
multiclass PatShiftVrSplatUimm<SDPatternOperator OpNode, string Inst> {
13501356
def : Pat<(OpNode (v16i8 LSX128:$vj), (v16i8 (SplatPat_uimm3 uimm3:$imm))),
13511357
(!cast<LAInst>(Inst#"_B") LSX128:$vj, uimm3:$imm)>;
13521358
def : Pat<(OpNode (v8i16 LSX128:$vj), (v8i16 (SplatPat_uimm4 uimm4:$imm))),
@@ -1357,6 +1363,17 @@ multiclass PatShiftVrUimm<SDPatternOperator OpNode, string Inst> {
13571363
(!cast<LAInst>(Inst#"_D") LSX128:$vj, uimm6:$imm)>;
13581364
}
13591365

1366+
multiclass PatShiftVrUimm<SDPatternOperator OpNode, string Inst> {
1367+
def : Pat<(OpNode(v16i8 LSX128:$vj), uimm3:$imm),
1368+
(!cast<LAInst>(Inst#"_B") LSX128:$vj, uimm3:$imm)>;
1369+
def : Pat<(OpNode(v8i16 LSX128:$vj), uimm4:$imm),
1370+
(!cast<LAInst>(Inst#"_H") LSX128:$vj, uimm4:$imm)>;
1371+
def : Pat<(OpNode(v4i32 LSX128:$vj), uimm5:$imm),
1372+
(!cast<LAInst>(Inst#"_W") LSX128:$vj, uimm5:$imm)>;
1373+
def : Pat<(OpNode(v2i64 LSX128:$vj), uimm6:$imm),
1374+
(!cast<LAInst>(Inst#"_D") LSX128:$vj, uimm6:$imm)>;
1375+
}
1376+
13601377
multiclass PatCCVrSimm5<CondCode CC, string Inst> {
13611378
def : Pat<(v16i8 (setcc (v16i8 LSX128:$vj),
13621379
(v16i8 (SplatPat_simm5 simm5:$imm)), CC)),
@@ -1494,20 +1511,32 @@ def : Pat<(or (v16i8 LSX128:$vj), (v16i8 (SplatPat_uimm8 uimm8:$imm))),
14941511
def : Pat<(xor (v16i8 LSX128:$vj), (v16i8 (SplatPat_uimm8 uimm8:$imm))),
14951512
(VXORI_B LSX128:$vj, uimm8:$imm)>;
14961513

1514+
// VBSLL_V
1515+
foreach vt = [v16i8, v8i16, v4i32, v2i64, v4f32,
1516+
v2f64] in def : Pat<(loongarch_vbsll(vt LSX128:$vj), uimm5:$imm),
1517+
(VBSLL_V LSX128:$vj, uimm5:$imm)>;
1518+
1519+
// VBSRL_V
1520+
foreach vt = [v16i8, v8i16, v4i32, v2i64, v4f32,
1521+
v2f64] in def : Pat<(loongarch_vbsrl(vt LSX128:$vj), uimm5:$imm),
1522+
(VBSRL_V LSX128:$vj, uimm5:$imm)>;
1523+
14971524
// VSLL[I]_{B/H/W/D}
14981525
defm : PatVrVr<shl, "VSLL">;
14991526
defm : PatShiftVrVr<shl, "VSLL">;
1500-
defm : PatShiftVrUimm<shl, "VSLLI">;
1527+
defm : PatShiftVrSplatUimm<shl, "VSLLI">;
1528+
defm : PatShiftVrUimm<loongarch_vslli, "VSLLI">;
15011529

15021530
// VSRL[I]_{B/H/W/D}
15031531
defm : PatVrVr<srl, "VSRL">;
15041532
defm : PatShiftVrVr<srl, "VSRL">;
1505-
defm : PatShiftVrUimm<srl, "VSRLI">;
1533+
defm : PatShiftVrSplatUimm<srl, "VSRLI">;
1534+
defm : PatShiftVrUimm<loongarch_vsrli, "VSRLI">;
15061535

15071536
// VSRA[I]_{B/H/W/D}
15081537
defm : PatVrVr<sra, "VSRA">;
15091538
defm : PatShiftVrVr<sra, "VSRA">;
1510-
defm : PatShiftVrUimm<sra, "VSRAI">;
1539+
defm : PatShiftVrSplatUimm<sra, "VSRAI">;
15111540

15121541
// VCLZ_{B/H/W/D}
15131542
defm : PatVr<ctlz, "VCLZ">;

0 commit comments

Comments
 (0)