Skip to content

Commit 494f672

Browse files
[SPIR-V] Prevent type change of GEP results in type inference (#129250)
The following reproducer demonstrates the issue with invalid definition of GEP results during type inference ``` define spir_kernel void @foo(i1 %fl, i64 %idx, ptr addrspace(1) %dest, ptr addrspace(3) %src) { %p1 = getelementptr inbounds i8, ptr addrspace(1) %dest, i64 %idx %res = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %p1, ptr addrspace(3) %src, i64 128, i64 1, target("spirv.Event") zeroinitializer) ret void } declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event")) ``` Here `OpGroupAsyncCopy` expects i32* arguments and type inference fails to set a correct type of the GEP result `%p1`, because it is an argument of `OpGroupAsyncCopy`. This PR fixes the issue by preventing type change of GEP results in type inference.
1 parent e1e20c0 commit 494f672

File tree

2 files changed

+52
-9
lines changed

2 files changed

+52
-9
lines changed

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,20 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
646646
Ty = RefTy;
647647
}
648648

649+
Type *getGEPType(GetElementPtrInst *Ref) {
650+
Type *Ty = nullptr;
651+
// TODO: not sure if GetElementPtrInst::getTypeAtIndex() does anything
652+
// useful here
653+
if (isNestedPointer(Ref->getSourceElementType())) {
654+
Ty = Ref->getSourceElementType();
655+
for (Use &U : drop_begin(Ref->indices()))
656+
Ty = GetElementPtrInst::getTypeAtIndex(Ty, U.get());
657+
} else {
658+
Ty = Ref->getResultElementType();
659+
}
660+
return Ty;
661+
}
662+
649663
Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
650664
Value *I, std::unordered_set<Value *> &Visited, bool UnknownElemTypeI8,
651665
bool IgnoreKnownType) {
@@ -668,15 +682,7 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
668682
if (auto *Ref = dyn_cast<AllocaInst>(I)) {
669683
maybeAssignPtrType(Ty, I, Ref->getAllocatedType(), UnknownElemTypeI8);
670684
} else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
671-
// TODO: not sure if GetElementPtrInst::getTypeAtIndex() does anything
672-
// useful here
673-
if (isNestedPointer(Ref->getSourceElementType())) {
674-
Ty = Ref->getSourceElementType();
675-
for (Use &U : drop_begin(Ref->indices()))
676-
Ty = GetElementPtrInst::getTypeAtIndex(Ty, U.get());
677-
} else {
678-
Ty = Ref->getResultElementType();
679-
}
685+
Ty = getGEPType(Ref);
680686
} else if (auto *Ref = dyn_cast<LoadInst>(I)) {
681687
Value *Op = Ref->getPointerOperand();
682688
Type *KnownTy = GR->findDeducedElementType(Op);
@@ -2307,6 +2313,7 @@ bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
23072313

23082314
// Apply types parsed from demangled function declarations.
23092315
void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
2316+
DenseMap<Function *, CallInst *> Ptrcasts;
23102317
for (auto It : FDeclPtrTys) {
23112318
Function *F = It.first;
23122319
for (auto *U : F->users()) {
@@ -2326,6 +2333,9 @@ void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
23262333
B.SetCurrentDebugLocation(DebugLoc());
23272334
buildAssignPtr(B, ElemTy, Arg);
23282335
}
2336+
} else if (isa<GetElementPtrInst>(Param)) {
2337+
replaceUsesOfWithSpvPtrcast(Param, normalizeType(ElemTy), CI,
2338+
Ptrcasts);
23292339
} else if (isa<Instruction>(Param)) {
23302340
GR->addDeducedElementType(Param, normalizeType(ElemTy));
23312341
// insertAssignTypeIntrs() will complete buildAssignPtr()
@@ -2370,6 +2380,15 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
23702380
AggrConstTypes.clear();
23712381
AggrStores.clear();
23722382

2383+
// fix GEP result types ahead of inference
2384+
for (auto &I : instructions(Func)) {
2385+
auto *Ref = dyn_cast<GetElementPtrInst>(&I);
2386+
if (!Ref || GR->findDeducedElementType(Ref))
2387+
continue;
2388+
if (Type *GepTy = getGEPType(Ref))
2389+
GR->addDeducedElementType(Ref, normalizeType(GepTy));
2390+
}
2391+
23732392
processParamTypesByFunHeader(CurrF, B);
23742393

23752394
// StoreInst's operand type can be changed during the next transformations,
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: %[[#Char:]] = OpTypeInt 8 0
5+
; CHECK-DAG: %[[#Long:]] = OpTypeInt 32 0
6+
; CHECK-DAG: %[[#CharPtr:]] = OpTypePointer CrossWorkgroup %[[#Char]]
7+
; CHECK-DAG: %[[#LongPtr:]] = OpTypePointer CrossWorkgroup %[[#Long]]
8+
; CHECK-DAG: %[[#LongPtrWG:]] = OpTypePointer Workgroup %[[#Long]]
9+
; CHECK: OpFunction
10+
; CHECK: OpFunctionParameter
11+
; CHECK: %[[#Dest:]] = OpFunctionParameter %[[#CharPtr]]
12+
; CHECK: %[[#Src:]] = OpFunctionParameter %[[#LongPtrWG]]
13+
; CHECK: %[[#InDest:]] = OpInBoundsPtrAccessChain %[[#CharPtr]] %[[#Dest]] %[[#]]
14+
; CHECK: %[[#InDestCasted:]] = OpBitcast %[[#LongPtr]] %[[#InDest]]
15+
; CHECK: OpGroupAsyncCopy %[[#]] %[[#]] %[[#InDestCasted]] %[[#Src]] %[[#]] %[[#]] %[[#]]
16+
17+
define spir_kernel void @foo(i64 %idx, ptr addrspace(1) %dest, ptr addrspace(3) %src) {
18+
%p1 = getelementptr inbounds i8, ptr addrspace(1) %dest, i64 %idx
19+
%res = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %p1, ptr addrspace(3) %src, i64 128, i64 1, target("spirv.Event") zeroinitializer)
20+
ret void
21+
}
22+
23+
; For this test case the mangling is important.
24+
declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))

0 commit comments

Comments
 (0)