Skip to content

Commit d153ef6

Browse files
Add support for SPIR-V extension: SPV_INTEL_function_pointers (#80759)
This PR adds initial support for "SPV_INTEL_function_pointers" SPIR-V extension: https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_function_pointers.asciidoc The goal of the extension is to support indirect function calls and translation of function pointers into SPIR-V.
1 parent 9d8a236 commit d153ef6

13 files changed

+307
-24
lines changed

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 99 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
3434
const Value *Val, ArrayRef<Register> VRegs,
3535
FunctionLoweringInfo &FLI,
3636
Register SwiftErrorVReg) const {
37+
// Maybe run postponed production of types for function pointers
38+
if (IndirectCalls.size() > 0) {
39+
produceIndirectPtrTypes(MIRBuilder);
40+
IndirectCalls.clear();
41+
}
42+
3743
// Currently all return types should use a single register.
3844
// TODO: handle the case of multiple registers.
3945
if (VRegs.size() > 1)
@@ -292,7 +298,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
292298
}
293299
}
294300

295-
// Generate a SPIR-V type for the function.
296301
auto MRI = MIRBuilder.getMRI();
297302
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
298303
MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
@@ -301,17 +306,17 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
301306
SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
302307
SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
303308
FTy, RetTy, ArgTypeVRegs, MIRBuilder);
304-
305-
// Build the OpTypeFunction declaring it.
306309
uint32_t FuncControl = getFunctionControl(F);
307310

308-
MIRBuilder.buildInstr(SPIRV::OpFunction)
309-
.addDef(FuncVReg)
310-
.addUse(GR->getSPIRVTypeID(RetTy))
311-
.addImm(FuncControl)
312-
.addUse(GR->getSPIRVTypeID(FuncTy));
311+
// Add OpFunction instruction
312+
MachineInstrBuilder MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
313+
.addDef(FuncVReg)
314+
.addUse(GR->getSPIRVTypeID(RetTy))
315+
.addImm(FuncControl)
316+
.addUse(GR->getSPIRVTypeID(FuncTy));
317+
GR->recordFunctionDefinition(&F, &MB.getInstr()->getOperand(0));
313318

314-
// Add OpFunctionParameters.
319+
// Add OpFunctionParameter instructions
315320
int i = 0;
316321
for (const auto &Arg : F.args()) {
317322
assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
@@ -343,9 +348,56 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
343348
{static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
344349
}
345350

351+
// Handle function pointers decoration
352+
const auto *ST =
353+
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
354+
bool hasFunctionPointers =
355+
ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
356+
if (hasFunctionPointers) {
357+
if (F.hasFnAttribute("referenced-indirectly")) {
358+
assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
359+
"Unexpected 'referenced-indirectly' attribute of the kernel "
360+
"function");
361+
buildOpDecorate(FuncVReg, MIRBuilder,
362+
SPIRV::Decoration::ReferencedIndirectlyINTEL, {});
363+
}
364+
}
365+
346366
return true;
347367
}
348368

