Skip to content

Commit 1a08aa2

Browse files
authored
[AArch64] Split AArch64ISD::COND_SMSTART/STOP off AArch64::SMSTART/STOP (NFC) (#140711)
The conditional variants of SMSTART/STOP currently take the current PStateSM as a variadic value. This is not supported by the verification added in #140472 (which requires variadic values to be of type Register or RegisterMask), so this patch splits the the conditional variants into new `COND_` nodes, where these extra parameters are fixed arguments. Suggested in #140472 (comment) Part of #140472.
1 parent 22a4930 commit 1a08aa2

File tree

4 files changed

+43
-29
lines changed

4 files changed

+43
-29
lines changed

llvm/docs/AArch64SME.rst

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,14 @@ Instruction Selection Nodes
213213

214214
.. code-block:: none
215215
216-
AArch64ISD::SMSTART Chain, [SM|ZA|Both], CurrentState, ExpectedState[, RegMask]
217-
AArch64ISD::SMSTOP Chain, [SM|ZA|Both], CurrentState, ExpectedState[, RegMask]
218-
219-
The ``SMSTART/SMSTOP`` nodes take ``CurrentState`` and ``ExpectedState`` operand for
220-
the case of a conditional SMSTART/SMSTOP. The instruction will only be executed
221-
if CurrentState != ExpectedState.
216+
AArch64ISD::SMSTART Chain, [SM|ZA|Both][, RegMask]
217+
AArch64ISD::SMSTOP Chain, [SM|ZA|Both][, RegMask]
218+
AArch64ISD::COND_SMSTART Chain, [SM|ZA|Both], CurrentState, ExpectedState[, RegMask]
219+
AArch64ISD::COND_SMSTOP Chain, [SM|ZA|Both], CurrentState, ExpectedState[, RegMask]
220+
221+
The ``COND_SMSTART/COND_SMSTOP`` nodes additionally take ``CurrentState`` and
222+
``ExpectedState``, in this case the instruction will only be executed if
223+
``CurrentState != ExpectedState``.
222224

223225
When ``CurrentState`` and ``ExpectedState`` can be evaluated at compile-time
224226
(i.e. they are both constants) then an unconditional ``smstart/smstop``

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2726,6 +2726,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
27262726
MAKE_CASE(AArch64ISD::VG_RESTORE)
27272727
MAKE_CASE(AArch64ISD::SMSTART)
27282728
MAKE_CASE(AArch64ISD::SMSTOP)
2729+
MAKE_CASE(AArch64ISD::COND_SMSTART)
2730+
MAKE_CASE(AArch64ISD::COND_SMSTOP)
27292731
MAKE_CASE(AArch64ISD::RESTORE_ZA)
27302732
MAKE_CASE(AArch64ISD::RESTORE_ZT)
27312733
MAKE_CASE(AArch64ISD::SAVE_ZT)
@@ -6033,14 +6035,12 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
60336035
return DAG.getNode(
60346036
AArch64ISD::SMSTART, DL, MVT::Other,
60356037
Op->getOperand(0), // Chain
6036-
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
6037-
DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
6038+
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
60386039
case Intrinsic::aarch64_sme_za_disable:
60396040
return DAG.getNode(
60406041
AArch64ISD::SMSTOP, DL, MVT::Other,
60416042
Op->getOperand(0), // Chain
6042-
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
6043-
DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
6043+
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
60446044
}
60456045
}
60466046

@@ -8927,18 +8927,22 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
89278927
SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask());
89288928
SDValue MSROp =
89298929
DAG.getTargetConstant((int32_t)AArch64SVCR::SVCRSM, DL, MVT::i32);
8930-
SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64);
8931-
SmallVector<SDValue> Ops = {Chain, MSROp, ConditionOp};
8930+
SmallVector<SDValue> Ops = {Chain, MSROp};
8931+
unsigned Opcode;
89328932
if (Condition != AArch64SME::Always) {
8933+
SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64);
8934+
Opcode = Enable ? AArch64ISD::COND_SMSTART : AArch64ISD::COND_SMSTOP;
89338935
assert(PStateSM && "PStateSM should be defined");
8936+
Ops.push_back(ConditionOp);
89348937
Ops.push_back(PStateSM);
8938+
} else {
8939+
Opcode = Enable ? AArch64ISD::SMSTART : AArch64ISD::SMSTOP;
89358940
}
89368941
Ops.push_back(RegMask);
89378942

