Skip to content

Commit 5a023f5

Browse files
author
Dinar Temirbulatov
authored
[AArch64][SVE2] Enable dynamic shuffle for fixed length types. (llvm#72490)
When SVE register size is unknown or the minimal size is not equal to the maximum size then we could determine the actual SVE register size in the runtime and adjust shuffle mask in the runtime.
1 parent 4d4af15 commit 5a023f5

File tree

2 files changed

+432
-35
lines changed

2 files changed

+432
-35
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26798,26 +26798,47 @@ static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2,
2679826798

2679926799
// Ignore two operands if no SVE2 or all index numbers couldn't
2680026800
// be represented.
26801-
if (!IsSingleOp && (!Subtarget.hasSVE2() || MinSVESize != MaxSVESize))
26801+
if (!IsSingleOp && !Subtarget.hasSVE2())
2680226802
return SDValue();
2680326803

2680426804
EVT VTOp1 = Op.getOperand(0).getValueType();
2680526805
unsigned BitsPerElt = VTOp1.getVectorElementType().getSizeInBits();
2680626806
unsigned IndexLen = MinSVESize / BitsPerElt;
2680726807
unsigned ElementsPerVectorReg = VTOp1.getVectorNumElements();
2680826808
uint64_t MaxOffset = APInt(BitsPerElt, -1, false).getZExtValue();
26809+
EVT MaskEltType = VTOp1.getVectorElementType().changeTypeToInteger();
26810+
EVT MaskType = EVT::getVectorVT(*DAG.getContext(), MaskEltType, IndexLen);
26811+
bool MinMaxEqual = (MinSVESize == MaxSVESize);
2680926812
assert(ElementsPerVectorReg <= IndexLen && ShuffleMask.size() <= IndexLen &&
2681026813
"Incorrectly legalised shuffle operation");
2681126814

2681226815
SmallVector<SDValue, 8> TBLMask;
26816+
// If MinSVESize is not equal to MaxSVESize then we need to know which
26817+
// TBL mask element needs adjustment.
26818+
SmallVector<SDValue, 8> AddRuntimeVLMask;
26819+
26820+
// Bail out for 8-bits element types, because with 2048-bit SVE register
26821+
// size 8 bits is only sufficient to index into the first source vector.
26822+
if (!IsSingleOp && !MinMaxEqual && BitsPerElt == 8)
26823+
return SDValue();
26824+
2681326825
for (int Index : ShuffleMask) {
2681426826
// Handling poison index value.
2681526827
if (Index < 0)
2681626828
Index = 0;
26817-
// If we refer to the second operand then we have to add elements
26818-
// number in hardware register minus number of elements in a type.
26819-
if ((unsigned)Index >= ElementsPerVectorReg)
26820-
Index += IndexLen - ElementsPerVectorReg;
26829+
// If the mask refers to elements in the second operand, then we have to
26830+
// offset the index by the number of elements in a vector. If this is number
26831+
// is not known at compile-time, we need to maintain a mask with 'VL' values
26832+
// to add at runtime.
26833+
if ((unsigned)Index >= ElementsPerVectorReg) {
26834+
if (MinMaxEqual) {
26835+
Index += IndexLen - ElementsPerVectorReg;
26836+
} else {
26837+
Index = Index - ElementsPerVectorReg;
26838+
AddRuntimeVLMask.push_back(DAG.getConstant(1, DL, MVT::i64));
26839+
}
26840+
} else if (!MinMaxEqual)
26841+
AddRuntimeVLMask.push_back(DAG.getConstant(0, DL, MVT::i64));
2682126842
// For 8-bit elements and 1024-bit SVE registers and MaxOffset equals
2682226843
// to 255, this might point to the last element of in the second operand
2682326844
// of the shufflevector, thus we are rejecting this transform.
@@ -26830,11 +26851,12 @@ static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2,
2683026851
// value where it would perform first lane duplication for out of
2683126852
// index elements. For i8 elements an out-of-range index could be a valid
2683226853
// for 2048-bit vector register size.
26833-
for (unsigned i = 0; i < IndexLen - ElementsPerVectorReg; ++i)
26854+
for (unsigned i = 0; i < IndexLen - ElementsPerVectorReg; ++i) {
2683426855
TBLMask.push_back(DAG.getConstant((int)MaxOffset, DL, MVT::i64));
26856+
if (!MinMaxEqual)
26857+
AddRuntimeVLMask.push_back(DAG.getConstant(0, DL, MVT::i64));
26858+
}
2683526859

26836-
EVT MaskEltType = EVT::getIntegerVT(*DAG.getContext(), BitsPerElt);
26837-
EVT MaskType = EVT::getVectorVT(*DAG.getContext(), MaskEltType, IndexLen);
2683826860
EVT MaskContainerVT = getContainerForFixedLengthVector(DAG, MaskType);
2683926861
SDValue VecMask =
2684026862
DAG.getBuildVector(MaskType, DL, ArrayRef(TBLMask.data(), IndexLen));
@@ -26846,13 +26868,29 @@ static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2,
2684626868
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
2684726869
DAG.getConstant(Intrinsic::aarch64_sve_tbl, DL, MVT::i32),
2684826870
Op1, SVEMask);
26849-
else if (Subtarget.hasSVE2())
26871+
else if (Subtarget.hasSVE2()) {
26872+
if (!MinMaxEqual) {
26873+
unsigned MinNumElts = AArch64::SVEBitsPerBlock / BitsPerElt;
26874+
SDValue VScale = (BitsPerElt == 64)
26875+
? DAG.getVScale(DL, MVT::i64, APInt(64, MinNumElts))
26876+
: DAG.getVScale(DL, MVT::i32, APInt(32, MinNumElts));
26877+
SDValue VecMask =
26878+
DAG.getBuildVector(MaskType, DL, ArrayRef(TBLMask.data(), IndexLen));
26879+
SDValue MulByMask = DAG.getNode(
26880+
ISD::MUL, DL, MaskType,
26881+
DAG.getNode(ISD::SPLAT_VECTOR, DL, MaskType, VScale),
26882+
DAG.getBuildVector(MaskType, DL,
26883+
ArrayRef(AddRuntimeVLMask.data(), IndexLen)));
26884+
SDValue UpdatedVecMask =
26885+
DAG.getNode(ISD::ADD, DL, MaskType, VecMask, MulByMask);
26886+
SVEMask = convertToScalableVector(
26887+
DAG, getContainerForFixedLengthVector(DAG, MaskType), UpdatedVecMask);
26888+
}
2685026889
Shuffle =
2685126890
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
2685226891
DAG.getConstant(Intrinsic::aarch64_sve_tbl2, DL, MVT::i32),
2685326892
Op1, Op2, SVEMask);
26854-
else
26855-
llvm_unreachable("Cannot lower shuffle without SVE2 TBL");
26893+
}
2685626894
Shuffle = convertFromScalableVector(DAG, VT, Shuffle);
2685726895
return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Shuffle);
2685826896
}

0 commit comments

Comments
 (0)