Skip to content

[LoongArch] lower vector shuffle to shift if possible #132866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 136 additions & 6 deletions llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,121 @@ SDValue LoongArchTargetLowering::lowerBITREVERSE(SDValue Op,
}
}

/// Attempts to match a shuffle mask against the VBSLL, VBSRL, VSLLI and VSRLI
/// instruction.
// The funciton matches elements from one of the input vector shuffled to the
// left or right with zeroable elements 'shifted in'. It handles both the
// strictly bit-wise element shifts and the byte shfit across an entire 128-bit
// lane.
// Mostly copied from X86.
static int matchShuffleAsShift(MVT &ShiftVT, unsigned &Opcode,
unsigned ScalarSizeInBits, ArrayRef<int> Mask,
int MaskOffset, const APInt &Zeroable) {
int Size = Mask.size();
unsigned SizeInBits = Size * ScalarSizeInBits;

auto CheckZeros = [&](int Shift, int Scale, bool Left) {
for (int i = 0; i < Size; i += Scale)
for (int j = 0; j < Shift; ++j)
if (!Zeroable[i + j + (Left ? 0 : (Scale - Shift))])
return false;

return true;
};

auto isSequentialOrUndefInRange = [&](unsigned Pos, unsigned Size, int Low,
int Step = 1) {
for (unsigned i = Pos, e = Pos + Size; i != e; ++i, Low += Step)
if (!(Mask[i] == -1 || Mask[i] == Low))
return false;
return true;
};

auto MatchShift = [&](int Shift, int Scale, bool Left) {
for (int i = 0; i != Size; i += Scale) {
unsigned Pos = Left ? i + Shift : i;
unsigned Low = Left ? i : i + Shift;
unsigned Len = Scale - Shift;
if (!isSequentialOrUndefInRange(Pos, Len, Low + MaskOffset))
return -1;
}

int ShiftEltBits = ScalarSizeInBits * Scale;
bool ByteShift = ShiftEltBits > 64;
Opcode = Left ? (ByteShift ? LoongArchISD::VBSLL : LoongArchISD::VSLLI)
: (ByteShift ? LoongArchISD::VBSRL : LoongArchISD::VSRLI);
int ShiftAmt = Shift * ScalarSizeInBits / (ByteShift ? 8 : 1);

// Normalize the scale for byte shifts to still produce an i64 element
// type.
Scale = ByteShift ? Scale / 2 : Scale;

// We need to round trip through the appropriate type for the shift.
MVT ShiftSVT = MVT::getIntegerVT(ScalarSizeInBits * Scale);
ShiftVT = ByteShift ? MVT::getVectorVT(MVT::i8, SizeInBits / 8)
: MVT::getVectorVT(ShiftSVT, Size / Scale);
return (int)ShiftAmt;
};

unsigned MaxWidth = 128;
for (int Scale = 2; Scale * ScalarSizeInBits <= MaxWidth; Scale *= 2)
for (int Shift = 1; Shift != Scale; ++Shift)
for (bool Left : {true, false})
if (CheckZeros(Shift, Scale, Left)) {
int ShiftAmt = MatchShift(Shift, Scale, Left);
if (0 < ShiftAmt)
return ShiftAmt;
}

// no match
return -1;
}

