Skip to content

Add support for SPIR-V extension: SPV_INTEL_function_pointers #80759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 99 additions & 22 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
const Value *Val, ArrayRef<Register> VRegs,
FunctionLoweringInfo &FLI,
Register SwiftErrorVReg) const {
// Maybe run postponed production of types for function pointers
if (IndirectCalls.size() > 0) {
produceIndirectPtrTypes(MIRBuilder);
IndirectCalls.clear();
}

// Currently all return types should use a single register.
// TODO: handle the case of multiple registers.
if (VRegs.size() > 1)
Expand Down Expand Up @@ -292,7 +298,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
}
}

// Generate a SPIR-V type for the function.
auto MRI = MIRBuilder.getMRI();
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
Expand All @@ -301,17 +306,17 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
FTy, RetTy, ArgTypeVRegs, MIRBuilder);

// Build the OpTypeFunction declaring it.
uint32_t FuncControl = getFunctionControl(F);

MIRBuilder.buildInstr(SPIRV::OpFunction)
.addDef(FuncVReg)
.addUse(GR->getSPIRVTypeID(RetTy))
.addImm(FuncControl)
.addUse(GR->getSPIRVTypeID(FuncTy));
// Add OpFunction instruction
MachineInstrBuilder MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
.addDef(FuncVReg)
.addUse(GR->getSPIRVTypeID(RetTy))
.addImm(FuncControl)
.addUse(GR->getSPIRVTypeID(FuncTy));
GR->recordFunctionDefinition(&F, &MB.getInstr()->getOperand(0));

// Add OpFunctionParameters.
// Add OpFunctionParameter instructions
int i = 0;
for (const auto &Arg : F.args()) {
assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
Expand Down Expand Up @@ -343,9 +348,56 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
{static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
}

// Handle function pointers decoration
const auto *ST =
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
bool hasFunctionPointers =
ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
if (hasFunctionPointers) {
if (F.hasFnAttribute("referenced-indirectly")) {
assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
"Unexpected 'referenced-indirectly' attribute of the kernel "
"function");
buildOpDecorate(FuncVReg, MIRBuilder,
SPIRV::Decoration::ReferencedIndirectlyINTEL, {});
}
}

return true;
}

// Used to postpone producing of indirect function pointer types after all
// indirect calls info is collected
// TODO:
// - add a topological sort of IndirectCalls to ensure the best types knowledge
// - we may need to fix function formal parameter types if they are opaque
// pointers used as function pointers in these indirect calls
void SPIRVCallLowering::produceIndirectPtrTypes(
MachineIRBuilder &MIRBuilder) const {
// Create indirect call data types if any
MachineFunction &MF = MIRBuilder.getMF();
for (auto const &IC : IndirectCalls) {
SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(IC.RetTy, MIRBuilder);
SmallVector<SPIRVType *, 4> SpirvArgTypes;
for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(IC.ArgTys[i], MIRBuilder);
SpirvArgTypes.push_back(SPIRVTy);
if (!GR->getSPIRVTypeForVReg(IC.ArgRegs[i]))
GR->assignSPIRVTypeToVReg(SPIRVTy, IC.ArgRegs[i], MF);
}
// SPIR-V function type:
FunctionType *FTy =
FunctionType::get(const_cast<Type *>(IC.RetTy), IC.ArgTys, false);
SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
// SPIR-V pointer to function type:
SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
// Correct the Calee type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: Calee -> Callee

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I'll fix it in the next PR/commit

GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
}
}

bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
CallLoweringInfo &Info) const {
// Currently call returns should have single vregs.
Expand All @@ -356,45 +408,44 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
GR->setCurrentFunc(MF);
FunctionType *FTy = nullptr;
const Function *CF = nullptr;
std::string DemangledName;
const Type *OrigRetTy = Info.OrigRet.Ty;

// Emit a regular OpFunctionCall. If it's an externally declared function,
// be sure to emit its type and function declaration here. It will be hoisted
// globally later.
if (Info.Callee.isGlobal()) {
std::string FuncName = Info.Callee.getGlobal()->getName().str();
DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
// TODO: support constexpr casts and indirect calls.
if (CF == nullptr)
return false;
FTy = getOriginalFunctionType(*CF);
if ((FTy = getOriginalFunctionType(*CF)) != nullptr)
OrigRetTy = FTy->getReturnType();
}

MachineRegisterInfo *MRI = MIRBuilder.getMRI();
Register ResVReg =
Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
std::string FuncName = Info.Callee.getGlobal()->getName().str();
std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
// TODO: check that it's OCL builtin, then apply OpenCL_std.
if (!DemangledName.empty() && CF && CF->isDeclaration() &&
ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
const Type *OrigRetTy = Info.OrigRet.Ty;
if (FTy)
OrigRetTy = FTy->getReturnType();
SmallVector<Register, 8> ArgVRegs;
for (auto Arg : Info.OrigArgs) {
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
ArgVRegs.push_back(Arg.Regs[0]);
SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF());
GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
}
if (auto Res = SPIRV::lowerBuiltin(
DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder,
ResVReg, OrigRetTy, ArgVRegs, GR))
return *Res;
}
if (CF && CF->isDeclaration() &&
!GR->find(CF, &MIRBuilder.getMF()).isValid()) {
if (CF && CF->isDeclaration() && !GR->find(CF, &MF).isValid()) {
// Emit the type info and forward function declaration to the first MBB
// to ensure VReg definition dependencies are valid across all MBBs.
MachineIRBuilder FirstBlockBuilder;
Expand All @@ -416,14 +467,40 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
}

unsigned CallOp;
if (Info.CB->isIndirectCall()) {
if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers))
report_fatal_error("An indirect call is encountered but SPIR-V without "
"extensions does not support it",
false);
// Set instruction operation according to SPV_INTEL_function_pointers
CallOp = SPIRV::OpFunctionPointerCallINTEL;
// Collect information about the indirect call to support possible
// specification of opaque ptr types of parent function's parameters
Register CalleeReg = Info.Callee.getReg();
if (CalleeReg.isValid()) {
SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
IndirectCall.Callee = CalleeReg;
IndirectCall.RetTy = OrigRetTy;
for (const auto &Arg : Info.OrigArgs) {
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
IndirectCall.ArgTys.push_back(Arg.Ty);
IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
}
IndirectCalls.push_back(IndirectCall);
}
} else {
// Emit a regular OpFunctionCall
CallOp = SPIRV::OpFunctionCall;
}

// Make sure there's a valid return reg, even for functions returning void.
if (!ResVReg.isValid())
ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
SPIRVType *RetType =
GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder);
SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder);

