Skip to content

Commit 891aaff

Browse files
committed
[AArch64][SVE2] Add the SVE2.1 pext and ptrue predicate-as-counter instructions
This patch adds the assembly/disassembly for the following instructions: pext (predicate) : Set predicate from predicate-as-counter ptrue (predicate-as-counter) : Initialise predicate-as-counter to all active This patch also introduces the predicate-as-counter registers pn8, etc. The reference can be found here: https://developer.arm.com/documentation/ddi0602/2022-09 Differential Revision: https://reviews.llvm.org/D136678
1 parent 80b08b6 commit 891aaff

15 files changed

+619
-123
lines changed

llvm/lib/Target/AArch64/AArch64RegisterInfo.td

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -871,15 +871,16 @@ class ZPRRegOp <string Suffix, AsmOperandClass C, ElementSizeEnum Size,
871871
//******************************************************************************
872872

873873
// SVE predicate register classes.
874-
class PPRClass<int lastreg> : RegisterClass<
874+
class PPRClass<int firstreg, int lastreg> : RegisterClass<
875875
"AArch64",
876876
[ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1 ], 16,
877-
(sequence "P%u", 0, lastreg)> {
877+
(sequence "P%u", firstreg, lastreg)> {
878878
let Size = 16;
879879
}
880880

881-
def PPR : PPRClass<15>;
882-
def PPR_3b : PPRClass<7>; // Restricted 3 bit SVE predicate register class.
881+
def PPR : PPRClass<0, 15>;
882+
def PPR_3b : PPRClass<0, 7>; // Restricted 3 bit SVE predicate register class.
883+
def PPR_p8to15 : PPRClass<8, 15>;
883884

884885
class PPRAsmOperand <string name, string RegClass, int Width>: AsmOperandClass {
885886
let Name = "SVE" # name # "Reg";
@@ -906,6 +907,38 @@ def PPRAsmOp3bAny : PPRAsmOperand<"Predicate3bAny", "PPR_3b", 0>;
906907

907908
def PPR3bAny : PPRRegOp<"", PPRAsmOp3bAny, ElementSizeNone, PPR_3b>;
908909

910+
911+
// SVE predicate-as-counter operand
912+
class PNRAsmOperand<string name, string RegClass, int Width>
913+
: PPRAsmOperand<name, RegClass, Width> {
914+
let PredicateMethod = "isSVEPredicateAsCounterRegOfWidth<"
915+
# Width # ", " # "AArch64::"
916+
# RegClass # "RegClassID>";
917+
let DiagnosticType = "InvalidSVE" # name # "Reg";
918+
let ParserMethod = "tryParseSVEPredicateAsCounter";
919+
}
920+
921+
class PNRP8to15RegOp<string Suffix, AsmOperandClass C, int EltSize, RegisterClass RC>
922+
: PPRRegOp<Suffix, C, ElementSizeNone, RC> {
923+
let PrintMethod = "printPredicateAsCounter<" # EltSize # ">";
924+
let EncoderMethod = "EncodePPR_p8to15";
925+
let DecoderMethod = "DecodePPR_p8to15RegisterClass";
926+
}
927+
928+
def PNRAsmAny_p8to15 : PNRAsmOperand<"PNPredicateAny_p8to15", "PPR_p8to15", 0>;
929+
def PNRAsmOp8_p8to15 : PNRAsmOperand<"PNPredicateB_p8to15", "PPR_p8to15", 8>;
930+
def PNRAsmOp16_p8to15 : PNRAsmOperand<"PNPredicateH_p8to15", "PPR_p8to15", 16>;
931+
def PNRAsmOp32_p8to15 : PNRAsmOperand<"PNPredicateS_p8to15", "PPR_p8to15", 32>;
932+
def PNRAsmOp64_p8to15 : PNRAsmOperand<"PNPredicateD_p8to15", "PPR_p8to15", 64>;
933+
934+
def PNRAny_p8to15 : PNRP8to15RegOp<"", PNRAsmAny_p8to15, 0, PPR_p8to15>;
935+
def PNR8_p8to15 : PNRP8to15RegOp<"b", PNRAsmOp8_p8to15, 8, PPR_p8to15>;
936+
def PNR16_p8to15 : PNRP8to15RegOp<"h", PNRAsmOp16_p8to15, 16, PPR_p8to15>;
937+
def PNR32_p8to15 : PNRP8to15RegOp<"s", PNRAsmOp32_p8to15, 32, PPR_p8to15>;
938+
def PNR64_p8to15 : PNRP8to15RegOp<"d", PNRAsmOp64_p8to15, 64, PPR_p8to15>;
939+
940+
941+
909942
//******************************************************************************
910943

911944
// SVE vector register classes

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3586,4 +3586,7 @@ def SDOT_ZZZ_HtoS : sve2p1_two_way_dot_vv<"sdot", 0b0>;
35863586
def UDOT_ZZZ_HtoS : sve2p1_two_way_dot_vv<"udot", 0b1>;
35873587
def SDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"sdot", 0b0>;
35883588
def UDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"udot", 0b1>;
3589+
3590+
defm PEXT_PCI : sve2p1_pred_as_ctr_to_mask<"pext">;
3591+
defm PTRUE_C : sve2p1_ptrue_pn<"ptrue">;
35893592
} // End HasSVE2p1_or_HasSME2

llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
//==- AArch64AsmParser.cpp - Parse AArch64 assembly to MCInst instructions -==//
23
//
34
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -68,6 +69,7 @@ enum class RegKind {
6869
Scalar,
6970
NeonVector,
7071
SVEDataVector,
72+
SVEPredicateAsCounter,
7173
SVEPredicateVector,
7274
Matrix
7375
};
@@ -266,6 +268,7 @@ class AArch64AsmParser : public MCTargetAsmParser {
266268
template <bool ParseShiftExtend, bool ParseSuffix>
267269
OperandMatchResultTy tryParseSVEDataVector(OperandVector &Operands);
268270
OperandMatchResultTy tryParseSVEPredicateVector(OperandVector &Operands);
271+
OperandMatchResultTy tryParseSVEPredicateAsCounter(OperandVector &Operands);
269272
template <RegKind VectorKind>
270273
OperandMatchResultTy tryParseVectorList(OperandVector &Operands,
271274
bool ExpectMatch = false);
@@ -1198,6 +1201,22 @@ class AArch64Operand : public MCParsedAsmOperand {
11981201
bool isMatrix() const { return Kind == k_MatrixRegister; }
11991202
bool isMatrixTileList() const { return Kind == k_MatrixTileList; }
12001203

1204+
template <unsigned Class> bool isSVEPredicateAsCounterReg() const {
1205+
RegKind RK;
1206+
switch (Class) {
1207+
case AArch64::PPRRegClassID:
1208+
case AArch64::PPR_3bRegClassID:
1209+
case AArch64::PPR_p8to15RegClassID:
1210+
RK = RegKind::SVEPredicateAsCounter;
1211+
break;
1212+
default:
1213+
llvm_unreachable("Unsupport register class");
1214+
}
1215+
1216+
return (Kind == k_Register && Reg.Kind == RK) &&
1217+
AArch64MCRegisterClasses[Class].contains(getReg());
1218+
}
1219+
12011220
template <unsigned Class> bool isSVEVectorReg() const {
12021221
RegKind RK;
12031222
switch (Class) {
@@ -1234,6 +1253,17 @@ class AArch64Operand : public MCParsedAsmOperand {
12341253
return DiagnosticPredicateTy::NearMatch;
12351254
}
12361255

1256+
template <int ElementWidth, unsigned Class>
1257+
DiagnosticPredicate isSVEPredicateAsCounterRegOfWidth() const {
1258+
if (Kind != k_Register || Reg.Kind != RegKind::SVEPredicateAsCounter)
1259+
return DiagnosticPredicateTy::NoMatch;
1260+
1261+
if (isSVEPredicateAsCounterReg<Class>() && (Reg.ElementWidth == ElementWidth))
1262+
return DiagnosticPredicateTy::Match;
1263+
1264+
return DiagnosticPredicateTy::NearMatch;
1265+
}
1266+
12371267
template <int ElementWidth, unsigned Class>
12381268
DiagnosticPredicate isSVEDataVectorRegOfWidth() const {
12391269
if (Kind != k_Register || Reg.Kind != RegKind::SVEDataVector)
@@ -2059,7 +2089,8 @@ class AArch64Operand : public MCParsedAsmOperand {
20592089
unsigned ShiftAmount = 0,
20602090
unsigned HasExplicitAmount = false) {
20612091
assert((Kind == RegKind::NeonVector || Kind == RegKind::SVEDataVector ||
2062-
Kind == RegKind::SVEPredicateVector) &&
2092+
Kind == RegKind::SVEPredicateVector ||
2093+
Kind == RegKind::SVEPredicateAsCounter) &&
20632094
"Invalid vector kind");
20642095
auto Op = CreateReg(RegNum, Kind, S, E, Ctx, EqualsReg, ExtTy, ShiftAmount,
20652096
HasExplicitAmount);
@@ -2478,6 +2509,7 @@ static Optional<std::pair<int, int>> parseVectorKind(StringRef Suffix,
24782509
.Case(".d", {0, 64})
24792510
.Default({-1, -1});
24802511
break;
2512+
case RegKind::SVEPredicateAsCounter:
24812513
case RegKind::SVEPredicateVector:
24822514
case RegKind::SVEDataVector:
24832515
case RegKind::Matrix:
@@ -2562,6 +2594,27 @@ static unsigned matchSVEPredicateVectorRegName(StringRef Name) {
25622594
.Default(0);
25632595
}
25642596

2597+
static unsigned matchSVEPredicateAsCounterRegName(StringRef Name) {
2598+
return StringSwitch<unsigned>(Name.lower())
2599+
.Case("pn0", AArch64::P0)
2600+
.Case("pn1", AArch64::P1)
2601+
.Case("pn2", AArch64::P2)
2602+
.Case("pn3", AArch64::P3)
2603+
.Case("pn4", AArch64::P4)
2604+
.Case("pn5", AArch64::P5)
2605+
.Case("pn6", AArch64::P6)
2606+
.Case("pn7", AArch64::P7)
2607+
.Case("pn8", AArch64::P8)
2608+
.Case("pn9", AArch64::P9)
2609+
.Case("pn10", AArch64::P10)
2610+
.Case("pn11", AArch64::P11)
2611+
.Case("pn12", AArch64::P12)
2612+
.Case("pn13", AArch64::P13)
2613+
.Case("pn14", AArch64::P14)
2614+
.Case("pn15", AArch64::P15)
2615+
.Default(0);
2616+
}
2617+
25652618
static unsigned matchMatrixTileListRegName(StringRef Name) {
25662619
return StringSwitch<unsigned>(Name.lower())
25672620
.Case("za0.d", AArch64::ZAD0)
@@ -2705,6 +2758,9 @@ unsigned AArch64AsmParser::matchRegisterNameAlias(StringRef Name,
27052758
if ((RegNum = matchSVEPredicateVectorRegName(Name)))
27062759
return Kind == RegKind::SVEPredicateVector ? RegNum : 0;
27072760

2761+
if ((RegNum = matchSVEPredicateAsCounterRegName(Name)))
2762+
return Kind == RegKind::SVEPredicateAsCounter ? RegNum : 0;
2763+
27082764
if ((RegNum = MatchNeonVectorRegName(Name)))
27092765
return Kind == RegKind::NeonVector ? RegNum : 0;
27102766

@@ -3803,6 +3859,32 @@ AArch64AsmParser::tryParseVectorRegister(unsigned &Reg, StringRef &Kind,
38033859
return MatchOperand_NoMatch;
38043860
}
38053861

3862+
OperandMatchResultTy
3863+
AArch64AsmParser::tryParseSVEPredicateAsCounter(OperandVector &Operands) {
3864+
const SMLoc S = getLoc();
3865+
StringRef Kind;
3866+
unsigned RegNum;
3867+
auto Res =
3868+
tryParseVectorRegister(RegNum, Kind, RegKind::SVEPredicateAsCounter);
3869+
if (Res != MatchOperand_Success)
3870+
return Res;
3871+
3872+
const auto &KindRes = parseVectorKind(Kind, RegKind::SVEPredicateAsCounter);
3873+
if (!KindRes)
3874+
return MatchOperand_NoMatch;
3875+
3876+
unsigned ElementWidth = KindRes->second;
3877+
Operands.push_back(
3878+
AArch64Operand::CreateVectorReg(RegNum, RegKind::SVEPredicateAsCounter,
3879+
ElementWidth, S, getLoc(), getContext()));
3880+
3881+
// Check if register is followed by an index
3882+
OperandMatchResultTy ResIndex = tryParseVectorIndex(Operands);
3883+
if (ResIndex == MatchOperand_ParseFail)
3884+
return ResIndex;
3885+
3886+
return MatchOperand_Success;
3887+
}
38063888
/// tryParseSVEPredicateVector - Parse a SVE predicate register operand.
38073889
OperandMatchResultTy
38083890
AArch64AsmParser::tryParseSVEPredicateVector(OperandVector &Operands) {
@@ -5573,6 +5655,15 @@ bool AArch64AsmParser::showMatchError(SMLoc Loc, unsigned ErrCode,
55735655
return Error(Loc, "invalid predicate register.");
55745656
case Match_InvalidSVEPredicate3bAnyReg:
55755657
return Error(Loc, "invalid restricted predicate register, expected p0..p7 (without element suffix)");
5658+
case Match_InvalidSVEPNPredicateB_p8to15Reg:
5659+
case Match_InvalidSVEPNPredicateH_p8to15Reg:
5660+
case Match_InvalidSVEPNPredicateS_p8to15Reg:
5661+
case Match_InvalidSVEPNPredicateD_p8to15Reg:
5662+
return Error(Loc, "Invalid predicate register, expected PN in range "
5663+
"pn8..pn15 with element suffix.");
5664+
case Match_InvalidSVEPNPredicateAny_p8to15Reg:
5665+
return Error(Loc, "invalid restricted predicate-as-counter register "
5666+
"expected pn8..pn15");
55765667
case Match_InvalidSVEExactFPImmOperandHalfOne:
55775668
return Error(Loc, "Invalid floating point constant, expected 0.5 or 1.0.");
55785669
case Match_InvalidSVEExactFPImmOperandHalfTwo:
@@ -6145,6 +6236,11 @@ bool AArch64AsmParser::MatchAndEmitInstruction(SMLoc IDLoc, unsigned &Opcode,
61456236
case Match_InvalidSVEPredicateSReg:
61466237
case Match_InvalidSVEPredicateDReg:
61476238
case Match_InvalidSVEPredicate3bAnyReg:
6239+
case Match_InvalidSVEPNPredicateB_p8to15Reg:
6240+
case Match_InvalidSVEPNPredicateH_p8to15Reg:
6241+
case Match_InvalidSVEPNPredicateS_p8to15Reg:
6242+
case Match_InvalidSVEPNPredicateD_p8to15Reg:
6243+
case Match_InvalidSVEPNPredicateAny_p8to15Reg:
61486244
case Match_InvalidSVEExactFPImmOperandHalfOne:
61496245
case Match_InvalidSVEExactFPImmOperandHalfTwo:
61506246
case Match_InvalidSVEExactFPImmOperandZeroOne:

llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ static DecodeStatus DecodePPRRegisterClass(MCInst &Inst, unsigned RegNo,
137137
static DecodeStatus DecodePPR_3bRegisterClass(MCInst &Inst, unsigned RegNo,
138138
uint64_t Address,
139139
const MCDisassembler *Decoder);
140+
static DecodeStatus
141+
DecodePPR_p8to15RegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Address,
142+
const MCDisassembler *Decoder);
140143

141144
static DecodeStatus DecodeFixedPointScaleImm32(MCInst &Inst, unsigned Imm,
142145
uint64_t Address,
@@ -709,6 +712,16 @@ static DecodeStatus DecodePPR_3bRegisterClass(MCInst &Inst, unsigned RegNo,
709712
return DecodePPRRegisterClass(Inst, RegNo, Addr, Decoder);
710713
}
711714

715+
static DecodeStatus
716+
DecodePPR_p8to15RegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Addr,
717+
const MCDisassembler *Decoder) {
718+
if (RegNo > 7)
719+
return Fail;
720+
721+
// Just reuse the PPR decode table
722+
return DecodePPRRegisterClass(Inst, RegNo + 8, Addr, Decoder);
723+
}
724+
712725
static DecodeStatus DecodeQQRegisterClass(MCInst &Inst, unsigned RegNo,
713726
uint64_t Addr,
714727
const MCDisassembler *Decoder) {

llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,35 @@ void AArch64InstPrinter::printRegWithShiftExtend(const MCInst *MI,
11931193
}
11941194
}
11951195

1196+
template <int EltSize>
1197+
void AArch64InstPrinter::printPredicateAsCounter(const MCInst *MI,
1198+
unsigned OpNum,
1199+
const MCSubtargetInfo &STI,
1200+
raw_ostream &O) {
1201+
unsigned Reg = MI->getOperand(OpNum).getReg();
1202+
1203+
assert(Reg <= AArch64::P15 && "Unsupported predicate register");
1204+
O << "pn" << (Reg - AArch64::P0);
1205+
switch (EltSize) {
1206+
case 0:
1207+
break;
1208+
case 8:
1209+
O << ".b";
1210+
break;
1211+
case 16:
1212+
O << ".h";
1213+
break;
1214+
case 32:
1215+
O << ".s";
1216+
break;
1217+
case 64:
1218+
O << ".d";
1219+
break;
1220+
default:
1221+
llvm_unreachable("Unsupported element size");
1222+
}
1223+
}
1224+
11961225
void AArch64InstPrinter::printCondCode(const MCInst *MI, unsigned OpNum,
11971226
const MCSubtargetInfo &STI,
11981227
raw_ostream &O) {

llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ class AArch64InstPrinter : public MCInstPrinter {
182182
const MCSubtargetInfo &STI, raw_ostream &O);
183183
void printSIMDType10Operand(const MCInst *MI, unsigned OpNum,
184184
const MCSubtargetInfo &STI, raw_ostream &O);
185+
template <int EltSize>
186+
void printPredicateAsCounter(const MCInst *MI, unsigned OpNum,
187+
const MCSubtargetInfo &STI, raw_ostream &O);
188+
185189
template<int64_t Angle, int64_t Remainder>
186190
void printComplexRotationOp(const MCInst *MI, unsigned OpNo,
187191
const MCSubtargetInfo &STI, raw_ostream &O);

llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCCodeEmitter.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ class AArch64MCCodeEmitter : public MCCodeEmitter {
189189
uint32_t EncodeRegAsMultipleOf(const MCInst &MI, unsigned OpIdx,
190190
SmallVectorImpl<MCFixup> &Fixups,
191191
const MCSubtargetInfo &STI) const;
192+
uint32_t EncodePPR_p8to15(const MCInst &MI, unsigned OpIdx,
193+
SmallVectorImpl<MCFixup> &Fixups,
194+
const MCSubtargetInfo &STI) const;
192195

193196
uint32_t EncodeMatrixTileListRegisterClass(const MCInst &MI, unsigned OpIdx,
194197
SmallVectorImpl<MCFixup> &Fixups,
@@ -533,6 +536,14 @@ AArch64MCCodeEmitter::EncodeRegAsMultipleOf(const MCInst &MI, unsigned OpIdx,
533536
return RegVal / Multiple;
534537
}
535538

539+
uint32_t
540+
AArch64MCCodeEmitter::EncodePPR_p8to15(const MCInst &MI, unsigned OpIdx,
541+
SmallVectorImpl<MCFixup> &Fixups,
542+
const MCSubtargetInfo &STI) const {
543+
auto RegOpnd = MI.getOperand(OpIdx).getReg();
544+
return RegOpnd - AArch64::P8;
545+
}
546+
536547
uint32_t AArch64MCCodeEmitter::EncodeMatrixTileListRegisterClass(
537548
const MCInst &MI, unsigned OpIdx, SmallVectorImpl<MCFixup> &Fixups,
538549
const MCSubtargetInfo &STI) const {

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8695,3 +8695,47 @@ class sve2p1_two_way_dot_vvi<string mnemonic, bit u>
86958695
let Constraints = "$Zda = $_Zda";
86968696
let DestructiveInstType = DestructiveOther;
86978697
}
8698+
8699+
8700+
class sve2p1_ptrue_pn<string mnemonic, bits<2> sz, PNRP8to15RegOp pnrty>
8701+
: I<(outs pnrty:$PNd), (ins ), mnemonic, "\t$PNd",
8702+
"", []>, Sched<[]> {
8703+
bits<3> PNd;
8704+
let Inst{31-24} = 0b00100101;
8705+
let Inst{23-22} = sz;
8706+
let Inst{21-3} = 0b1000000111100000010;
8707+
let Inst{2-0} = PNd;
8708+
}
8709+
8710+
8711+
multiclass sve2p1_ptrue_pn<string mnemonic> {
8712+
def _B : sve2p1_ptrue_pn<mnemonic, 0b00, PNR8_p8to15>;
8713+
def _H : sve2p1_ptrue_pn<mnemonic, 0b01, PNR16_p8to15>;
8714+
def _S : sve2p1_ptrue_pn<mnemonic, 0b10, PNR32_p8to15>;
8715+
def _D : sve2p1_ptrue_pn<mnemonic, 0b11, PNR64_p8to15>;
8716+
}
8717+
8718+
8719+
class sve2p1_pred_as_ctr_to_mask<string mnemonic, bits<2> sz, PPRRegOp pprty>
8720+
: I<(outs pprty:$Pd), (ins PNRAny_p8to15:$PNn, VectorIndexS:$imm2),
8721+
mnemonic, "\t$Pd, $PNn$imm2",
8722+
"", []>, Sched<[]> {
8723+
bits<4> Pd;
8724+
bits<3> PNn;
8725+
bits<2> imm2;
8726+
let Inst{31-24} = 0b00100101;
8727+
let Inst{23-22} = sz;
8728+
let Inst{21-10} = 0b100000011100;
8729+
let Inst{9-8} = imm2;
8730+
let Inst{7-5} = PNn;
8731+
let Inst{4} = 0b1;
8732+
let Inst{3-0} = Pd;
8733+
}
8734+
8735+
8736+
multiclass sve2p1_pred_as_ctr_to_mask<string mnemonic> {
8737+
def _B : sve2p1_pred_as_ctr_to_mask<mnemonic, 0b00, PPR8>;
8738+
def _H : sve2p1_pred_as_ctr_to_mask<mnemonic, 0b01, PPR16>;
8739+
def _S : sve2p1_pred_as_ctr_to_mask<mnemonic, 0b10, PPR32>;
8740+
def _D : sve2p1_pred_as_ctr_to_mask<mnemonic, 0b11, PPR64>;
8741+
}

0 commit comments

Comments
 (0)