Skip to content

[AArch64][SME] Spill p-regs as z-regs when streaming hazards are possible #123752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
313 changes: 308 additions & 5 deletions llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,9 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
case AArch64::STR_PXI:
case AArch64::LDR_ZXI:
case AArch64::LDR_PXI:
case AArch64::PTRUE_B:
case AArch64::CPY_ZPzI_B:
case AArch64::CMPNE_PPzZI_B:
return I->getFlag(MachineInstr::FrameSetup) ||
I->getFlag(MachineInstr::FrameDestroy);
}
Expand Down Expand Up @@ -3265,7 +3268,8 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
StrOpc = RPI.isPaired() ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI;
break;
case RegPairInfo::PPR:
StrOpc = AArch64::STR_PXI;
StrOpc =
Size == 16 ? AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO : AArch64::STR_PXI;
break;
case RegPairInfo::VG:
StrOpc = AArch64::STRXui;
Expand Down Expand Up @@ -3494,7 +3498,8 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
LdrOpc = RPI.isPaired() ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI;
break;
case RegPairInfo::PPR:
LdrOpc = AArch64::LDR_PXI;
LdrOpc = Size == 16 ? AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO
: AArch64::LDR_PXI;
break;
case RegPairInfo::VG:
continue;
Expand Down Expand Up @@ -3720,6 +3725,14 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
continue;
}

// Always save P4 when PPR spills are ZPR-sized and a predicate above p8 is
// spilled. If all of p0-p3 are used as return values p4 is must be free
// to reload p8-p15.
if (RegInfo->getSpillSize(AArch64::PPRRegClass) == 16 &&
AArch64::PPR_p8to15RegClass.contains(Reg)) {
SavedRegs.set(AArch64::P4);
}

// MachO's compact unwind format relies on all registers being stored in
// pairs.
// FIXME: the usual format is actually better if unwinding isn't needed.
Expand Down Expand Up @@ -4159,8 +4172,295 @@ int64_t AArch64FrameLowering::assignSVEStackObjectOffsets(
true);
}

/// Attempts to scavenge a register from \p ScavengeableRegs given the used
/// registers in \p UsedRegs.
static Register tryScavengeRegister(LiveRegUnits const &UsedRegs,
BitVector const &ScavengeableRegs) {
for (auto Reg : ScavengeableRegs.set_bits()) {
if (UsedRegs.available(Reg))
return Reg;
}
return AArch64::NoRegister;
}

/// Propagates frame-setup/destroy flags from \p SourceMI to all instructions in
/// \p MachineInstrs.
static void propagateFrameFlags(MachineInstr &SourceMI,
ArrayRef<MachineInstr *> MachineInstrs) {
for (MachineInstr *MI : MachineInstrs) {
if (SourceMI.getFlag(MachineInstr::FrameSetup))
MI->setFlag(MachineInstr::FrameSetup);
if (SourceMI.getFlag(MachineInstr::FrameDestroy))
MI->setFlag(MachineInstr::FrameDestroy);
}
}

