Skip to content

Commit 1a8d72d

Browse files
committed
lower vector shuffle to shift
1 parent 41d718b commit 1a8d72d

File tree

6 files changed

+263
-333
lines changed

6 files changed

+263
-333
lines changed

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 129 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,121 @@ SDValue LoongArchTargetLowering::lowerBITREVERSE(SDValue Op,
525525
}
526526
}
527527

528+
/// Attempts to match a shuffle mask against the VBSLL, VBSRL, VSLLI and VSRLI
529+
/// instruction.
530+
// The funciton matches elements form one of the input vector shuffled to the
531+
// left or right with zeroable elements 'shifted in'. It handles both the
532+
// strictly bit-wise element shifts and the byte shfit across an entire 128-bit
533+
// lane.
534+
// This is mainly copy from X86.
535+
static int matchShuffleAsShift(MVT &ShiftVT, unsigned &Opcode,
536+
unsigned ScalarSizeInBits, ArrayRef<int> Mask,
537+
int MaskOffset, const APInt &Zeroable) {
538+
int Size = Mask.size();
539+
unsigned SizeInBits = Size * ScalarSizeInBits;
540+
541+
auto CheckZeros = [&](int Shift, int Scale, bool Left) {
542+
for (int i = 0; i < Size; i += Scale)
543+
for (int j = 0; j < Shift; ++j)
544+
if (!Zeroable[i + j + (Left ? 0 : (Scale - Shift))])
545+
return false;
546+
547+
return true;
548+
};
549+
550+
auto isSequentialOrUndefInRange = [&](unsigned Pos, unsigned Size, int Low,
551+
int Step = 1) {
552+
for (unsigned i = Pos, e = Pos + Size; i != e; ++i, Low += Step)
553+
if (!(Mask[i] == -1 || Mask[i] == Low))
554+
return false;
555+
return true;
556+
};
557+
558+
auto MatchShift = [&](int Shift, int Scale, bool Left) {
559+
for (int i = 0; i != Size; i += Scale) {
560+
unsigned Pos = Left ? i + Shift : i;
561+
unsigned Low = Left ? i : i + Shift;
562+
unsigned Len = Scale - Shift;
563+
if (!isSequentialOrUndefInRange(Pos, Len, Low + MaskOffset))
564+
return -1;
565+
}
566+
567+
int ShiftEltBits = ScalarSizeInBits * Scale;
568+
bool ByteShift = ShiftEltBits > 64;
569+
Opcode = Left ? (ByteShift ? LoongArchISD::VBSLL : LoongArchISD::VSLLI)
570+
: (ByteShift ? LoongArchISD::VBSRL : LoongArchISD::VSRLI);
571+
int ShiftAmt = Shift * ScalarSizeInBits / (ByteShift ? 8 : 1);
572+
573+
// Normalize the scale for byte shifts to still produce an i64 element
574+
// type.
575+
Scale = ByteShift ? Scale / 2 : Scale;
576+
577+
// We need to round trip through the appropriate type for the shift.
578+
MVT ShiftSVT = MVT::getIntegerVT(ScalarSizeInBits * Scale);
579+
ShiftVT = ByteShift ? MVT::getVectorVT(MVT::i8, SizeInBits / 8)
580+
: MVT::getVectorVT(ShiftSVT, Size / Scale);
581+
return (int)ShiftAmt;
582+
};
583+
584+
unsigned MaxWidth = 128;
585+
for (int Scale = 2; Scale * ScalarSizeInBits <= MaxWidth; Scale *= 2)
586+
for (int Shift = 1; Shift != Scale; ++Shift)
587+
for (bool Left : {true, false})
588+
if (CheckZeros(Shift, Scale, Left)) {
589+
int ShiftAmt = MatchShift(Shift, Scale, Left);
590+
if (0 < ShiftAmt)
591+
return ShiftAmt;
592+
}
593+
594+
// no match
595+
return -1;
596+
}
597+
598+
/// Lower VECTOR_SHUFFLE as shift (if possible).
599+
///
600+
/// For example:
601+
/// %2 = shufflevector <4 x i32> %0, <4 x i32> zeroinitializer,
602+
/// <4 x i32> <i32 4, i32 0, i32 1, i32 2>
603+
/// is lowered to:
604+
/// (VBSLL_V $v0, $v0, 4)
605+
///
606+
/// %2 = shufflevector <4 x i32> %0, <4 x i32> zeroinitializer,
607+
/// <4 x i32> <i32 4, i32 0, i32 4, i32 2>
608+
/// is lowered to:
609+
/// (VSLLI_D $v0, $v0, 32)
610+
static SDValue lowerVECTOR_SHUFFLEAsShift(const SDLoc &DL, ArrayRef<int> Mask,
611+
MVT VT, SDValue V1, SDValue V2,
612+
SelectionDAG &DAG,
613+
const APInt &Zeroable) {
614+
int Size = Mask.size();
615+
assert(Size == (int)VT.getVectorNumElements() && "Unexpected mask size");
616+
617+
MVT ShiftVT;
618+
SDValue V = V1;
619+
unsigned Opcode;
620+
621+
// Try to match shuffle against V1 shift.
622+
int ShiftAmt = matchShuffleAsShift(ShiftVT, Opcode, VT.getScalarSizeInBits(),
623+
Mask, 0, Zeroable);
624+
625+
// If V1 failed, try to match shuffle against V2 shift.
626+
if (ShiftAmt < 0) {
627+
ShiftAmt = matchShuffleAsShift(ShiftVT, Opcode, VT.getScalarSizeInBits(),
628+
Mask, Size, Zeroable);
629+
V = V2;
630+
}
631+
632+
if (ShiftAmt < 0)
633+
return SDValue();
634+
635+
assert(DAG.getTargetLoweringInfo().isTypeLegal(ShiftVT) &&
636+
"Illegal integer vector type");
637+
V = DAG.getBitcast(ShiftVT, V);
638+
V = DAG.getNode(Opcode, DL, ShiftVT, V,
639+
DAG.getConstant(ShiftAmt, DL, MVT::i64));
640+
return DAG.getBitcast(VT, V);
641+
}
642+
528643
/// Determine whether a range fits a regular pattern of values.
529644
/// This function accounts for the possibility of jumping over the End iterator.
530645
template <typename ValType>
@@ -593,14 +708,12 @@ static void computeZeroableShuffleElements(ArrayRef<int> Mask, SDValue V1,
593708
static SDValue lowerVECTOR_SHUFFLEAsZeroOrAnyExtend(const SDLoc &DL,
594709
ArrayRef<int> Mask, MVT VT,
595710
SDValue V1, SDValue V2,
596-
SelectionDAG &DAG) {
711+
SelectionDAG &DAG,
712+
const APInt &Zeroable) {
597713
int Bits = VT.getSizeInBits();
598714
int EltBits = VT.getScalarSizeInBits();
599715
int NumElements = VT.getVectorNumElements();
600716

601-
APInt KnownUndef, KnownZero;
602-
computeZeroableShuffleElements(Mask, V1, V2, KnownUndef, KnownZero);
603-
APInt Zeroable = KnownUndef | KnownZero;
604717
if (Zeroable.isAllOnes())
605718
return DAG.getConstant(0, DL, VT);
606719

@@ -1062,6 +1175,10 @@ static SDValue lower128BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
10621175
"Unexpected mask size for shuffle!");
10631176
assert(Mask.size() % 2 == 0 && "Expected even mask size.");
10641177

1178+
APInt KnownUndef, KnownZero;
1179+
computeZeroableShuffleElements(Mask, V1, V2, KnownUndef, KnownZero);
1180+
APInt Zeroable = KnownUndef | KnownZero;
1181+
10651182
SDValue Result;
10661183
// TODO: Add more comparison patterns.
10671184
if (V2.isUndef()) {
@@ -1089,12 +1206,14 @@ static SDValue lower128BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
10891206
return Result;
10901207
if ((Result = lowerVECTOR_SHUFFLE_VPICKOD(DL, Mask, VT, V1, V2, DAG)))
10911208
return Result;
1209+
if ((Result = lowerVECTOR_SHUFFLEAsZeroOrAnyExtend(DL, Mask, VT, V1, V2, DAG,
1210+
Zeroable)))
1211+
return Result;
10921212
if ((Result =
1093-
lowerVECTOR_SHUFFLEAsZeroOrAnyExtend(DL, Mask, VT, V1, V2, DAG)))
1213+
lowerVECTOR_SHUFFLEAsShift(DL, Mask, VT, V1, V2, DAG, Zeroable)))
10941214
return Result;
10951215
if ((Result = lowerVECTOR_SHUFFLE_VSHUF(DL, Mask, VT, V1, V2, DAG)))
10961216
return Result;
1097-
10981217
return SDValue();
10991218
}
11001219

