Skip to content

[llvm][RISCV] Add RISCV vector tuple type to value types(MVT) #97993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llvm/include/llvm/CodeGen/ValueTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ namespace llvm {
return isSimple() ? V.isScalableVector() : isExtendedScalableVector();
}

/// Return true if this is a vector value type.
bool isRISCVVectorTuple() const { return V.isRISCVVectorTuple(); }

bool isFixedLengthVector() const {
return isSimple() ? V.isFixedLengthVector()
: isExtendedFixedLengthVector();
Expand Down Expand Up @@ -351,6 +354,11 @@ namespace llvm {
return getVectorElementCount().getKnownMinValue();
}

/// Given a RISCV vector tuple type, return the num_fields.
unsigned getRISCVVectorTupleNumFields() const {
return V.getRISCVVectorTupleNumFields();
}

/// Return the size of the specified value type in bits.
///
/// If the value type is a scalable vector type, the scalable property will
Expand Down
65 changes: 54 additions & 11 deletions llvm/include/llvm/CodeGen/ValueTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class ValueType<int size, int value> {
bit isFP = false;
bit isVector = false;
bit isScalable = false;
int NF = 0;
bit isRISCVVecTuple = false;
// Indicates this VT should be included in the
// [FIRST_VALUETYPE,LAST_VALUETYPE] range.
bit isNormalValueType = true;
Expand Down Expand Up @@ -56,6 +58,13 @@ class VTScalableVec<int nelem, ValueType elt, int value>
let isScalable = true;
}

class VTVecTup<int size, int nf, ValueType dummy_elt, int value>
: ValueType<size, value> {
let NF = nf;
let ElementType = dummy_elt;
let isRISCVVecTuple = true;
}

defset list<ValueType> ValueTypes = {

def OtherVT : ValueType<0, 1> { // "Other" value
Expand Down Expand Up @@ -273,20 +282,54 @@ def nxv2f64 : VTScalableVec<2, f64, 187>; // n x 2 x f64 vector value
def nxv4f64 : VTScalableVec<4, f64, 188>; // n x 4 x f64 vector value
def nxv8f64 : VTScalableVec<8, f64, 189>; // n x 8 x f64 vector value

def x86mmx : ValueType<64, 190>; // X86 MMX value
def Glue : ValueType<0, 191>; // Pre-RA sched glue
def isVoid : ValueType<0, 192>; // Produces no value
def untyped : ValueType<8, 193> { // Produces an untyped value
// Sz = NF * MinNumElts * 8(bits)
def riscv_nxv1i8x2 : VTVecTup<16, 2, i8, 190>; // RISCV vector tuple(min_num_elts=1, nf=2)
def riscv_nxv1i8x3 : VTVecTup<24, 3, i8, 191>; // RISCV vector tuple(min_num_elts=1, nf=3)
def riscv_nxv1i8x4 : VTVecTup<32, 4, i8, 192>; // RISCV vector tuple(min_num_elts=1, nf=4)
def riscv_nxv1i8x5 : VTVecTup<40, 5, i8, 193>; // RISCV vector tuple(min_num_elts=1, nf=5)
def riscv_nxv1i8x6 : VTVecTup<48, 6, i8, 194>; // RISCV vector tuple(min_num_elts=1, nf=6)
def riscv_nxv1i8x7 : VTVecTup<56, 7, i8, 195>; // RISCV vector tuple(min_num_elts=1, nf=7)
def riscv_nxv1i8x8 : VTVecTup<64, 8, i8, 196>; // RISCV vector tuple(min_num_elts=1, nf=8)
def riscv_nxv2i8x2 : VTVecTup<32, 2, i8, 197>; // RISCV vector tuple(min_num_elts=2, nf=2)
def riscv_nxv2i8x3 : VTVecTup<48, 3, i8, 198>; // RISCV vector tuple(min_num_elts=2, nf=3)
def riscv_nxv2i8x4 : VTVecTup<64, 4, i8, 199>; // RISCV vector tuple(min_num_elts=2, nf=4)
def riscv_nxv2i8x5 : VTVecTup<80, 5, i8, 200>; // RISCV vector tuple(min_num_elts=2, nf=5)
def riscv_nxv2i8x6 : VTVecTup<96, 6, i8, 201>; // RISCV vector tuple(min_num_elts=2, nf=6)
def riscv_nxv2i8x7 : VTVecTup<112, 7, i8, 202>; // RISCV vector tuple(min_num_elts=2, nf=7)
def riscv_nxv2i8x8 : VTVecTup<128, 8, i8, 203>; // RISCV vector tuple(min_num_elts=2, nf=8)
def riscv_nxv4i8x2 : VTVecTup<64, 2, i8, 204>; // RISCV vector tuple(min_num_elts=4, nf=2)
def riscv_nxv4i8x3 : VTVecTup<96, 3, i8, 205>; // RISCV vector tuple(min_num_elts=4, nf=3)
def riscv_nxv4i8x4 : VTVecTup<128, 4, i8, 206>; // RISCV vector tuple(min_num_elts=4, nf=4)
def riscv_nxv4i8x5 : VTVecTup<160, 5, i8, 207>; // RISCV vector tuple(min_num_elts=4, nf=5)
def riscv_nxv4i8x6 : VTVecTup<192, 6, i8, 208>; // RISCV vector tuple(min_num_elts=4, nf=6)
def riscv_nxv4i8x7 : VTVecTup<224, 7, i8, 209>; // RISCV vector tuple(min_num_elts=4, nf=7)
def riscv_nxv4i8x8 : VTVecTup<256, 8, i8, 210>; // RISCV vector tuple(min_num_elts=4, nf=8)
def riscv_nxv8i8x2 : VTVecTup<128, 2, i8, 211>; // RISCV vector tuple(min_num_elts=8, nf=2)
def riscv_nxv8i8x3 : VTVecTup<192, 3, i8, 212>; // RISCV vector tuple(min_num_elts=8, nf=3)
def riscv_nxv8i8x4 : VTVecTup<256, 4, i8, 213>; // RISCV vector tuple(min_num_elts=8, nf=4)
def riscv_nxv8i8x5 : VTVecTup<320, 5, i8, 214>; // RISCV vector tuple(min_num_elts=8, nf=5)
def riscv_nxv8i8x6 : VTVecTup<384, 6, i8, 215>; // RISCV vector tuple(min_num_elts=8, nf=6)
def riscv_nxv8i8x7 : VTVecTup<448, 7, i8, 216>; // RISCV vector tuple(min_num_elts=8, nf=7)
def riscv_nxv8i8x8 : VTVecTup<512, 8, i8, 217>; // RISCV vector tuple(min_num_elts=8, nf=8)
def riscv_nxv16i8x2 : VTVecTup<256, 2, i8, 218>; // RISCV vector tuple(min_num_elts=16, nf=2)
def riscv_nxv16i8x3 : VTVecTup<384, 3, i8, 219>; // RISCV vector tuple(min_num_elts=16, nf=3)
def riscv_nxv16i8x4 : VTVecTup<512, 4, i8, 220>; // RISCV vector tuple(min_num_elts=16, nf=4)
def riscv_nxv32i8x2 : VTVecTup<512, 2, i8, 221>; // RISCV vector tuple(min_num_elts=32, nf=2)

def x86mmx : ValueType<64, 222>; // X86 MMX value
def Glue : ValueType<0, 223>; // Pre-RA sched glue
def isVoid : ValueType<0, 224>; // Produces no value
def untyped : ValueType<8, 225> { // Produces an untyped value
let LLVMName = "Untyped";
}
def funcref : ValueType<0, 194>; // WebAssembly's funcref type
def externref : ValueType<0, 195>; // WebAssembly's externref type
def exnref : ValueType<0, 196>; // WebAssembly's exnref type
def x86amx : ValueType<8192, 197>; // X86 AMX value
def i64x8 : ValueType<512, 198>; // 8 Consecutive GPRs (AArch64)
def funcref : ValueType<0, 226>; // WebAssembly's funcref type
def externref : ValueType<0, 227>; // WebAssembly's externref type
def exnref : ValueType<0, 228>; // WebAssembly's exnref type
def x86amx : ValueType<8192, 229>; // X86 AMX value
def i64x8 : ValueType<512, 230>; // 8 Consecutive GPRs (AArch64)
def aarch64svcount
: ValueType<16, 199>; // AArch64 predicate-as-counter
def spirvbuiltin : ValueType<0, 200>; // SPIR-V's builtin type
: ValueType<16, 231>; // AArch64 predicate-as-counter
def spirvbuiltin : ValueType<0, 232>; // SPIR-V's builtin type

let isNormalValueType = false in {
def token : ValueType<0, 504>; // TokenTy
Expand Down
53 changes: 43 additions & 10 deletions llvm/include/llvm/CodeGenTypes/MachineValueType.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ namespace llvm {
// are considered extended value types.
INVALID_SIMPLE_VALUE_TYPE = 0,

#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, NElem, EltTy) Ty = n,
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, Tup, NF, NElem, EltTy) \
Ty = n,
#define GET_VT_RANGES
#include "llvm/CodeGen/GenVT.inc"
#undef GET_VT_ATTR
Expand Down Expand Up @@ -114,6 +115,13 @@ namespace llvm {
SimpleTy <= MVT::LAST_SCALABLE_VECTOR_VALUETYPE);
}

/// Return true if this is a RISCV vector tuple type where the
/// runtime length is machine dependent
bool isRISCVVectorTuple() const {
return (SimpleTy >= MVT::FIRST_RISCV_VECTOR_TUPLE_VALUETYPE &&
SimpleTy <= MVT::LAST_RISCV_VECTOR_TUPLE_VALUETYPE);
}

/// Return true if this is a custom target type that has a scalable size.
bool isScalableTargetExtVT() const {
return SimpleTy == MVT::aarch64svcount;
Expand Down Expand Up @@ -172,7 +180,7 @@ namespace llvm {
/// Return true if this is an overloaded type for TableGen.
bool isOverloaded() const {
switch (SimpleTy) {
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, NElem, EltTy) \
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, Tup, NF, NElem, EltTy) \
case Ty: \
return Any;
#include "llvm/CodeGen/GenVT.inc"
Expand Down Expand Up @@ -255,7 +263,8 @@ namespace llvm {
MVT getVectorElementType() const {
assert(SimpleTy >= FIRST_VALUETYPE && SimpleTy <= LAST_VALUETYPE);
static constexpr SimpleValueType EltTyTable[] = {
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, NElem, EltTy) EltTy,
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, Tup, NF, NElem, EltTy) \
EltTy,
#include "llvm/CodeGen/GenVT.inc"
#undef GET_VT_ATTR
};
Expand All @@ -268,7 +277,8 @@ namespace llvm {
unsigned getVectorMinNumElements() const {
assert(SimpleTy >= FIRST_VALUETYPE && SimpleTy <= LAST_VALUETYPE);
static constexpr uint16_t NElemTable[] = {
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, NElem, EltTy) NElem,
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, Tup, NF, NElem, EltTy) \
NElem,
#include "llvm/CodeGen/GenVT.inc"
#undef GET_VT_ATTR
};
Expand Down Expand Up @@ -297,7 +307,7 @@ namespace llvm {
/// base size.
TypeSize getSizeInBits() const {
static constexpr TypeSize SizeTable[] = {
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, NElem, EltTy) \
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, Tup, NF, NElem, EltTy) \
TypeSize(Sz, Sc || Ty == aarch64svcount /* FIXME: Not in the td. */),
#include "llvm/CodeGen/GenVT.inc"
#undef GET_VT_ATTR
Expand Down Expand Up @@ -419,7 +429,7 @@ namespace llvm {
}

static MVT getFloatingPointVT(unsigned BitWidth) {
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, NElem, EltTy) \
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, Tup, NF, NElem, EltTy) \
if (FP == 3 && sz == BitWidth) \
return Ty;
#include "llvm/CodeGen/GenVT.inc"
Expand All @@ -429,7 +439,7 @@ namespace llvm {
}

static MVT getIntegerVT(unsigned BitWidth) {
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, NElem, EltTy) \
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, Tup, NF, NElem, EltTy) \
if (Int == 3 && sz == BitWidth) \
return Ty;
#include "llvm/CodeGen/GenVT.inc"
Expand All @@ -439,8 +449,8 @@ namespace llvm {
}

static MVT getVectorVT(MVT VT, unsigned NumElements) {
#define GET_VT_VECATTR(Ty, Sc, nElem, ElTy) \
if (!Sc && VT.SimpleTy == ElTy && NumElements == nElem) \
#define GET_VT_VECATTR(Ty, Sc, Tup, nElem, ElTy) \
if (!Sc && !Tup && VT.SimpleTy == ElTy && NumElements == nElem) \
return Ty;
#include "llvm/CodeGen/GenVT.inc"
#undef GET_VT_VECATTR
Expand All @@ -449,7 +459,7 @@ namespace llvm {
}

static MVT getScalableVectorVT(MVT VT, unsigned NumElements) {
#define GET_VT_VECATTR(Ty, Sc, nElem, ElTy) \
#define GET_VT_VECATTR(Ty, Sc, Tup, nElem, ElTy) \
if (Sc && VT.SimpleTy == ElTy && NumElements == nElem) \
return Ty;
#include "llvm/CodeGen/GenVT.inc"
Expand All @@ -458,6 +468,29 @@ namespace llvm {
return (MVT::SimpleValueType)(MVT::INVALID_SIMPLE_VALUE_TYPE);
}

static MVT getRISCVVectorTupleVT(unsigned Sz, unsigned NFields) {
#define GET_VT_ATTR(Ty, n, sz, Any, Int, FP, Vec, Sc, Tup, NF, nElem, EltTy) \
if (Tup && sz == Sz && NF == NFields) \
return Ty;
#include "llvm/CodeGen/GenVT.inc"
#undef GET_VT_ATTR

llvm_unreachable("Invalid RISCV vector tuple type");
}

/// Given a RISC-V vector tuple type, return the num_fields.
unsigned getRISCVVectorTupleNumFields() const {
assert(isRISCVVectorTuple() && SimpleTy >= FIRST_VALUETYPE &&
SimpleTy <= LAST_VALUETYPE);
static constexpr uint8_t NFTable[] = {
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, Tup, NF, NElem, EltTy) \
NF,
#include "llvm/CodeGen/GenVT.inc"
#undef GET_VT_ATTR
};
return NFTable[SimpleTy - FIRST_VALUETYPE];
}

static MVT getVectorVT(MVT VT, unsigned NumElements, bool IsScalable) {
if (IsScalable)
return getScalableVectorVT(VT, NumElements);
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/CodeGen/ValueTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ TypeSize EVT::getExtendedSizeInBits() const {
std::string EVT::getEVTString() const {
switch (V.SimpleTy) {
default:
if (isRISCVVectorTuple()) {
unsigned Sz = getSizeInBits();
unsigned NF = getRISCVVectorTupleNumFields();
unsigned MinNumElts = Sz / (NF * 8);
return "riscv_nxv" + utostr(MinNumElts) + "i8x" + utostr(NF);
}
if (isVector())
return (isScalableVector() ? "nxv" : "v") +
utostr(getVectorElementCount().getKnownMinValue()) +
Expand Down Expand Up @@ -250,6 +256,14 @@ MVT MVT::getVT(Type *Ty, bool HandleUnknown){
return MVT(MVT::aarch64svcount);
else if (TargetExtTy->getName().starts_with("spirv."))
return MVT(MVT::spirvbuiltin);
if (TargetExtTy->getName() == "riscv.vector.tuple") {
unsigned Sz = cast<ScalableVectorType>(TargetExtTy->getTypeParameter(0))
->getMinNumElements() *
8;
unsigned NF = TargetExtTy->getIntParameter(0);

return MVT::getRISCVVectorTupleVT(Sz * NF, NF);
}
if (HandleUnknown)
return MVT(MVT::Other);
llvm_unreachable("Unknown target ext type!");
Expand Down
2 changes: 1 addition & 1 deletion llvm/utils/TableGen/Common/CodeGenTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ StringRef llvm::getName(MVT::SimpleValueType T) {
StringRef llvm::getEnumName(MVT::SimpleValueType T) {
// clang-format off
switch (T) {
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, NElem, EltTy) \
#define GET_VT_ATTR(Ty, N, Sz, Any, Int, FP, Vec, Sc, Tup, NF, NElem, EltTy) \
case MVT::Ty: return "MVT::" # Ty;
#include "llvm/CodeGen/GenVT.inc"
default: llvm_unreachable("ILLEGAL VALUE TYPE!");
Expand Down
24 changes: 21 additions & 3 deletions llvm/utils/TableGen/VTEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ class VTEmitter {

static void VTtoGetLLVMTyString(raw_ostream &OS, const Record *VT) {
bool IsVector = VT->getValueAsBit("isVector");
bool IsRISCVVecTuple = VT->getValueAsBit("isRISCVVecTuple");

if (IsRISCVVecTuple) {
unsigned NElem = VT->getValueAsInt("nElem");
unsigned Sz = VT->getValueAsInt("Size");
OS << "TargetExtType::get(Context, \"riscv.vector.tuple\", "
"ScalableVectorType::get(Type::getInt8Ty(Context), "
<< (Sz / (NElem * 8)) << "), " << NElem << ")";
return;
}

if (IsVector)
OS << (VT->getValueAsBit("isScalable") ? "Scalable" : "Fixed")
<< "VectorType::get(";
Expand Down Expand Up @@ -109,7 +120,7 @@ void VTEmitter::run(raw_ostream &OS) {
}
};

OS << "#ifdef GET_VT_ATTR // (Ty, n, sz, Any, Int, FP, Vec, Sc)\n";
OS << "#ifdef GET_VT_ATTR // (Ty, n, sz, Any, Int, FP, Vec, Sc, Tup, NF)\n";
for (const auto *VT : VTsByNumber) {
if (!VT)
continue;
Expand All @@ -119,6 +130,8 @@ void VTEmitter::run(raw_ostream &OS) {
bool IsFP = VT->getValueAsBit("isFP");
bool IsVector = VT->getValueAsBit("isVector");
bool IsScalable = VT->getValueAsBit("isScalable");
bool IsRISCVVecTuple = VT->getValueAsBit("isRISCVVecTuple");
int64_t NF = VT->getValueAsInt("NF");
bool IsNormalValueType = VT->getValueAsBit("isNormalValueType");
int64_t NElem = IsVector ? VT->getValueAsInt("nElem") : 0;
StringRef EltName = IsVector ? VT->getValueAsDef("ElementType")->getName()
Expand All @@ -133,6 +146,7 @@ void VTEmitter::run(raw_ostream &OS) {
UpdateVTRange("FP_SCALABLE_VECTOR_VALUETYPE", Name, IsFP && IsScalable);
UpdateVTRange("FIXEDLEN_VECTOR_VALUETYPE", Name, IsVector && !IsScalable);
UpdateVTRange("SCALABLE_VECTOR_VALUETYPE", Name, IsScalable);
UpdateVTRange("RISCV_VECTOR_TUPLE_VALUETYPE", Name, IsRISCVVecTuple);
UpdateVTRange("VECTOR_VALUETYPE", Name, IsVector);
UpdateVTRange("INTEGER_VALUETYPE", Name, IsInteger && !IsVector);
UpdateVTRange("FP_VALUETYPE", Name, IsFP && !IsVector);
Expand All @@ -148,6 +162,8 @@ void VTEmitter::run(raw_ostream &OS) {
<< (IsFP ? Name[0] == 'f' ? 3 : 1 : 0) << ", "
<< IsVector << ", "
<< IsScalable << ", "
<< IsRISCVVecTuple << ", "
<< NF << ", "
<< NElem << ", "
<< EltName << ")\n";
// clang-format on
Expand All @@ -162,7 +178,7 @@ void VTEmitter::run(raw_ostream &OS) {
}
OS << "#endif\n\n";

OS << "#ifdef GET_VT_VECATTR // (Ty, Sc, nElem, ElTy)\n";
OS << "#ifdef GET_VT_VECATTR // (Ty, Sc, Tup, nElem, ElTy)\n";
for (const auto *VT : VTsByNumber) {
if (!VT || !VT->getValueAsBit("isVector"))
continue;
Expand All @@ -172,6 +188,7 @@ void VTEmitter::run(raw_ostream &OS) {
OS << " GET_VT_VECATTR("
<< VT->getValueAsString("LLVMName") << ", "
<< VT->getValueAsBit("isScalable") << ", "
<< VT->getValueAsBit("isRISCVVecTuple") << ", "
<< VT->getValueAsInt("nElem") << ", "
<< ElTy->getName() << ")\n";
// clang-format on
Expand All @@ -185,8 +202,9 @@ void VTEmitter::run(raw_ostream &OS) {
bool IsInteger = VT->getValueAsBit("isInteger");
bool IsVector = VT->getValueAsBit("isVector");
bool IsFP = VT->getValueAsBit("isFP");
bool IsRISCVVecTuple = VT->getValueAsBit("isRISCVVecTuple");

if (!IsInteger && !IsVector && !IsFP)
if (!IsInteger && !IsVector && !IsFP && !IsRISCVVecTuple)
continue;

OS << " GET_VT_EVT(" << VT->getValueAsString("LLVMName") << ", ";
Expand Down
Loading