Skip to content

Commit 25e4333

Browse files
authored
[RISCV] Lower shuffle which splats a single span (without exact VLEN) (#127108)
If we have a shuffle which repeats the same pattern of elements, all of which come from the first register in the source register group, we can lower this to a single vrgather at m1 to perform the element rearrangement, and reuse that for each register in the result vector register group.
1 parent 625cb5a commit 25e4333

File tree

2 files changed

+70
-16
lines changed

2 files changed

+70
-16
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5352,8 +5352,24 @@ static bool isLocalRepeatingShuffle(ArrayRef<int> Mask, int Span) {
53525352

53535353
/// Is this mask only using elements from the first span of the input?
53545354
static bool isLowSourceShuffle(ArrayRef<int> Mask, int Span) {
5355-
return all_of(Mask,
5356-
[&](const auto &Idx) { return Idx == -1 || Idx < Span; });
5355+
return all_of(Mask, [&](const auto &Idx) { return Idx == -1 || Idx < Span; });
5356+
}
5357+
5358+
/// Return true for a mask which performs an arbitrary shuffle within the first
5359+
/// span, and then repeats that same result across all remaining spans. Note
5360+
/// that this doesn't check if all the inputs come from a single span!
5361+
static bool isSpanSplatShuffle(ArrayRef<int> Mask, int Span) {
5362+
SmallVector<int> LowSpan(Span, -1);
5363+
for (auto [I, M] : enumerate(Mask)) {
5364+
if (M == -1)
5365+
continue;
5366+
int SpanIdx = I % Span;
5367+
if (LowSpan[SpanIdx] == -1)
5368+
LowSpan[SpanIdx] = M;
5369+
if (LowSpan[SpanIdx] != M)
5370+
return false;
5371+
}
5372+
return true;
53575373
}
53585374

53595375
/// Try to widen element type to get a new mask value for a better permutation
@@ -5771,6 +5787,35 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
57715787
Gather = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, Gather,
57725788
SubVec, SubIdx);
57735789
}
5790+
} else if (NumElts > MinVLMAX && isLowSourceShuffle(Mask, MinVLMAX) &&
5791+
isSpanSplatShuffle(Mask, MinVLMAX)) {
5792+
// If we have a shuffle which only uses the first register in our source
5793+
// register group, and repeats the same index across all spans, we can
5794+
// use a single vrgather (and possibly some register moves).
5795+
// TODO: This can be generalized for m2 or m4, or for any shuffle for
5796+
// which we can do a linear number of shuffles to form an m1 which
5797+
// contains all the output elements.
5798+
const MVT M1VT = getLMUL1VT(ContainerVT);
5799+
EVT SubIndexVT = M1VT.changeVectorElementType(IndexVT.getScalarType());
5800+
auto [InnerTrueMask, InnerVL] =
5801+
getDefaultScalableVLOps(M1VT, DL, DAG, Subtarget);
5802+
int N = ContainerVT.getVectorMinNumElements() /
5803+
M1VT.getVectorMinNumElements();
5804+
assert(isPowerOf2_32(N) && N <= 8);
5805+
SDValue SubV1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, M1VT, V1,
5806+
DAG.getVectorIdxConstant(0, DL));
5807+
SDValue SubIndex =
5808+
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubIndexVT, LHSIndices,
5809+
DAG.getVectorIdxConstant(0, DL));
5810+
SDValue SubVec = DAG.getNode(GatherVVOpc, DL, M1VT, SubV1, SubIndex,
5811+
DAG.getUNDEF(M1VT), InnerTrueMask, InnerVL);
5812+
Gather = DAG.getUNDEF(ContainerVT);
5813+
for (int i = 0; i < N; i++) {
5814+
SDValue SubIdx =
5815+
DAG.getVectorIdxConstant(M1VT.getVectorMinNumElements() * i, DL);
5816+
Gather = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, Gather,
5817+
SubVec, SubIdx);
5818+
}
57745819
} else if (NumElts > MinVLMAX && isLowSourceShuffle(Mask, MinVLMAX)) {
57755820
// If we have a shuffle which only uses the first register in our
57765821
// source register group, we can do a linear number of m1 vrgathers

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-shuffles.ll

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,22 +1311,14 @@ define void @shuffle_i128_splat(ptr %p) nounwind {
13111311
; CHECK: # %bb.0:
13121312
; CHECK-NEXT: vsetivli zero, 8, e64, m4, ta, ma
13131313
; CHECK-NEXT: vle64.v v8, (a0)
1314-
; CHECK-NEXT: csrr a1, vlenb
1315-
; CHECK-NEXT: lui a2, 16
1316-
; CHECK-NEXT: srli a1, a1, 3
1314+
; CHECK-NEXT: lui a1, 16
13171315
; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
1318-
; CHECK-NEXT: vmv.v.x v9, a2
1319-
; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma
1320-
; CHECK-NEXT: vslidedown.vx v10, v9, a1
1321-
; CHECK-NEXT: vslidedown.vx v11, v10, a1
1322-
; CHECK-NEXT: vsetvli a2, zero, e64, m1, ta, ma
1323-
; CHECK-NEXT: vrgatherei16.vv v13, v8, v10
1324-
; CHECK-NEXT: vrgatherei16.vv v12, v8, v9
1325-
; CHECK-NEXT: vrgatherei16.vv v14, v8, v11
1326-
; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma
1327-
; CHECK-NEXT: vslidedown.vx v9, v11, a1
1316+
; CHECK-NEXT: vmv.v.x v9, a1
13281317
; CHECK-NEXT: vsetvli a1, zero, e64, m1, ta, ma
1329-
; CHECK-NEXT: vrgatherei16.vv v15, v8, v9
1318+
; CHECK-NEXT: vrgatherei16.vv v12, v8, v9
1319+
; CHECK-NEXT: vmv.v.v v13, v12
1320+
; CHECK-NEXT: vmv.v.v v14, v12
1321+
; CHECK-NEXT: vmv.v.v v15, v12
13301322
; CHECK-NEXT: vsetivli zero, 8, e64, m4, ta, ma
13311323
; CHECK-NEXT: vse64.v v12, (a0)
13321324
; CHECK-NEXT: ret
@@ -1435,3 +1427,20 @@ define <4 x i16> @vmerge_3(<4 x i16> %x) {
14351427
%s = shufflevector <4 x i16> %x, <4 x i16> <i16 poison, i16 5, i16 poison, i16 poison>, <4 x i32> <i32 0, i32 5, i32 5, i32 3>
14361428
ret <4 x i16> %s
14371429
}
1430+
1431+
1432+
define <8 x i64> @shuffle_v8i164_span_splat(<8 x i64> %a) nounwind {
1433+
; CHECK-LABEL: shuffle_v8i164_span_splat:
1434+
; CHECK: # %bb.0:
1435+
; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
1436+
; CHECK-NEXT: vmv.v.i v9, 1
1437+
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
1438+
; CHECK-NEXT: vrgatherei16.vv v12, v8, v9
1439+
; CHECK-NEXT: vmv.v.v v13, v12
1440+
; CHECK-NEXT: vmv.v.v v14, v12
1441+
; CHECK-NEXT: vmv.v.v v15, v12
1442+
; CHECK-NEXT: vmv4r.v v8, v12
1443+
; CHECK-NEXT: ret
1444+
%res = shufflevector <8 x i64> %a, <8 x i64> poison, <8 x i32> <i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0>
1445+
ret <8 x i64> %res
1446+
}

0 commit comments

Comments
 (0)