@@ -5041,6 +5160,10 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
50415160
NODE_NAME_CASE(VANY_NONZERO)
50425161
NODE_NAME_CASE(FRECIPE)
50435162
NODE_NAME_CASE(FRSQRTE)
5163+
NODE_NAME_CASE(VSLLI)
5164+
NODE_NAME_CASE(VSRLI)
5165+
NODE_NAME_CASE(VBSLL)
5166+
NODE_NAME_CASE(VBSRL)
50445167
}
50455168
#undef NODE_NAME_CASE
50465169
return nullptr;

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,15 @@ enum NodeType : unsigned {
147147

148148
// Floating point approximate reciprocal operation
149149
FRECIPE,
150-
FRSQRTE
150+
FRSQRTE,
151+
152+
// Vector logicial left / right shift by immediate
153+
VSLLI,
154+
VSRLI,
155+
156+
// Vector byte logicial left / right shift
157+
VBSLL,
158+
VBSRL
151159

152160
// Intrinsic operations end =============================================
153161
};

llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ def loongarch_vreplgr2vr: SDNode<"LoongArchISD::VREPLGR2VR", SDT_LoongArchVreplg
5858
def loongarch_vfrecipe: SDNode<"LoongArchISD::FRECIPE", SDT_LoongArchVFRECIPE>;
5959
def loongarch_vfrsqrte: SDNode<"LoongArchISD::FRSQRTE", SDT_LoongArchVFRSQRTE>;
6060