/// Lower VECTOR_SHUFFLE as shift (if possible).
///
/// For example:
/// %2 = shufflevector <4 x i32> %0, <4 x i32> zeroinitializer,
/// <4 x i32> <i32 4, i32 0, i32 1, i32 2>
/// is lowered to:
/// (VBSLL_V $v0, $v0, 4)
///
/// %2 = shufflevector <4 x i32> %0, <4 x i32> zeroinitializer,
/// <4 x i32> <i32 4, i32 0, i32 4, i32 2>
/// is lowered to:
/// (VSLLI_D $v0, $v0, 32)
static SDValue lowerVECTOR_SHUFFLEAsShift(const SDLoc &DL, ArrayRef<int> Mask,
MVT VT, SDValue V1, SDValue V2,
SelectionDAG &DAG,
const APInt &Zeroable) {
int Size = Mask.size();
assert(Size == (int)VT.getVectorNumElements() && "Unexpected mask size");

MVT ShiftVT;
SDValue V = V1;
unsigned Opcode;

// Try to match shuffle against V1 shift.
int ShiftAmt = matchShuffleAsShift(ShiftVT, Opcode, VT.getScalarSizeInBits(),
Mask, 0, Zeroable);

// If V1 failed, try to match shuffle against V2 shift.
if (ShiftAmt < 0) {
ShiftAmt = matchShuffleAsShift(ShiftVT, Opcode, VT.getScalarSizeInBits(),
Mask, Size, Zeroable);
V = V2;
}

if (ShiftAmt < 0)
return SDValue();

assert(DAG.getTargetLoweringInfo().isTypeLegal(ShiftVT) &&
"Illegal integer vector type");
V = DAG.getBitcast(ShiftVT, V);
V = DAG.getNode(Opcode, DL, ShiftVT, V,
DAG.getConstant(ShiftAmt, DL, MVT::i64));
return DAG.getBitcast(VT, V);
}

/// Determine whether a range fits a regular pattern of values.
/// This function accounts for the possibility of jumping over the End iterator.
template <typename ValType>
Expand Down Expand Up @@ -593,14 +708,12 @@ static void computeZeroableShuffleElements(ArrayRef<int> Mask, SDValue V1,
static SDValue lowerVECTOR_SHUFFLEAsZeroOrAnyExtend(const SDLoc &DL,
ArrayRef<int> Mask, MVT VT,
SDValue V1, SDValue V2,
SelectionDAG &DAG) {
SelectionDAG &DAG,
const APInt &Zeroable) {
int Bits = VT.getSizeInBits();
int EltBits = VT.getScalarSizeInBits();
int NumElements = VT.getVectorNumElements();

APInt KnownUndef, KnownZero;
computeZeroableShuffleElements(Mask, V1, V2, KnownUndef, KnownZero);
APInt Zeroable = KnownUndef | KnownZero;
if (Zeroable.isAllOnes())
return DAG.getConstant(0, DL, VT);

Expand Down Expand Up @@ -1062,6 +1175,10 @@ static SDValue lower128BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
"Unexpected mask size for shuffle!");
assert(Mask.size() % 2 == 0 && "Expected even mask size.");

APInt KnownUndef, KnownZero;
computeZeroableShuffleElements(Mask, V1, V2, KnownUndef, KnownZero);
APInt Zeroable = KnownUndef | KnownZero;

SDValue Result;
// TODO: Add more comparison patterns.
if (V2.isUndef()) {
Expand Down Expand Up @@ -1089,12 +1206,14 @@ static SDValue lower128BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
return Result;
if ((Result = lowerVECTOR_SHUFFLE_VPICKOD(DL, Mask, VT, V1, V2, DAG)))
return Result;
if ((Result = lowerVECTOR_SHUFFLEAsZeroOrAnyExtend(DL, Mask, VT, V1, V2, DAG,
Zeroable)))
return Result;
if ((Result =
lowerVECTOR_SHUFFLEAsZeroOrAnyExtend(DL, Mask, VT, V1, V2, DAG)))
lowerVECTOR_SHUFFLEAsShift(DL, Mask, VT, V1, V2, DAG, Zeroable)))
return Result;
if ((Result = lowerVECTOR_SHUFFLE_VSHUF(DL, Mask, VT, V1, V2, DAG)))
return Result;

return SDValue();
}

Expand Down Expand Up @@ -1495,6 +1614,10 @@ static SDValue lower256BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
SmallVector<int> NewMask(Mask);
canonicalizeShuffleVectorByLane(DL, NewMask, VT, V1, V2, DAG);