/// RAII helper class for scavenging or spilling a register. On construction
/// attempts to find a free register of class \p RC (given \p UsedRegs and \p
/// AllocatableRegs), if no register can be found spills \p SpillCandidate to \p
/// MaybeSpillFI to free a register. The free'd register is returned via the \p
/// FreeReg output parameter. On destruction, if there is a spill, its previous
/// value is reloaded. The spilling and scavenging is only valid at the
/// insertion point \p MBBI, this class should _not_ be used in places that
/// create or manipulate basic blocks, moving the expected insertion point.
struct ScopedScavengeOrSpill {
ScopedScavengeOrSpill(const ScopedScavengeOrSpill &) = delete;
ScopedScavengeOrSpill(ScopedScavengeOrSpill &&) = delete;

ScopedScavengeOrSpill(MachineFunction &MF, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
Register SpillCandidate, const TargetRegisterClass &RC,
LiveRegUnits const &UsedRegs,
BitVector const &AllocatableRegs,
std::optional<int> *MaybeSpillFI)
: MBB(MBB), MBBI(MBBI), RC(RC), TII(static_cast<const AArch64InstrInfo &>(
*MF.getSubtarget().getInstrInfo())),
TRI(*MF.getSubtarget().getRegisterInfo()) {
FreeReg = tryScavengeRegister(UsedRegs, AllocatableRegs);
if (FreeReg != AArch64::NoRegister)
return;
assert(MaybeSpillFI && "Expected emergency spill slot FI information "
"(attempted to spill in prologue/epilogue?)");
if (!MaybeSpillFI->has_value()) {
MachineFrameInfo &MFI = MF.getFrameInfo();
*MaybeSpillFI = MFI.CreateSpillStackObject(TRI.getSpillSize(RC),
TRI.getSpillAlign(RC));
}
FreeReg = SpillCandidate;
SpillFI = MaybeSpillFI->value();
TII.storeRegToStackSlot(MBB, MBBI, FreeReg, false, *SpillFI, &RC, &TRI,
Register());
}

bool hasSpilled() const { return SpillFI.has_value(); }

/// Returns the free register (found from scavenging or spilling a register).
Register freeRegister() const { return FreeReg; }

Register operator*() const { return freeRegister(); }

~ScopedScavengeOrSpill() {
if (hasSpilled())
TII.loadRegFromStackSlot(MBB, MBBI, FreeReg, *SpillFI, &RC, &TRI,
Register());
}

private:
MachineBasicBlock &MBB;
MachineBasicBlock::iterator MBBI;
const TargetRegisterClass &RC;
const AArch64InstrInfo &TII;
const TargetRegisterInfo &TRI;
Register FreeReg = AArch64::NoRegister;
std::optional<int> SpillFI;
};

/// Emergency stack slots for expanding SPILL_PPR_TO_ZPR_SLOT_PSEUDO and
/// FILL_PPR_FROM_ZPR_SLOT_PSEUDO.
struct EmergencyStackSlots {
std::optional<int> ZPRSpillFI;
std::optional<int> PPRSpillFI;
std::optional<int> GPRSpillFI;
};

/// Registers available for scavenging (ZPR, PPR3b, GPR).
struct ScavengeableRegs {
BitVector ZPRRegs;
BitVector PPR3bRegs;
BitVector GPRRegs;
};

static bool isInPrologueOrEpilogue(const MachineInstr &MI) {
return MI.getFlag(MachineInstr::FrameSetup) ||
MI.getFlag(MachineInstr::FrameDestroy);
}

/// Expands:
/// ```
/// SPILL_PPR_TO_ZPR_SLOT_PSEUDO $p0, %stack.0, 0
/// ```
/// To:
/// ```
/// $z0 = CPY_ZPzI_B $p0, 1, 0
/// STR_ZXI $z0, $stack.0, 0
/// ```
/// While ensuring a ZPR ($z0 in this example) is free for the predicate (
/// spilling if necessary).
static void expandSpillPPRToZPRSlotPseudo(MachineBasicBlock &MBB,
MachineInstr &MI,
const TargetRegisterInfo &TRI,
LiveRegUnits const &UsedRegs,
ScavengeableRegs const &SR,
EmergencyStackSlots &SpillSlots) {
MachineFunction &MF = *MBB.getParent();
auto *TII =
static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());

ScopedScavengeOrSpill ZPredReg(
MF, MBB, MI, AArch64::Z0, AArch64::ZPRRegClass, UsedRegs, SR.ZPRRegs,
isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.ZPRSpillFI);

SmallVector<MachineInstr *, 2> MachineInstrs;
const DebugLoc &DL = MI.getDebugLoc();
MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::CPY_ZPzI_B))
.addReg(*ZPredReg, RegState::Define)
.add(MI.getOperand(0))
.addImm(1)
.addImm(0)
.getInstr());
MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::STR_ZXI))
.addReg(*ZPredReg)
.add(MI.getOperand(1))
.addImm(MI.getOperand(2).getImm())
.setMemRefs(MI.memoperands())
.getInstr());
propagateFrameFlags(MI, MachineInstrs);
}

