Skip to content

Improve how lowering of formal arguments in SPIR-V Backend interprets a value of 'kernel_arg_type' #78730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,11 @@ static bool generateSampleImageInst(const StringRef DemangledCall,
ReturnType = ReturnType.substr(0, ReturnType.find('('));
}
SPIRVType *Type = GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder);
if (!Type) {
std::string DiagMsg =
"Unable to recognize SPIRV type name: " + ReturnType;
report_fatal_error(DiagMsg.c_str());
}
MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
MRI->setRegClass(Call->Arguments[3], &SPIRV::IDRegClass);
Expand Down
32 changes: 16 additions & 16 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,22 +157,22 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
isSpecialOpaqueType(OriginalArgType))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);

MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx);
if (!MDKernelArgType || (!MDKernelArgType->getString().ends_with("*") &&
!MDKernelArgType->getString().ends_with("_t")))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);

if (MDKernelArgType->getString().ends_with("*"))
return GR->getOrCreateSPIRVTypeByName(
MDKernelArgType->getString(), MIRBuilder,
addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace()));

if (MDKernelArgType->getString().ends_with("_t"))
return GR->getOrCreateSPIRVTypeByName(
"opencl." + MDKernelArgType->getString().str(), MIRBuilder,
SPIRV::StorageClass::Function, ArgAccessQual);

llvm_unreachable("Unable to recognize argument type name.");
SPIRVType *ResArgType = nullptr;
if (MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx)) {
StringRef MDTypeStr = MDKernelArgType->getString();
if (MDTypeStr.ends_with("*"))
ResArgType = GR->getOrCreateSPIRVTypeByName(
MDTypeStr, MIRBuilder,
addressSpaceToStorageClass(
OriginalArgType->getPointerAddressSpace()));
else if (MDTypeStr.ends_with("_t"))
ResArgType = GR->getOrCreateSPIRVTypeByName(
"opencl." + MDTypeStr.str(), MIRBuilder,
SPIRV::StorageClass::Function, ArgAccessQual);
}
return ResArgType ? ResArgType
: GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder,
ArgAccessQual);
}

static bool isEntryPoint(const Function &F) {
Expand Down
12 changes: 8 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,9 @@ Register SPIRVGlobalRegistry::buildConstantSampler(
SPIRVType *SampTy;
if (SpvType)
SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);
else
SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder);
else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t",
MIRBuilder)) == nullptr)
report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t");

auto Sampler =
ResReg.isValid()
Expand Down Expand Up @@ -941,6 +942,7 @@ SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
return nullptr;
}

// Returns nullptr if unable to recognize SPIRV type name
// TODO: maybe use tablegen to implement this.
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
StringRef TypeStr, MachineIRBuilder &MIRBuilder,
Expand Down Expand Up @@ -992,8 +994,10 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
} else if (TypeStr.starts_with("double")) {
Ty = Type::getDoubleTy(Ctx);
TypeStr = TypeStr.substr(strlen("double"));
} else
llvm_unreachable("Unable to recognize SPIRV type name.");
} else {
// Unable to recognize SPIRV type name
return nullptr;
}

auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);

Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class SPIRVGlobalRegistry {

// Either generate a new OpTypeXXX instruction or return an existing one
// corresponding to the given string containing the name of the builtin type.
// Return nullptr if unable to recognize SPIRV type name from `TypeStr`.
SPIRVType *getOrCreateSPIRVTypeByName(
StringRef TypeStr, MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC = SPIRV::StorageClass::Function,
Expand Down
34 changes: 34 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/custom-kernel-arg-type.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK: %[[TyInt:.*]] = OpTypeInt 8 0
; CHECK: %[[TyPtr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt]]
; CHECK: OpFunctionParameter %[[TyPtr]]
; CHECK: OpFunctionParameter %[[TyPtr]]

%struct.my_kernel_data = type { i32, i32, i32, i32, i32 }
%struct.my_struct = type { i32, i32 }

define spir_kernel void @test(ptr addrspace(1) %in, ptr addrspace(1) %outData) !kernel_arg_type !5 {
entry:
ret void
}

!llvm.module.flags = !{!0}
!opencl.enable.FP_CONTRACT = !{}
!opencl.ocl.version = !{!1}
!opencl.spir.version = !{!2}
!opencl.used.extensions = !{!3}
!opencl.used.optional.core.features = !{!3}
!opencl.compiler.options = !{!3}
!llvm.ident = !{!4}
!opencl.kernels = !{!6}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 1, i32 0}
!2 = !{i32 1, i32 2}
!3 = !{}
!4 = !{!"clang version 6.0.0"}
!5 = !{!"my_kernel_data*", !"struct my_struct*"}
!6 = !{ptr @test}