Skip to content

Commit 1588aab

Browse files
authored
[AArch64] Generalize integer FPR lane stores for all types (#134117)
This rewrites the fold from #129756 to apply to all types, including stores of i8s. This required adding a new `aarch64mfp8` MVT to represent FPR8 types on AArch64, which can be used to extract and store 8-bit values using b sub-registers. Follow on from: #129756 Closes: #131793
1 parent bf6986f commit 1588aab

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1035
-795
lines changed

llvm/include/llvm/CodeGen/ValueTypes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,8 @@ def amdgpuBufferFatPointer : ValueType<160, 234>;
338338
// FIXME: Remove this and the getPointerType() override if MVT::i82 is added.
339339
def amdgpuBufferStridedPointer : ValueType<192, 235>;
340340

341+
def aarch64mfp8 : ValueType<8, 236>; // 8-bit value in FPR (AArch64)
342+
341343
let isNormalValueType = false in {
342344
def token : ValueType<0, 504>; // TokenTy
343345
def MetadataVT : ValueType<0, 505> { // Metadata

llvm/lib/CodeGen/ValueTypes.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ std::string EVT::getEVTString() const {
198198
return "amdgpuBufferFatPointer";
199199
case MVT::amdgpuBufferStridedPointer:
200200
return "amdgpuBufferStridedPointer";
201+
case MVT::aarch64mfp8:
202+
return "aarch64mfp8";
201203
}
202204
}
203205

@@ -221,6 +223,8 @@ Type *EVT::getTypeForEVT(LLVMContext &Context) const {
221223
case MVT::x86mmx: return llvm::FixedVectorType::get(llvm::IntegerType::get(Context, 64), 1);
222224
case MVT::aarch64svcount:
223225
return TargetExtType::get(Context, "aarch64.svcount");
226+
case MVT::aarch64mfp8:
227+
return FixedVectorType::get(IntegerType::get(Context, 8), 1);
224228
case MVT::x86amx: return Type::getX86_AMXTy(Context);
225229
case MVT::i64x8: return IntegerType::get(Context, 512);
226230
case MVT::amdgpuBufferFatPointer: return IntegerType::get(Context, 160);

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
400400
}
401401

402402
if (Subtarget->hasFPARMv8()) {
403+
addRegisterClass(MVT::aarch64mfp8, &AArch64::FPR8RegClass);
403404
addRegisterClass(MVT::f16, &AArch64::FPR16RegClass);
404405
addRegisterClass(MVT::bf16, &AArch64::FPR16RegClass);
405406
addRegisterClass(MVT::f32, &AArch64::FPR32RegClass);
@@ -23930,6 +23931,8 @@ static SDValue combineI8TruncStore(StoreSDNode *ST, SelectionDAG &DAG,
2393023931
static unsigned getFPSubregForVT(EVT VT) {
2393123932
assert(VT.isSimple() && "Expected simple VT");
2393223933
switch (VT.getSimpleVT().SimpleTy) {
23934+
case MVT::aarch64mfp8:
23935+
return AArch64::bsub;
2393323936
case MVT::f16:
2393423937
return AArch64::hsub;
2393523938
case MVT::f32:
@@ -24019,39 +24022,65 @@ static SDValue performSTORECombine(SDNode *N,
2401924022
SDValue ExtIdx = Value.getOperand(1);
2402024023
EVT VectorVT = Vector.getValueType();
2402124024
EVT ElemVT = VectorVT.getVectorElementType();
24022-
if (!ValueVT.isInteger() || ElemVT == MVT::i8 || MemVT == MVT::i8)
24025+
24026+
if (!ValueVT.isInteger())
24027+
return SDValue();
24028+
24029+
// Propagate zero constants (applying this fold may miss optimizations).
24030+
if (ISD::isConstantSplatVectorAllZeros(Vector.getNode())) {
24031+
SDValue ZeroElt = DAG.getConstant(0, DL, ValueVT);
24032+
DAG.ReplaceAllUsesWith(Value, ZeroElt);
2402324033
return SDValue();
24034+
}
24035+
2402424036
if (ValueVT != MemVT && !ST->isTruncatingStore())
2402524037
return SDValue();
2402624038

24027-
// Heuristic: If there are other users of integer scalars extracted from
24028-
// this vector that won't fold into the store -- abandon folding. Applying
24029-
// this fold may extend the vector lifetime and disrupt paired stores.
24030-
for (const auto &Use : Vector->uses()) {
24031-
if (Use.getResNo() != Vector.getResNo())
24032-
continue;
24033-
const SDNode *User = Use.getUser();
24034-
if (User->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
24035-
(!User->hasOneUse() ||
24036-
(*User->user_begin())->getOpcode() != ISD::STORE))
24037-
return SDValue();
24038-
}
24039+
// This could generate an additional extract if the index is non-zero and
24040+
// the extracted value has multiple uses.
24041+
auto *ExtCst = dyn_cast<ConstantSDNode>(ExtIdx);
24042+
if ((!ExtCst || !ExtCst->isZero()) && !Value.hasOneUse())
24043+
return SDValue();
2403924044

24040-
EVT FPElemVT = EVT::getFloatingPointVT(ElemVT.getSizeInBits());
24041-
EVT FPVectorVT = VectorVT.changeVectorElementType(FPElemVT);
24042-
SDValue Cast = DAG.getNode(ISD::BITCAST, DL, FPVectorVT, Vector);
24043-
SDValue Ext =
24044-
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, FPElemVT, Cast, ExtIdx);
24045+
// These can lower to st1, which is preferable if we're unlikely to fold the
24046+
// addressing into the store.
24047+
if (Subtarget->isNeonAvailable() && ElemVT == MemVT &&
24048+
(VectorVT.is64BitVector() || VectorVT.is128BitVector()) && ExtCst &&
24049+
!ExtCst->isZero() && ST->getBasePtr().getOpcode() != ISD::ADD)
24050+
return SDValue();
2404524051

24046-
EVT FPMemVT = EVT::getFloatingPointVT(MemVT.getSizeInBits());
24047-
if (ST->isTruncatingStore() && FPMemVT != FPElemVT) {
24048-
SDValue Trunc = DAG.getTargetExtractSubreg(getFPSubregForVT(FPMemVT), DL,
24049-
FPMemVT, Ext);
24050-
return DAG.getStore(ST->getChain(), DL, Trunc, ST->getBasePtr(),
24051-
ST->getMemOperand());
24052+
if (MemVT == MVT::i64 || MemVT == MVT::i32) {
24053+
// Heuristic: If there are other users of w/x integer scalars extracted
24054+
// from this vector that won't fold into the store -- abandon folding.
24055+
// Applying this fold may disrupt paired stores.
24056+
for (const auto &Use : Vector->uses()) {
24057+
if (Use.getResNo() != Vector.getResNo())
24058+
continue;
24059+
const SDNode *User = Use.getUser();
24060+
if (User->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
24061+
(!User->hasOneUse() ||
24062+
(*User->user_begin())->getOpcode() != ISD::STORE))
24063+
return SDValue();
24064+
}
2405224065
}
2405324066

24054-
return DAG.getStore(ST->getChain(), DL, Ext, ST->getBasePtr(),
24067+
SDValue ExtVector = Vector;
24068+
if (!ExtCst || !ExtCst->isZero()) {
24069+
// Handle extracting from lanes != 0.
24070+
SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL,
24071+
Value.getValueType(), Vector, ExtIdx);
24072+
SDValue Zero = DAG.getVectorIdxConstant(0, DL);
24073+
ExtVector = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VectorVT,
24074+
DAG.getUNDEF(VectorVT), Ext, Zero);
24075+
}
24076+
24077+
EVT FPMemVT = MemVT == MVT::i8
24078+
? MVT::aarch64mfp8
24079+
: EVT::getFloatingPointVT(MemVT.getSizeInBits());
24080+
SDValue FPSubreg = DAG.getTargetExtractSubreg(getFPSubregForVT(FPMemVT), DL,
24081+
FPMemVT, ExtVector);
24082+
24083+
return DAG.getStore(ST->getChain(), DL, FPSubreg, ST->getBasePtr(),
2405524084
ST->getMemOperand());
2405624085
}
2405724086

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3590,7 +3590,7 @@ defm LDRW : LoadUI<0b10, 0, 0b01, GPR32z, uimm12s4, "ldr",
35903590
(load (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset)))]>;
35913591
let Predicates = [HasFPARMv8] in {
35923592
defm LDRB : LoadUI<0b00, 1, 0b01, FPR8Op, uimm12s1, "ldr",
3593-
[(set FPR8Op:$Rt,
3593+
[(set (i8 FPR8Op:$Rt),
35943594
(load (am_indexed8 GPR64sp:$Rn, uimm12s1:$offset)))]>;
35953595
defm LDRH : LoadUI<0b01, 1, 0b01, FPR16Op, uimm12s2, "ldr",
35963596
[(set (f16 FPR16Op:$Rt),
@@ -3778,7 +3778,7 @@ defm LDURW : LoadUnscaled<0b10, 0, 0b01, GPR32z, "ldur",
37783778
(load (am_unscaled32 GPR64sp:$Rn, simm9:$offset)))]>;
37793779
let Predicates = [HasFPARMv8] in {
37803780
defm LDURB : LoadUnscaled<0b00, 1, 0b01, FPR8Op, "ldur",
3781-
[(set FPR8Op:$Rt,
3781+
[(set (i8 FPR8Op:$Rt),
37823782
(load (am_unscaled8 GPR64sp:$Rn, simm9:$offset)))]>;
37833783
defm LDURH : LoadUnscaled<0b01, 1, 0b01, FPR16Op, "ldur",
37843784
[(set (f16 FPR16Op:$Rt),
@@ -4348,7 +4348,7 @@ defm STRW : StoreUIz<0b10, 0, 0b00, GPR32z, uimm12s4, "str",
43484348
(am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))]>;
43494349
let Predicates = [HasFPARMv8] in {
43504350
defm STRB : StoreUI<0b00, 1, 0b00, FPR8Op, uimm12s1, "str",
4351-
[(store FPR8Op:$Rt,
4351+
[(store (i8 FPR8Op:$Rt),
43524352
(am_indexed8 GPR64sp:$Rn, uimm12s1:$offset))]>;
43534353
defm STRH : StoreUI<0b01, 1, 0b00, FPR16Op, uimm12s2, "str",
43544354
[(store (f16 FPR16Op:$Rt),
@@ -4484,7 +4484,7 @@ defm STURW : StoreUnscaled<0b10, 0, 0b00, GPR32z, "stur",
44844484
(am_unscaled32 GPR64sp:$Rn, simm9:$offset))]>;
44854485
let Predicates = [HasFPARMv8] in {
44864486
defm STURB : StoreUnscaled<0b00, 1, 0b00, FPR8Op, "stur",
4487-
[(store FPR8Op:$Rt,
4487+
[(store (i8 FPR8Op:$Rt),
44884488
(am_unscaled8 GPR64sp:$Rn, simm9:$offset))]>;
44894489
defm STURH : StoreUnscaled<0b01, 1, 0b00, FPR16Op, "stur",
44904490
[(store (f16 FPR16Op:$Rt),
@@ -4604,6 +4604,12 @@ def : Pat<(truncstorei16 GPR64:$Rt, (am_unscaled16 GPR64sp:$Rn, simm9:$offset)),
46044604
def : Pat<(truncstorei8 GPR64:$Rt, (am_unscaled8 GPR64sp:$Rn, simm9:$offset)),
46054605
(STURBBi (EXTRACT_SUBREG GPR64:$Rt, sub_32), GPR64sp:$Rn, simm9:$offset)>;
46064606

4607+
// aarch64mfp8 (bsub) stores
4608+
def : Pat<(store aarch64mfp8:$Rt, (am_unscaled8 GPR64sp:$Rn, simm9:$offset)),
4609+
(STURBi FPR8:$Rt, GPR64sp:$Rn, simm9:$offset)>;
4610+
def : Pat<(store aarch64mfp8:$Rt, (am_indexed8 GPR64sp:$Rn, uimm12s1:$offset)),
4611+
(STRBui FPR8:$Rt, GPR64sp:$Rn, uimm12s1:$offset)>;
4612+
46074613
// Match stores from lane 0 to the appropriate subreg's store.
46084614
multiclass VecStoreULane0Pat<SDPatternOperator StoreOp,
46094615
ValueType VTy, ValueType STy,
@@ -7245,8 +7251,15 @@ def : Pat<(v2i64 (int_aarch64_neon_vcopy_lane
72457251

72467252
// Move elements between vectors
72477253
multiclass Neon_INS_elt_pattern<ValueType VT128, ValueType VT64, ValueType VTSVE,
7248-
ValueType VTScal, Operand SVEIdxTy, Instruction INS> {
7254+
ValueType VTScal, Operand SVEIdxTy, Instruction INS, Instruction DUP, SubRegIndex DUPSub> {
72497255
// Extracting from the lowest 128-bits of an SVE vector
7256+
def : Pat<(VT128 (vector_insert undef,
7257+
(VTScal (vector_extract VTSVE:$Rm, (i64 SVEIdxTy:$Immn))),
7258+
(i64 0))),
7259+
(INSERT_SUBREG (VT128 (IMPLICIT_DEF)),
7260+
(DUP (VT128 (EXTRACT_SUBREG VTSVE:$Rm, zsub)), SVEIdxTy:$Immn),
7261+
DUPSub)>;
7262+
72507263
def : Pat<(VT128 (vector_insert VT128:$Rn,
72517264
(VTScal (vector_extract VTSVE:$Rm, (i64 SVEIdxTy:$Immn))),
72527265
(i64 imm:$Immd))),
@@ -7265,6 +7278,11 @@ multiclass Neon_INS_elt_pattern<ValueType VT128, ValueType VT64, ValueType VTSVE
72657278
(i64 imm:$Immd))),
72667279
(INS V128:$src, imm:$Immd, V128:$Rn, imm:$Immn)>;
72677280

7281+
def : Pat<(VT128 (vector_insert undef,
7282+
(VTScal (vector_extract (VT128 V128:$Rn), (i64 imm:$Immn))),
7283+
(i64 0))),
7284+
(INSERT_SUBREG (VT128 (IMPLICIT_DEF)), (DUP V128:$Rn, imm:$Immn), DUPSub)>;
7285+
72687286
def : Pat<(VT128 (vector_insert V128:$src,
72697287
(VTScal (vector_extract (VT64 V64:$Rn), (i64 imm:$Immn))),
72707288
(i64 imm:$Immd))),
@@ -7287,15 +7305,15 @@ multiclass Neon_INS_elt_pattern<ValueType VT128, ValueType VT64, ValueType VTSVE
72877305
dsub)>;
72887306
}
72897307

7290-
defm : Neon_INS_elt_pattern<v8f16, v4f16, nxv8f16, f16, VectorIndexH, INSvi16lane>;
7291-
defm : Neon_INS_elt_pattern<v8bf16, v4bf16, nxv8bf16, bf16, VectorIndexH, INSvi16lane>;
7292-
defm : Neon_INS_elt_pattern<v4f32, v2f32, nxv4f32, f32, VectorIndexS, INSvi32lane>;
7293-
defm : Neon_INS_elt_pattern<v2f64, v1f64, nxv2f64, f64, VectorIndexD, INSvi64lane>;
7308+
defm : Neon_INS_elt_pattern<v8f16, v4f16, nxv8f16, f16, VectorIndexH, INSvi16lane, DUPi16, hsub>;
7309+
defm : Neon_INS_elt_pattern<v8bf16, v4bf16, nxv8bf16, bf16, VectorIndexH, INSvi16lane, DUPi16, hsub>;
7310+
defm : Neon_INS_elt_pattern<v4f32, v2f32, nxv4f32, f32, VectorIndexS, INSvi32lane, DUPi32, ssub>;
7311+
defm : Neon_INS_elt_pattern<v2f64, v1f64, nxv2f64, f64, VectorIndexD, INSvi64lane, DUPi64, dsub>;
72947312

7295-
defm : Neon_INS_elt_pattern<v16i8, v8i8, nxv16i8, i32, VectorIndexB, INSvi8lane>;
7296-
defm : Neon_INS_elt_pattern<v8i16, v4i16, nxv8i16, i32, VectorIndexH, INSvi16lane>;
7297-
defm : Neon_INS_elt_pattern<v4i32, v2i32, nxv4i32, i32, VectorIndexS, INSvi32lane>;
7298-
defm : Neon_INS_elt_pattern<v2i64, v1i64, nxv2i64, i64, VectorIndexD, INSvi64lane>;
7313+
defm : Neon_INS_elt_pattern<v16i8, v8i8, nxv16i8, i32, VectorIndexB, INSvi8lane, DUPi8, bsub>;
7314+
defm : Neon_INS_elt_pattern<v8i16, v4i16, nxv8i16, i32, VectorIndexH, INSvi16lane, DUPi16, hsub>;
7315+
defm : Neon_INS_elt_pattern<v4i32, v2i32, nxv4i32, i32, VectorIndexS, INSvi32lane, DUPi32, ssub>;
7316+
defm : Neon_INS_elt_pattern<v2i64, v1i64, nxv2i64, i64, VectorIndexD, INSvi64lane, DUPi64, dsub>;
72997317

73007318
// Insert from bitcast
73017319
// vector_insert(bitcast(f32 src), n, lane) -> INSvi32lane(src, lane, INSERT_SUBREG(-, n), 0)

llvm/lib/Target/AArch64/AArch64RegisterInfo.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ def Q30 : AArch64Reg<30, "q30", [D30, D30_HI], ["v30", ""]>, DwarfRegAlias<B30
497497
def Q31 : AArch64Reg<31, "q31", [D31, D31_HI], ["v31", ""]>, DwarfRegAlias<B31>;
498498
}
499499

500-
def FPR8 : RegisterClass<"AArch64", [i8], 8, (sequence "B%u", 0, 31)> {
500+
def FPR8 : RegisterClass<"AArch64", [i8, aarch64mfp8], 8, (sequence "B%u", 0, 31)> {
501501
let Size = 8;
502502
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::FPR8RegClassID, 0, 32>";
503503
}

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3498,6 +3498,22 @@ let Predicates = [HasSVE_or_SME] in {
34983498
(EXTRACT_SUBREG ZPR:$Zs, dsub)>;
34993499
}
35003500

3501+
multiclass sve_insert_extract_elt<ValueType VT, ValueType VTScalar, Instruction DUP, Operand IdxTy> {
3502+
// NOP pattern (needed to avoid pointless DUPs being added by the second pattern).
3503+
def : Pat<(VT (vector_insert undef,
3504+
(VTScalar (vector_extract VT:$vec, (i64 0))), (i64 0))),
3505+
(VT $vec)>;
3506+
3507+
def : Pat<(VT (vector_insert undef,
3508+
(VTScalar (vector_extract VT:$vec, (i64 IdxTy:$Idx))), (i64 0))),
3509+
(DUP ZPR:$vec, IdxTy:$Idx)>;
3510+
}
3511+
3512+
defm : sve_insert_extract_elt<nxv16i8, i32, DUP_ZZI_B, sve_elm_idx_extdup_b>;
3513+
defm : sve_insert_extract_elt<nxv8i16, i32, DUP_ZZI_H, sve_elm_idx_extdup_h>;
3514+
defm : sve_insert_extract_elt<nxv4i32, i32, DUP_ZZI_S, sve_elm_idx_extdup_s>;
3515+
defm : sve_insert_extract_elt<nxv2i64, i64, DUP_ZZI_D, sve_elm_idx_extdup_d>;
3516+
35013517
multiclass sve_predicated_add<SDNode extend, int value> {
35023518
def : Pat<(nxv16i8 (add ZPR:$op, (extend nxv16i1:$pred))),
35033519
(ADD_ZPmZ_B PPR:$pred, ZPR:$op, (DUP_ZI_B value, 0))>;

0 commit comments

Comments
 (0)