369+
// Used to postpone producing of indirect function pointer types after all
370+
// indirect calls info is collected
371+
// TODO:
372+
// - add a topological sort of IndirectCalls to ensure the best types knowledge
373+
// - we may need to fix function formal parameter types if they are opaque
374+
// pointers used as function pointers in these indirect calls
375+
void SPIRVCallLowering::produceIndirectPtrTypes(
376+
MachineIRBuilder &MIRBuilder) const {
377+
// Create indirect call data types if any
378+
MachineFunction &MF = MIRBuilder.getMF();
379+
for (auto const &IC : IndirectCalls) {
380+
SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(IC.RetTy, MIRBuilder);
381+
SmallVector<SPIRVType *, 4> SpirvArgTypes;
382+
for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
383+
SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(IC.ArgTys[i], MIRBuilder);
384+
SpirvArgTypes.push_back(SPIRVTy);
385+
if (!GR->getSPIRVTypeForVReg(IC.ArgRegs[i]))
386+
GR->assignSPIRVTypeToVReg(SPIRVTy, IC.ArgRegs[i], MF);
387+
}
388+
// SPIR-V function type:
389+
FunctionType *FTy =
390+
FunctionType::get(const_cast<Type *>(IC.RetTy), IC.ArgTys, false);
391+
SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
392+
FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
393+
// SPIR-V pointer to function type:
394+
SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
395+
SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
396+
// Correct the Calee type
397+
GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
398+
}
399+
}
400+
349401
bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
350402
CallLoweringInfo &Info) const {
351403
// Currently call returns should have single vregs.
@@ -356,45 +408,44 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
356408
GR->setCurrentFunc(MF);
357409
FunctionType *FTy = nullptr;
358410
const Function *CF = nullptr;
411+
std::string DemangledName;
412+
const Type *OrigRetTy = Info.OrigRet.Ty;
359413

360414
// Emit a regular OpFunctionCall. If it's an externally declared function,
361415
// be sure to emit its type and function declaration here. It will be hoisted
362416
// globally later.
363417
if (Info.Callee.isGlobal()) {
418+
std::string FuncName = Info.Callee.getGlobal()->getName().str();
419+
DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
364420
CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
365421
// TODO: support constexpr casts and indirect calls.
366422
if (CF == nullptr)
367423
return false;
368-
FTy = getOriginalFunctionType(*CF);
424+
if ((FTy = getOriginalFunctionType(*CF)) != nullptr)
425+
OrigRetTy = FTy->getReturnType();
369426
}
370427

371428
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
372429
Register ResVReg =
373430
Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
374-
std::string FuncName = Info.Callee.getGlobal()->getName().str();
375-
std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
376431
const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
377432
// TODO: check that it's OCL builtin, then apply OpenCL_std.
378433
if (!DemangledName.empty() && CF && CF->isDeclaration() &&
379434
ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
380-
const Type *OrigRetTy = Info.OrigRet.Ty;
381-
if (FTy)
382-
OrigRetTy = FTy->getReturnType();
383435
SmallVector<Register, 8> ArgVRegs;
384436
for (auto Arg : Info.OrigArgs) {
385437
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
386438
ArgVRegs.push_back(Arg.Regs[0]);
387439
SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
388440
if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
389-
GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF());
441+
GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
390442
}
391443
if (auto Res = SPIRV::lowerBuiltin(
392444
DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder,
393445
ResVReg, OrigRetTy, ArgVRegs, GR))
394446
return *Res;
395447
}
396-
if (CF && CF->isDeclaration() &&
397-
!GR->find(CF, &MIRBuilder.getMF()).isValid()) {
448+
if (CF && CF->isDeclaration() && !GR->find(CF, &MF).isValid()) {
398449
// Emit the type info and forward function declaration to the first MBB
399450
// to ensure VReg definition dependencies are valid across all MBBs.
400451
MachineIRBuilder FirstBlockBuilder;
@@ -416,14 +467,40 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
416467
lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
417468
}
418469

470+
unsigned CallOp;
471+
if (Info.CB->isIndirectCall()) {
472+
if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers))
473+
report_fatal_error("An indirect call is encountered but SPIR-V without "
474+
"extensions does not support it",
475+
false);
476+
// Set instruction operation according to SPV_INTEL_function_pointers
477+
CallOp = SPIRV::OpFunctionPointerCallINTEL;
478+
// Collect information about the indirect call to support possible
479+
// specification of opaque ptr types of parent function's parameters
480+
Register CalleeReg = Info.Callee.getReg();
481+
if (CalleeReg.isValid()) {
482+
SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
483+
IndirectCall.Callee = CalleeReg;
484+
IndirectCall.RetTy = OrigRetTy;
485+
for (const auto &Arg : Info.OrigArgs) {
486+
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
487+
IndirectCall.ArgTys.push_back(Arg.Ty);
488+
IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
489+
}
490+
IndirectCalls.push_back(IndirectCall);
491+
}
492+
} else {
493+
// Emit a regular OpFunctionCall
494+
CallOp = SPIRV::OpFunctionCall;
495+
}
496+
419497
// Make sure there's a valid return reg, even for functions returning void.
420498
if (!ResVReg.isValid())
421499
ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
422-
SPIRVType *RetType =
423-
GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder);
500+
SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder);
424501

