Skip to content

[X86] X86LowerTileCopy: Find dead register to use to prevent save-reload of tile register #83628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 42 additions & 19 deletions llvm/lib/Target/X86/X86LowerTileCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "X86InstrBuilder.h"
#include "X86InstrInfo.h"
#include "X86Subtarget.h"
#include "llvm/CodeGen/LiveRegUnits.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineFunction.h"
Expand Down Expand Up @@ -72,10 +73,16 @@ FunctionPass *llvm::createX86LowerTileCopyPass() {
bool X86LowerTileCopy::runOnMachineFunction(MachineFunction &MF) {
const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
const X86InstrInfo *TII = ST.getInstrInfo();
const TargetRegisterInfo *TRI = ST.getRegisterInfo();
BitVector GR64Regs =
TRI->getAllocatableSet(MF, TRI->getRegClass(X86::GR64RegClassID));
bool Changed = false;

for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) {
LiveRegUnits UsedRegs(*TRI);
UsedRegs.addLiveOuts(MBB);
Copy link
Contributor

@phoebewang phoebewang Apr 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can early out the loop by checking tile registers, e.g.,

  BitVector TILERegs =
      TRI->getAllocatableSet(MF, TRI->getRegClass(X86::TILERegClassID));
  bool Changed = false;

  for (MachineBasicBlock &MBB : MF) {
    LiveRegUnits UsedRegs(*TRI);
    UsedRegs.addLiveOuts(MBB);
    // There won't be a tile copy if no tile register live out.
    if (!TILERegs.anyCommon(UsedRegs.getBitVector())
      continue;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add if (!ST.hasAMXTILE()) return false at line 75?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's a good idea.

for (MachineInstr &MI : llvm::make_early_inc_range(reverse(MBB))) {
UsedRegs.stepBackward(MI);
if (!MI.isCopy())
continue;
MachineOperand &DstMO = MI.getOperand(0);
Expand All @@ -85,27 +92,41 @@ bool X86LowerTileCopy::runOnMachineFunction(MachineFunction &MF) {
if (!X86::TILERegClass.contains(DstReg, SrcReg))
continue;

const TargetRegisterInfo *TRI = ST.getRegisterInfo();
// Allocate stack slot for tile register
unsigned Size = TRI->getSpillSize(X86::TILERegClass);
Align Alignment = TRI->getSpillAlign(X86::TILERegClass);
int TileSS = MF.getFrameInfo().CreateSpillStackObject(Size, Alignment);
// Allocate stack slot for stride register
Size = TRI->getSpillSize(X86::GR64RegClass);
Alignment = TRI->getSpillAlign(X86::GR64RegClass);
int StrideSS = MF.getFrameInfo().CreateSpillStackObject(Size, Alignment);

// TODO: Pick a killed regiter to avoid save/reload. There is problem
// to get live interval in this stage.
Register GR64Cand = X86::RAX;
int StrideSS = 0;

// Pick a killed register to avoid a save/reload.
Register GR64Cand = X86::NoRegister;
for (auto RegT : GR64Regs.set_bits()) {
if (UsedRegs.available(RegT)) {
GR64Cand = RegT;
break;
}
}

const DebugLoc &DL = MI.getDebugLoc();
// mov %rax (%sp)
BuildMI(MBB, MI, DL, TII->get(X86::IMPLICIT_DEF), GR64Cand);
addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV64mr)), StrideSS)
.addReg(GR64Cand);
// mov 64 %rax
BuildMI(MBB, MI, DL, TII->get(X86::MOV64ri), GR64Cand).addImm(64);
if (GR64Cand) {
// mov 64 %reg
BuildMI(MBB, MI, DL, TII->get(X86::MOV64ri), GR64Cand).addImm(64);
} else {
// No available register? Save RAX and reload it after use.

// Allocate stack slot for stride register
Size = TRI->getSpillSize(X86::GR64RegClass);
Alignment = TRI->getSpillAlign(X86::GR64RegClass);
StrideSS = MF.getFrameInfo().CreateSpillStackObject(Size, Alignment);

// mov %reg (%sp)
addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV64mr)),
StrideSS)
.addReg(X86::RAX);
// mov 64 %reg
BuildMI(MBB, MI, DL, TII->get(X86::MOV64ri), X86::RAX).addImm(64);
}
// tilestored %tmm, (%sp, %idx)
#define GET_EGPR_IF_ENABLED(OPC) (ST.hasEGPR() ? OPC##_EVEX : OPC)
unsigned Opc = GET_EGPR_IF_ENABLED(X86::TILESTORED);
Expand All @@ -120,10 +141,12 @@ bool X86LowerTileCopy::runOnMachineFunction(MachineFunction &MF) {
#undef GET_EGPR_IF_ENABLED
NewMI = addFrameReference(BuildMI(MBB, MI, DL, TII->get(Opc), DstReg),
TileSS);
// restore %rax
// mov (%sp) %rax
addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV64rm), GR64Cand),
StrideSS);
if (!GR64Cand) {
// restore %rax
// mov (%sp) %rax
addFrameReference(
BuildMI(MBB, MI, DL, TII->get(X86::MOV64rm), GR64Cand), StrideSS);
}
MI.eraseFromParent();
Changed = true;
}
Expand Down
10 changes: 0 additions & 10 deletions llvm/test/CodeGen/X86/AMX/amx-lower-tile-copy.ll
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,8 @@ define dso_local void @test1(ptr%buf) nounwind {
; CHECK-NEXT: tileloadd 3024(%rsp,%rax), %tmm3 # 1024-byte Folded Reload
; CHECK-NEXT: tileloadd (%rbx,%r15), %tmm0
; CHECK-NEXT: tileloadd (%rbx,%r15), %tmm1
; CHECK-NEXT: # implicit-def: $rax
; CHECK-NEXT: movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
; CHECK-NEXT: movabsq $64, %rax
; CHECK-NEXT: tilestored %tmm3, 1024(%rsp,%rax) # 1024-byte Folded Spill
; CHECK-NEXT: tileloadd {{[-0-9]+}}(%r{{[sb]}}p), %tmm2 # 1024-byte Folded Reload
; CHECK-NEXT: movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload
; CHECK-NEXT: tdpbssd %tmm1, %tmm0, %tmm2
; CHECK-NEXT: tilestored %tmm2, (%rbx,%r15)
; CHECK-NEXT: incl %r14d
Expand Down Expand Up @@ -111,16 +107,10 @@ define dso_local void @test1(ptr%buf) nounwind {
; EGPR-NEXT: # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7b,0x4b,0x9c,0x04,0xd0,0x0b,0x00,0x00]
; EGPR-NEXT: tileloadd (%rbx,%r15), %tmm0 # EVEX TO VEX Compression encoding: [0xc4,0xa2,0x7b,0x4b,0x04,0x3b]
; EGPR-NEXT: tileloadd (%rbx,%r15), %tmm1 # EVEX TO VEX Compression encoding: [0xc4,0xa2,0x7b,0x4b,0x0c,0x3b]
; EGPR-NEXT: # implicit-def: $rax
; EGPR-NEXT: movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
; EGPR-NEXT: # encoding: [0x48,0x89,0x84,0x24,0xb8,0x03,0x00,0x00]
; EGPR-NEXT: movabsq $64, %rax # encoding: [0x48,0xb8,0x40,0x00,0x00,0x00,0x00,0x00,0x00,0x00]
; EGPR-NEXT: tilestored %tmm3, 1024(%rsp,%rax) # 1024-byte Folded Spill
; EGPR-NEXT: # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7a,0x4b,0x9c,0x04,0x00,0x04,0x00,0x00]
; EGPR-NEXT: tileloadd {{[-0-9]+}}(%r{{[sb]}}p), %tmm2 # 1024-byte Folded Reload
; EGPR-NEXT: # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7b,0x4b,0x94,0x24,0x00,0x04,0x00,0x00]
; EGPR-NEXT: movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload
; EGPR-NEXT: # encoding: [0x48,0x8b,0x84,0x24,0xb8,0x03,0x00,0x00]
; EGPR-NEXT: tdpbssd %tmm1, %tmm0, %tmm2 # encoding: [0xc4,0xe2,0x73,0x5e,0xd0]
; EGPR-NEXT: tilestored %tmm2, (%rbx,%r15) # EVEX TO VEX Compression encoding: [0xc4,0xa2,0x7a,0x4b,0x14,0x3b]
; EGPR-NEXT: incl %r14d # encoding: [0x41,0xff,0xc6]
Expand Down
Loading