Skip to content

[TRI][RISCV] Add methods to get common register class of two registers #118435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions llvm/include/llvm/CodeGen/TargetRegisterInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,13 +347,28 @@ class TargetRegisterInfo : public MCRegisterInfo {
const TargetRegisterClass *getMinimalPhysRegClass(MCRegister Reg,
MVT VT = MVT::Other) const;

/// Returns the common Register Class of two physical registers of the given
/// type, picking the most sub register class of the right type that contains
/// these two physregs.
const TargetRegisterClass *
getCommonMinimalPhysRegClass(MCRegister Reg1, MCRegister Reg2,
MVT VT = MVT::Other) const;

/// Returns the Register Class of a physical register of the given type,
/// picking the most sub register class of the right type that contains this
/// physreg. If there is no register class compatible with the given type,
/// returns nullptr.
const TargetRegisterClass *getMinimalPhysRegClassLLT(MCRegister Reg,
LLT Ty = LLT()) const;

/// Returns the common Register Class of two physical registers of the given
/// type, picking the most sub register class of the right type that contains
/// these two physregs. If there is no register class compatible with the
/// given type, returns nullptr.
const TargetRegisterClass *
getCommonMinimalPhysRegClassLLT(MCRegister Reg1, MCRegister Reg2,
LLT Ty = LLT()) const;

/// Return the maximal subclass of the given register class that is
/// allocatable or NULL.
const TargetRegisterClass *
Expand Down
77 changes: 59 additions & 18 deletions llvm/lib/CodeGen/TargetRegisterInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,44 +201,85 @@ TargetRegisterInfo::getAllocatableClass(const TargetRegisterClass *RC) const {
return nullptr;
}

