Skip to content

Commit df08350

Browse files
authored
[RISCV] Implement foward inserting save/restore FRM instructions. (#77744)
Previously, RISCVInsertReadWriteCSR inserted an FRM swap for any value other than 7 and restored the original value right after the vector instruction. This is inefficient if multiple vector instructions use the same rounding mode if the next vector instruction uses a different explicit rounding mode. This patch implements a local optimization to solve the above problem. We assume the starting rounding mode of the basic block is "dynamic." When iterating through a basic block and encountering an instruction whose rounding mode is not the same as the current rounding mode, we change the current rounding mode and save the current rounding mode if needed. And we may need to restore FRM when encountering function call, inline asm and some uses of FRM. The advanced version of this is to perform cross basic block analysis for the starting rounding mode of each basic block.
1 parent a13b7df commit df08350

File tree

2 files changed

+729
-2
lines changed

2 files changed

+729
-2
lines changed

llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ using namespace llvm;
2323
#define DEBUG_TYPE "riscv-insert-read-write-csr"
2424
#define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass"
2525

26+
static cl::opt<bool>
27+
DisableFRMInsertOpt("riscv-disable-frm-insert-opt", cl::init(false),
28+
cl::Hidden,
29+
cl::desc("Disable optimized frm insertion."));
30+
2631
namespace {
2732

2833
class RISCVInsertReadWriteCSR : public MachineFunctionPass {
@@ -46,6 +51,7 @@ class RISCVInsertReadWriteCSR : public MachineFunctionPass {
4651

4752
private:
4853
bool emitWriteRoundingMode(MachineBasicBlock &MBB);
54+
bool emitWriteRoundingModeOpt(MachineBasicBlock &MBB);
4955
};
5056

5157
} // end anonymous namespace
@@ -55,6 +61,82 @@ char RISCVInsertReadWriteCSR::ID = 0;
5561
INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE,
5662
RISCV_INSERT_READ_WRITE_CSR_NAME, false, false)
5763

64+
// TODO: Use more accurate rounding mode at the start of MBB.
65+
bool RISCVInsertReadWriteCSR::emitWriteRoundingModeOpt(MachineBasicBlock &MBB) {
66+
bool Changed = false;
67+
MachineInstr *LastFRMChanger = nullptr;
68+
unsigned CurrentRM = RISCVFPRndMode::DYN;
69+
Register SavedFRM;
70+
71+
for (MachineInstr &MI : MBB) {
72+
if (MI.getOpcode() == RISCV::SwapFRMImm ||
73+
MI.getOpcode() == RISCV::WriteFRMImm) {
74+
CurrentRM = MI.getOperand(0).getImm();
75+
SavedFRM = Register();
76+
continue;
77+
}
78+
79+
if (MI.getOpcode() == RISCV::WriteFRM) {
80+
CurrentRM = RISCVFPRndMode::DYN;
81+
SavedFRM = Register();
82+
continue;
83+
}
84+
85+
if (MI.isCall() || MI.isInlineAsm() || MI.readsRegister(RISCV::FRM)) {
86+
// Restore FRM before unknown operations.
87+
if (SavedFRM.isValid())
88+
BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRM))
89+
.addReg(SavedFRM);
90+
CurrentRM = RISCVFPRndMode::DYN;
91+
SavedFRM = Register();
92+
continue;
93+
}
94+
95+
assert(!MI.modifiesRegister(RISCV::FRM) &&
96+
"Expected that MI could not modify FRM.");
97+
98+
int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc());
99+
if (FRMIdx < 0)
100+
continue;
101+
unsigned InstrRM = MI.getOperand(FRMIdx).getImm();
102+
103+
LastFRMChanger = &MI;
104+
105+
// Make MI implicit use FRM.
106+
MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
107+
/*IsImp*/ true));
108+
Changed = true;
109+
110+
// Skip if MI uses same rounding mode as FRM.
111+
if (InstrRM == CurrentRM)
112+
continue;
113+
114+
if (!SavedFRM.isValid()) {
115+
// Save current FRM value to SavedFRM.
116+
MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
117+
SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
118+
BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm), SavedFRM)
119+
.addImm(InstrRM);
120+
} else {
121+
// Don't need to save current FRM when SavedFRM having value.
122+
BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRMImm))
123+
.addImm(InstrRM);
124+
}
125+
CurrentRM = InstrRM;
126+
}
127+
128+
// Restore FRM if needed.
129+
if (SavedFRM.isValid()) {
130+
assert(LastFRMChanger && "Expected valid pointer.");
131+
MachineInstrBuilder MIB =
132+
BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
133+
.addReg(SavedFRM);
134+
MBB.insertAfter(LastFRMChanger, MIB);
135+
}
136+
137+
return Changed;
138+
}
139+
58140
// This function also swaps frm and restores it when encountering an RVV
59141
// floating point instruction with a static rounding mode.
60142
bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) {
@@ -99,8 +181,12 @@ bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) {
99181

100182
bool Changed = false;
101183

102-
for (MachineBasicBlock &MBB : MF)
103-
Changed |= emitWriteRoundingMode(MBB);
184+
for (MachineBasicBlock &MBB : MF) {
185+
if (DisableFRMInsertOpt)
186+
Changed |= emitWriteRoundingMode(MBB);
187+
else
188+
Changed |= emitWriteRoundingModeOpt(MBB);
189+
}
104190

105191
return Changed;
106192
}

0 commit comments

Comments
 (0)