Skip to content

Commit 8e77390

Browse files
authored
[X86][CodeGen] Support folding memory broadcast in X86InstrInfo::foldMemoryOperandImpl (#79761)
1 parent c12f30c commit 8e77390

8 files changed

+1719
-1616
lines changed

llvm/lib/Target/X86/X86InstrAVX512.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,7 @@ multiclass avx512_broadcast_rm_split<bits<8> opc, string OpcodeStr,
10671067
MaskInfo.RC:$src0))],
10681068
DestInfo.ExeDomain>, T8, PD, EVEX, EVEX_K, Sched<[SchedRR]>;
10691069

1070-
let hasSideEffects = 0, mayLoad = 1 in
1070+
let hasSideEffects = 0, mayLoad = 1, isReMaterializable = 1, canFoldAsLoad = 1 in
10711071
def rm : AVX512PI<opc, MRMSrcMem, (outs MaskInfo.RC:$dst),
10721072
(ins SrcInfo.ScalarMemOp:$src),
10731073
!strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"),

llvm/lib/Target/X86/X86InstrFoldTables.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,23 @@ const X86FoldTableEntry *llvm::lookupFoldTable(unsigned RegOp, unsigned OpNum) {
143143
return lookupFoldTableImpl(FoldTable, RegOp);
144144
}
145145

146+
const X86FoldTableEntry *llvm::lookupBroadcastFoldTable(unsigned RegOp,
147+
unsigned OpNum) {
148+
ArrayRef<X86FoldTableEntry> FoldTable;
149+
if (OpNum == 1)
150+
FoldTable = ArrayRef(BroadcastTable1);
151+
else if (OpNum == 2)
152+
FoldTable = ArrayRef(BroadcastTable2);
153+
else if (OpNum == 3)
154+
FoldTable = ArrayRef(BroadcastTable3);
155+
else if (OpNum == 4)
156+
FoldTable = ArrayRef(BroadcastTable4);
157+
else
158+
return nullptr;
159+
160+
return lookupFoldTableImpl(FoldTable, RegOp);
161+
}
162+
146163
namespace {
147164

148165
// This class stores the memory unfolding tables. It is instantiated as a
@@ -288,8 +305,8 @@ struct X86BroadcastFoldTable {
288305
};
289306
} // namespace
290307

