Skip to content

Commit 5fe93b0

Browse files
authored
[CodeGen][TII] Allow reassociation on custom operand indices (#88306)
This opens up a door for reusing reassociation optimizations on target-specific binary operations with non-standard operand list. This is effectively a NFC.
1 parent 033453a commit 5fe93b0

File tree

3 files changed

+114
-47
lines changed

3 files changed

+114
-47
lines changed

llvm/include/llvm/CodeGen/TargetInstrInfo.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "llvm/MC/MCInstrInfo.h"
3232
#include "llvm/Support/BranchProbability.h"
3333
#include "llvm/Support/ErrorHandling.h"
34+
#include <array>
3435
#include <cassert>
3536
#include <cstddef>
3637
#include <cstdint>
@@ -1271,11 +1272,20 @@ class TargetInstrInfo : public MCInstrInfo {
12711272
return true;
12721273
}
12731274

1275+
/// The returned array encodes the operand index for each parameter because
1276+
/// the operands may be commuted; the operand indices for associative
1277+
/// operations might also be target-specific. Each element specifies the index
1278+
/// of {Prev, A, B, X, Y}.
1279+
virtual void
1280+
getReassociateOperandIndices(const MachineInstr &Root, unsigned Pattern,
1281+
std::array<unsigned, 5> &OperandIndices) const;
1282+
12741283
/// Attempt to reassociate \P Root and \P Prev according to \P Pattern to
12751284
/// reduce critical path length.
12761285
void reassociateOps(MachineInstr &Root, MachineInstr &Prev, unsigned Pattern,
12771286
SmallVectorImpl<MachineInstr *> &InsInstrs,
12781287
SmallVectorImpl<MachineInstr *> &DelInstrs,
1288+
ArrayRef<unsigned> OperandIndices,
12791289
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const;
12801290

12811291
/// Reassociation of some instructions requires inverse operations (e.g.

llvm/lib/CodeGen/TargetInstrInfo.cpp

Lines changed: 100 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,42 +1055,45 @@ static std::pair<bool, bool> mustSwapOperands(unsigned Pattern) {
10551055
}
10561056
}
10571057

1058+
void TargetInstrInfo::getReassociateOperandIndices(
1059+
const MachineInstr &Root, unsigned Pattern,
1060+
std::array<unsigned, 5> &OperandIndices) const {
1061+
switch (Pattern) {
1062+
case MachineCombinerPattern::REASSOC_AX_BY:
1063+
OperandIndices = {1, 1, 1, 2, 2};
1064+
break;
1065+
case MachineCombinerPattern::REASSOC_AX_YB:
1066+
OperandIndices = {2, 1, 2, 2, 1};
1067+
break;
1068+
case MachineCombinerPattern::REASSOC_XA_BY:
1069+
OperandIndices = {1, 2, 1, 1, 2};
1070+
break;
1071+
case MachineCombinerPattern::REASSOC_XA_YB:
1072+
OperandIndices = {2, 2, 2, 1, 1};
1073+
break;
1074+
default:
1075+
llvm_unreachable("unexpected MachineCombinerPattern");
1076+
}
1077+
}
1078+
10581079
/// Attempt the reassociation transformation to reduce critical path length.
10591080
/// See the above comments before getMachineCombinerPatterns().
10601081
void TargetInstrInfo::reassociateOps(
10611082
MachineInstr &Root, MachineInstr &Prev, unsigned Pattern,
10621083
SmallVectorImpl<MachineInstr *> &InsInstrs,
10631084
SmallVectorImpl<MachineInstr *> &DelInstrs,
1085+
ArrayRef<unsigned> OperandIndices,
10641086
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const {
10651087
MachineFunction *MF = Root.getMF();
10661088
MachineRegisterInfo &MRI = MF->getRegInfo();
10671089
const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
10681090
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
10691091
const TargetRegisterClass *RC = Root.getRegClassConstraint(0, TII, TRI);
10701092

1071-
// This array encodes the operand index for each parameter because the
1072-
// operands may be commuted. Each row corresponds to a pattern value,
1073-
// and each column specifies the index of A, B, X, Y.
1074-
unsigned OpIdx[4][4] = {
1075-
{ 1, 1, 2, 2 },
1076-
{ 1, 2, 2, 1 },
1077-
{ 2, 1, 1, 2 },
1078-
{ 2, 2, 1, 1 }
1079-
};
1080-
1081-
int Row;
1082-
switch (Pattern) {
1083-
case MachineCombinerPattern::REASSOC_AX_BY: Row = 0; break;
1084-
case MachineCombinerPattern::REASSOC_AX_YB: Row = 1; break;
1085-
case MachineCombinerPattern::REASSOC_XA_BY: Row = 2; break;
1086-
case MachineCombinerPattern::REASSOC_XA_YB: Row = 3; break;
1087-
default: llvm_unreachable("unexpected MachineCombinerPattern");
1088-
}
1089-
1090-
MachineOperand &OpA = Prev.getOperand(OpIdx[Row][0]);
1091-
MachineOperand &OpB = Root.getOperand(OpIdx[Row][1]);
1092-
MachineOperand &OpX = Prev.getOperand(OpIdx[Row][2]);
1093-
MachineOperand &OpY = Root.getOperand(OpIdx[Row][3]);
1093+
MachineOperand &OpA = Prev.getOperand(OperandIndices[1]);
1094+
MachineOperand &OpB = Root.getOperand(OperandIndices[2]);
1095+
MachineOperand &OpX = Prev.getOperand(OperandIndices[3]);
1096+
MachineOperand &OpY = Root.getOperand(OperandIndices[4]);
10941097
MachineOperand &OpC = Root.getOperand(0);
10951098

10961099
Register RegA = OpA.getReg();
@@ -1129,21 +1132,83 @@ void TargetInstrInfo::reassociateOps(
11291132
std::swap(KillX, KillY);
11301133
}
11311134

1135+
unsigned PrevFirstOpIdx, PrevSecondOpIdx;
1136+
unsigned RootFirstOpIdx, RootSecondOpIdx;
1137+
switch (Pattern) {
1138+
case MachineCombinerPattern::REASSOC_AX_BY:
1139+
PrevFirstOpIdx = OperandIndices[1];
1140+
PrevSecondOpIdx = OperandIndices[3];
1141+
RootFirstOpIdx = OperandIndices[2];
1142+
RootSecondOpIdx = OperandIndices[4];
1143+
break;
1144+
case MachineCombinerPattern::REASSOC_AX_YB:
1145+
PrevFirstOpIdx = OperandIndices[1];
1146+
PrevSecondOpIdx = OperandIndices[3];
1147+
RootFirstOpIdx = OperandIndices[4];
1148+
RootSecondOpIdx = OperandIndices[2];
1149+
break;
1150+
case MachineCombinerPattern::REASSOC_XA_BY:
1151+
PrevFirstOpIdx = OperandIndices[3];
1152+
PrevSecondOpIdx = OperandIndices[1];
1153+
RootFirstOpIdx = OperandIndices[2];
1154+
RootSecondOpIdx = OperandIndices[4];
1155+
break;
1156+
case MachineCombinerPattern::REASSOC_XA_YB:
1157+
PrevFirstOpIdx = OperandIndices[3];
1158+
PrevSecondOpIdx = OperandIndices[1];
1159+
RootFirstOpIdx = OperandIndices[4];
1160+
RootSecondOpIdx = OperandIndices[2];
1161+
break;
1162+
default:
1163+
llvm_unreachable("unexpected MachineCombinerPattern");
1164+
}
1165+
1166+
// Basically BuildMI but doesn't add implicit operands by default.
1167+
auto buildMINoImplicit = [](MachineFunction &MF, const MIMetadata &MIMD,
1168+
const MCInstrDesc &MCID, Register DestReg) {
1169+
return MachineInstrBuilder(
1170+
MF, MF.CreateMachineInstr(MCID, MIMD.getDL(), /*NoImpl=*/true))
1171+
.setPCSections(MIMD.getPCSections())
1172+
.addReg(DestReg, RegState::Define);
1173+
};
1174+
11321175
// Create new instructions for insertion.
11331176
MachineInstrBuilder MIB1 =
1134-
BuildMI(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR)
1135-
.addReg(RegX, getKillRegState(KillX))
1136-
.addReg(RegY, getKillRegState(KillY));
1177+
buildMINoImplicit(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR);
1178+
for (const auto &MO : Prev.explicit_operands()) {
1179+
unsigned Idx = MO.getOperandNo();
1180+
// Skip the result operand we'd already added.
1181+
if (Idx == 0)
1182+
continue;
1183+
if (Idx == PrevFirstOpIdx)
1184+
MIB1.addReg(RegX, getKillRegState(KillX));
1185+
else if (Idx == PrevSecondOpIdx)
1186+
MIB1.addReg(RegY, getKillRegState(KillY));
1187+
else
1188+
MIB1.add(MO);
1189+
}
1190+
MIB1.copyImplicitOps(Prev);
11371191

11381192
if (SwapRootOperands) {
11391193
std::swap(RegA, NewVR);
11401194
std::swap(KillA, KillNewVR);
11411195
}
11421196

11431197
MachineInstrBuilder MIB2 =
1144-
BuildMI(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC)
1145-
.addReg(RegA, getKillRegState(KillA))
1146-
.addReg(NewVR, getKillRegState(KillNewVR));
1198+
buildMINoImplicit(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC);
1199+
for (const auto &MO : Root.explicit_operands()) {
1200+
unsigned Idx = MO.getOperandNo();
1201+
// Skip the result operand.
1202+
if (Idx == 0)
1203+
continue;
1204+
if (Idx == RootFirstOpIdx)
1205+
MIB2 = MIB2.addReg(RegA, getKillRegState(KillA));
1206+
else if (Idx == RootSecondOpIdx)
1207+
MIB2 = MIB2.addReg(NewVR, getKillRegState(KillNewVR));
1208+
else
1209+
MIB2 = MIB2.add(MO);
1210+
}
1211+
MIB2.copyImplicitOps(Root);
11471212

11481213
// Propagate FP flags from the original instructions.
11491214
// But clear poison-generating flags because those may not be valid now.
@@ -1187,25 +1252,17 @@ void TargetInstrInfo::genAlternativeCodeSequence(
11871252
MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
11881253

11891254
// Select the previous instruction in the sequence based on the input pattern.
1190-
MachineInstr *Prev = nullptr;
1191-
switch (Pattern) {
1192-
case MachineCombinerPattern::REASSOC_AX_BY:
1193-
case MachineCombinerPattern::REASSOC_XA_BY:
1194-
Prev = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
1195-
break;
1196-
case MachineCombinerPattern::REASSOC_AX_YB:
1197-
case MachineCombinerPattern::REASSOC_XA_YB:
1198-
Prev = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
1199-
break;
1200-
default:
1201-
llvm_unreachable("Unknown pattern for machine combiner");
1202-
}
1255+
std::array<unsigned, 5> OperandIndices;
1256+
getReassociateOperandIndices(Root, Pattern, OperandIndices);
1257+
MachineInstr *Prev =
1258+
MRI.getUniqueVRegDef(Root.getOperand(OperandIndices[0]).getReg());
12031259

12041260
// Don't reassociate if Prev and Root are in different blocks.
12051261
if (Prev->getParent() != Root.getParent())
12061262
return;
12071263

1208-
reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, InstIdxForVirtReg);
1264+
reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices,
1265+
InstIdxForVirtReg);
12091266
}
12101267

12111268
MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy() const {

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,10 +1582,10 @@ void RISCVInstrInfo::finalizeInsInstrs(
15821582
MachineFunction &MF = *Root.getMF();
15831583

15841584
for (auto *NewMI : InsInstrs) {
1585-
assert(static_cast<unsigned>(RISCV::getNamedOperandIdx(
1586-
NewMI->getOpcode(), RISCV::OpName::frm)) ==
1587-
NewMI->getNumOperands() &&
1588-
"Instruction has unexpected number of operands");
1585+
// We'd already added the FRM operand.
1586+
if (static_cast<unsigned>(RISCV::getNamedOperandIdx(
1587+
NewMI->getOpcode(), RISCV::OpName::frm)) != NewMI->getNumOperands())
1588+
continue;
15891589
MachineInstrBuilder MIB(MF, NewMI);
15901590
MIB.add(FRM);
15911591
if (FRM.getImm() == RISCVFPRndMode::DYN)

0 commit comments

Comments
 (0)