@@ -4823,17 +4823,6 @@ SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain,
4823
4823
Mask);
4824
4824
}
4825
4825
4826
- static std::optional<SMEAttrs> getCalleeAttrsFromExternalFunction(SDValue V) {
4827
- if (auto *ES = dyn_cast<ExternalSymbolSDNode>(V)) {
4828
- StringRef S(ES->getSymbol());
4829
- if (S == "__arm_sme_state" || S == "__arm_tpidr2_save")
4830
- return SMEAttrs(SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved);
4831
- if (S == "__arm_tpidr2_restore")
4832
- return SMEAttrs(SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared);
4833
- }
4834
- return std::nullopt;
4835
- }
4836
-
4837
4826
SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
4838
4827
SelectionDAG &DAG) const {
4839
4828
unsigned IntNo = Op.getConstantOperandVal(1);
@@ -7375,28 +7364,31 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
7375
7364
SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
7376
7365
if (CLI.CB)
7377
7366
CalleeAttrs = SMEAttrs(*CLI.CB);
7378
- else if (std::optional<SMEAttrs> Attrs =
7379
- getCalleeAttrsFromExternalFunction(CLI.Callee))
7380
- CalleeAttrs = *Attrs;
7367
+ else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
7368
+ CalleeAttrs = SMEAttrs(ES->getSymbol());
7381
7369
7382
7370
bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
7383
-
7384
- MachineFrameInfo &MFI = MF.getFrameInfo();
7385
7371
if (RequiresLazySave) {
7386
- // Set up a lazy save mechanism by storing the runtime live slices
7387
- // (worst-case N*N) to the TPIDR2 stack object.
7388
- SDValue N = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
7389
- DAG.getConstant(1, DL, MVT::i32));
7390
- SDValue NN = DAG.getNode(ISD::MUL, DL, MVT::i64, N, N);
7391
- unsigned TPIDR2Obj = FuncInfo->getLazySaveTPIDR2Obj();
7372
+ SDValue NumZaSaveSlices;
7373
+ if (!CalleeAttrs.preservesZA()) {
7374
+ // Set up a lazy save mechanism by storing the runtime live slices
7375
+ // (worst-case SVL*SVL) to the TPIDR2 stack object.
7376
+ SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
7377
+ DAG.getConstant(1, DL, MVT::i32));
7378
+ NumZaSaveSlices = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
7379
+ } else if (CalleeAttrs.preservesZA()) {
7380
+ NumZaSaveSlices = DAG.getConstant(0, DL, MVT::i64);
7381
+ }
7392
7382
7383
+ unsigned TPIDR2Obj = FuncInfo->getLazySaveTPIDR2Obj();
7393
7384
MachinePointerInfo MPI = MachinePointerInfo::getStack(MF, TPIDR2Obj);
7394
7385
SDValue TPIDR2ObjAddr = DAG.getFrameIndex(TPIDR2Obj,
7395
7386
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
7396
- SDValue BufferPtrAddr =
7387
+ SDValue NumZaSaveSlicesAddr =
7397
7388
DAG.getNode(ISD::ADD, DL, TPIDR2ObjAddr.getValueType(), TPIDR2ObjAddr,
7398
7389
DAG.getConstant(8, DL, TPIDR2ObjAddr.getValueType()));
7399
- Chain = DAG.getTruncStore(Chain, DL, NN, BufferPtrAddr, MPI, MVT::i16);
7390
+ Chain = DAG.getTruncStore(Chain, DL, NumZaSaveSlices, NumZaSaveSlicesAddr,
7391
+ MPI, MVT::i16);
7400
7392
Chain = DAG.getNode(
7401
7393
ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
7402
7394
DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
@@ -7503,6 +7495,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
7503
7495
7504
7496
Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext());
7505
7497
Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty);
7498
+ MachineFrameInfo &MFI = MF.getFrameInfo();
7506
7499
int FI = MFI.CreateStackObject(StoreSize, Alignment, false);
7507
7500
if (isScalable)
7508
7501
MFI.setStackID(FI, TargetStackID::ScalableVector);
@@ -7819,35 +7812,34 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
7819
7812
}
7820
7813
7821
7814
if (RequiresLazySave) {
7822
- // Unconditionally resume ZA.
7823
- Result = DAG.getNode(
7824
- AArch64ISD::SMSTART, DL, MVT::Other, Result,
7825
- DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
7826
- DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
7827
-
7828
- // Conditionally restore the lazy save using a pseudo node.
7829
- unsigned FI = FuncInfo->getLazySaveTPIDR2Obj();
7830
- SDValue RegMask = DAG.getRegisterMask(
7831
- TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
7832
- SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
7833
- "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
7834
- SDValue TPIDR2_EL0 = DAG.getNode(
7835
- ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
7836
- DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
7837
-
7838
- // Copy the address of the TPIDR2 block into X0 before 'calling' the
7839
- // RESTORE_ZA pseudo.
7840
- SDValue Glue;
7841
- SDValue TPIDR2Block = DAG.getFrameIndex(
7842
- FI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
7843
- Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
7844
- Result = DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
7845
- {Result, TPIDR2_EL0,
7846
- DAG.getRegister(AArch64::X0, MVT::i64),
7847
- RestoreRoutine,
7848
- RegMask,
7849
- Result.getValue(1)});
7850
-
7815
+ if (!CalleeAttrs.preservesZA()) {
7816
+ // Unconditionally resume ZA.
7817
+ Result = DAG.getNode(
7818
+ AArch64ISD::SMSTART, DL, MVT::Other, Result,
7819
+ DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
7820
+ DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
7821
+
7822
+ // Conditionally restore the lazy save using a pseudo node.
7823
+ unsigned FI = FuncInfo->getLazySaveTPIDR2Obj();
7824
+ SDValue RegMask = DAG.getRegisterMask(
7825
+ TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
7826
+ SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
7827
+ "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
7828
+ SDValue TPIDR2_EL0 = DAG.getNode(
7829
+ ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
7830
+ DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
7831
+
7832
+ // Copy the address of the TPIDR2 block into X0 before 'calling' the
7833
+ // RESTORE_ZA pseudo.
7834
+ SDValue Glue;
7835
+ SDValue TPIDR2Block = DAG.getFrameIndex(
7836
+ FI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
7837
+ Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
7838
+ Result = DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
7839
+ {Result, TPIDR2_EL0,
7840
+ DAG.getRegister(AArch64::X0, MVT::i64),
7841
+ RestoreRoutine, RegMask, Result.getValue(1)});
7842
+ }
7851
7843
// Finally reset the TPIDR2_EL0 register to 0.
7852
7844
Result = DAG.getNode(
7853
7845
ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
0 commit comments