// Emit the OpFunctionCall and its args.
auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall)
// Emit the call instruction and its args.
auto MIB = MIRBuilder.buildInstr(CallOp)
.addDef(ResVReg)
.addUse(GR->getSPIRVTypeID(RetType))
.add(Info.Callee);
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ class SPIRVCallLowering : public CallLowering {
// Used to create and assign function, argument, and return type information.
SPIRVGlobalRegistry *GR;

// Used to postpone producing of indirect function pointer types
// after all indirect calls info is collected
struct SPIRVIndirectCall {
const Type *RetTy = nullptr;
SmallVector<Type *> ArgTys;
SmallVector<Register> ArgRegs;
Register Callee;
};
void produceIndirectPtrTypes(MachineIRBuilder &MIRBuilder) const;
mutable SmallVector<SPIRVIndirectCall> IndirectCalls;

public:
SPIRVCallLowering(const SPIRVTargetLowering &TLI, SPIRVGlobalRegistry *GR);

Expand Down
9 changes: 8 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@ void SPIRVGeneralDuplicatesTracker::buildDepsGraph(
MachineOperand &Op = MI->getOperand(i);
if (!Op.isReg())
continue;
MachineOperand *RegOp = &MRI.getVRegDef(Op.getReg())->getOperand(0);
MachineInstr *VRegDef = MRI.getVRegDef(Op.getReg());
// References to a function via function pointers generate virtual
// registers without a definition. We are able to resolve this
// reference using Globar Register info into an OpFunction instruction
// but do not expect to find it in Reg2Entry.
if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL && i == 2)
continue;
MachineOperand *RegOp = &VRegDef->getOperand(0);
assert((MI->getOpcode() == SPIRV::OpVariable && i == 3) ||
Reg2Entry.count(RegOp));
if (Reg2Entry.count(RegOp))
Expand Down
29 changes: 29 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ class SPIRVGlobalRegistry {

DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;

// map a Function to its definition (as a machine instruction operand)
DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
// map function pointer (as a machine instruction operand) to the used
// Function
DenseMap<const MachineOperand *, const Function *> InstrToFunction;

// Look for an equivalent of the newType in the map. Return the equivalent
// if it's found, otherwise insert newType to the map and return the type.
const MachineInstr *checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
Expand Down Expand Up @@ -101,6 +107,29 @@ class SPIRVGlobalRegistry {
DT.buildDepsGraph(Graph, MMI);
}

// Map a machine operand that represents a use of a function via function
// pointer to a machine operand that represents the function definition.
// Return either the register or invalid value, because we have no context for
// a good diagnostic message in case of unexpectedly missing references.
const MachineOperand *getFunctionDefinitionByUse(const MachineOperand *Use) {
auto ResF = InstrToFunction.find(Use);
if (ResF == InstrToFunction.end())
return nullptr;
auto ResReg = FunctionToInstr.find(ResF->second);
return ResReg == FunctionToInstr.end() ? nullptr : ResReg->second;
}
// map function pointer (as a machine instruction operand) to the used
// Function
void recordFunctionPointer(const MachineOperand *MO, const Function *F) {
InstrToFunction[MO] = F;
}
// map a Function to its definition (as a machine instruction)
void recordFunctionDefinition(const Function *F, const MachineOperand *MO) {
FunctionToInstr[F] = MO;
}
// Return true if any OpConstantFunctionPointerINTEL were generated
bool hasConstFunPtr() { return !InstrToFunction.empty(); }

// Get or create a SPIR-V type corresponding the given LLVM IR type,
// and map it to the given VReg by creating an ASSIGN_TYPE instruction.
SPIRVType *assignTypeToVReg(const Type *Type, Register VReg,
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ bool SPIRVInstrInfo::isConstantInstr(const MachineInstr &MI) const {
case SPIRV::OpSpecConstantComposite:
case SPIRV::OpSpecConstantOp:
case SPIRV::OpUndef:
case SPIRV::OpConstantFunctionPointerINTEL:
return true;
default:
return false;
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,16 @@ def OpGroupNonUniformLogicalAnd: OpGroupNUGroup<"LogicalAnd", 362>;
def OpGroupNonUniformLogicalOr: OpGroupNUGroup<"LogicalOr", 363>;
def OpGroupNonUniformLogicalXor: OpGroupNUGroup<"LogicalXor", 364>;

// 3.49.7, Constant-Creation Instructions

// - SPV_INTEL_function_pointers
def OpConstantFunctionPointerINTEL: Op<5600, (outs ID:$res), (ins TYPE:$ty, ID:$fun), "$res = OpConstantFunctionPointerINTEL $ty $fun">;

// 3.49.9. Function Instructions

// - SPV_INTEL_function_pointers
def OpFunctionPointerCallINTEL: Op<5601, (outs ID:$res), (ins TYPE:$ty, ID:$funPtr, variable_ops), "$res = OpFunctionPointerCallINTEL $ty $funPtr">;

// 3.49.21. Group and Subgroup Instructions
def OpSubgroupShuffleINTEL: Op<5571, (outs ID:$res), (ins TYPE:$type, ID:$data, ID:$invocationId),
"$res = OpSubgroupShuffleINTEL $type $data $invocationId">;
Expand Down
27 changes: 27 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,12 @@ bool SPIRVInstructionSelector::selectGlobalValue(
GlobalIdent = GV->getGlobalIdentifier();
}

// Behaviour of functions as operands depends on availability of the
// corresponding extension (SPV_INTEL_function_pointers):
// - If there is an extension to operate with functions as operands:
// We create a proper constant operand and evaluate a correct type for a
// function pointer.
// - Without the required extension:
// We have functions as operands in tests with blocks of instruction e.g. in
// transcoding/global_block.ll. These operands are not used and should be
// substituted by zero constants. Their type is expected to be always
Expand All @@ -1545,6 +1551,27 @@ bool SPIRVInstructionSelector::selectGlobalValue(
if (!NewReg.isValid()) {
Register NewReg = ResVReg;
GR.add(ConstVal, GR.CurMF, NewReg);
const Function *GVFun =
STI.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)
? dyn_cast<Function>(GV)
: nullptr;
if (GVFun) {
// References to a function via function pointers generate virtual
// registers without a definition. We will resolve it later, during
// module analysis stage.
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
MachineInstrBuilder MB =
BuildMI(BB, I, I.getDebugLoc(),
TII.get(SPIRV::OpConstantFunctionPointerINTEL))
.addDef(NewReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(FuncVReg);
// mapping the function pointer to the used Function
GR.recordFunctionPointer(&MB.getInstr()->getOperand(2), GVFun);
return MB.constrainAllUses(TII, TRI, RBI);
}
return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
.addDef(NewReg)
.addUse(GR.getSPIRVTypeID(ResType))
Expand Down
Loading