Skip to content

[RISCV] Add 16 bit GPR sub-register for Zhinx. #107446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,13 @@ struct RISCVOperand final : public MCParsedAsmOperand {
RISCVMCRegisterClasses[RISCV::GPRRegClassID].contains(Reg.RegNum);
}

bool isGPRF16() const {
return Kind == KindTy::Register &&
RISCVMCRegisterClasses[RISCV::GPRF16RegClassID].contains(Reg.RegNum);
}

bool isGPRAsFPR() const { return isGPR() && Reg.IsGPRAsFPR; }
bool isGPRAsFPR16() const { return isGPRF16() && Reg.IsGPRAsFPR; }

bool isGPRPair() const {
return Kind == KindTy::Register &&
Expand Down Expand Up @@ -1341,6 +1347,10 @@ unsigned RISCVAsmParser::validateTargetOperandClass(MCParsedAsmOperand &AsmOp,
Op.Reg.RegNum = convertFPR64ToFPR16(Reg);
return Match_Success;
}
if (Kind == MCK_GPRAsFPR16 && Op.isGPRAsFPR()) {
Op.Reg.RegNum = Reg - RISCV::X0 + RISCV::X0_H;
return Match_Success;
}
// As the parser couldn't differentiate an VRM2/VRM4/VRM8 from an VR, coerce
// the register from VR to VRM2/VRM4/VRM8 if necessary.
if (IsRegVR && (Kind == MCK_VRM2 || Kind == MCK_VRM4 || Kind == MCK_VRM8)) {
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ static DecodeStatus DecodeGPRRegisterClass(MCInst &Inst, uint32_t RegNo,
return MCDisassembler::Success;
}

static DecodeStatus DecodeGPRF16RegisterClass(MCInst &Inst, uint32_t RegNo,
uint64_t Address,
const MCDisassembler *Decoder) {
bool IsRVE = Decoder->getSubtargetInfo().hasFeature(RISCV::FeatureStdExtE);

if (RegNo >= 32 || (IsRVE && RegNo >= 16))
return MCDisassembler::Fail;

MCRegister Reg = RISCV::X0_H + RegNo;
Inst.addOperand(MCOperand::createReg(Reg));
return MCDisassembler::Success;
}

static DecodeStatus DecodeGPRX1X5RegisterClass(MCInst &Inst, uint32_t RegNo,
uint64_t Address,
const MCDisassembler *Decoder) {
Expand Down
55 changes: 53 additions & 2 deletions llvm/lib/Target/RISCV/RISCVCallingConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,23 @@ ArrayRef<MCPhysReg> RISCV::getArgGPRs(const RISCVABI::ABI ABI) {
return ArrayRef(ArgIGPRs);
}

static ArrayRef<MCPhysReg> getArgGPR16s(const RISCVABI::ABI ABI) {
// The GPRs used for passing arguments in the ILP32* and LP64* ABIs, except
// the ILP32E ABI.
static const MCPhysReg ArgIGPRs[] = {RISCV::X10_H, RISCV::X11_H, RISCV::X12_H,
RISCV::X13_H, RISCV::X14_H, RISCV::X15_H,
RISCV::X16_H, RISCV::X17_H};
// The GPRs used for passing arguments in the ILP32E/ILP64E ABI.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: ILP64E => LP64E (I've just committed a fix to other similar instances in this file).

static const MCPhysReg ArgEGPRs[] = {RISCV::X10_H, RISCV::X11_H,
RISCV::X12_H, RISCV::X13_H,
RISCV::X14_H, RISCV::X15_H};

if (ABI == RISCVABI::ABI_ILP32E || ABI == RISCVABI::ABI_LP64E)
return ArrayRef(ArgEGPRs);

return ArrayRef(ArgIGPRs);
}

static ArrayRef<MCPhysReg> getFastCCArgGPRs(const RISCVABI::ABI ABI) {
// The GPRs used for passing arguments in the FastCC, X5 and X6 might be used
// for save-restore libcall, so we don't use them.
Expand All @@ -157,6 +174,26 @@ static ArrayRef<MCPhysReg> getFastCCArgGPRs(const RISCVABI::ABI ABI) {
return ArrayRef(FastCCIGPRs);
}

static ArrayRef<MCPhysReg> getFastCCArgGPRF16s(const RISCVABI::ABI ABI) {
// The GPRs used for passing arguments in the FastCC, X5 and X6 might be used
// for save-restore libcall, so we don't use them.
// Don't use X7 for fastcc, since Zicfilp uses X7 as the label register.
static const MCPhysReg FastCCIGPRs[] = {
RISCV::X10_H, RISCV::X11_H, RISCV::X12_H, RISCV::X13_H,
RISCV::X14_H, RISCV::X15_H, RISCV::X16_H, RISCV::X17_H,
RISCV::X28_H, RISCV::X29_H, RISCV::X30_H, RISCV::X31_H};

// The GPRs used for passing arguments in the FastCC when using ILP32E/ILP64E.
static const MCPhysReg FastCCEGPRs[] = {RISCV::X10_H, RISCV::X11_H,
RISCV::X12_H, RISCV::X13_H,
RISCV::X14_H, RISCV::X15_H};

if (ABI == RISCVABI::ABI_ILP32E || ABI == RISCVABI::ABI_LP64E)
return ArrayRef(FastCCEGPRs);

return ArrayRef(FastCCIGPRs);
}

// Pass a 2*XLEN argument that has been split into two XLEN values through
// registers or the stack as necessary.
static bool CC_RISCVAssign2XLen(unsigned XLen, CCState &State, CCValAssign VA1,
Expand Down Expand Up @@ -309,6 +346,13 @@ bool llvm::CC_RISCV(unsigned ValNo, MVT ValVT, MVT LocVT,
// similar local variables rather than directly checking against the target
// ABI.

if ((ValVT == MVT::f16 && Subtarget.hasStdExtZhinxmin())) {
if (MCRegister Reg = State.AllocateReg(getArgGPR16s(ABI))) {
State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
return false;
}
}

ArrayRef<MCPhysReg> ArgGPRs = RISCV::getArgGPRs(ABI);

if ((ValVT == MVT::f32 && XLen == 32 && Subtarget.hasStdExtZfinx()) ||
Expand Down Expand Up @@ -566,8 +610,7 @@ bool llvm::CC_RISCV_FastCC(unsigned ValNo, MVT ValVT, MVT LocVT,
}

// Check if there is an available GPR before hitting the stack.
if ((LocVT == MVT::f16 && Subtarget.hasStdExtZhinxmin()) ||
(LocVT == MVT::f32 && Subtarget.hasStdExtZfinx()) ||
if ((LocVT == MVT::f32 && Subtarget.hasStdExtZfinx()) ||
(LocVT == MVT::f64 && Subtarget.is64Bit() &&
Subtarget.hasStdExtZdinx())) {
if (MCRegister Reg = State.AllocateReg(getFastCCArgGPRs(ABI))) {
Expand All @@ -582,6 +625,14 @@ bool llvm::CC_RISCV_FastCC(unsigned ValNo, MVT ValVT, MVT LocVT,
}
}

// Check if there is an available GPRF16 before hitting the stack.
if ((LocVT == MVT::f16 && Subtarget.hasStdExtZhinxmin())) {
if (MCRegister Reg = State.AllocateReg(getFastCCArgGPRF16s(ABI))) {
State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
return false;
}
}

if (LocVT == MVT::f16 || LocVT == MVT::bf16) {
int64_t Offset2 = State.AllocateStack(2, Align(2));
State.addLoc(CCValAssign::getMem(ValNo, ValVT, Offset2, LocVT, LocInfo));
Expand Down
9 changes: 7 additions & 2 deletions llvm/lib/Target/RISCV/RISCVDeadRegisterDefinitions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,19 @@ bool RISCVDeadRegisterDefinitions::runOnMachineFunction(MachineFunction &MF) {
continue;
LLVM_DEBUG(dbgs() << " Dead def operand #" << I << " in:\n ";
MI.print(dbgs()));
Register X0Reg;
const TargetRegisterClass *RC = TII->getRegClass(Desc, I, TRI, MF);
if (!(RC && RC->contains(RISCV::X0))) {
if (RC && RC->contains(RISCV::X0)) {
X0Reg = RISCV::X0;
} else if (RC && RC->contains(RISCV::X0_H)) {
X0Reg = RISCV::X0_H;
} else {
LLVM_DEBUG(dbgs() << " Ignoring, register is not a GPR.\n");
continue;
}
assert(LIS.hasInterval(Reg));
LIS.removeInterval(Reg);
MO.setReg(RISCV::X0);
MO.setReg(X0Reg);
LLVM_DEBUG(dbgs() << " Replacing with zero register. New:\n ";
MI.print(dbgs()));
++NumDeadDefsReplaced;
Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,10 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
}

SDNode *Res;
if (Opc == RISCV::FCVT_D_W_IN32X || Opc == RISCV::FCVT_D_W)
if (VT.SimpleTy == MVT::f16 && Opc == RISCV::COPY) {
Res =
CurDAG->getTargetExtractSubreg(RISCV::sub_16, DL, VT, Imm).getNode();
} else if (Opc == RISCV::FCVT_D_W_IN32X || Opc == RISCV::FCVT_D_W)
Res = CurDAG->getMachineNode(
Opc, DL, VT, Imm,
CurDAG->getTargetConstant(RISCVFPRndMode::RNE, DL, XLenVT));
Expand Down
23 changes: 23 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,23 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
return;
}

if (RISCV::GPRF16RegClass.contains(DstReg, SrcReg)) {
if (STI.hasStdExtZhinx()) {
BuildMI(MBB, MBBI, DL, get(RISCV::FSGNJ_H_INX), DstReg)
.addReg(SrcReg, getKillRegState(KillSrc))
.addReg(SrcReg, getKillRegState(KillSrc));
return;
}
DstReg =
TRI->getMatchingSuperReg(DstReg, RISCV::sub_16, &RISCV::GPRRegClass);
SrcReg =
TRI->getMatchingSuperReg(SrcReg, RISCV::sub_16, &RISCV::GPRRegClass);
BuildMI(MBB, MBBI, DL, get(RISCV::ADDI), DstReg)
.addReg(SrcReg, getKillRegState(KillSrc))
.addImm(0);
return;
}

if (RISCV::GPRPairRegClass.contains(DstReg, SrcReg)) {
// Emit an ADDI for both parts of GPRPair.
BuildMI(MBB, MBBI, DL, get(RISCV::ADDI),
Expand Down Expand Up @@ -579,6 +596,9 @@ void RISCVInstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
Opcode = TRI->getRegSizeInBits(RISCV::GPRRegClass) == 32 ?
RISCV::SW : RISCV::SD;
IsScalableVector = false;
} else if (RISCV::GPRF16RegClass.hasSubClassEq(RC)) {
Opcode = RISCV::SH_INX;
IsScalableVector = false;
} else if (RISCV::GPRPairRegClass.hasSubClassEq(RC)) {
Opcode = RISCV::PseudoRV32ZdinxSD;
IsScalableVector = false;
Expand Down Expand Up @@ -662,6 +682,9 @@ void RISCVInstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
Opcode = TRI->getRegSizeInBits(RISCV::GPRRegClass) == 32 ?
RISCV::LW : RISCV::LD;
IsScalableVector = false;
} else if (RISCV::GPRF16RegClass.hasSubClassEq(RC)) {
Opcode = RISCV::LH_INX;
IsScalableVector = false;
} else if (RISCV::GPRPairRegClass.hasSubClassEq(RC)) {
Opcode = RISCV::PseudoRV32ZdinxLD;
IsScalableVector = false;
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,8 @@ class BranchCC_rri<bits<3> funct3, string opcodestr>
}

let hasSideEffects = 0, mayLoad = 1, mayStore = 0 in {
class Load_ri<bits<3> funct3, string opcodestr>
: RVInstI<funct3, OPC_LOAD, (outs GPR:$rd), (ins GPRMem:$rs1, simm12:$imm12),
class Load_ri<bits<3> funct3, string opcodestr, DAGOperand rty = GPR>
: RVInstI<funct3, OPC_LOAD, (outs rty:$rd), (ins GPRMem:$rs1, simm12:$imm12),
opcodestr, "$rd, ${imm12}(${rs1})">;

class HLoad_r<bits<7> funct7, bits<5> funct5, string opcodestr>
Expand All @@ -529,9 +529,9 @@ class HLoad_r<bits<7> funct7, bits<5> funct5, string opcodestr>
// reflecting the order these fields are specified in the instruction
// encoding.
let hasSideEffects = 0, mayLoad = 0, mayStore = 1 in {
class Store_rri<bits<3> funct3, string opcodestr>
class Store_rri<bits<3> funct3, string opcodestr, DAGOperand rty = GPR>
: RVInstS<funct3, OPC_STORE, (outs),
(ins GPR:$rs2, GPRMem:$rs1, simm12:$imm12),
(ins rty:$rs2, GPRMem:$rs1, simm12:$imm12),
opcodestr, "$rs2, ${imm12}(${rs1})">;

class HStore_rr<bits<7> funct7, string opcodestr>
Expand Down
26 changes: 17 additions & 9 deletions llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,14 @@ def riscv_fmv_x_signexth

// Zhinxmin and Zhinx

def GPRAsFPR16 : AsmOperandClass {
let Name = "GPRAsFPR16";
let ParserMethod = "parseGPRAsFPR";
let RenderMethod = "addRegOperands";
}

def FPR16INX : RegisterOperand<GPRF16> {
let ParserMatchClass = GPRAsFPR;
let DecoderMethod = "DecodeGPRRegisterClass";
let ParserMatchClass = GPRAsFPR16;
}

def ZfhExt : ExtInfo<"", "", [HasStdExtZfh],
Expand Down Expand Up @@ -84,6 +89,12 @@ def FLH : FPLoad_r<0b001, "flh", FPR16, WriteFLD16>;
def FSH : FPStore_r<0b001, "fsh", FPR16, WriteFST16>;
} // Predicates = [HasHalfFPLoadStoreMove]

let Predicates = [HasStdExtZhinxmin], isCodeGenOnly = 1 in {
def LH_INX : Load_ri<0b001, "lh", GPRF16>, Sched<[WriteLDH, ReadMemBase]>;
def SH_INX : Store_rri<0b001, "sh", GPRF16>,
Sched<[WriteSTH, ReadStoreData, ReadMemBase]>;
}

foreach Ext = ZfhExts in {
let SchedRW = [WriteFMA16, ReadFMA16, ReadFMA16, ReadFMA16Addend] in {
defm FMADD_H : FPFMA_rrr_frm_m<OPC_MADD, 0b10, "fmadd.h", Ext>;
Expand Down Expand Up @@ -426,13 +437,10 @@ let Predicates = [HasStdExtZhinxmin] in {
defm Select_FPR16INX : SelectCC_GPR_rrirr<FPR16INX, f16>;

/// Loads
def : Pat<(f16 (load (AddrRegImm (XLenVT GPR:$rs1), simm12:$imm12))),
(COPY_TO_REGCLASS (LH GPR:$rs1, simm12:$imm12), GPRF16)>;
def : LdPat<load, LH_INX, f16>;

/// Stores
def : Pat<(store (f16 FPR16INX:$rs2),
(AddrRegImm (XLenVT GPR:$rs1), simm12:$imm12)),
(SH (COPY_TO_REGCLASS FPR16INX:$rs2, GPR), GPR:$rs1, simm12:$imm12)>;
def : StPat<store, SH_INX, GPRF16, f16>;
} // Predicates = [HasStdExtZhinxmin]

let Predicates = [HasStdExtZfhmin] in {
Expand All @@ -458,8 +466,8 @@ def : Pat<(any_fpround FPR32INX:$rs1), (FCVT_H_S_INX FPR32INX:$rs1, FRM_DYN)>;
def : Pat<(any_fpextend FPR16INX:$rs1), (FCVT_S_H_INX FPR16INX:$rs1, FRM_RNE)>;

// Moves (no conversion)
def : Pat<(f16 (riscv_fmv_h_x GPR:$src)), (COPY_TO_REGCLASS GPR:$src, GPR)>;
def : Pat<(riscv_fmv_x_anyexth FPR16INX:$src), (COPY_TO_REGCLASS FPR16INX:$src, GPR)>;
def : Pat<(f16 (riscv_fmv_h_x GPR:$src)), (EXTRACT_SUBREG GPR:$src, sub_16)>;
def : Pat<(riscv_fmv_x_anyexth FPR16INX:$src), (INSERT_SUBREG (XLenVT (IMPLICIT_DEF)), FPR16INX:$src, sub_16)>;

def : Pat<(fcopysign FPR32INX:$rs1, FPR16INX:$rs2), (FSGNJ_S_INX $rs1, (FCVT_S_H_INX $rs2, FRM_RNE))>;
} // Predicates = [HasStdExtZhinxmin]
Expand Down
14 changes: 7 additions & 7 deletions llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ BitVector RISCVRegisterInfo::getReservedRegs(const MachineFunction &MF) const {
}

// Use markSuperRegs to ensure any register aliases are also reserved
markSuperRegs(Reserved, RISCV::X2); // sp
markSuperRegs(Reserved, RISCV::X3); // gp
markSuperRegs(Reserved, RISCV::X4); // tp
markSuperRegs(Reserved, RISCV::X2_H); // sp
markSuperRegs(Reserved, RISCV::X3_H); // gp
markSuperRegs(Reserved, RISCV::X4_H); // tp
if (TFI->hasFP(MF))
markSuperRegs(Reserved, RISCV::X8); // fp
markSuperRegs(Reserved, RISCV::X8_H); // fp
// Reserve the base register if we need to realign the stack and allocate
// variable-sized objects at runtime.
if (TFI->hasBP(MF))
Expand All @@ -131,7 +131,7 @@ BitVector RISCVRegisterInfo::getReservedRegs(const MachineFunction &MF) const {

// There are only 16 GPRs for RVE.
if (Subtarget.hasStdExtE())
for (MCPhysReg Reg = RISCV::X16; Reg <= RISCV::X31; Reg++)
for (MCPhysReg Reg = RISCV::X16_H; Reg <= RISCV::X31_H; Reg++)
markSuperRegs(Reserved, Reg);

// V registers for code generation. We handle them manually.
Expand All @@ -150,8 +150,8 @@ BitVector RISCVRegisterInfo::getReservedRegs(const MachineFunction &MF) const {
if (MF.getFunction().getCallingConv() == CallingConv::GRAAL) {
if (Subtarget.hasStdExtE())
report_fatal_error("Graal reserved registers do not exist in RVE");
markSuperRegs(Reserved, RISCV::X23);
markSuperRegs(Reserved, RISCV::X27);
markSuperRegs(Reserved, RISCV::X23_H);
markSuperRegs(Reserved, RISCV::X27_H);
}

// Shadow stack pointer.
Expand Down
Loading
Loading