Skip to content

[SPIR-V] Update type inference and instruction selection #88254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 15, 2024
Merged
2 changes: 1 addition & 1 deletion llvm/lib/Target/SPIRV/SPIRV.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
FunctionPass *createSPIRVRegularizerPass();
FunctionPass *createSPIRVPreLegalizerPass();
FunctionPass *createSPIRVPostLegalizerPass();
FunctionPass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
ModulePass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
InstructionSelector *
createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
const SPIRVSubtarget &Subtarget,
Expand Down
18 changes: 16 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,16 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
if (F.isDeclaration())
GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
FunctionType *FTy = getOriginalFunctionType(F);
SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
Type *FRetTy = FTy->getReturnType();
if (isUntypedPointerTy(FRetTy)) {
if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
TypedPointerType *DerivedTy =
TypedPointerType::get(FRetElemTy, getPointerAddressSpace(FRetTy));
GR->addReturnType(&F, DerivedTy);
FRetTy = DerivedTy;
}
}
SPIRVType *RetTy = GR->getOrCreateSPIRVType(FRetTy, MIRBuilder);
FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
FTy, RetTy, ArgTypeVRegs, MIRBuilder);
Expand Down Expand Up @@ -505,8 +514,13 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
// TODO: support constexpr casts and indirect calls.
if (CF == nullptr)
return false;
if (FunctionType *FTy = getOriginalFunctionType(*CF))
if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
OrigRetTy = FTy->getReturnType();
if (isUntypedPointerTy(OrigRetTy)) {
if (auto *DerivedRetTy = GR->findReturnType(CF))
OrigRetTy = DerivedRetTy;
}
}
}

MachineRegisterInfo *MRI = MIRBuilder.getMRI();
Expand Down
167 changes: 158 additions & 9 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);

namespace {
class SPIRVEmitIntrinsics
: public FunctionPass,
: public ModulePass,
public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
SPIRVTargetMachine *TM = nullptr;
SPIRVGlobalRegistry *GR = nullptr;
Expand All @@ -61,6 +61,9 @@ class SPIRVEmitIntrinsics
DenseMap<Instruction *, Type *> AggrConstTypes;
DenseSet<Instruction *> AggrStores;

// a registry of created Intrinsic::spv_assign_ptr_type instructions
DenseMap<Value *, CallInst *> AssignPtrTypeInstr;

// deduce element type of untyped pointers
Type *deduceElementType(Value *I);
Type *deduceElementTypeHelper(Value *I);
Expand All @@ -75,6 +78,9 @@ class SPIRVEmitIntrinsics
Type *deduceNestedTypeHelper(User *U, Type *Ty,
std::unordered_set<Value *> &Visited);

// deduce Types of operands of the Instruction if possible
void deduceOperandElementType(Instruction *I);

void preprocessCompositeConstants(IRBuilder<> &B);
void preprocessUndefs(IRBuilder<> &B);

Expand Down Expand Up @@ -111,10 +117,10 @@ class SPIRVEmitIntrinsics

public:
static char ID;
SPIRVEmitIntrinsics() : FunctionPass(ID) {
SPIRVEmitIntrinsics() : ModulePass(ID) {
initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry());
}
SPIRVEmitIntrinsics(SPIRVTargetMachine *_TM) : FunctionPass(ID), TM(_TM) {
SPIRVEmitIntrinsics(SPIRVTargetMachine *_TM) : ModulePass(ID), TM(_TM) {
initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry());
}
Instruction *visitInstruction(Instruction &I) { return &I; }
Expand All @@ -130,7 +136,15 @@ class SPIRVEmitIntrinsics
Instruction *visitAllocaInst(AllocaInst &I);
Instruction *visitAtomicCmpXchgInst(AtomicCmpXchgInst &I);
Instruction *visitUnreachableInst(UnreachableInst &I);
bool runOnFunction(Function &F) override;

StringRef getPassName() const override { return "SPIRV emit intrinsics"; }

bool runOnModule(Module &M) override;
bool runOnFunction(Function &F);

void getAnalysisUsage(AnalysisUsage &AU) const override {
ModulePass::getAnalysisUsage(AU);
}
};
} // namespace

Expand Down Expand Up @@ -269,6 +283,12 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
if (Ty)
break;
}
} else if (auto *Ref = dyn_cast<SelectInst>(I)) {
for (Value *Op : {Ref->getTrueValue(), Ref->getFalseValue()}) {
Ty = deduceElementTypeByUsersDeep(Op, Visited);
if (Ty)
break;
}
}

// remember the found relationship
Expand Down Expand Up @@ -368,6 +388,112 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
return IntegerType::getInt8Ty(I->getContext());
}