/// Expands:
/// ```
/// $p0 = FILL_PPR_FROM_ZPR_SLOT_PSEUDO %stack.0, 0
/// ```
/// To:
/// ```
/// $z0 = LDR_ZXI %stack.0, 0
/// $p0 = PTRUE_B 31, implicit $vg
/// $p0 = CMPNE_PPzZI_B $p0, $z0, 0, implicit-def $nzcv, implicit-def $nzcv
/// ```
/// While ensuring a ZPR ($z0 in this example) is free for the predicate (
/// spilling if necessary). If the status flags are in use at the point of
/// expansion they are preserved (by moving them to/from a GPR). This may cause
/// an additional spill if no GPR is free at the expansion point.
static bool expandFillPPRFromZPRSlotPseudo(MachineBasicBlock &MBB,
MachineInstr &MI,
const TargetRegisterInfo &TRI,
LiveRegUnits const &UsedRegs,
ScavengeableRegs const &SR,
EmergencyStackSlots &SpillSlots) {
MachineFunction &MF = *MBB.getParent();
auto *TII =
static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());

ScopedScavengeOrSpill ZPredReg(
MF, MBB, MI, AArch64::Z0, AArch64::ZPRRegClass, UsedRegs, SR.ZPRRegs,
isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.ZPRSpillFI);

ScopedScavengeOrSpill PredReg(
MF, MBB, MI, AArch64::P0, AArch64::PPR_3bRegClass, UsedRegs, SR.PPR3bRegs,
isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.PPRSpillFI);

// Elide NZCV spills if we know it is not used.
bool IsNZCVUsed = !UsedRegs.available(AArch64::NZCV);
std::optional<ScopedScavengeOrSpill> NZCVSaveReg;
if (IsNZCVUsed)
NZCVSaveReg.emplace(
MF, MBB, MI, AArch64::X0, AArch64::GPR64RegClass, UsedRegs, SR.GPRRegs,
isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.GPRSpillFI);
SmallVector<MachineInstr *, 4> MachineInstrs;
const DebugLoc &DL = MI.getDebugLoc();
MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::LDR_ZXI))
.addReg(*ZPredReg, RegState::Define)
.add(MI.getOperand(1))
.addImm(MI.getOperand(2).getImm())
.setMemRefs(MI.memoperands())
.getInstr());
if (IsNZCVUsed)
MachineInstrs.push_back(
BuildMI(MBB, MI, DL, TII->get(AArch64::MRS))
.addReg(NZCVSaveReg->freeRegister(), RegState::Define)
.addImm(AArch64SysReg::NZCV)
.addReg(AArch64::NZCV, RegState::Implicit)
.getInstr());
MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::PTRUE_B))
.addReg(*PredReg, RegState::Define)
.addImm(31));
MachineInstrs.push_back(
BuildMI(MBB, MI, DL, TII->get(AArch64::CMPNE_PPzZI_B))
.addReg(MI.getOperand(0).getReg(), RegState::Define)
.addReg(*PredReg)
.addReg(*ZPredReg)
.addImm(0)
.addReg(AArch64::NZCV, RegState::ImplicitDefine)
.getInstr());
if (IsNZCVUsed)
MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::MSR))
.addImm(AArch64SysReg::NZCV)
.addReg(NZCVSaveReg->freeRegister())
.addReg(AArch64::NZCV, RegState::ImplicitDefine)
.getInstr());

propagateFrameFlags(MI, MachineInstrs);
return PredReg.hasSpilled();
}

/// Expands all FILL_PPR_FROM_ZPR_SLOT_PSEUDO and SPILL_PPR_TO_ZPR_SLOT_PSEUDO
/// operations within the MachineBasicBlock \p MBB.
static bool expandSMEPPRToZPRSpillPseudos(MachineBasicBlock &MBB,
const TargetRegisterInfo &TRI,
ScavengeableRegs const &SR,
EmergencyStackSlots &SpillSlots) {
LiveRegUnits UsedRegs(TRI);
UsedRegs.addLiveOuts(MBB);
bool HasPPRSpills = false;
for (MachineInstr &MI : make_early_inc_range(reverse(MBB))) {
UsedRegs.stepBackward(MI);
switch (MI.getOpcode()) {
case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO:
HasPPRSpills |= expandFillPPRFromZPRSlotPseudo(MBB, MI, TRI, UsedRegs, SR,
SpillSlots);
MI.eraseFromParent();
break;
case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO:
expandSpillPPRToZPRSlotPseudo(MBB, MI, TRI, UsedRegs, SR, SpillSlots);
MI.eraseFromParent();
break;
default:
break;
}
}

return HasPPRSpills;
}

