Skip to content

Commit f768083

Browse files
[SPIR-V] Update type inference and instruction selection (#88254)
This PR contains a series of fixes which are to improve type inference and instruction selection. Namely, it includes: * fix OpSelect to support operands of a pointer type, according to the SPIR-V specification (previously only integer/float/vectors of integer or float were supported) -- a new test case is added and existing test case is updated; * fix TableGen typo's in definition of register classes and introduce a new reg class that is a vector of pointers; * fix usage of a machine function context when there is a need to switch between different machine functions to infer/validate correct types; * add usage of TypedPointerType instead of PointerType so that later stages of type inference are able to distinguish pointer types by their element types, effectively supporting hierarchy of pointer/pointee types and avoiding more complicated recursive type matching on level of machine instructions in favor of direct pointer comparison using LLVM's `Type *` values; * extracting detailed information about operand types using known type rules for some llvm instructions (for instance, by deducing PHI's operand pointee types if PHI's results type was deducted on previous stages of type inference), and adding correspondent `Intrinsic::spv_assign_ptr_type` to keep type info along consequent passes, * ensure that OpConstantComposite reuses a constant when it's already created and available in the same machine function -- otherwise there is a crash while building a dependency graph, the corresponding test case is attached, * implement deduction of function's return type for opaque pointers, a new test case is attached, * make 'emit intrinsics' a module pass to resolve function return types over the module -- first types for all functions of the module must be calculated, and only after that it's feasible to deduct function return types on this earlier stage of translation.
1 parent b79b6f9 commit f768083

19 files changed

+448
-38
lines changed

llvm/lib/Target/SPIRV/SPIRV.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
2424
FunctionPass *createSPIRVRegularizerPass();
2525
FunctionPass *createSPIRVPreLegalizerPass();
2626
FunctionPass *createSPIRVPostLegalizerPass();
27-
FunctionPass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
27+
ModulePass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
2828
InstructionSelector *
2929
createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
3030
const SPIRVSubtarget &Subtarget,

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,16 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
383383
if (F.isDeclaration())
384384
GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
385385
FunctionType *FTy = getOriginalFunctionType(F);
386-
SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
386+
Type *FRetTy = FTy->getReturnType();
387+
if (isUntypedPointerTy(FRetTy)) {
388+
if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
389+
TypedPointerType *DerivedTy =
390+
TypedPointerType::get(FRetElemTy, getPointerAddressSpace(FRetTy));
391+
GR->addReturnType(&F, DerivedTy);
392+
FRetTy = DerivedTy;
393+
}
394+
}
395+
SPIRVType *RetTy = GR->getOrCreateSPIRVType(FRetTy, MIRBuilder);
387396
FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
388397
SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
389398
FTy, RetTy, ArgTypeVRegs, MIRBuilder);
@@ -505,8 +514,13 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
505514
// TODO: support constexpr casts and indirect calls.
506515
if (CF == nullptr)
507516
return false;
508-
if (FunctionType *FTy = getOriginalFunctionType(*CF))
517+
if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
509518
OrigRetTy = FTy->getReturnType();
519+
if (isUntypedPointerTy(OrigRetTy)) {
520+
if (auto *DerivedRetTy = GR->findReturnType(CF))
521+
OrigRetTy = DerivedRetTy;
522+
}
523+
}
510524
}
511525

512526
MachineRegisterInfo *MRI = MIRBuilder.getMRI();

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 158 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
5151