291-
static bool matchBroadcastSize(const X86FoldTableEntry &Entry,
292-
unsigned BroadcastBits) {
308+
bool llvm::matchBroadcastSize(const X86FoldTableEntry &Entry,
309+
unsigned BroadcastBits) {
293310
switch (Entry.Flags & TB_BCAST_MASK) {
294311
case TB_BCAST_W:
295312
case TB_BCAST_SH:

llvm/lib/Target/X86/X86InstrFoldTables.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ const X86FoldTableEntry *lookupTwoAddrFoldTable(unsigned RegOp);
4444
// operand OpNum.
4545
const X86FoldTableEntry *lookupFoldTable(unsigned RegOp, unsigned OpNum);
4646

47+
// Look up the broadcast folding table entry for folding a broadcast with
48+
// operand OpNum.
49+
const X86FoldTableEntry *lookupBroadcastFoldTable(unsigned RegOp,
50+
unsigned OpNum);
51+
4752
// Look up the memory unfolding table entry for this instruction.
4853
const X86FoldTableEntry *lookupUnfoldTable(unsigned MemOp);
4954

@@ -52,6 +57,7 @@ const X86FoldTableEntry *lookupUnfoldTable(unsigned MemOp);
5257
const X86FoldTableEntry *lookupBroadcastFoldTableBySize(unsigned MemOp,
5358
unsigned BroadcastBits);
5459

60+
bool matchBroadcastSize(const X86FoldTableEntry &Entry, unsigned BroadcastBits);
5561
} // namespace llvm
5662

5763
#endif

llvm/lib/Target/X86/X86InstrInfo.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,28 @@ bool X86InstrInfo::isReallyTriviallyReMaterializable(
862862
case X86::MMX_MOVD64rm:
863863
case X86::MMX_MOVQ64rm:
864864
// AVX-512
865+
case X86::VPBROADCASTBZ128rm:
866+
case X86::VPBROADCASTBZ256rm:
867+
case X86::VPBROADCASTBZrm:
868+
case X86::VBROADCASTF32X2Z256rm:
869+
case X86::VBROADCASTF32X2Zrm:
870+
case X86::VBROADCASTI32X2Z128rm:
871+
case X86::VBROADCASTI32X2Z256rm:
872+
case X86::VBROADCASTI32X2Zrm:
873+
case X86::VPBROADCASTWZ128rm:
874+
case X86::VPBROADCASTWZ256rm:
875+
case X86::VPBROADCASTWZrm:
876+
case X86::VPBROADCASTDZ128rm:
877+
case X86::VPBROADCASTDZ256rm:
878+
case X86::VPBROADCASTDZrm:
879+
case X86::VBROADCASTSSZ128rm:
880+
case X86::VBROADCASTSSZ256rm:
881+
case X86::VBROADCASTSSZrm:
882+
case X86::VPBROADCASTQZ128rm:
883+
case X86::VPBROADCASTQZ256rm:
884+
case X86::VPBROADCASTQZrm:
885+
case X86::VBROADCASTSDZ256rm:
886+
case X86::VBROADCASTSDZrm:
865887
case X86::VMOVSSZrm:
866888
case X86::VMOVSSZrm_alt:
867889
case X86::VMOVSDZrm:
@@ -8067,6 +8089,39 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
80678089
MOs.push_back(MachineOperand::CreateReg(0, false));
80688090
break;
80698091
}
8092+
case X86::VPBROADCASTBZ128rm:
8093+
case X86::VPBROADCASTBZ256rm:
8094+
case X86::VPBROADCASTBZrm:
8095+
case X86::VBROADCASTF32X2Z256rm:
8096+
case X86::VBROADCASTF32X2Zrm:
8097+
case X86::VBROADCASTI32X2Z128rm:
8098+
case X86::VBROADCASTI32X2Z256rm:
8099+
case X86::VBROADCASTI32X2Zrm:
8100+
// No instructions currently fuse with 8bits or 32bits x 2.
8101+
return nullptr;
8102+
8103+
#define FOLD_BROADCAST(SIZE) \
8104+
MOs.append(LoadMI.operands_begin() + NumOps - X86::AddrNumOperands, \
8105+
LoadMI.operands_begin() + NumOps); \
8106+
return foldMemoryBroadcast(MF, MI, Ops[0], MOs, InsertPt, /*Size=*/SIZE, \
8107+
/*AllowCommute=*/true);
8108+
case X86::VPBROADCASTWZ128rm:
8109+
case X86::VPBROADCASTWZ256rm:
8110+
case X86::VPBROADCASTWZrm:
8111+
FOLD_BROADCAST(16);
8112+
case X86::VPBROADCASTDZ128rm:
8113+
case X86::VPBROADCASTDZ256rm:
8114+
case X86::VPBROADCASTDZrm:
8115+
case X86::VBROADCASTSSZ128rm:
8116+
case X86::VBROADCASTSSZ256rm:
8117+
case X86::VBROADCASTSSZrm:
8118+
FOLD_BROADCAST(32);
8119+
case X86::VPBROADCASTQZ128rm:
8120+
case X86::VPBROADCASTQZ256rm:
8121+
case X86::VPBROADCASTQZrm:
8122+
case X86::VBROADCASTSDZ256rm:
8123+
case X86::VBROADCASTSDZrm:
8124+
FOLD_BROADCAST(64);
80708125
default: {
80718126
if (isNonFoldablePartialRegisterLoad(LoadMI, MI, MF))
80728127
return nullptr;
@@ -8081,6 +8136,37 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
80818136
/*Size=*/0, Alignment, /*AllowCommute=*/true);
80828137
}
80838138

8139+
MachineInstr *
8140+
X86InstrInfo::foldMemoryBroadcast(MachineFunction &MF, MachineInstr &MI,
8141+
unsigned OpNum, ArrayRef<MachineOperand> MOs,
8142+
MachineBasicBlock::iterator InsertPt,
8143+
unsigned BitsSize, bool AllowCommute) const {
8144+
8145+
if (auto *I = lookupBroadcastFoldTable(MI.getOpcode(), OpNum))
8146+
return matchBroadcastSize(*I, BitsSize)
8147+
? FuseInst(MF, I->DstOp, OpNum, MOs, InsertPt, MI, *this)
8148+
: nullptr;
8149+
8150+
if (AllowCommute) {
8151+
// If the instruction and target operand are commutable, commute the
8152+
// instruction and try again.
8153+
unsigned CommuteOpIdx2 = commuteOperandsForFold(MI, OpNum);
8154+
if (CommuteOpIdx2 == OpNum) {
8155+
printFailMsgforFold(MI, OpNum);
8156+
return nullptr;
8157+
}
8158+
MachineInstr *NewMI =
8159+
foldMemoryBroadcast(MF, MI, CommuteOpIdx2, MOs, InsertPt, BitsSize,
8160+
/*AllowCommute=*/false);
8161+
if (NewMI)
8162+
return NewMI;
8163+
UndoCommuteForFold(MI, OpNum, CommuteOpIdx2);
8164+
}
8165+
8166+
printFailMsgforFold(MI, OpNum);
8167+
return nullptr;
8168+
}
8169+
80848170
static SmallVector<MachineMemOperand *, 2>
80858171
extractLoadMMOs(ArrayRef<MachineMemOperand *> MMOs, MachineFunction &MF) {
80868172
SmallVector<MachineMemOperand *, 2> LoadMMOs;

llvm/lib/Target/X86/X86InstrInfo.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,12 @@ class X86InstrInfo final : public X86GenInstrInfo {
643643
MachineBasicBlock::iterator InsertPt,
644644
unsigned Size, Align Alignment) const;
645645

646+
MachineInstr *foldMemoryBroadcast(MachineFunction &MF, MachineInstr &MI,
647+
unsigned OpNum,
648+
ArrayRef<MachineOperand> MOs,
649+
MachineBasicBlock::iterator InsertPt,
650+
unsigned BitsSize, bool AllowCommute) const;
651+
646652
/// isFrameOperand - Return true and the FrameIndex if the specified
647653
/// operand and follow operands form a reference to the stack frame.
648654
bool isFrameOperand(const MachineInstr &MI, unsigned int Op,

0 commit comments

Comments
 (0)