425-
// Emit the OpFunctionCall and its args.
426-
auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall)
502+
// Emit the call instruction and its args.
503+
auto MIB = MIRBuilder.buildInstr(CallOp)
427504
.addDef(ResVReg)
428505
.addUse(GR->getSPIRVTypeID(RetType))
429506
.add(Info.Callee);

llvm/lib/Target/SPIRV/SPIRVCallLowering.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ class SPIRVCallLowering : public CallLowering {
2626
// Used to create and assign function, argument, and return type information.
2727
SPIRVGlobalRegistry *GR;
2828

29+
// Used to postpone producing of indirect function pointer types
30+
// after all indirect calls info is collected
31+
struct SPIRVIndirectCall {
32+
const Type *RetTy = nullptr;
33+
SmallVector<Type *> ArgTys;
34+
SmallVector<Register> ArgRegs;
35+
Register Callee;
36+
};
37+
void produceIndirectPtrTypes(MachineIRBuilder &MIRBuilder) const;
38+
mutable SmallVector<SPIRVIndirectCall> IndirectCalls;
39+
2940
public:
3041
SPIRVCallLowering(const SPIRVTargetLowering &TLI, SPIRVGlobalRegistry *GR);
3142

llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,14 @@ void SPIRVGeneralDuplicatesTracker::buildDepsGraph(
5454
MachineOperand &Op = MI->getOperand(i);
5555
if (!Op.isReg())
5656
continue;
57-
MachineOperand *RegOp = &MRI.getVRegDef(Op.getReg())->getOperand(0);
57+
MachineInstr *VRegDef = MRI.getVRegDef(Op.getReg());
58+
// References to a function via function pointers generate virtual
59+
// registers without a definition. We are able to resolve this
60+
// reference using Globar Register info into an OpFunction instruction
61+
// but do not expect to find it in Reg2Entry.
62+
if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL && i == 2)
63+
continue;
64+
MachineOperand *RegOp = &VRegDef->getOperand(0);
5865
assert((MI->getOpcode() == SPIRV::OpVariable && i == 3) ||
5966
Reg2Entry.count(RegOp));
6067
if (Reg2Entry.count(RegOp))

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ class SPIRVGlobalRegistry {
3838

3939
DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;
4040

41+
// map a Function to its definition (as a machine instruction operand)
42+
DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
43+
// map function pointer (as a machine instruction operand) to the used
44+
// Function
45+
DenseMap<const MachineOperand *, const Function *> InstrToFunction;
46+
4147
// Look for an equivalent of the newType in the map. Return the equivalent
4248
// if it's found, otherwise insert newType to the map and return the type.
4349
const MachineInstr *checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
@@ -101,6 +107,29 @@ class SPIRVGlobalRegistry {
101107
DT.buildDepsGraph(Graph, MMI);
102108
}
103109

110+
// Map a machine operand that represents a use of a function via function
111+
// pointer to a machine operand that represents the function definition.
112+
// Return either the register or invalid value, because we have no context for
113+
// a good diagnostic message in case of unexpectedly missing references.
114+
const MachineOperand *getFunctionDefinitionByUse(const MachineOperand *Use) {
115+
auto ResF = InstrToFunction.find(Use);
116+
if (ResF == InstrToFunction.end())
117+
return nullptr;
118+
auto ResReg = FunctionToInstr.find(ResF->second);
119+
return ResReg == FunctionToInstr.end() ? nullptr : ResReg->second;
120+
}
121+
// map function pointer (as a machine instruction operand) to the used
122+
// Function
123+
void recordFunctionPointer(const MachineOperand *MO, const Function *F) {
124+
InstrToFunction[MO] = F;
125+
}
126+
// map a Function to its definition (as a machine instruction)
127+
void recordFunctionDefinition(const Function *F, const MachineOperand *MO) {
128+
FunctionToInstr[F] = MO;
129+
}
130+
// Return true if any OpConstantFunctionPointerINTEL were generated
131+
bool hasConstFunPtr() { return !InstrToFunction.empty(); }
132+
104133
// Get or create a SPIR-V type corresponding the given LLVM IR type,
105134
// and map it to the given VReg by creating an ASSIGN_TYPE instruction.
106135
SPIRVType *assignTypeToVReg(const Type *Type, Register VReg,

llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ bool SPIRVInstrInfo::isConstantInstr(const MachineInstr &MI) const {
4040
case SPIRV::OpSpecConstantComposite:
4141
case SPIRV::OpSpecConstantOp:
4242
case SPIRV::OpUndef:
43+
case SPIRV::OpConstantFunctionPointerINTEL:
4344
return true;
4445
default:
4546
return false;

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,16 @@ def OpGroupNonUniformLogicalAnd: OpGroupNUGroup<"LogicalAnd", 362>;
762762
def OpGroupNonUniformLogicalOr: OpGroupNUGroup<"LogicalOr", 363>;
763763
def OpGroupNonUniformLogicalXor: OpGroupNUGroup<"LogicalXor", 364>;
764764

765+
// 3.49.7, Constant-Creation Instructions
766+
767+
// - SPV_INTEL_function_pointers
768+
def OpConstantFunctionPointerINTEL: Op<5600, (outs ID:$res), (ins TYPE:$ty, ID:$fun), "$res = OpConstantFunctionPointerINTEL $ty $fun">;
769+
770+
// 3.49.9. Function Instructions
771+
772+
// - SPV_INTEL_function_pointers
773+
def OpFunctionPointerCallINTEL: Op<5601, (outs ID:$res), (ins TYPE:$ty, ID:$funPtr, variable_ops), "$res = OpFunctionPointerCallINTEL $ty $funPtr">;
774+
765775
// 3.49.21. Group and Subgroup Instructions
766776
def OpSubgroupShuffleINTEL: Op<5571, (outs ID:$res), (ins TYPE:$type, ID:$data, ID:$invocationId),
767777
"$res = OpSubgroupShuffleINTEL $type $data $invocationId">;

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,6 +1534,12 @@ bool SPIRVInstructionSelector::selectGlobalValue(
15341534
GlobalIdent = GV->getGlobalIdentifier();
15351535
}
15361536

1537+
// Behaviour of functions as operands depends on availability of the
1538+
// corresponding extension (SPV_INTEL_function_pointers):
1539+
// - If there is an extension to operate with functions as operands:
1540+
// We create a proper constant operand and evaluate a correct type for a
1541+
// function pointer.
1542+
// - Without the required extension:
15371543
// We have functions as operands in tests with blocks of instruction e.g. in
15381544
// transcoding/global_block.ll. These operands are not used and should be
15391545
// substituted by zero constants. Their type is expected to be always
@@ -1545,6 +1551,27 @@ bool SPIRVInstructionSelector::selectGlobalValue(
15451551
if (!NewReg.isValid()) {
15461552
Register NewReg = ResVReg;
15471553
GR.add(ConstVal, GR.CurMF, NewReg);
1554+
const Function *GVFun =
1555+
STI.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)
1556+
? dyn_cast<Function>(GV)
1557+
: nullptr;
1558+
if (GVFun) {
1559+
// References to a function via function pointers generate virtual
1560+
// registers without a definition. We will resolve it later, during
1561+
// module analysis stage.
1562+
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1563+
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
1564+
MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
1565+
MachineInstrBuilder MB =
1566+
BuildMI(BB, I, I.getDebugLoc(),
1567+
TII.get(SPIRV::OpConstantFunctionPointerINTEL))
1568+
.addDef(NewReg)
1569+
.addUse(GR.getSPIRVTypeID(ResType))
1570+
.addUse(FuncVReg);
1571+
// mapping the function pointer to the used Function
1572+
GR.recordFunctionPointer(&MB.getInstr()->getOperand(2), GVFun);
1573+
return MB.constrainAllUses(TII, TRI, RBI);
1574+
}
15481575
return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
15491576
.addDef(NewReg)
15501577
.addUse(GR.getSPIRVTypeID(ResType))

0 commit comments

Comments
 (0)