5252
namespace {
5353
class SPIRVEmitIntrinsics
54-
: public FunctionPass,
54+
: public ModulePass,
5555
public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
5656
SPIRVTargetMachine *TM = nullptr;
5757
SPIRVGlobalRegistry *GR = nullptr;
@@ -61,6 +61,9 @@ class SPIRVEmitIntrinsics
6161
DenseMap<Instruction *, Type *> AggrConstTypes;
6262
DenseSet<Instruction *> AggrStores;
6363

64+
// a registry of created Intrinsic::spv_assign_ptr_type instructions
65+
DenseMap<Value *, CallInst *> AssignPtrTypeInstr;
66+
6467
// deduce element type of untyped pointers
6568
Type *deduceElementType(Value *I);
6669
Type *deduceElementTypeHelper(Value *I);
@@ -75,6 +78,9 @@ class SPIRVEmitIntrinsics
7578
Type *deduceNestedTypeHelper(User *U, Type *Ty,
7679
std::unordered_set<Value *> &Visited);
7780

81+
// deduce Types of operands of the Instruction if possible
82+
void deduceOperandElementType(Instruction *I);
83+
7884
void preprocessCompositeConstants(IRBuilder<> &B);
7985
void preprocessUndefs(IRBuilder<> &B);
8086

@@ -111,10 +117,10 @@ class SPIRVEmitIntrinsics
111117

112118
public:
113119
static char ID;
114-
SPIRVEmitIntrinsics() : FunctionPass(ID) {
120+
SPIRVEmitIntrinsics() : ModulePass(ID) {
115121
initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry());
116122
}
117-
SPIRVEmitIntrinsics(SPIRVTargetMachine *_TM) : FunctionPass(ID), TM(_TM) {
123+
SPIRVEmitIntrinsics(SPIRVTargetMachine *_TM) : ModulePass(ID), TM(_TM) {
118124
initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry());
119125
}
120126
Instruction *visitInstruction(Instruction &I) { return &I; }
@@ -130,7 +136,15 @@ class SPIRVEmitIntrinsics
130136
Instruction *visitAllocaInst(AllocaInst &I);
131137
Instruction *visitAtomicCmpXchgInst(AtomicCmpXchgInst &I);
132138
Instruction *visitUnreachableInst(UnreachableInst &I);
133-
bool runOnFunction(Function &F) override;
139+
140+
StringRef getPassName() const override { return "SPIRV emit intrinsics"; }
141+
142+
bool runOnModule(Module &M) override;
143+
bool runOnFunction(Function &F);
144+
145+
void getAnalysisUsage(AnalysisUsage &AU) const override {
146+
ModulePass::getAnalysisUsage(AU);
147+
}
134148
};
135149
} // namespace
136150

@@ -269,6 +283,12 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
269283
if (Ty)
270284
break;
271285
}
286+
} else if (auto *Ref = dyn_cast<SelectInst>(I)) {
287+
for (Value *Op : {Ref->getTrueValue(), Ref->getFalseValue()}) {
288+
Ty = deduceElementTypeByUsersDeep(Op, Visited);
289+
if (Ty)
290+
break;
291+
}
272292
}
273293

274294
// remember the found relationship
@@ -368,6 +388,112 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
368388
return IntegerType::getInt8Ty(I->getContext());
369389
}
370390

