Skip to content

Commit 0a443f1

Browse files
[SPIR-V] Add implementation of G_SPLAT_VECTOR opcode and fix invalid types processing (#84766)
This PR: * adds support for G_SPLAT_VECTOR generic opcode that may be legally generated instead of G_BUILD_VECTOR by previous passes of the translator (see #80378 for the source of breaking changes); * improves deduction of types for opaque pointers. This PR also fixes the following issues: * if a function has ptr argument(s), two functions that have different SPIR-V type definitions may get identical LLVM function types and break agreements of global register and duplicate checker; * checks for pointer types do not account for TypedPointerType. Update of tests: * A test case is added to cover the issue with function ptr parameters. * The first case, that is support for G_SPLAT_VECTOR generic opcode, is covered by existing test cases. * Multiple additional checks by `spirv-val` is added to cover more possibilities of generation of invalid code.
1 parent cd2f616 commit 0a443f1

File tree

71 files changed

+382
-56
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+382
-56
lines changed

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,42 @@ static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
8585
return nullptr;
8686
}
8787

88+
// If the function has pointer arguments, we are forced to re-create this
89+
// function type from the very beginning, changing PointerType by
90+
// TypedPointerType for each pointer argument. Otherwise, the same `Type*`
91+
// potentially corresponds to different SPIR-V function type, effectively
92+
// invalidating logic behind global registry and duplicates tracker.
93+
static FunctionType *
94+
fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
95+
FunctionType *FTy, const SPIRVType *SRetTy,
96+
const SmallVector<SPIRVType *, 4> &SArgTys) {
97+
if (F.getParent()->getNamedMetadata("spv.cloned_funcs"))
98+
return FTy;
99+
100+
bool hasArgPtrs = false;
101+
for (auto &Arg : F.args()) {
102+
// check if it's an instance of a non-typed PointerType
103+
if (Arg.getType()->isPointerTy()) {
104+
hasArgPtrs = true;
105+
break;
106+
}
107+
}
108+
if (!hasArgPtrs) {
109+
Type *RetTy = FTy->getReturnType();
110+
// check if it's an instance of a non-typed PointerType
111+
if (!RetTy->isPointerTy())
112+
return FTy;
113+
}
114+
115+
// re-create function type, using TypedPointerType instead of PointerType to
116+
// properly trace argument types
117+
const Type *RetTy = GR->getTypeForSPIRVType(SRetTy);
118+
SmallVector<Type *, 4> ArgTys;
119+
for (auto SArgTy : SArgTys)
120+
ArgTys.push_back(const_cast<Type *>(GR->getTypeForSPIRVType(SArgTy)));
121+
return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false);
122+
}
123+
88124
// This code restores function args/retvalue types for composite cases
89125
// because the final types should still be aggregate whereas they're i32
90126
// during the translation to cope with aggregate flattening etc.
@@ -162,7 +198,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
162198

163199
// If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
164200
// be legally reassigned later).
165-
if (!OriginalArgType->isPointerTy())
201+
if (!isPointerTy(OriginalArgType))
166202
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
167203

168204
// In case OriginalArgType is of pointer type, there are three possibilities:
@@ -179,8 +215,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
179215
SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
180216
return GR->getOrCreateSPIRVPointerType(
181217
ElementType, MIRBuilder,
182-
addressSpaceToStorageClass(Arg->getType()->getPointerAddressSpace(),
183-
ST));
218+
addressSpaceToStorageClass(getPointerAddressSpace(Arg->getType()), ST));
184219
}
185220

