Skip to content

Commit a5b22b7

Browse files
author
Cameron McInally
committed
[AArch64][SVE] Add support for DestructiveBinary and DestructiveBinaryComm DestructiveInstTypes
Add support for DestructiveBinaryComm DestructiveInstType, as well as the lowering code to expand the new Pseudos into the final movprfx+instruction pairs. Differential Revision: https://reviews.llvm.org/D73711
1 parent b72f144 commit a5b22b7

11 files changed

+638
-26
lines changed

llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ class AArch64ExpandPseudo : public MachineFunctionPass {
6868
bool expandMOVImm(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
6969
unsigned BitSize);
7070

71+
bool expand_DestructiveOp(MachineInstr &MI, MachineBasicBlock &MBB,
72+
MachineBasicBlock::iterator MBBI);
7173
bool expandCMP_SWAP(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
7274
unsigned LdarOp, unsigned StlrOp, unsigned CmpOp,
7375
unsigned ExtendImm, unsigned ZeroReg,
@@ -344,6 +346,176 @@ bool AArch64ExpandPseudo::expandCMP_SWAP_128(
344346
return true;
345347
}
346348

349+
/// \brief Expand Pseudos to Instructions with destructive operands.
350+
///
351+
/// This mechanism uses MOVPRFX instructions for zeroing the false lanes
352+
/// or for fixing relaxed register allocation conditions to comply with
353+
/// the instructions register constraints. The latter case may be cheaper
354+
/// than setting the register constraints in the register allocator,
355+
/// since that will insert regular MOV instructions rather than MOVPRFX.
356+
///
357+
/// Example (after register allocation):
358+
///
359+
/// FSUB_ZPZZ_ZERO_B Z0, Pg, Z1, Z0
360+
///
361+
/// * The Pseudo FSUB_ZPZZ_ZERO_B maps to FSUB_ZPmZ_B.
362+
/// * We cannot map directly to FSUB_ZPmZ_B because the register
363+
/// constraints of the instruction are not met.
364+
/// * Also the _ZERO specifies the false lanes need to be zeroed.
365+
///
366+
/// We first try to see if the destructive operand == result operand,
367+
/// if not, we try to swap the operands, e.g.
368+
///
369+
/// FSUB_ZPmZ_B Z0, Pg/m, Z0, Z1
370+
///
371+
/// But because FSUB_ZPmZ is not commutative, this is semantically
372+
/// different, so we need a reverse instruction:
373+
///
374+
/// FSUBR_ZPmZ_B Z0, Pg/m, Z0, Z1
375+
///
376+
/// Then we implement the zeroing of the false lanes of Z0 by adding
377+
/// a zeroing MOVPRFX instruction:
378+
///
379+
/// MOVPRFX_ZPzZ_B Z0, Pg/z, Z0
380+
/// FSUBR_ZPmZ_B Z0, Pg/m, Z0, Z1
381+
///
382+
/// Note that this can only be done for _ZERO or _UNDEF variants where
383+
/// we can guarantee the false lanes to be zeroed (by implementing this)
384+
/// or that they are undef (don't care / not used), otherwise the
385+
/// swapping of operands is illegal because the operation is not
386+
/// (or cannot be emulated to be) fully commutative.
387+
bool AArch64ExpandPseudo::expand_DestructiveOp(
388+
MachineInstr &MI,
389+
MachineBasicBlock &MBB,
390+
MachineBasicBlock::iterator MBBI) {
391+
unsigned Opcode = AArch64::getSVEPseudoMap(MI.getOpcode());
392+
uint64_t DType = TII->get(Opcode).TSFlags & AArch64::DestructiveInstTypeMask;
393+
uint64_t FalseLanes = MI.getDesc().TSFlags & AArch64::FalseLanesMask;
394+
bool FalseZero = FalseLanes == AArch64::FalseLanesZero;
395+
396+
unsigned DstReg = MI.getOperand(0).getReg();
397+
bool DstIsDead = MI.getOperand(0).isDead();
398+
399+
if (DType == AArch64::DestructiveBinary)
400+
assert(DstReg != MI.getOperand(3).getReg());
401+
402+
bool UseRev = false;
403+
unsigned PredIdx, DOPIdx, SrcIdx;
404+
switch (DType) {
405+
case AArch64::DestructiveBinaryComm:
406+
case AArch64::DestructiveBinaryCommWithRev:
407+
if (DstReg == MI.getOperand(3).getReg()) {
408+
// FSUB Zd, Pg, Zs1, Zd ==> FSUBR Zd, Pg/m, Zd, Zs1
409+
std::tie(PredIdx, DOPIdx, SrcIdx) = std::make_tuple(1, 3, 2);
410+
UseRev = true;
411+
break;
412+
}
413+
LLVM_FALLTHROUGH;
414+
case AArch64::DestructiveBinary:
415+
std::tie(PredIdx, DOPIdx, SrcIdx) = std::make_tuple(1, 2, 3);
416+
break;
417+
default:
418+
llvm_unreachable("Unsupported Destructive Operand type");
419+
}
420+
421+
#ifndef NDEBUG
422+
// MOVPRFX can only be used if the destination operand
423+
// is the destructive operand, not as any other operand,
424+
// so the Destructive Operand must be unique.
425+
bool DOPRegIsUnique = false;
426+
switch (DType) {
427+
case AArch64::DestructiveBinaryComm:
428+
case AArch64::DestructiveBinaryCommWithRev:
429+
DOPRegIsUnique =
430+
DstReg != MI.getOperand(DOPIdx).getReg() ||
431+
MI.getOperand(DOPIdx).getReg() != MI.getOperand(SrcIdx).getReg();
432+
break;
433+
}
434+
435+
assert (DOPRegIsUnique && "The destructive operand should be unique");
436+
#endif
437+
438+
// Resolve the reverse opcode
439+
if (UseRev) {
440+
if (AArch64::getSVERevInstr(Opcode) != -1)
441+
Opcode = AArch64::getSVERevInstr(Opcode);
442+
else if (AArch64::getSVEOrigInstr(Opcode) != -1)
443+
Opcode = AArch64::getSVEOrigInstr(Opcode);
444+
}
445+
446+
// Get the right MOVPRFX
447+
uint64_t ElementSize = TII->getElementSizeForOpcode(Opcode);
448+
unsigned MovPrfx, MovPrfxZero;
449+
switch (ElementSize) {
450+
case AArch64::ElementSizeNone:
451+
case AArch64::ElementSizeB:
452+
MovPrfx = AArch64::MOVPRFX_ZZ;
453+
MovPrfxZero = AArch64::MOVPRFX_ZPzZ_B;
454+
break;
455+
case AArch64::ElementSizeH:
456+
MovPrfx = AArch64::MOVPRFX_ZZ;
457+
MovPrfxZero = AArch64::MOVPRFX_ZPzZ_H;
458+
break;
459+
case AArch64::ElementSizeS:
460+
MovPrfx = AArch64::MOVPRFX_ZZ;
461+
MovPrfxZero = AArch64::MOVPRFX_ZPzZ_S;
462+
break;
463+
case AArch64::ElementSizeD:
464+
MovPrfx = AArch64::MOVPRFX_ZZ;
465+
MovPrfxZero = AArch64::MOVPRFX_ZPzZ_D;
466+
break;
467+
default:
468+
llvm_unreachable("Unsupported ElementSize");
469+
}
470+
471+
//
472+
// Create the destructive operation (if required)
473+
//
474+
MachineInstrBuilder PRFX, DOP;
475+
if (FalseZero) {
476+
assert(ElementSize != AArch64::ElementSizeNone &&
477+
"This instruction is unpredicated");
478+
479+
// Merge source operand into destination register
480+
PRFX = BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(MovPrfxZero))
481+
.addReg(DstReg, RegState::Define)
482+
.addReg(MI.getOperand(PredIdx).getReg())
483+
.addReg(MI.getOperand(DOPIdx).getReg());
484+
485+
// After the movprfx, the destructive operand is same as Dst
486+
DOPIdx = 0;
487+
} else if (DstReg != MI.getOperand(DOPIdx).getReg()) {
488+
PRFX = BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(MovPrfx))
489+
.addReg(DstReg, RegState::Define)
490+
.addReg(MI.getOperand(DOPIdx).getReg());
491+
DOPIdx = 0;
492+
}
493+
494+
//
495+
// Create the destructive operation
496+
//
497+
DOP = BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(Opcode))
498+
.addReg(DstReg, RegState::Define | getDeadRegState(DstIsDead));
499+
500+
switch (DType) {
501+
case AArch64::DestructiveBinaryComm:
502+
case AArch64::DestructiveBinaryCommWithRev:
503+
DOP.add(MI.getOperand(PredIdx))
504+
.addReg(MI.getOperand(DOPIdx).getReg(), RegState::Kill)
505+
.add(MI.getOperand(SrcIdx));
506+
break;
507+
}
508+
509+
if (PRFX) {
510+
finalizeBundle(MBB, PRFX->getIterator(), MBBI->getIterator());
511+
transferImpOps(MI, PRFX, DOP);
512+
} else
513+
transferImpOps(MI, DOP, DOP);
514+
515+
MI.eraseFromParent();
516+
return true;
517+
}
518+
347519
bool AArch64ExpandPseudo::expandSetTagLoop(
348520
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
349521
MachineBasicBlock::iterator &NextMBBI) {
@@ -425,6 +597,17 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
425597
MachineBasicBlock::iterator &NextMBBI) {
426598
MachineInstr &MI = *MBBI;
427599
unsigned Opcode = MI.getOpcode();
600+
601+
// Check if we can expand the destructive op
602+
int OrigInstr = AArch64::getSVEPseudoMap(MI.getOpcode());
603+
if (OrigInstr != -1) {
604+
auto &Orig = TII->get(OrigInstr);
605+
if ((Orig.TSFlags & AArch64::DestructiveInstTypeMask)
606+
!= AArch64::NotDestructive) {
607+
return expand_DestructiveOp(MI, MBB, MBBI);
608+
}
609+
}
610+
428611
switch (Opcode) {
429612
default:
430613
break;

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,25 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
162162
return false;
163163
}
164164

165+
bool SelectDupZero(SDValue N) {
166+
switch(N->getOpcode()) {
167+
case AArch64ISD::DUP:
168+
case ISD::SPLAT_VECTOR: {
169+
auto Opnd0 = N->getOperand(0);
170+
if (auto CN = dyn_cast<ConstantSDNode>(Opnd0))
171+
if (CN->isNullValue())
172+
return true;
173+
if (auto CN = dyn_cast<ConstantFPSDNode>(Opnd0))
174+
if (CN->isZero())
175+
return true;
176+
}
177+
default:
178+
break;
179+
}
180+
181+
return false;
182+
}
183+
165184
template<MVT::SimpleValueType VT>
166185
bool SelectSVEAddSubImm(SDValue N, SDValue &Imm, SDValue &Shift) {
167186
return SelectSVEAddSubImm(N, VT, Imm, Shift);

llvm/lib/Target/AArch64/AArch64InstrFormats.td

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,27 @@ def NormalFrm : Format<1>; // Do we need any others?
2222

2323
// Enum describing whether an instruction is
2424
// destructive in its first source operand.
25-
class DestructiveInstTypeEnum<bits<1> val> {
26-
bits<1> Value = val;
25+
class DestructiveInstTypeEnum<bits<4> val> {
26+
bits<4> Value = val;
2727
}
28-
def NotDestructive : DestructiveInstTypeEnum<0>;
28+
def NotDestructive : DestructiveInstTypeEnum<0>;
2929
// Destructive in its first operand and can be MOVPRFX'd, but has no other
3030
// special properties.
31-
def DestructiveOther : DestructiveInstTypeEnum<1>;
31+
def DestructiveOther : DestructiveInstTypeEnum<1>;
32+
def DestructiveUnary : DestructiveInstTypeEnum<2>;
33+
def DestructiveBinaryImm : DestructiveInstTypeEnum<3>;
34+
def DestructiveBinaryShImmUnpred : DestructiveInstTypeEnum<4>;
35+
def DestructiveBinary : DestructiveInstTypeEnum<5>;
36+
def DestructiveBinaryComm : DestructiveInstTypeEnum<6>;
37+
def DestructiveBinaryCommWithRev : DestructiveInstTypeEnum<7>;
38+
def DestructiveTernaryCommWithRev : DestructiveInstTypeEnum<8>;
39+
40+
class FalseLanesEnum<bits<2> val> {
41+
bits<2> Value = val;
42+
}
43+
def FalseLanesNone : FalseLanesEnum<0>;
44+
def FalseLanesZero : FalseLanesEnum<1>;
45+
def FalseLanesUndef : FalseLanesEnum<2>;
3246

3347
// AArch64 Instruction Format
3448
class AArch64Inst<Format f, string cstr> : Instruction {
@@ -46,10 +60,12 @@ class AArch64Inst<Format f, string cstr> : Instruction {
4660
bits<2> Form = F.Value;
4761

4862
// Defaults
63+
FalseLanesEnum FalseLanes = FalseLanesNone;
4964
DestructiveInstTypeEnum DestructiveInstType = NotDestructive;
5065
ElementSizeEnum ElementSize = ElementSizeNone;
5166

52-
let TSFlags{3} = DestructiveInstType.Value;
67+
let TSFlags{8-7} = FalseLanes.Value;
68+
let TSFlags{6-3} = DestructiveInstType.Value;
5369
let TSFlags{2-0} = ElementSize.Value;
5470

5571
let Pattern = [];

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,25 @@ unsigned AArch64InstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {
119119
case AArch64::SPACE:
120120
NumBytes = MI.getOperand(1).getImm();
121121
break;
122+
case TargetOpcode::BUNDLE:
123+
NumBytes = getInstBundleLength(MI);
124+
break;
122125
}
123126

124127
return NumBytes;
125128
}
126129

130+
unsigned AArch64InstrInfo::getInstBundleLength(const MachineInstr &MI) const {
131+
unsigned Size = 0;
132+
MachineBasicBlock::const_instr_iterator I = MI.getIterator();
133+
MachineBasicBlock::const_instr_iterator E = MI.getParent()->instr_end();
134+
while (++I != E && I->isInsideBundle()) {
135+
assert(!I->isBundle() && "No nested bundle!");
136+
Size += getInstSizeInBytes(*I);
137+
}
138+
return Size;
139+
}
140+
127141
static void parseCondBranch(MachineInstr *LastInst, MachineBasicBlock *&Target,
128142
SmallVectorImpl<MachineOperand> &Cond) {
129143
// Block ends with fall-through condbranch.
@@ -6680,5 +6694,10 @@ AArch64InstrInfo::describeLoadedValue(const MachineInstr &MI,
66806694
return TargetInstrInfo::describeLoadedValue(MI, Reg);
66816695
}
66826696

6697+
uint64_t AArch64InstrInfo::getElementSizeForOpcode(unsigned Opc) const {
6698+
return get(Opc).TSFlags & AArch64::ElementSizeMask;
6699+
}
6700+
66836701
#define GET_INSTRINFO_HELPERS
6702+
#define GET_INSTRMAP_INFO
66846703
#include "AArch64GenInstrInfo.inc"

llvm/lib/Target/AArch64/AArch64InstrInfo.h

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
271271
MachineBasicBlock::iterator &It, MachineFunction &MF,
272272
const outliner::Candidate &C) const override;
273273
bool shouldOutlineFromFunctionByDefault(MachineFunction &MF) const override;
274+
/// Returns the vector element size (B, H, S or D) of an SVE opcode.
275+
uint64_t getElementSizeForOpcode(unsigned Opc) const;
274276
/// Returns true if the instruction has a shift by immediate that can be
275277
/// executed in one cycle less.
276278
static bool isFalkorShiftExtFast(const MachineInstr &MI);
@@ -295,6 +297,8 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
295297
isCopyInstrImpl(const MachineInstr &MI) const override;
296298

297299
private:
300+
unsigned getInstBundleLength(const MachineInstr &MI) const;
301+
298302
/// Sets the offsets on outlined instructions in \p MBB which use SP
299303
/// so that they will be valid post-outlining.
300304
///
@@ -381,7 +385,8 @@ static inline bool isIndirectBranchOpcode(int Opc) {
381385

382386
// struct TSFlags {
383387
#define TSFLAG_ELEMENT_SIZE_TYPE(X) (X) // 3-bits
384-
#define TSFLAG_DESTRUCTIVE_INST_TYPE(X) ((X) << 3) // 1-bit
388+
#define TSFLAG_DESTRUCTIVE_INST_TYPE(X) ((X) << 3) // 4-bit
389+
#define TSFLAG_FALSE_LANE_TYPE(X) ((X) << 7) // 2-bits
385390
// }
386391

387392
namespace AArch64 {
@@ -396,13 +401,31 @@ enum ElementSizeType {
396401
};
397402

398403
enum DestructiveInstType {
399-
DestructiveInstTypeMask = TSFLAG_DESTRUCTIVE_INST_TYPE(0x1),
400-
NotDestructive = TSFLAG_DESTRUCTIVE_INST_TYPE(0x0),
401-
DestructiveOther = TSFLAG_DESTRUCTIVE_INST_TYPE(0x1),
404+
DestructiveInstTypeMask = TSFLAG_DESTRUCTIVE_INST_TYPE(0xf),
405+
NotDestructive = TSFLAG_DESTRUCTIVE_INST_TYPE(0x0),
406+
DestructiveOther = TSFLAG_DESTRUCTIVE_INST_TYPE(0x1),
407+
DestructiveUnary = TSFLAG_DESTRUCTIVE_INST_TYPE(0x2),
408+
DestructiveBinaryImm = TSFLAG_DESTRUCTIVE_INST_TYPE(0x3),
409+
DestructiveBinaryShImmUnpred = TSFLAG_DESTRUCTIVE_INST_TYPE(0x4),
410+
DestructiveBinary = TSFLAG_DESTRUCTIVE_INST_TYPE(0x5),
411+
DestructiveBinaryComm = TSFLAG_DESTRUCTIVE_INST_TYPE(0x6),
412+
DestructiveBinaryCommWithRev = TSFLAG_DESTRUCTIVE_INST_TYPE(0x7),
413+
DestructiveTernaryCommWithRev = TSFLAG_DESTRUCTIVE_INST_TYPE(0x8),
414+
};
415+
416+
enum FalseLaneType {
417+
FalseLanesMask = TSFLAG_FALSE_LANE_TYPE(0x3),
418+
FalseLanesZero = TSFLAG_FALSE_LANE_TYPE(0x1),
419+
FalseLanesUndef = TSFLAG_FALSE_LANE_TYPE(0x2),
402420
};
403421

404422
#undef TSFLAG_ELEMENT_SIZE_TYPE
405423
#undef TSFLAG_DESTRUCTIVE_INST_TYPE
424+
#undef TSFLAG_FALSE_LANE_TYPE
425+
426+
int getSVEPseudoMap(uint16_t Opcode);
427+
int getSVERevInstr(uint16_t Opcode);
428+
int getSVEOrigInstr(uint16_t Opcode);
406429
}
407430

408431
} // end namespace llvm

0 commit comments

Comments
 (0)