void AArch64FrameLowering::processFunctionBeforeFrameFinalized(
MachineFunction &MF, RegScavenger *RS) const {

AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
const TargetSubtargetInfo &TSI = MF.getSubtarget();
const TargetRegisterInfo &TRI = *TSI.getRegisterInfo();

// If predicates spills are 16-bytes we may need to expand
// SPILL_PPR_TO_ZPR_SLOT_PSEUDO/FILL_PPR_FROM_ZPR_SLOT_PSEUDO.
if (AFI->hasStackFrame() && TRI.getSpillSize(AArch64::PPRRegClass) == 16) {
auto ComputeScavengeableRegisters = [&](unsigned RegClassID) {
BitVector Regs = TRI.getAllocatableSet(MF, TRI.getRegClass(RegClassID));
assert(Regs.count() > 0 && "Expected scavengeable registers");
return Regs;
};

ScavengeableRegs SR{};
SR.ZPRRegs = ComputeScavengeableRegisters(AArch64::ZPRRegClassID);
// Only p0-7 are possible as the second operand of cmpne (needed for fills).
SR.PPR3bRegs = ComputeScavengeableRegisters(AArch64::PPR_3bRegClassID);
SR.GPRRegs = ComputeScavengeableRegisters(AArch64::GPR64RegClassID);

EmergencyStackSlots SpillSlots;
for (MachineBasicBlock &MBB : MF) {
// In the case we had to spill a predicate (in the range p0-p7) to reload
// a predicate (>= p8), additional spill/fill pseudos will be created.
// These need an additional expansion pass. Note: There will only be at
// most two expansion passes, as spilling/filling a predicate in the range
// p0-p7 never requires spilling another predicate.
for (int Pass = 0; Pass < 2; Pass++) {
bool HasPPRSpills =
expandSMEPPRToZPRSpillPseudos(MBB, TRI, SR, SpillSlots);
assert((Pass == 0 || !HasPPRSpills) && "Did not expect PPR spills");
if (!HasPPRSpills)
break;
}
}
}

MachineFrameInfo &MFI = MF.getFrameInfo();

assert(getStackGrowthDirection() == TargetFrameLowering::StackGrowsDown &&
Expand All @@ -4170,7 +4470,6 @@ void AArch64FrameLowering::processFunctionBeforeFrameFinalized(
int64_t SVEStackSize =
assignSVEStackObjectOffsets(MFI, MinCSFrameIndex, MaxCSFrameIndex);

AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
AFI->setStackSizeSVE(alignTo(SVEStackSize, 16U));
AFI->setMinMaxSVECSFrameIndex(MinCSFrameIndex, MaxCSFrameIndex);

Expand Down Expand Up @@ -5204,9 +5503,13 @@ void AArch64FrameLowering::emitRemarks(

unsigned RegTy = StackAccess::AccessType::GPR;
if (MFI.getStackID(FrameIdx) == TargetStackID::ScalableVector) {
if (AArch64::PPRRegClass.contains(MI.getOperand(0).getReg()))
// SPILL_PPR_TO_ZPR_SLOT_PSEUDO and FILL_PPR_FROM_ZPR_SLOT_PSEUDO
// spill/fill the predicate as a data vector (so are an FPR acess).
if (MI.getOpcode() != AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO &&
MI.getOpcode() != AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO &&
AArch64::PPRRegClass.contains(MI.getOperand(0).getReg())) {
RegTy = StackAccess::PPR;
else
} else
RegTy = StackAccess::FPR;
} else if (AArch64InstrInfo::isFpOrNEON(MI)) {
RegTy = StackAccess::FPR;
Expand Down
Loading