APInt KnownUndef, KnownZero;
computeZeroableShuffleElements(NewMask, V1, V2, KnownUndef, KnownZero);
APInt Zeroable = KnownUndef | KnownZero;

SDValue Result;
// TODO: Add more comparison patterns.
if (V2.isUndef()) {
Expand Down Expand Up @@ -1522,6 +1645,9 @@ static SDValue lower256BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
return Result;
if ((Result = lowerVECTOR_SHUFFLE_XVPICKOD(DL, NewMask, VT, V1, V2, DAG)))
return Result;
if ((Result =
lowerVECTOR_SHUFFLEAsShift(DL, NewMask, VT, V1, V2, DAG, Zeroable)))
return Result;
if ((Result = lowerVECTOR_SHUFFLE_XVSHUF(DL, NewMask, VT, V1, V2, DAG)))
return Result;

Expand Down Expand Up @@ -5041,6 +5167,10 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VANY_NONZERO)
NODE_NAME_CASE(FRECIPE)
NODE_NAME_CASE(FRSQRTE)
NODE_NAME_CASE(VSLLI)
NODE_NAME_CASE(VSRLI)
NODE_NAME_CASE(VBSLL)
NODE_NAME_CASE(VBSRL)
}
#undef NODE_NAME_CASE
return nullptr;
Expand Down
10 changes: 9 additions & 1 deletion llvm/lib/Target/LoongArch/LoongArchISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,15 @@ enum NodeType : unsigned {

// Floating point approximate reciprocal operation
FRECIPE,
FRSQRTE
FRSQRTE,

// Vector logicial left / right shift by immediate
VSLLI,
VSRLI,

// Vector byte logicial left / right shift
VBSLL,
VBSRL

// Intrinsic operations end =============================================
};
Expand Down
31 changes: 27 additions & 4 deletions llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,7 @@ multiclass PatShiftXrXr<SDPatternOperator OpNode, string Inst> {
(!cast<LAInst>(Inst#"_D") LASX256:$xj, LASX256:$xk)>;
}

multiclass PatShiftXrUimm<SDPatternOperator OpNode, string Inst> {
multiclass PatShiftXrSplatUimm<SDPatternOperator OpNode, string Inst> {
def : Pat<(OpNode (v32i8 LASX256:$xj), (v32i8 (SplatPat_uimm3 uimm3:$imm))),
(!cast<LAInst>(Inst#"_B") LASX256:$xj, uimm3:$imm)>;
def : Pat<(OpNode (v16i16 LASX256:$xj), (v16i16 (SplatPat_uimm4 uimm4:$imm))),
Expand All @@ -1198,6 +1198,17 @@ multiclass PatShiftXrUimm<SDPatternOperator OpNode, string Inst> {
(!cast<LAInst>(Inst#"_D") LASX256:$xj, uimm6:$imm)>;
}

multiclass PatShiftXrUimm<SDPatternOperator OpNode, string Inst> {
def : Pat<(OpNode(v32i8 LASX256:$vj), uimm3:$imm),
(!cast<LAInst>(Inst#"_B") LASX256:$vj, uimm3:$imm)>;
def : Pat<(OpNode(v16i16 LASX256:$vj), uimm4:$imm),
(!cast<LAInst>(Inst#"_H") LASX256:$vj, uimm4:$imm)>;
def : Pat<(OpNode(v8i32 LASX256:$vj), uimm5:$imm),
(!cast<LAInst>(Inst#"_W") LASX256:$vj, uimm5:$imm)>;
def : Pat<(OpNode(v4i64 LASX256:$vj), uimm6:$imm),
(!cast<LAInst>(Inst#"_D") LASX256:$vj, uimm6:$imm)>;
}

multiclass PatCCXrSimm5<CondCode CC, string Inst> {
def : Pat<(v32i8 (setcc (v32i8 LASX256:$xj),
(v32i8 (SplatPat_simm5 simm5:$imm)), CC)),
Expand Down Expand Up @@ -1335,20 +1346,32 @@ def : Pat<(or (v32i8 LASX256:$xj), (v32i8 (SplatPat_uimm8 uimm8:$imm))),
def : Pat<(xor (v32i8 LASX256:$xj), (v32i8 (SplatPat_uimm8 uimm8:$imm))),
(XVXORI_B LASX256:$xj, uimm8:$imm)>;

// XVBSLL_V
foreach vt = [v32i8, v16i16, v8i32, v4i64, v8f32,
v4f64] in def : Pat<(loongarch_vbsll(vt LASX256:$xj), uimm5:$imm),
(XVBSLL_V LASX256:$xj, uimm5:$imm)>;

// XVBSRL_V
foreach vt = [v32i8, v16i16, v8i32, v4i64, v8f32,
v4f64] in def : Pat<(loongarch_vbsrl(vt LASX256:$xj), uimm5:$imm),
(XVBSRL_V LASX256:$xj, uimm5:$imm)>;

// XVSLL[I]_{B/H/W/D}
defm : PatXrXr<shl, "XVSLL">;
defm : PatShiftXrXr<shl, "XVSLL">;
defm : PatShiftXrUimm<shl, "XVSLLI">;
defm : PatShiftXrSplatUimm<shl, "XVSLLI">;
defm : PatShiftXrUimm<loongarch_vslli, "XVSLLI">;

// XVSRL[I]_{B/H/W/D}
defm : PatXrXr<srl, "XVSRL">;
defm : PatShiftXrXr<srl, "XVSRL">;
defm : PatShiftXrUimm<srl, "XVSRLI">;
defm : PatShiftXrSplatUimm<srl, "XVSRLI">;
defm : PatShiftXrUimm<loongarch_vsrli, "XVSRLI">;

// XVSRA[I]_{B/H/W/D}
defm : PatXrXr<sra, "XVSRA">;
defm : PatShiftXrXr<sra, "XVSRA">;
defm : PatShiftXrUimm<sra, "XVSRAI">;
defm : PatShiftXrSplatUimm<sra, "XVSRAI">;

// XVCLZ_{B/H/W/D}
defm : PatXr<ctlz, "XVCLZ">;
Expand Down
37 changes: 33 additions & 4 deletions llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ def loongarch_vreplgr2vr: SDNode<"LoongArchISD::VREPLGR2VR", SDT_LoongArchVreplg
def loongarch_vfrecipe: SDNode<"LoongArchISD::FRECIPE", SDT_LoongArchVFRECIPE>;
def loongarch_vfrsqrte: SDNode<"LoongArchISD::FRSQRTE", SDT_LoongArchVFRSQRTE>;

def loongarch_vslli : SDNode<"LoongArchISD::VSLLI", SDT_LoongArchV1RUimm>;
def loongarch_vsrli : SDNode<"LoongArchISD::VSRLI", SDT_LoongArchV1RUimm>;

def loongarch_vbsll : SDNode<"LoongArchISD::VBSLL", SDT_LoongArchV1RUimm>;
def loongarch_vbsrl : SDNode<"LoongArchISD::VBSRL", SDT_LoongArchV1RUimm>;

def immZExt1 : ImmLeaf<i64, [{return isUInt<1>(Imm);}]>;
def immZExt2 : ImmLeaf<i64, [{return isUInt<2>(Imm);}]>;
def immZExt3 : ImmLeaf<i64, [{return isUInt<3>(Imm);}]>;
Expand Down Expand Up @@ -1346,7 +1352,7 @@ multiclass PatShiftVrVr<SDPatternOperator OpNode, string Inst> {
(!cast<LAInst>(Inst#"_D") LSX128:$vj, LSX128:$vk)>;
}

multiclass PatShiftVrUimm<SDPatternOperator OpNode, string Inst> {
multiclass PatShiftVrSplatUimm<SDPatternOperator OpNode, string Inst> {
def : Pat<(OpNode (v16i8 LSX128:$vj), (v16i8 (SplatPat_uimm3 uimm3:$imm))),
(!cast<LAInst>(Inst#"_B") LSX128:$vj, uimm3:$imm)>;
def : Pat<(OpNode (v8i16 LSX128:$vj), (v8i16 (SplatPat_uimm4 uimm4:$imm))),
Expand All @@ -1357,6 +1363,17 @@ multiclass PatShiftVrUimm<SDPatternOperator OpNode, string Inst> {
(!cast<LAInst>(Inst#"_D") LSX128:$vj, uimm6:$imm)>;
}

multiclass PatShiftVrUimm<SDPatternOperator OpNode, string Inst> {
def : Pat<(OpNode(v16i8 LSX128:$vj), uimm3:$imm),
(!cast<LAInst>(Inst#"_B") LSX128:$vj, uimm3:$imm)>;
def : Pat<(OpNode(v8i16 LSX128:$vj), uimm4:$imm),
(!cast<LAInst>(Inst#"_H") LSX128:$vj, uimm4:$imm)>;
def : Pat<(OpNode(v4i32 LSX128:$vj), uimm5:$imm),
(!cast<LAInst>(Inst#"_W") LSX128:$vj, uimm5:$imm)>;
def : Pat<(OpNode(v2i64 LSX128:$vj), uimm6:$imm),
(!cast<LAInst>(Inst#"_D") LSX128:$vj, uimm6:$imm)>;
}

multiclass PatCCVrSimm5<CondCode CC, string Inst> {
def : Pat<(v16i8 (setcc (v16i8 LSX128:$vj),
(v16i8 (SplatPat_simm5 simm5:$imm)), CC)),
Expand Down Expand Up @@ -1494,20 +1511,32 @@ def : Pat<(or (v16i8 LSX128:$vj), (v16i8 (SplatPat_uimm8 uimm8:$imm))),
def : Pat<(xor (v16i8 LSX128:$vj), (v16i8 (SplatPat_uimm8 uimm8:$imm))),
(VXORI_B LSX128:$vj, uimm8:$imm)>;

// VBSLL_V
foreach vt = [v16i8, v8i16, v4i32, v2i64, v4f32,
v2f64] in def : Pat<(loongarch_vbsll(vt LSX128:$vj), uimm5:$imm),
(VBSLL_V LSX128:$vj, uimm5:$imm)>;

// VBSRL_V
foreach vt = [v16i8, v8i16, v4i32, v2i64, v4f32,
v2f64] in def : Pat<(loongarch_vbsrl(vt LSX128:$vj), uimm5:$imm),
(VBSRL_V LSX128:$vj, uimm5:$imm)>;

// VSLL[I]_{B/H/W/D}
defm : PatVrVr<shl, "VSLL">;
defm : PatShiftVrVr<shl, "VSLL">;
defm : PatShiftVrUimm<shl, "VSLLI">;
defm : PatShiftVrSplatUimm<shl, "VSLLI">;
defm : PatShiftVrUimm<loongarch_vslli, "VSLLI">;

// VSRL[I]_{B/H/W/D}
defm : PatVrVr<srl, "VSRL">;
defm : PatShiftVrVr<srl, "VSRL">;
defm : PatShiftVrUimm<srl, "VSRLI">;
defm : PatShiftVrSplatUimm<srl, "VSRLI">;
defm : PatShiftVrUimm<loongarch_vsrli, "VSRLI">;

// VSRA[I]_{B/H/W/D}
defm : PatVrVr<sra, "VSRA">;
defm : PatShiftVrVr<sra, "VSRA">;
defm : PatShiftVrUimm<sra, "VSRAI">;
defm : PatShiftVrSplatUimm<sra, "VSRAI">;

// VCLZ_{B/H/W/D}
defm : PatVr<ctlz, "VCLZ">;
Expand Down
Loading