61+
def loongarch_vslli : SDNode<"LoongArchISD::VSLLI", SDT_LoongArchV1RUimm>;
62+
def loongarch_vsrli : SDNode<"LoongArchISD::VSRLI", SDT_LoongArchV1RUimm>;
63+
64+
def loongarch_vbsll : SDNode<"LoongArchISD::VBSLL", SDT_LoongArchV1RUimm>;
65+
def loongarch_vbsrl : SDNode<"LoongArchISD::VBSRL", SDT_LoongArchV1RUimm>;
66+
6167
def immZExt1 : ImmLeaf<i64, [{return isUInt<1>(Imm);}]>;
6268
def immZExt2 : ImmLeaf<i64, [{return isUInt<2>(Imm);}]>;
6369
def immZExt3 : ImmLeaf<i64, [{return isUInt<3>(Imm);}]>;
@@ -1494,15 +1500,59 @@ def : Pat<(or (v16i8 LSX128:$vj), (v16i8 (SplatPat_uimm8 uimm8:$imm))),
14941500
def : Pat<(xor (v16i8 LSX128:$vj), (v16i8 (SplatPat_uimm8 uimm8:$imm))),
14951501
(VXORI_B LSX128:$vj, uimm8:$imm)>;
14961502

1503+
// VBSLL_V
1504+
def : Pat<(loongarch_vbsll v16i8:$vj, uimm5:$imm), (VBSLL_V v16i8:$vj,
1505+
uimm5:$imm)>;
1506+
def : Pat<(loongarch_vbsll v8i16:$vj, uimm5:$imm), (VBSLL_V v8i16:$vj,
1507+
uimm5:$imm)>;
1508+
def : Pat<(loongarch_vbsll v4i32:$vj, uimm5:$imm), (VBSLL_V v4i32:$vj,
1509+
uimm5:$imm)>;
1510+
def : Pat<(loongarch_vbsll v2i64:$vj, uimm5:$imm), (VBSLL_V v2i64:$vj,
1511+
uimm5:$imm)>;
1512+
def : Pat<(loongarch_vbsll v4f32:$vj, uimm5:$imm), (VBSLL_V v4f32:$vj,
1513+
uimm5:$imm)>;
1514+
def : Pat<(loongarch_vbsll v2f64:$vj, uimm5:$imm), (VBSLL_V v2f64:$vj,
1515+
uimm5:$imm)>;
1516+
1517+
// VBSRL_V
1518+
def : Pat<(loongarch_vbsrl v16i8:$vj, uimm5:$imm), (VBSRL_V v16i8:$vj,
1519+
uimm5:$imm)>;
1520+
def : Pat<(loongarch_vbsrl v8i16:$vj, uimm5:$imm), (VBSRL_V v8i16:$vj,
1521+
uimm5:$imm)>;
1522+
def : Pat<(loongarch_vbsrl v4i32:$vj, uimm5:$imm), (VBSRL_V v4i32:$vj,
1523+
uimm5:$imm)>;
1524+
def : Pat<(loongarch_vbsrl v2i64:$vj, uimm5:$imm), (VBSRL_V v2i64:$vj,
1525+
uimm5:$imm)>;
1526+
def : Pat<(loongarch_vbsrl v4f32:$vj, uimm5:$imm), (VBSRL_V v4f32:$vj,
1527+
uimm5:$imm)>;
1528+
def : Pat<(loongarch_vbsrl v2f64:$vj, uimm5:$imm), (VBSRL_V v2f64:$vj,
1529+
uimm5:$imm)>;
1530+
14971531
// VSLL[I]_{B/H/W/D}
14981532
defm : PatVrVr<shl, "VSLL">;
14991533
defm : PatShiftVrVr<shl, "VSLL">;
15001534
defm : PatShiftVrUimm<shl, "VSLLI">;
1535+
def : Pat<(loongarch_vslli v16i8:$vj, uimm3:$imm), (VSLLI_B v16i8:$vj,
1536+
uimm3:$imm)>;
1537+
def : Pat<(loongarch_vslli v8i16:$vj, uimm4:$imm), (VSLLI_H v8i16:$vj,
1538+
uimm4:$imm)>;
1539+
def : Pat<(loongarch_vslli v4i32:$vj, uimm5:$imm), (VSLLI_W v4i32:$vj,
1540+
uimm5:$imm)>;
1541+
def : Pat<(loongarch_vslli v2i64:$vj, uimm6:$imm), (VSLLI_D v2i64:$vj,
1542+
uimm6:$imm)>;
15011543