186221
for (auto User : Arg->users()) {
@@ -240,7 +275,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
240275
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
241276

242277
// Assign types and names to all args, and store their types for later.
243-
FunctionType *FTy = getOriginalFunctionType(F);
244278
SmallVector<SPIRVType *, 4> ArgTypeVRegs;
245279
if (VRegs.size() > 0) {
246280
unsigned i = 0;
@@ -255,7 +289,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
255289

256290
if (Arg.hasName())
257291
buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
258-
if (Arg.getType()->isPointerTy()) {
292+
if (isPointerTy(Arg.getType())) {
259293
auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
260294
if (DerefBytes != 0)
261295
buildOpDecorate(VRegs[i][0], MIRBuilder,
@@ -322,7 +356,9 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
322356
MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
323357
if (F.isDeclaration())
324358
GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
359+
FunctionType *FTy = getOriginalFunctionType(F);
325360
SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
361+
FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
326362
SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
327363
FTy, RetTy, ArgTypeVRegs, MIRBuilder);
328364
uint32_t FuncControl = getFunctionControl(F);
@@ -429,7 +465,6 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
429465
return false;
430466
MachineFunction &MF = MIRBuilder.getMF();
431467
GR->setCurrentFunc(MF);
432-
FunctionType *FTy = nullptr;
433468
const Function *CF = nullptr;
434469
std::string DemangledName;
435470
const Type *OrigRetTy = Info.OrigRet.Ty;
@@ -444,7 +479,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
444479
// TODO: support constexpr casts and indirect calls.
445480
if (CF == nullptr)
446481
return false;
447-
if ((FTy = getOriginalFunctionType(*CF)) != nullptr)
482+
if (FunctionType *FTy = getOriginalFunctionType(*CF))
448483
OrigRetTy = FTy->getReturnType();
449484
}
450485

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 100 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,14 @@ class SPIRVEmitIntrinsics
5757
bool TrackConstants = true;
5858
DenseMap<Instruction *, Constant *> AggrConsts;
5959
DenseSet<Instruction *> AggrStores;
60+
61+
// deduce values type
62+
DenseMap<Value *, Type *> DeducedElTys;
63+
Type *deduceElementType(Value *I);
64+
6065
void preprocessCompositeConstants(IRBuilder<> &B);
6166
void preprocessUndefs(IRBuilder<> &B);
67+
6268
CallInst *buildIntrWithMD(Intrinsic::ID IntrID, ArrayRef<Type *> Types,
6369
Value *Arg, Value *Arg2, ArrayRef<Constant *> Imms,
6470
IRBuilder<> &B) {
@@ -72,6 +78,7 @@ class SPIRVEmitIntrinsics
7278
Args.push_back(Imm);
7379
return B.CreateIntrinsic(IntrID, {Types}, Args);
7480
}
81+
7582
void replaceMemInstrUses(Instruction *Old, Instruction *New, IRBuilder<> &B);
7683
void processInstrAfterVisit(Instruction *I, IRBuilder<> &B);
7784
void insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B);
@@ -156,6 +163,48 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
156163
false);
157164
}
158165

