Skip to content

Commit 6f8d278

Browse files
[SandboxIR] Add missing VectorType functions (#107650)
Fills in many missing functions from VectorType
1 parent 53a81d4 commit 6f8d278

File tree

3 files changed

+108
-8
lines changed

3 files changed

+108
-8
lines changed

llvm/include/llvm/SandboxIR/Type.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ class Type {
5050
friend class ConstantArray; // For LLVMTy.
5151
friend class ConstantStruct; // For LLVMTy.
5252
friend class ConstantVector; // For LLVMTy.
53-
friend class CmpInst; // For LLVMTy. TODO: Cleanup after sandboxir::VectorType
54-
// is more complete.
53+
friend class CmpInst; // For LLVMTy. TODO: Cleanup after
54+
// sandboxir::VectorType is more complete.
5555

5656
// Friend all instruction classes because `create()` functions use LLVMTy.
5757
#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
@@ -317,7 +317,28 @@ class StructType : public Type {
317317
class VectorType : public Type {
318318
public:
319319
static VectorType *get(Type *ElementType, ElementCount EC);
320-
// TODO: add missing functions
320+
static VectorType *get(Type *ElementType, unsigned NumElements,
321+
bool Scalable) {
322+
return VectorType::get(ElementType,
323+
ElementCount::get(NumElements, Scalable));
324+
}
325+
Type *getElementType() const;
326+
327+
static VectorType *get(Type *ElementType, const VectorType *Other) {
328+
return VectorType::get(ElementType, Other->getElementCount());
329+
}
330+
331+
inline ElementCount getElementCount() const {
332+
return cast<llvm::VectorType>(LLVMTy)->getElementCount();
333+
}
334+
static VectorType *getInteger(VectorType *VTy);
335+
static VectorType *getExtendedElementVectorType(VectorType *VTy);
336+
static VectorType *getTruncatedElementVectorType(VectorType *VTy);
337+
static VectorType *getSubdividedVectorType(VectorType *VTy, int NumSubdivs);
338+
static VectorType *getHalfElementsVectorType(VectorType *VTy);
339+
static VectorType *getDoubleElementsVectorType(VectorType *VTy);
340+
static bool isValidElementType(Type *ElemTy);
341+
321342
static bool classof(const Type *From) {
322343
return isa<llvm::VectorType>(From->LLVMTy);
323344
}

llvm/lib/SandboxIR/Type.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ Type *Type::getDoubleTy(Context &Ctx) {
3636
Type *Type::getFloatTy(Context &Ctx) {
3737
return Ctx.getType(llvm::Type::getFloatTy(Ctx.LLVMCtx));
3838
}
39-
4039
PointerType *PointerType::get(Type *ElementType, unsigned AddressSpace) {
4140
return cast<PointerType>(ElementType->getContext().getType(
4241
llvm::PointerType::get(ElementType->LLVMTy, AddressSpace)));
@@ -67,6 +66,43 @@ VectorType *VectorType::get(Type *ElementType, ElementCount EC) {
6766
llvm::VectorType::get(ElementType->LLVMTy, EC)));
6867
}
6968

69+
Type *VectorType::getElementType() const {
70+
return Ctx.getType(cast<llvm::VectorType>(LLVMTy)->getElementType());
71+
}
72+
VectorType *VectorType::getInteger(VectorType *VTy) {
73+
return cast<VectorType>(VTy->getContext().getType(
74+
llvm::VectorType::getInteger(cast<llvm::VectorType>(VTy->LLVMTy))));
75+
}
76+
VectorType *VectorType::getExtendedElementVectorType(VectorType *VTy) {
77+
return cast<VectorType>(
78+
VTy->getContext().getType(llvm::VectorType::getExtendedElementVectorType(
79+
cast<llvm::VectorType>(VTy->LLVMTy))));
80+
}
81+
VectorType *VectorType::getTruncatedElementVectorType(VectorType *VTy) {
82+
return cast<VectorType>(
83+
VTy->getContext().getType(llvm::VectorType::getTruncatedElementVectorType(
84+
cast<llvm::VectorType>(VTy->LLVMTy))));
85+
}
86+
VectorType *VectorType::getSubdividedVectorType(VectorType *VTy,
87+
int NumSubdivs) {
88+
return cast<VectorType>(
89+
VTy->getContext().getType(llvm::VectorType::getSubdividedVectorType(
90+
cast<llvm::VectorType>(VTy->LLVMTy), NumSubdivs)));
91+
}
92+
VectorType *VectorType::getHalfElementsVectorType(VectorType *VTy) {
93+
return cast<VectorType>(
94+
VTy->getContext().getType(llvm::VectorType::getHalfElementsVectorType(
95+
cast<llvm::VectorType>(VTy->LLVMTy))));
96+
}
97+
VectorType *VectorType::getDoubleElementsVectorType(VectorType *VTy) {
98+
return cast<VectorType>(
99+
VTy->getContext().getType(llvm::VectorType::getDoubleElementsVectorType(
100+
cast<llvm::VectorType>(VTy->LLVMTy))));
101+
}
102+
bool VectorType::isValidElementType(Type *ElemTy) {
103+
return llvm::VectorType::isValidElementType(ElemTy->LLVMTy);
104+
}
105+
70106
IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
71107
return cast<IntegerType>(
72108
Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));

llvm/unittests/SandboxIR/TypesTest.cpp

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,16 +268,59 @@ define void @foo({i32, i8} %v0) {
268268

269269
TEST_F(SandboxTypeTest, VectorType) {
270270
parseIR(C, R"IR(
271-
define void @foo(<2 x i8> %v0) {
271+
define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
272272
ret void
273273
}
274274
)IR");
275275
llvm::Function *LLVMF = &*M->getFunction("foo");
276276
sandboxir::Context Ctx(C);
277277
auto *F = Ctx.createFunction(LLVMF);
278-
// Check classof(), creation.
279-
[[maybe_unused]] auto *VecTy =
280-
cast<sandboxir::VectorType>(F->getArg(0)->getType());
278+
// Check classof(), creation, accessors
279+
auto *VecTy = cast<sandboxir::VectorType>(F->getArg(0)->getType());
280+
EXPECT_TRUE(VecTy->getElementType()->isIntegerTy(16));
281+
EXPECT_EQ(VecTy->getElementCount(), ElementCount::getFixed(4));
282+
283+
// get(ElementType, NumElements, Scalable)
284+
EXPECT_EQ(sandboxir::VectorType::get(sandboxir::Type::getInt16Ty(Ctx), 4,
285+
/*Scalable=*/false),
286+
F->getArg(0)->getType());
287+
// get(ElementType, Other)
288+
EXPECT_EQ(sandboxir::VectorType::get(
289+
sandboxir::Type::getInt16Ty(Ctx),
290+
cast<sandboxir::VectorType>(F->getArg(0)->getType())),
291+
F->getArg(0)->getType());
292+
auto *FVecTy = cast<sandboxir::VectorType>(F->getArg(1)->getType());
293+
EXPECT_TRUE(FVecTy->getElementType()->isFloatTy());
294+
// getInteger
295+
auto *IVecTy = sandboxir::VectorType::getInteger(FVecTy);
296+
EXPECT_TRUE(IVecTy->getElementType()->isIntegerTy(32));
297+
EXPECT_EQ(IVecTy->getElementCount(), FVecTy->getElementCount());
298+
// getExtendedElementCountVectorType
299+
auto *ExtVecTy = sandboxir::VectorType::getExtendedElementVectorType(IVecTy);
300+
EXPECT_TRUE(ExtVecTy->getElementType()->isIntegerTy(64));
301+
EXPECT_EQ(ExtVecTy->getElementCount(), VecTy->getElementCount());
302+
// getTruncatedElementVectorType
303+
auto *TruncVecTy =
304+
sandboxir::VectorType::getTruncatedElementVectorType(IVecTy);
305+
EXPECT_TRUE(TruncVecTy->getElementType()->isIntegerTy(16));
306+
EXPECT_EQ(TruncVecTy->getElementCount(), VecTy->getElementCount());
307+
// getSubdividedVectorType
308+
auto *SubVecTy = sandboxir::VectorType::getSubdividedVectorType(VecTy, 1);
309+
EXPECT_TRUE(SubVecTy->getElementType()->isIntegerTy(8));
310+
EXPECT_EQ(SubVecTy->getElementCount(), ElementCount::getFixed(8));
311+
// getHalfElementsVectorType
312+
auto *HalfVecTy = sandboxir::VectorType::getHalfElementsVectorType(VecTy);
313+
EXPECT_TRUE(HalfVecTy->getElementType()->isIntegerTy(16));
314+
EXPECT_EQ(HalfVecTy->getElementCount(), ElementCount::getFixed(2));
315+
// getDoubleElementsVectorType
316+
auto *DoubleVecTy = sandboxir::VectorType::getDoubleElementsVectorType(VecTy);
317+
EXPECT_TRUE(DoubleVecTy->getElementType()->isIntegerTy(16));
318+
EXPECT_EQ(DoubleVecTy->getElementCount(), ElementCount::getFixed(8));
319+
// isValidElementType
320+
auto *I8Type = F->getArg(2)->getType();
321+
EXPECT_TRUE(I8Type->isIntegerTy());
322+
EXPECT_TRUE(sandboxir::VectorType::isValidElementType(I8Type));
323+
EXPECT_FALSE(sandboxir::VectorType::isValidElementType(FVecTy));
281324
}
282325

283326
TEST_F(SandboxTypeTest, FunctionType) {

0 commit comments

Comments
 (0)