391+
// If the Instruction has Pointer operands with unresolved types, this function
392+
// tries to deduce them. If the Instruction has Pointer operands with known
393+
// types which differ from expected, this function tries to insert a bitcast to
394+
// resolve the issue.
395+
void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
396+
SmallVector<std::pair<Value *, unsigned>> Ops;
397+
Type *KnownElemTy = nullptr;
398+
// look for known basic patterns of type inference
399+
if (auto *Ref = dyn_cast<PHINode>(I)) {
400+
if (!isPointerTy(I->getType()) ||
401+
!(KnownElemTy = GR->findDeducedElementType(I)))
402+
return;
403+
for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
404+
Value *Op = Ref->getIncomingValue(i);
405+
if (isPointerTy(Op->getType()))
406+
Ops.push_back(std::make_pair(Op, i));
407+
}
408+
} else if (auto *Ref = dyn_cast<SelectInst>(I)) {
409+
if (!isPointerTy(I->getType()) ||
410+
!(KnownElemTy = GR->findDeducedElementType(I)))
411+
return;
412+
for (unsigned i = 0; i < Ref->getNumOperands(); i++) {
413+
Value *Op = Ref->getOperand(i);
414+
if (isPointerTy(Op->getType()))
415+
Ops.push_back(std::make_pair(Op, i));
416+
}
417+
} else if (auto *Ref = dyn_cast<ReturnInst>(I)) {
418+
Type *RetTy = F->getReturnType();
419+
if (!isPointerTy(RetTy))
420+
return;
421+
Value *Op = Ref->getReturnValue();
422+
if (!Op)
423+
return;
424+
if (!(KnownElemTy = GR->findDeducedElementType(F))) {
425+
if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
426+
GR->addDeducedElementType(F, OpElemTy);
427+
TypedPointerType *DerivedTy =
428+
TypedPointerType::get(OpElemTy, getPointerAddressSpace(RetTy));
429+
GR->addReturnType(F, DerivedTy);
430+
}
431+
return;
432+
}
433+
Ops.push_back(std::make_pair(Op, 0));
434+
} else if (auto *Ref = dyn_cast<ICmpInst>(I)) {
435+
if (!isPointerTy(Ref->getOperand(0)->getType()))
436+
return;
437+
Value *Op0 = Ref->getOperand(0);
438+
Value *Op1 = Ref->getOperand(1);
439+
Type *ElemTy0 = GR->findDeducedElementType(Op0);
440+
Type *ElemTy1 = GR->findDeducedElementType(Op1);
441+
if (ElemTy0) {
442+
KnownElemTy = ElemTy0;
443+
Ops.push_back(std::make_pair(Op1, 1));
444+
} else if (ElemTy1) {
445+
KnownElemTy = ElemTy1;
446+
Ops.push_back(std::make_pair(Op0, 0));
447+
}
448+
}
449+
450+
// There is no enough info to deduce types or all is valid.
451+
if (!KnownElemTy || Ops.size() == 0)
452+
return;
453+
454+
LLVMContext &Ctx = F->getContext();
455+
IRBuilder<> B(Ctx);
456+
for (auto &OpIt : Ops) {
457+
Value *Op = OpIt.first;
458+
if (Op->use_empty())
459+
continue;
460+
Type *Ty = GR->findDeducedElementType(Op);
461+
if (Ty == KnownElemTy)
462+
continue;
463+
if (Instruction *User = dyn_cast<Instruction>(Op->use_begin()->get()))
464+
setInsertPointSkippingPhis(B, User->getNextNode());
465+
else
466+
B.SetInsertPoint(I);
467+
Value *OpTyVal = Constant::getNullValue(KnownElemTy);
468+
Type *OpTy = Op->getType();
469+
if (!Ty) {
470+
GR->addDeducedElementType(Op, KnownElemTy);
471+
// check if there is existing Intrinsic::spv_assign_ptr_type instruction
472+
auto It = AssignPtrTypeInstr.find(Op);
473+
if (It == AssignPtrTypeInstr.end()) {
474+
CallInst *CI =
475+
buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy}, OpTyVal, Op,
476+
{B.getInt32(getPointerAddressSpace(OpTy))}, B);
477+
AssignPtrTypeInstr[Op] = CI;
478+
} else {
479+
It->second->setArgOperand(
480+
1,
481+
MetadataAsValue::get(
482+
Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal))));
483+
}
484+
} else {
485+
SmallVector<Type *, 2> Types = {OpTy, OpTy};
486+
MetadataAsValue *VMD = MetadataAsValue::get(
487+
Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal)));
488+
SmallVector<Value *, 2> Args = {Op, VMD,
489+
B.getInt32(getPointerAddressSpace(OpTy))};
490+
CallInst *PtrCastI =
491+
B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
492+
I->setOperand(OpIt.second, PtrCastI);
493+
}
494+
}
495+
}
496+
371497
void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
372498
Instruction *New,
373499
IRBuilder<> &B) {
@@ -630,6 +756,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
630756
ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B);
631757
GR->addDeducedElementType(CI, ExpectedElementType);
632758
GR->addDeducedElementType(Pointer, ExpectedElementType);
759+
AssignPtrTypeInstr[Pointer] = CI;
633760
return;
634761
}
635762

@@ -914,6 +1041,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
9141041
CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()},
9151042
EltTyConst, I, {B.getInt32(AddressSpace)}, B);
9161043
GR->addDeducedElementType(CI, ElemTy);
1044+
AssignPtrTypeInstr[I] = CI;
9171045
}
9181046

9191047
void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
@@ -1070,6 +1198,7 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
10701198
{B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
10711199
GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
10721200
GR->addDeducedElementType(Arg, ElemTy);
1201+
AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
10731202
}
10741203
}
10751204
}
@@ -1114,6 +1243,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
11141243
insertAssignTypeIntrs(I, B);
11151244
insertPtrCastOrAssignTypeInstr(I, B);
11161245
}
1246+
1247+
for (auto &I : instructions(Func))
1248+
deduceOperandElementType(&I);
1249+
11171250
for (auto *I : Worklist) {
11181251
TrackConstants = true;
11191252
if (!I->getType()->isVoidTy() || isa<StoreInst>(I))
@@ -1126,13 +1259,29 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
11261259
processInstrAfterVisit(I, B);
11271260
}
11281261

1129-
// check if function parameter types are set
1130-
if (!F->isIntrinsic())
1131-
processParamTypes(F, B);
1132-
11331262
return true;
11341263
}
11351264