/// getMinimalPhysRegClass - Returns the Register Class of a physical
/// register of the given type, picking the most sub register class of
/// the right type that contains this physreg.
const TargetRegisterClass *
TargetRegisterInfo::getMinimalPhysRegClass(MCRegister reg, MVT VT) const {
assert(Register::isPhysicalRegister(reg) &&
template <typename TypeT>
static const TargetRegisterClass *
getMinimalPhysRegClass(const TargetRegisterInfo *TRI, MCRegister Reg,
TypeT Ty) {
static_assert(std::is_same_v<TypeT, MVT> || std::is_same_v<TypeT, LLT>);
assert(Register::isPhysicalRegister(Reg) &&
"reg must be a physical register");

bool IsDefault = [&]() {
if constexpr (std::is_same_v<TypeT, MVT>)
return Ty == MVT::Other;
else
return !Ty.isValid();
}();

// Pick the most sub register class of the right type that contains
// this physreg.
const TargetRegisterClass* BestRC = nullptr;
for (const TargetRegisterClass* RC : regclasses()) {
if ((VT == MVT::Other || isTypeLegalForClass(*RC, VT)) &&
RC->contains(reg) && (!BestRC || BestRC->hasSubClass(RC)))
const TargetRegisterClass *BestRC = nullptr;
for (const TargetRegisterClass *RC : TRI->regclasses()) {
if ((IsDefault || TRI->isTypeLegalForClass(*RC, Ty)) && RC->contains(Reg) &&
(!BestRC || BestRC->hasSubClass(RC)))
BestRC = RC;
}

assert(BestRC && "Couldn't find the register class");
if constexpr (std::is_same_v<TypeT, MVT>)
assert(BestRC && "Couldn't find the register class");
return BestRC;
}

const TargetRegisterClass *
TargetRegisterInfo::getMinimalPhysRegClassLLT(MCRegister reg, LLT Ty) const {
assert(Register::isPhysicalRegister(reg) &&
"reg must be a physical register");
template <typename TypeT>
static const TargetRegisterClass *
getCommonMinimalPhysRegClass(const TargetRegisterInfo *TRI, MCRegister Reg1,
MCRegister Reg2, TypeT Ty) {
static_assert(std::is_same_v<TypeT, MVT> || std::is_same_v<TypeT, LLT>);
assert(Register::isPhysicalRegister(Reg1) &&
Register::isPhysicalRegister(Reg2) &&
"Reg1/Reg2 must be a physical register");

bool IsDefault = [&]() {
if constexpr (std::is_same_v<TypeT, MVT>)
return Ty == MVT::Other;
else
return !Ty.isValid();
}();

// Pick the most sub register class of the right type that contains
// this physreg.
const TargetRegisterClass *BestRC = nullptr;
for (const TargetRegisterClass *RC : regclasses()) {
if ((!Ty.isValid() || isTypeLegalForClass(*RC, Ty)) && RC->contains(reg) &&
(!BestRC || BestRC->hasSubClass(RC)))
for (const TargetRegisterClass *RC : TRI->regclasses()) {
if ((IsDefault || TRI->isTypeLegalForClass(*RC, Ty)) &&
RC->contains(Reg1, Reg2) && (!BestRC || BestRC->hasSubClass(RC)))
BestRC = RC;
}

if constexpr (std::is_same_v<TypeT, MVT>)
assert(BestRC && "Couldn't find the register class");
return BestRC;
}

const TargetRegisterClass *
TargetRegisterInfo::getMinimalPhysRegClass(MCRegister Reg, MVT VT) const {
return ::getMinimalPhysRegClass(this, Reg, VT);
}

const TargetRegisterClass *TargetRegisterInfo::getCommonMinimalPhysRegClass(
MCRegister Reg1, MCRegister Reg2, MVT VT) const {
return ::getCommonMinimalPhysRegClass(this, Reg1, Reg2, VT);
}

const TargetRegisterClass *
TargetRegisterInfo::getMinimalPhysRegClassLLT(MCRegister Reg, LLT Ty) const {
return ::getMinimalPhysRegClass(this, Reg, Ty);
}

const TargetRegisterClass *TargetRegisterInfo::getCommonMinimalPhysRegClassLLT(
MCRegister Reg1, MCRegister Reg2, LLT Ty) const {
return ::getCommonMinimalPhysRegClass(this, Reg1, Reg2, Ty);
}

/// getAllocatableSetForRC - Toggle the bits that represent allocatable
/// registers for the specific register class.
static void getAllocatableSetForRC(const MachineFunction &MF,
Expand Down
18 changes: 6 additions & 12 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ void RISCVInstrInfo::copyPhysRegVector(
auto FindRegWithEncoding = [TRI](const TargetRegisterClass &RegClass,
uint16_t Encoding) {
MCRegister Reg = RISCV::V0 + Encoding;
if (&RegClass == &RISCV::VRRegClass)
if (RISCVRI::getLMul(RegClass.TSFlags) == RISCVII::LMUL_1)
return Reg;
return TRI->getMatchingSuperReg(Reg, RISCV::sub_vrm1_0, &RegClass);
};
Expand Down Expand Up @@ -564,17 +564,11 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
}

// VR->VR copies.
static const TargetRegisterClass *RVVRegClasses[] = {
&RISCV::VRRegClass, &RISCV::VRM2RegClass, &RISCV::VRM4RegClass,
&RISCV::VRM8RegClass, &RISCV::VRN2M1RegClass, &RISCV::VRN2M2RegClass,
&RISCV::VRN2M4RegClass, &RISCV::VRN3M1RegClass, &RISCV::VRN3M2RegClass,
&RISCV::VRN4M1RegClass, &RISCV::VRN4M2RegClass, &RISCV::VRN5M1RegClass,
&RISCV::VRN6M1RegClass, &RISCV::VRN7M1RegClass, &RISCV::VRN8M1RegClass};
for (const auto &RegClass : RVVRegClasses) {
if (RegClass->contains(DstReg, SrcReg)) {
copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RegClass);
return;
}
const TargetRegisterClass *RegClass =
TRI->getCommonMinimalPhysRegClass(SrcReg, DstReg);
if (RISCVRegisterInfo::isRVVRegClass(RegClass)) {
copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RegClass);
return;
}

llvm_unreachable("Impossible reg-to-reg copy");
Expand Down
Loading