// If the Instruction has Pointer operands with unresolved types, this function
// tries to deduce them. If the Instruction has Pointer operands with known
// types which differ from expected, this function tries to insert a bitcast to
// resolve the issue.
void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
SmallVector<std::pair<Value *, unsigned>> Ops;
Type *KnownElemTy = nullptr;
// look for known basic patterns of type inference
if (auto *Ref = dyn_cast<PHINode>(I)) {
if (!isPointerTy(I->getType()) ||
!(KnownElemTy = GR->findDeducedElementType(I)))
return;
for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
Value *Op = Ref->getIncomingValue(i);
if (isPointerTy(Op->getType()))
Ops.push_back(std::make_pair(Op, i));
}
} else if (auto *Ref = dyn_cast<SelectInst>(I)) {
if (!isPointerTy(I->getType()) ||
!(KnownElemTy = GR->findDeducedElementType(I)))
return;
for (unsigned i = 0; i < Ref->getNumOperands(); i++) {
Value *Op = Ref->getOperand(i);
if (isPointerTy(Op->getType()))
Ops.push_back(std::make_pair(Op, i));
}
} else if (auto *Ref = dyn_cast<ReturnInst>(I)) {
Type *RetTy = F->getReturnType();
if (!isPointerTy(RetTy))
return;
Value *Op = Ref->getReturnValue();
if (!Op)
return;
if (!(KnownElemTy = GR->findDeducedElementType(F))) {
if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
GR->addDeducedElementType(F, OpElemTy);
TypedPointerType *DerivedTy =
TypedPointerType::get(OpElemTy, getPointerAddressSpace(RetTy));
GR->addReturnType(F, DerivedTy);
}
return;
}
Ops.push_back(std::make_pair(Op, 0));
} else if (auto *Ref = dyn_cast<ICmpInst>(I)) {
if (!isPointerTy(Ref->getOperand(0)->getType()))
return;
Value *Op0 = Ref->getOperand(0);
Value *Op1 = Ref->getOperand(1);
Type *ElemTy0 = GR->findDeducedElementType(Op0);
Type *ElemTy1 = GR->findDeducedElementType(Op1);
if (ElemTy0) {
KnownElemTy = ElemTy0;
Ops.push_back(std::make_pair(Op1, 1));
} else if (ElemTy1) {
KnownElemTy = ElemTy1;
Ops.push_back(std::make_pair(Op0, 0));
}
}

// There is no enough info to deduce types or all is valid.
if (!KnownElemTy || Ops.size() == 0)
return;

LLVMContext &Ctx = F->getContext();
IRBuilder<> B(Ctx);
for (auto &OpIt : Ops) {
Value *Op = OpIt.first;
if (Op->use_empty())
continue;
Type *Ty = GR->findDeducedElementType(Op);
if (Ty == KnownElemTy)
continue;
if (Instruction *User = dyn_cast<Instruction>(Op->use_begin()->get()))
setInsertPointSkippingPhis(B, User->getNextNode());
else
B.SetInsertPoint(I);
Value *OpTyVal = Constant::getNullValue(KnownElemTy);
Type *OpTy = Op->getType();
if (!Ty) {
GR->addDeducedElementType(Op, KnownElemTy);
// check if there is existing Intrinsic::spv_assign_ptr_type instruction
auto It = AssignPtrTypeInstr.find(Op);
if (It == AssignPtrTypeInstr.end()) {
CallInst *CI =
buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy}, OpTyVal, Op,
{B.getInt32(getPointerAddressSpace(OpTy))}, B);
AssignPtrTypeInstr[Op] = CI;
} else {
It->second->setArgOperand(
1,
MetadataAsValue::get(
Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal))));
}
} else {
SmallVector<Type *, 2> Types = {OpTy, OpTy};
MetadataAsValue *VMD = MetadataAsValue::get(
Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal)));
SmallVector<Value *, 2> Args = {Op, VMD,
B.getInt32(getPointerAddressSpace(OpTy))};
CallInst *PtrCastI =
B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
I->setOperand(OpIt.second, PtrCastI);
}
}
}

void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
Instruction *New,
IRBuilder<> &B) {
Expand Down Expand Up @@ -630,6 +756,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B);
GR->addDeducedElementType(CI, ExpectedElementType);
GR->addDeducedElementType(Pointer, ExpectedElementType);
AssignPtrTypeInstr[Pointer] = CI;
return;
}

Expand Down Expand Up @@ -914,6 +1041,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()},
EltTyConst, I, {B.getInt32(AddressSpace)}, B);
GR->addDeducedElementType(CI, ElemTy);
AssignPtrTypeInstr[I] = CI;
}