1136-
FunctionPass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) {
1265+
bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
1266+
bool Changed = false;
1267+
1268+
for (auto &F : M) {
1269+
Changed |= runOnFunction(F);
1270+
}
1271+
1272+
for (auto &F : M) {
1273+
// check if function parameter types are set
1274+
if (!F.isDeclaration() && !F.isIntrinsic()) {
1275+
const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);
1276+
GR = ST.getSPIRVGlobalRegistry();
1277+
IRBuilder<> B(F.getContext());
1278+
processParamTypes(&F, B);
1279+
}
1280+
}
1281+
1282+
return Changed;
1283+
}
1284+
1285+
ModulePass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) {
11371286
return new SPIRVEmitIntrinsics(TM);
11381287
}

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
#include "llvm/ADT/APInt.h"
2424
#include "llvm/IR/Constants.h"
2525
#include "llvm/IR/Type.h"
26-
#include "llvm/IR/TypedPointerType.h"
2726
#include "llvm/Support/Casting.h"
2827
#include <cassert>
2928

@@ -61,7 +60,6 @@ SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
6160
SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
6261
const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
6362
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
64-
6563
SPIRVType *SpirvType =
6664
getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
6765
assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
@@ -726,7 +724,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
726724
bool EmitIR) {
727725
SmallVector<Register, 4> FieldTypes;
728726
for (const auto &Elem : Ty->elements()) {
729-
SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
727+
SPIRVType *ElemTy =
728+
findSPIRVType(toTypedPointer(Elem, Ty->getContext()), MIRBuilder);
730729
assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
731730
"Invalid struct element type");
732731
FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -919,8 +918,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
919918
return SpirvType;
920919
}
921920

922-
SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
923-
auto t = VRegToTypeMap.find(CurMF);
921+
SPIRVType *
922+
SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
923+
const MachineFunction *MF) const {
924+
auto t = VRegToTypeMap.find(MF ? MF : CurMF);
924925
if (t != VRegToTypeMap.end()) {
925926
auto tt = t->second.find(VReg);
926927
if (tt != t->second.end())

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "SPIRVInstrInfo.h"
2222
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
2323
#include "llvm/IR/Constant.h"
24+
#include "llvm/IR/TypedPointerType.h"
2425

2526
namespace llvm {
2627
using SPIRVType = const MachineInstr;
@@ -58,6 +59,9 @@ class SPIRVGlobalRegistry {
5859
SmallPtrSet<const Type *, 4> TypesInProcessing;
5960
DenseMap<const Type *, SPIRVType *> ForwardPointerTypes;
6061

62+
// if a function returns a pointer, this is to map it into TypedPointerType
63+
DenseMap<const Function *, TypedPointerType *> FunResPointerTypes;
64+
6165
// Number of bits pointers and size_t integers require.
6266
const unsigned PointerSize;
6367

@@ -134,6 +138,16 @@ class SPIRVGlobalRegistry {
134138
void setBound(unsigned V) { Bound = V; }
135139
unsigned getBound() { return Bound; }
136140

141+
// Add a record to the map of function return pointer types.
142+
void addReturnType(const Function *ArgF, TypedPointerType *DerivedTy) {
143+
FunResPointerTypes[ArgF] = DerivedTy;
144+
}
145+
// Find a record in the map of function return pointer types.
146+
const TypedPointerType *findReturnType(const Function *ArgF) {
147+
auto It = FunResPointerTypes.find(ArgF);
148+
return It == FunResPointerTypes.end() ? nullptr : It->second;
149+
}
150+
137151
// Deduced element types of untyped pointers and composites:
138152
// - Add a record to the map of deduced element types.
139153
void addDeducedElementType(Value *Val, Type *Ty) { DeducedElTys[Val] = Ty; }
@@ -276,8 +290,12 @@ class SPIRVGlobalRegistry {
276290
SPIRV::AccessQualifier::ReadWrite);
277291

278292
// Return the SPIR-V type instruction corresponding to the given VReg, or
279-
// nullptr if no such type instruction exists.
280-
SPIRVType *getSPIRVTypeForVReg(Register VReg) const;
293+
// nullptr if no such type instruction exists. The second argument MF
294+
// allows to search for the association in a context of the machine functions
295+
// than the current one, without switching between different "current" machine
296+
// functions.
297+
SPIRVType *getSPIRVTypeForVReg(Register VReg,
298+
const MachineFunction *MF = nullptr) const;
281299

282300
// Whether the given VReg has a SPIR-V type mapped to it yet.
283301
bool hasSPIRVTypeForVReg(Register VReg) const {

0 commit comments

Comments
 (0)