89388943
if (InGlue)
89398944
Ops.push_back(InGlue);
89408945

8941-
unsigned Opcode = Enable ? AArch64ISD::SMSTART : AArch64ISD::SMSTOP;
89428946
return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
89438947
}
89448948

@@ -9203,9 +9207,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
92039207

92049208
if (DisableZA)
92059209
Chain = DAG.getNode(
9206-
AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
9207-
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
9208-
DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
9210+
AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
9211+
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
92099212

92109213
// Adjust the stack pointer for the new arguments...
92119214
// These operations are automatically eliminated by the prolog/epilog pass
@@ -9682,9 +9685,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96829685
if (CallAttrs.requiresEnablingZAAfterCall())
96839686
// Unconditionally resume ZA.
96849687
Result = DAG.getNode(
9685-
AArch64ISD::SMSTART, DL, MVT::Other, Result,
9686-
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
9687-
DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
9688+
AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
9689+
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
96889690

96899691
if (ShouldPreserveZT0)
96909692
Result =

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ enum NodeType : unsigned {
7373

7474
SMSTART,
7575
SMSTOP,
76+
COND_SMSTART,
77+
COND_SMSTOP,
7678
RESTORE_ZA,
7779
RESTORE_ZT,
7880
SAVE_ZT,

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,20 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13-
def AArch64_smstart : SDNode<"AArch64ISD::SMSTART", SDTypeProfile<0, 2,
14-
[SDTCisInt<0>, SDTCisInt<0>]>,
13+
def AArch64_smstart : SDNode<"AArch64ISD::SMSTART", SDTypeProfile<0, 1,
14+
[SDTCisInt<0>]>,
1515
[SDNPHasChain, SDNPSideEffect, SDNPVariadic,
1616
SDNPOptInGlue, SDNPOutGlue]>;
17-
def AArch64_smstop : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 2,
18-
[SDTCisInt<0>, SDTCisInt<0>]>,
17+
def AArch64_smstop : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 1,
18+
[SDTCisInt<0>]>,
19+
[SDNPHasChain, SDNPSideEffect, SDNPVariadic,
20+
SDNPOptInGlue, SDNPOutGlue]>;
21+
def AArch64_cond_smstart : SDNode<"AArch64ISD::COND_SMSTART", SDTypeProfile<0, 3,
22+
[SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>]>,
23+
[SDNPHasChain, SDNPSideEffect, SDNPVariadic,
24+
SDNPOptInGlue, SDNPOutGlue]>;
25+
def AArch64_cond_smstop : SDNode<"AArch64ISD::COND_SMSTOP", SDTypeProfile<0, 3,
26+
[SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>]>,
1927
[SDNPHasChain, SDNPSideEffect, SDNPVariadic,
2028
SDNPOptInGlue, SDNPOutGlue]>;
2129
def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
@@ -305,15 +313,15 @@ def MSRpstatePseudo :
305313
let Defs = [VG];
306314
}
307315

308-
def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 timm0_31:$condition)),
309-
(MSRpstatePseudo svcr_op:$pstate, 0b1, timm0_31:$condition)>;
310-
def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 timm0_31:$condition)),
311-
(MSRpstatePseudo svcr_op:$pstate, 0b0, timm0_31:$condition)>;
316+
def : Pat<(AArch64_cond_smstart (i32 svcr_op:$pstate), (i64 timm0_31:$condition), (i64 GPR64:$pstatesm)),
317+
(MSRpstatePseudo svcr_op:$pstate, 0b1, timm0_31:$condition, GPR64:$pstatesm)>;
318+
def : Pat<(AArch64_cond_smstop (i32 svcr_op:$pstate), (i64 timm0_31:$condition), (i64 GPR64:$pstatesm)),
319+
(MSRpstatePseudo svcr_op:$pstate, 0b0, timm0_31:$condition, GPR64:$pstatesm)>;
312320

313321
// Unconditional start/stop
314-
def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)),
322+
def : Pat<(AArch64_smstart (i32 svcr_op:$pstate)),
315323
(MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
316-
def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)),
324+
def : Pat<(AArch64_smstop (i32 svcr_op:$pstate)),
317325
(MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
318326

319327

0 commit comments

Comments
 (0)