Skip to content

Commit ac0a888

Browse files
preamesAnkur-0429
authored andcommitted
[RISCV] Use early return to simplify VLA shuffle lowering [nfc]
1 parent 7bca358 commit ac0a888

File tree

1 file changed

+50
-49
lines changed

1 file changed

+50
-49
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6053,23 +6053,30 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
60536053
SDValue LHSIndices = DAG.getBuildVector(IndexVT, DL, GatherIndicesLHS);
60546054
LHSIndices =
60556055
convertToScalableVector(IndexContainerVT, LHSIndices, DAG, Subtarget);
6056+
// At m1 and less, there's no point trying any of the high LMUL splitting
6057+
// techniques. TODO: Should we reconsider this for DLEN < VLEN?
6058+
if (NumElts <= MinVLMAX) {
6059+
SDValue Gather = DAG.getNode(GatherVVOpc, DL, ContainerVT, V1, LHSIndices,
6060+
DAG.getUNDEF(ContainerVT), TrueMask, VL);
6061+
return convertFromScalableVector(VT, Gather, DAG, Subtarget);
6062+
}
60566063

6057-
SDValue Gather;
6058-
if (NumElts > MinVLMAX && isLocalRepeatingShuffle(Mask, MinVLMAX)) {
6059-
// If we have a locally repeating mask, then we can reuse the first
6060-
// register in the index register group for all registers within the
6061-
// source register group. TODO: This generalizes to m2, and m4.
6062-
const MVT M1VT = getLMUL1VT(ContainerVT);
6063-
EVT SubIndexVT = M1VT.changeVectorElementType(IndexVT.getScalarType());
6064+
const MVT M1VT = getLMUL1VT(ContainerVT);
6065+
EVT SubIndexVT = M1VT.changeVectorElementType(IndexVT.getScalarType());
6066+
auto [InnerTrueMask, InnerVL] =
6067+
getDefaultScalableVLOps(M1VT, DL, DAG, Subtarget);
6068+
int N =
6069+
ContainerVT.getVectorMinNumElements() / M1VT.getVectorMinNumElements();
6070+
assert(isPowerOf2_32(N) && N <= 8);
6071+
6072+
// If we have a locally repeating mask, then we can reuse the first
6073+
// register in the index register group for all registers within the
6074+
// source register group. TODO: This generalizes to m2, and m4.
6075+
if (isLocalRepeatingShuffle(Mask, MinVLMAX)) {
60646076
SDValue SubIndex =
60656077
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubIndexVT, LHSIndices,
60666078
DAG.getVectorIdxConstant(0, DL));
6067-
auto [InnerTrueMask, InnerVL] =
6068-
getDefaultScalableVLOps(M1VT, DL, DAG, Subtarget);
6069-
int N = ContainerVT.getVectorMinNumElements() /
6070-
M1VT.getVectorMinNumElements();
6071-
assert(isPowerOf2_32(N) && N <= 8);
6072-
Gather = DAG.getUNDEF(ContainerVT);
6079+
SDValue Gather = DAG.getUNDEF(ContainerVT);
60736080
for (int i = 0; i < N; i++) {
60746081
SDValue SubIdx =
60756082
DAG.getVectorIdxConstant(M1VT.getVectorMinNumElements() * i, DL);
@@ -6081,54 +6088,45 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
60816088
Gather = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, Gather,
60826089
SubVec, SubIdx);
60836090
}
6084-
} else if (NumElts > MinVLMAX && isLowSourceShuffle(Mask, MinVLMAX) &&
6085-
isSpanSplatShuffle(Mask, MinVLMAX)) {
6086-
// If we have a shuffle which only uses the first register in our source
6087-
// register group, and repeats the same index across all spans, we can
6088-
// use a single vrgather (and possibly some register moves).
6089-
// TODO: This can be generalized for m2 or m4, or for any shuffle for
6090-
// which we can do a linear number of shuffles to form an m1 which
6091-
// contains all the output elements.
6092-
const MVT M1VT = getLMUL1VT(ContainerVT);
6093-
EVT SubIndexVT = M1VT.changeVectorElementType(IndexVT.getScalarType());
6094-
auto [InnerTrueMask, InnerVL] =
6095-
getDefaultScalableVLOps(M1VT, DL, DAG, Subtarget);
6096-
int N = ContainerVT.getVectorMinNumElements() /
6097-
M1VT.getVectorMinNumElements();
6098-
assert(isPowerOf2_32(N) && N <= 8);
6091+
return convertFromScalableVector(VT, Gather, DAG, Subtarget);
6092+
}
6093+
6094+
// If we have a shuffle which only uses the first register in our source
6095+
// register group, and repeats the same index across all spans, we can
6096+
// use a single vrgather (and possibly some register moves).
6097+
// TODO: This can be generalized for m2 or m4, or for any shuffle for
6098+
// which we can do a linear number of shuffles to form an m1 which
6099+
// contains all the output elements.
6100+
if (isLowSourceShuffle(Mask, MinVLMAX) &&
6101+
isSpanSplatShuffle(Mask, MinVLMAX)) {
60996102
SDValue SubV1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, M1VT, V1,
61006103
DAG.getVectorIdxConstant(0, DL));
61016104
SDValue SubIndex =
61026105
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubIndexVT, LHSIndices,
61036106
DAG.getVectorIdxConstant(0, DL));
61046107
SDValue SubVec = DAG.getNode(GatherVVOpc, DL, M1VT, SubV1, SubIndex,
61056108
DAG.getUNDEF(M1VT), InnerTrueMask, InnerVL);
6106-
Gather = DAG.getUNDEF(ContainerVT);
6109+
SDValue Gather = DAG.getUNDEF(ContainerVT);
61076110
for (int i = 0; i < N; i++) {
61086111
SDValue SubIdx =
61096112
DAG.getVectorIdxConstant(M1VT.getVectorMinNumElements() * i, DL);
61106113
Gather = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, Gather,
61116114
SubVec, SubIdx);
61126115
}
6113-
} else if (NumElts > MinVLMAX && isLowSourceShuffle(Mask, MinVLMAX)) {
6114-
// If we have a shuffle which only uses the first register in our
6115-
// source register group, we can do a linear number of m1 vrgathers
6116-
// reusing the same source register (but with different indices)
6117-
// TODO: This can be generalized for m2 or m4, or for any shuffle
6118-
// for which we can do a vslidedown followed by this expansion.
6119-
const MVT M1VT = getLMUL1VT(ContainerVT);
6120-
EVT SubIndexVT = M1VT.changeVectorElementType(IndexVT.getScalarType());
6121-
auto [InnerTrueMask, InnerVL] =
6122-
getDefaultScalableVLOps(M1VT, DL, DAG, Subtarget);
6123-
int N = ContainerVT.getVectorMinNumElements() /
6124-
M1VT.getVectorMinNumElements();
6125-
assert(isPowerOf2_32(N) && N <= 8);
6126-
Gather = DAG.getUNDEF(ContainerVT);
6116+
return convertFromScalableVector(VT, Gather, DAG, Subtarget);
6117+
}
6118+
6119+
// If we have a shuffle which only uses the first register in our
6120+
// source register group, we can do a linear number of m1 vrgathers
6121+
// reusing the same source register (but with different indices)
6122+
// TODO: This can be generalized for m2 or m4, or for any shuffle
6123+
// for which we can do a vslidedown followed by this expansion.
6124+
if (isLowSourceShuffle(Mask, MinVLMAX)) {
61276125
SDValue SlideAmt =
61286126
DAG.getElementCount(DL, XLenVT, M1VT.getVectorElementCount());
6129-
SDValue SubV1 =
6130-
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, M1VT, V1,
6131-
DAG.getVectorIdxConstant(0, DL));
6127+
SDValue SubV1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, M1VT, V1,
6128+
DAG.getVectorIdxConstant(0, DL));
6129+
SDValue Gather = DAG.getUNDEF(ContainerVT);
61326130
for (int i = 0; i < N; i++) {
61336131
if (i != 0)
61346132
LHSIndices = getVSlidedown(DAG, Subtarget, DL, IndexContainerVT,
@@ -6145,10 +6143,13 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
61456143
Gather = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, Gather,
61466144
SubVec, SubIdx);
61476145
}
6148-
} else {
6149-
Gather = DAG.getNode(GatherVVOpc, DL, ContainerVT, V1, LHSIndices,
6150-
DAG.getUNDEF(ContainerVT), TrueMask, VL);
6146+
return convertFromScalableVector(VT, Gather, DAG, Subtarget);
61516147
}
6148+
6149+
// Fallback to generic vrgather if we can't find anything better.
6150+
// On many machines, this will be O(LMUL^2)
6151+
SDValue Gather = DAG.getNode(GatherVVOpc, DL, ContainerVT, V1, LHSIndices,
6152+
DAG.getUNDEF(ContainerVT), TrueMask, VL);
61526153
return convertFromScalableVector(VT, Gather, DAG, Subtarget);
61536154
}
61546155

0 commit comments

Comments
 (0)