void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
Expand Down Expand Up @@ -1070,6 +1198,7 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
{B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
GR->addDeducedElementType(Arg, ElemTy);
AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
}
}
}
Expand Down Expand Up @@ -1114,6 +1243,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
insertAssignTypeIntrs(I, B);
insertPtrCastOrAssignTypeInstr(I, B);
}

for (auto &I : instructions(Func))
deduceOperandElementType(&I);

for (auto *I : Worklist) {
TrackConstants = true;
if (!I->getType()->isVoidTy() || isa<StoreInst>(I))
Expand All @@ -1126,13 +1259,29 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
processInstrAfterVisit(I, B);
}

// check if function parameter types are set
if (!F->isIntrinsic())
processParamTypes(F, B);

return true;
}

FunctionPass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) {
bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
bool Changed = false;

for (auto &F : M) {
Changed |= runOnFunction(F);
}

for (auto &F : M) {
// check if function parameter types are set
if (!F.isDeclaration() && !F.isIntrinsic()) {
const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);
GR = ST.getSPIRVGlobalRegistry();
IRBuilder<> B(F.getContext());
processParamTypes(&F, B);
}
}

return Changed;
}

ModulePass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) {
return new SPIRVEmitIntrinsics(TM);
}
11 changes: 6 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "llvm/ADT/APInt.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/TypedPointerType.h"
#include "llvm/Support/Casting.h"
#include <cassert>

Expand Down Expand Up @@ -61,7 +60,6 @@ SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {

SPIRVType *SpirvType =
getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
Expand Down Expand Up @@ -726,7 +724,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
bool EmitIR) {
SmallVector<Register, 4> FieldTypes;
for (const auto &Elem : Ty->elements()) {
SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
SPIRVType *ElemTy =
findSPIRVType(toTypedPointer(Elem, Ty->getContext()), MIRBuilder);
assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
"Invalid struct element type");
FieldTypes.push_back(getSPIRVTypeID(ElemTy));
Expand Down Expand Up @@ -919,8 +918,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
return SpirvType;
}

SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
auto t = VRegToTypeMap.find(CurMF);
SPIRVType *
SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
const MachineFunction *MF) const {
auto t = VRegToTypeMap.find(MF ? MF : CurMF);
if (t != VRegToTypeMap.end()) {
auto tt = t->second.find(VReg);
if (tt != t->second.end())
Expand Down
22 changes: 20 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "SPIRVInstrInfo.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/TypedPointerType.h"

namespace llvm {
using SPIRVType = const MachineInstr;
Expand Down Expand Up @@ -58,6 +59,9 @@ class SPIRVGlobalRegistry {
SmallPtrSet<const Type *, 4> TypesInProcessing;
DenseMap<const Type *, SPIRVType *> ForwardPointerTypes;

// if a function returns a pointer, this is to map it into TypedPointerType
DenseMap<const Function *, TypedPointerType *> FunResPointerTypes;

// Number of bits pointers and size_t integers require.
const unsigned PointerSize;

Expand Down Expand Up @@ -134,6 +138,16 @@ class SPIRVGlobalRegistry {
void setBound(unsigned V) { Bound = V; }
unsigned getBound() { return Bound; }

// Add a record to the map of function return pointer types.
void addReturnType(const Function *ArgF, TypedPointerType *DerivedTy) {
FunResPointerTypes[ArgF] = DerivedTy;
}
// Find a record in the map of function return pointer types.
const TypedPointerType *findReturnType(const Function *ArgF) {
auto It = FunResPointerTypes.find(ArgF);
return It == FunResPointerTypes.end() ? nullptr : It->second;
}

// Deduced element types of untyped pointers and composites:
// - Add a record to the map of deduced element types.
void addDeducedElementType(Value *Val, Type *Ty) { DeducedElTys[Val] = Ty; }
Expand Down Expand Up @@ -276,8 +290,12 @@ class SPIRVGlobalRegistry {
SPIRV::AccessQualifier::ReadWrite);

// Return the SPIR-V type instruction corresponding to the given VReg, or
// nullptr if no such type instruction exists.
SPIRVType *getSPIRVTypeForVReg(Register VReg) const;
// nullptr if no such type instruction exists. The second argument MF
// allows to search for the association in a context of the machine functions
// than the current one, without switching between different "current" machine
// functions.
SPIRVType *getSPIRVTypeForVReg(Register VReg,
const MachineFunction *MF = nullptr) const;

// Whether the given VReg has a SPIR-V type mapped to it yet.
bool hasSPIRVTypeForVReg(Register VReg) const {
Expand Down
Loading