166+
// Deduce and return a successfully deduced Type of the Instruction,
167+
// or nullptr otherwise.
168+
static Type *deduceElementTypeHelper(Value *I,
169+
std::unordered_set<Value *> &Visited,
170+
DenseMap<Value *, Type *> &DeducedElTys) {
171+
// maybe already known
172+
auto It = DeducedElTys.find(I);
173+
if (It != DeducedElTys.end())
174+
return It->second;
175+
176+
// maybe a cycle
177+
if (Visited.find(I) != Visited.end())
178+
return nullptr;
179+
Visited.insert(I);
180+
181+
// fallback value in case when we fail to deduce a type
182+
Type *Ty = nullptr;
183+
// look for known basic patterns of type inference
184+
if (auto *Ref = dyn_cast<AllocaInst>(I))
185+
Ty = Ref->getAllocatedType();
186+
else if (auto *Ref = dyn_cast<GetElementPtrInst>(I))
187+
Ty = Ref->getResultElementType();
188+
else if (auto *Ref = dyn_cast<GlobalValue>(I))
189+
Ty = Ref->getValueType();
190+
else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I))
191+
Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
192+
DeducedElTys);
193+
194+
// remember the found relationship
195+
if (Ty)
196+
DeducedElTys[I] = Ty;
197+
198+
return Ty;
199+
}
200+
201+
Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
202+
std::unordered_set<Value *> Visited;
203+
if (Type *Ty = deduceElementTypeHelper(I, Visited, DeducedElTys))
204+
return Ty;
205+
return IntegerType::getInt8Ty(I->getContext());
206+
}
207+
159208
void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
160209
Instruction *New,
161210
IRBuilder<> &B) {
@@ -280,7 +329,7 @@ Instruction *SPIRVEmitIntrinsics::visitBitCastInst(BitCastInst &I) {
280329
// varying element types. In case of IR coming from older versions of LLVM
281330
// such bitcasts do not provide sufficient information, should be just skipped
282331
// here, and handled in insertPtrCastOrAssignTypeInstr.
283-
if (I.getType()->isPointerTy()) {
332+
if (isPointerTy(I.getType())) {
284333
I.replaceAllUsesWith(Source);
285334
I.eraseFromParent();
286335
return nullptr;
@@ -333,20 +382,10 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
333382
while (BitCastInst *BC = dyn_cast<BitCastInst>(Pointer))
334383
Pointer = BC->getOperand(0);
335384

336-
// Do not emit spv_ptrcast if Pointer is a GlobalValue of expected type.
337-
GlobalValue *GV = dyn_cast<GlobalValue>(Pointer);
338-
if (GV && GV->getValueType() == ExpectedElementType)
339-
return;
340-
341-
// Do not emit spv_ptrcast if Pointer is a result of alloca with expected
342-
// type.
343-
AllocaInst *A = dyn_cast<AllocaInst>(Pointer);
344-
if (A && A->getAllocatedType() == ExpectedElementType)
345-
return;
346-
347-
// Do not emit spv_ptrcast if Pointer is a result of GEP of expected type.
348-
GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Pointer);
349-
if (GEPI && GEPI->getResultElementType() == ExpectedElementType)
385+
// Do not emit spv_ptrcast if Pointer's element type is ExpectedElementType
386+
std::unordered_set<Value *> Visited;
387+
Type *PointerElemTy = deduceElementTypeHelper(Pointer, Visited, DeducedElTys);
388+
if (PointerElemTy == ExpectedElementType)
350389
return;
351390

352391
setInsertPointSkippingPhis(B, I);
@@ -356,7 +395,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
356395
ValueAsMetadata::getConstant(ExpectedElementTypeConst);
357396
MDTuple *TyMD = MDNode::get(F->getContext(), CM);
358397
MetadataAsValue *VMD = MetadataAsValue::get(F->getContext(), TyMD);
359-
unsigned AddressSpace = Pointer->getType()->getPointerAddressSpace();
398+
unsigned AddressSpace = getPointerAddressSpace(Pointer->getType());
360399
bool FirstPtrCastOrAssignPtrType = true;
361400

362401
// Do not emit new spv_ptrcast if equivalent one already exists or when
@@ -401,9 +440,11 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
401440
// spv_assign_ptr_type instead.
402441
if (FirstPtrCastOrAssignPtrType &&
403442
(isa<Instruction>(Pointer) || isa<Argument>(Pointer))) {
404-
buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
405-
ExpectedElementTypeConst, Pointer,
406-
{B.getInt32(AddressSpace)}, B);
443+
CallInst *CI = buildIntrWithMD(
444+
Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
445+
ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B);
446+
DeducedElTys[CI] = ExpectedElementType;
447+
DeducedElTys[Pointer] = ExpectedElementType;
407448
return;
408449
}
409450

@@ -419,7 +460,7 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
419460
// Handle basic instructions:
420461
StoreInst *SI = dyn_cast<StoreInst>(I);
421462
if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
422-
SI->getValueOperand()->getType()->isPointerTy() &&
463+
isPointerTy(SI->getValueOperand()->getType()) &&
423464
isa<Argument>(SI->getValueOperand())) {
424465
return replacePointerOperandWithPtrCast(
425466
I, SI->getValueOperand(), IntegerType::getInt8Ty(F->getContext()), 0,
@@ -440,9 +481,34 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
440481
if (!CI || CI->isIndirectCall() || CI->getCalledFunction()->isIntrinsic())
441482
return;
442483

484+
// collect information about formal parameter types
485+
Function *CalledF = CI->getCalledFunction();
486+
SmallVector<Type *, 4> CalledArgTys;
487+
bool HaveTypes = false;
488+
for (auto &CalledArg : CalledF->args()) {
489+
if (!isPointerTy(CalledArg.getType())) {
490+
CalledArgTys.push_back(nullptr);
491+
continue;
492+
}
493+
auto It = DeducedElTys.find(&CalledArg);
494+
Type *ParamTy = It != DeducedElTys.end() ? It->second : nullptr;
495+
if (!ParamTy) {
496+
for (User *U : CalledArg.users()) {
497+
if (Instruction *Inst = dyn_cast<Instruction>(U)) {
498+
std::unordered_set<Value *> Visited;
499+
ParamTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
500+
if (ParamTy)
501+
break;
502+
}
503+
}
504+
}
505+
HaveTypes |= ParamTy != nullptr;
506+
CalledArgTys.push_back(ParamTy);
507+
}
508+
443509
std::string DemangledName =
444510
getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
445-
if (DemangledName.empty())
511+
if (DemangledName.empty() && !HaveTypes)
446512
return;
447513

448514
for (unsigned OpIdx = 0; OpIdx < CI->arg_size(); OpIdx++) {
@@ -455,8 +521,11 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
455521
if (!isa<Instruction>(ArgOperand) && !isa<Argument>(ArgOperand))
456522
continue;
457523

458-
Type *ExpectedType = SPIRV::parseBuiltinCallArgumentBaseType(
459-
DemangledName, OpIdx, I->getContext());
524+
Type *ExpectedType =
525+
OpIdx < CalledArgTys.size() ? CalledArgTys[OpIdx] : nullptr;
526+
if (!ExpectedType && !DemangledName.empty())
527+
ExpectedType = SPIRV::parseBuiltinCallArgumentBaseType(
528+
DemangledName, OpIdx, I->getContext());
460529
if (!ExpectedType)
461530
continue;
462531

@@ -639,30 +708,25 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV,
639708
void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
640709
IRBuilder<> &B) {
641710
reportFatalOnTokenType(I);
642-
if (!I->getType()->isPointerTy() || !requireAssignType(I) ||
711+
if (!isPointerTy(I->getType()) || !requireAssignType(I) ||
643712
isa<BitCastInst>(I))
644713
return;
645714

646715
setInsertPointSkippingPhis(B, I->getNextNode());
647716

648-
Constant *EltTyConst;
649-
unsigned AddressSpace = I->getType()->getPointerAddressSpace();
650-
if (auto *AI = dyn_cast<AllocaInst>(I))
651-
EltTyConst = UndefValue::get(AI->getAllocatedType());
652-
else if (auto *GEP = dyn_cast<GetElementPtrInst>(I))
653-
EltTyConst = UndefValue::get(GEP->getResultElementType());
654-
else
655-
EltTyConst = UndefValue::get(IntegerType::getInt8Ty(I->getContext()));
656-
657-
buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()}, EltTyConst, I,
658-
{B.getInt32(AddressSpace)}, B);
717+
Type *ElemTy = deduceElementType(I);
718+
Constant *EltTyConst = UndefValue::get(ElemTy);
719+
unsigned AddressSpace = getPointerAddressSpace(I->getType());
720+
CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()},
721+
EltTyConst, I, {B.getInt32(AddressSpace)}, B);
722+
DeducedElTys[CI] = ElemTy;
659723
}
660724

661725
void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
662726
IRBuilder<> &B) {
663727
reportFatalOnTokenType(I);
664728
Type *Ty = I->getType();
665-
if (!Ty->isVoidTy() && !Ty->isPointerTy() && requireAssignType(I)) {
729+
if (!Ty->isVoidTy() && !isPointerTy(Ty) && requireAssignType(I)) {
666730
setInsertPointSkippingPhis(B, I->getNextNode());
667731
Type *TypeToAssign = Ty;
668732
if (auto *II = dyn_cast<IntrinsicInst>(I)) {

0 commit comments

Comments
 (0)