15021544
// VSRL[I]_{B/H/W/D}
15031545
defm : PatVrVr<srl, "VSRL">;
15041546
defm : PatShiftVrVr<srl, "VSRL">;
15051547
defm : PatShiftVrUimm<srl, "VSRLI">;
1548+
def : Pat<(loongarch_vsrli v16i8:$vj, uimm3:$imm), (VSRLI_B v16i8:$vj,
1549+
uimm3:$imm)>;
1550+
def : Pat<(loongarch_vsrli v8i16:$vj, uimm4:$imm), (VSRLI_H v8i16:$vj,
1551+
uimm4:$imm)>;
1552+
def : Pat<(loongarch_vsrli v4i32:$vj, uimm5:$imm), (VSRLI_W v4i32:$vj,
1553+
uimm5:$imm)>;
1554+
def : Pat<(loongarch_vsrli v2i64:$vj, uimm6:$imm), (VSRLI_D v2i64:$vj,
1555+
uimm6:$imm)>;
15061556

15071557
// VSRA[I]_{B/H/W/D}
15081558
defm : PatVrVr<sra, "VSRA">;

llvm/test/CodeGen/LoongArch/lsx/build-vector.ll

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -374,11 +374,8 @@ define void @extract1_i32_zext_insert0_i64_undef(ptr %src, ptr %dst) nounwind {
374374
; CHECK-LABEL: extract1_i32_zext_insert0_i64_undef:
375375
; CHECK: # %bb.0:
376376
; CHECK-NEXT: vld $vr0, $a0, 0
377-
; CHECK-NEXT: pcalau12i $a0, %pc_hi20(.LCPI24_0)
378-
; CHECK-NEXT: vld $vr1, $a0, %pc_lo12(.LCPI24_0)
379-
; CHECK-NEXT: vrepli.b $vr2, 0
380-
; CHECK-NEXT: vshuf.w $vr1, $vr2, $vr0
381-
; CHECK-NEXT: vst $vr1, $a1, 0
377+
; CHECK-NEXT: vsrli.d $vr0, $vr0, 32
378+
; CHECK-NEXT: vst $vr0, $a1, 0
382379
; CHECK-NEXT: ret
383380
%v = load volatile <4 x i32>, ptr %src
384381
%e = extractelement <4 x i32> %v, i32 1

0 commit comments

Comments
 (0)