Skip to content

[SPIRV][NFC] Refactor pointer creation in GlobalRegistery #134429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
Argument *Arg = F.getArg(ArgIdx);
Type *ArgType = Arg->getType();
if (isTypedPointerTy(ArgType)) {
SPIRVType *ElementType = GR->getOrCreateSPIRVType(
cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder,
SPIRV::AccessQualifier::ReadWrite, true);
return GR->getOrCreateSPIRVPointerType(
ElementType, MIRBuilder,
cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder,
addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
}

Expand All @@ -232,11 +229,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
// spv_assign_ptr_type intrinsic or otherwise use default pointer element
// type.
if (hasPointeeTypeAttr(Arg)) {
SPIRVType *ElementType =
GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder,
SPIRV::AccessQualifier::ReadWrite, true);
return GR->getOrCreateSPIRVPointerType(
ElementType, MIRBuilder,
getPointeeTypeByAttr(Arg), MIRBuilder,
addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
}

Expand All @@ -259,10 +253,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
Type *ElementTy =
toTypedPointer(cast<ConstantAsMetadata>(VMD->getMetadata())->getType());
SPIRVType *ElementType = GR->getOrCreateSPIRVType(
ElementTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
return GR->getOrCreateSPIRVPointerType(
ElementType, MIRBuilder,
ElementTy, MIRBuilder,
addressSpaceToStorageClass(
cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST));
}
Expand Down
84 changes: 74 additions & 10 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,40 @@ static unsigned typeToAddressSpace(const Type *Ty) {
report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
}

static bool
storageClassRequiresExplictLayout(SPIRV::StorageClass::StorageClass SC) {
switch (SC) {
case SPIRV::StorageClass::Uniform:
case SPIRV::StorageClass::PushConstant:
case SPIRV::StorageClass::StorageBuffer:
case SPIRV::StorageClass::PhysicalStorageBufferEXT:
return true;
case SPIRV::StorageClass::UniformConstant:
case SPIRV::StorageClass::Input:
case SPIRV::StorageClass::Output:
case SPIRV::StorageClass::Workgroup:
case SPIRV::StorageClass::CrossWorkgroup:
case SPIRV::StorageClass::Private:
case SPIRV::StorageClass::Function:
case SPIRV::StorageClass::Generic:
case SPIRV::StorageClass::AtomicCounter:
case SPIRV::StorageClass::Image:
case SPIRV::StorageClass::CallableDataNV:
case SPIRV::StorageClass::IncomingCallableDataNV:
case SPIRV::StorageClass::RayPayloadNV:
case SPIRV::StorageClass::HitAttributeNV:
case SPIRV::StorageClass::IncomingRayPayloadNV:
case SPIRV::StorageClass::ShaderRecordBufferNV:
case SPIRV::StorageClass::CodeSectionINTEL:
case SPIRV::StorageClass::DeviceOnlyINTEL:
case SPIRV::StorageClass::HostOnlyINTEL:
return false;
default:
llvm_unreachable("Unknown storage class");
return false;
}
}

SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
: PointerSize(PointerSize), Bound(0) {}

Expand Down Expand Up @@ -1342,7 +1376,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateVulkanBufferType(
SPIRV::Decoration::NonWritable, 0, {});
}

SPIRVType *R = getOrCreateSPIRVPointerType(BlockType, MIRBuilder, SC);
SPIRVType *R = getOrCreateSPIRVPointerTypeInternal(BlockType, MIRBuilder, SC);
add(Key, R);
return R;
}
Expand Down Expand Up @@ -1524,7 +1558,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(

// Handle "type*" or "type* vector[N]".
if (TypeStr.starts_with("*")) {
SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
SpirvTy = getOrCreateSPIRVPointerType(Ty, MIRBuilder, SC);
TypeStr = TypeStr.substr(strlen("*"));
}

Expand Down Expand Up @@ -1693,6 +1727,44 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
const Type *BaseType, MachineInstr &I,
SPIRV::StorageClass::StorageClass SC) {
MachineIRBuilder MIRBuilder(I);
return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
const Type *BaseType, MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC) {
SPIRVType *SpirvBaseType = getOrCreateSPIRVType(
BaseType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
return getOrCreateSPIRVPointerTypeInternal(SpirvBaseType, MIRBuilder, SC);
}

SPIRVType *SPIRVGlobalRegistry::changePointerStorageClass(
SPIRVType *PtrType, SPIRV::StorageClass::StorageClass SC, MachineInstr &I) {
SPIRV::StorageClass::StorageClass OldSC = getPointerStorageClass(PtrType);
assert(storageClassRequiresExplictLayout(OldSC) ==
storageClassRequiresExplictLayout(SC));

SPIRVType *PointeeType = getPointeeType(PtrType);
MachineIRBuilder MIRBuilder(I);
return getOrCreateSPIRVPointerTypeInternal(PointeeType, MIRBuilder, SC);
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC) {
const Type *LLVMType = getTypeForSPIRVType(BaseType);
assert(!storageClassRequiresExplictLayout(SC));
SPIRVType *R = getOrCreateSPIRVPointerType(LLVMType, MIRBuilder, SC);
assert(
getPointeeType(R) == BaseType &&
"The base type was not correctly laid out for the given storage class.");
return R;
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerTypeInternal(
SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC) {
const Type *PointerElementType = getTypeForSPIRVType(BaseType);
Expand All @@ -1714,14 +1786,6 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
return finishCreatingSPIRVType(Ty, NewMI);
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &,
SPIRV::StorageClass::StorageClass SC) {
MachineInstr *DepMI = const_cast<MachineInstr *>(BaseType);
MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
}

Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII) {
Expand Down
39 changes: 33 additions & 6 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,15 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
Constant *CA, unsigned BitWidth,
unsigned ElemCnt);

// Returns a pointer to a SPIR-V pointer type with the given base type and
// storage class. It is the responsibility of the caller to make sure the
// decorations on the base type are valid for the given storage class. For
// example, it has the correct offset and stride decorations.
SPIRVType *
getOrCreateSPIRVPointerTypeInternal(SPIRVType *BaseType,
MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC);

public:
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR,
Expand Down Expand Up @@ -540,12 +549,30 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
unsigned NumElements, MachineInstr &I,
const SPIRVInstrInfo &TII);

SPIRVType *getOrCreateSPIRVPointerType(
SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function);
SPIRVType *getOrCreateSPIRVPointerType(
SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function);
// Returns a pointer to a SPIR-V pointer type with the given base type and
// storage class. The base type will be translated to a SPIR-V type, and the
// appropriate layout decorations will be added to the base type.
SPIRVType *getOrCreateSPIRVPointerType(const Type *BaseType,
MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC);
SPIRVType *getOrCreateSPIRVPointerType(const Type *BaseType, MachineInstr &I,
SPIRV::StorageClass::StorageClass SC);

// Returns a pointer to a SPIR-V pointer type with the given base type and
// storage class. It is the responsibility of the caller to make sure the
// decorations on the base type are valid for the given storage class. For
// example, it has the correct offset and stride decorations.
SPIRVType *getOrCreateSPIRVPointerType(SPIRVType *BaseType,
MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC);

// Returns a pointer to a SPIR-V pointer type that is the same as `PtrType`
// except the stroage class has been changed to `SC`. It is the responsibility
// of the caller to be sure that the original and new storage class have the
// same layout requirements.
SPIRVType *changePointerStorageClass(SPIRVType *PtrType,
SPIRV::StorageClass::StorageClass SC,
MachineInstr &I);

SPIRVType *getOrCreateVulkanBufferType(MachineIRBuilder &MIRBuilder,
Type *ElemType,
Expand Down
6 changes: 2 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,8 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI,
PtrType->getOperand(1).getImm());
MachineIRBuilder MIB(I);
LLVMContext &Context = MF->getFunction().getContext();
SPIRVType *ElemType =
GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB,
SPIRV::AccessQualifier::ReadWrite, false);
SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ElemType, MIB, SC);
SPIRVType *NewPtrType =
GR.getOrCreateSPIRVPointerType(IntegerType::getInt8Ty(Context), MIB, SC);
doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
}

Expand Down
48 changes: 21 additions & 27 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1259,14 +1259,18 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
Register SrcReg = I.getOperand(1).getReg();
bool Result = true;
if (I.getOpcode() == TargetOpcode::G_MEMSET) {
MachineIRBuilder MIRBuilder(I);
assert(I.getOperand(1).isReg() && I.getOperand(2).isReg());
unsigned Val = getIConstVal(I.getOperand(1).getReg(), MRI);
unsigned Num = getIConstVal(I.getOperand(2).getReg(), MRI);
SPIRVType *ValTy = GR.getOrCreateSPIRVIntegerType(8, I, TII);
SPIRVType *ArrTy = GR.getOrCreateSPIRVArrayType(ValTy, Num, I, TII);
Register Const = GR.getOrCreateConstIntArray(Val, Num, I, ArrTy, TII);
Type *ValTy = Type::getInt8Ty(I.getMF()->getFunction().getContext());
Type *ArrTy = ArrayType::get(ValTy, Num);
SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
ArrTy, I, TII, SPIRV::StorageClass::UniformConstant);
ArrTy, MIRBuilder, SPIRV::StorageClass::UniformConstant);

SPIRVType *SpvArrTy = GR.getOrCreateSPIRVType(
ArrTy, MIRBuilder, SPIRV::AccessQualifier::None, false);
Register Const = GR.getOrCreateConstIntArray(Val, Num, I, SpvArrTy, TII);
// TODO: check if we have such GV, add init, use buildGlobalVariable.
Function &CurFunction = GR.CurMF->getFunction();
Type *LLVMArrTy =
Expand All @@ -1289,7 +1293,7 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,

buildOpDecorate(VarReg, I, TII, SPIRV::Decoration::Constant, {});
SPIRVType *SourceTy = GR.getOrCreateSPIRVPointerType(
ValTy, I, TII, SPIRV::StorageClass::UniformConstant);
ValTy, I, SPIRV::StorageClass::UniformConstant);
SrcReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
selectOpWithSrcs(SrcReg, SourceTy, I, {VarReg}, SPIRV::OpBitcast);
}
Expand Down Expand Up @@ -1590,7 +1594,7 @@ static bool isASCastInGVar(MachineRegisterInfo *MRI, Register ResVReg) {
Register SPIRVInstructionSelector::getUcharPtrTypeReg(
MachineInstr &I, SPIRV::StorageClass::StorageClass SC) const {
return GR.getSPIRVTypeID(GR.getOrCreateSPIRVPointerType(
GR.getOrCreateSPIRVIntegerType(8, I, TII), I, TII, SC));
Type::getInt8Ty(I.getMF()->getFunction().getContext()), I, SC));
}

MachineInstrBuilder
Expand All @@ -1608,8 +1612,8 @@ SPIRVInstructionSelector::buildSpecConstantOp(MachineInstr &I, Register Dest,
MachineInstrBuilder
SPIRVInstructionSelector::buildConstGenericPtr(MachineInstr &I, Register SrcPtr,
SPIRVType *SrcPtrTy) const {
SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
SPIRVType *GenericPtrTy =
GR.changePointerStorageClass(SrcPtrTy, SPIRV::StorageClass::Generic, I);
Register Tmp = MRI->createVirtualRegister(&SPIRV::pIDRegClass);
MRI->setType(Tmp, LLT::pointer(storageClassToAddressSpace(
SPIRV::StorageClass::Generic),
Expand Down Expand Up @@ -1694,8 +1698,8 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
return selectUnOp(ResVReg, ResType, I, SPIRV::OpGenericCastToPtr);
// Casting between 2 eligible pointers using Generic as an intermediary.
if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
SPIRVType *GenericPtrTy =
GR.changePointerStorageClass(SrcPtrTy, SPIRV::StorageClass::Generic, I);
Register Tmp = createVirtualRegister(GenericPtrTy, &GR, MRI, MRI->getMF());
bool Result = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric))
.addDef(Tmp)
Expand Down Expand Up @@ -3366,18 +3370,20 @@ bool SPIRVInstructionSelector::selectImageWriteIntrinsic(
}

Register SPIRVInstructionSelector::buildPointerToResource(
const SPIRVType *ResType, SPIRV::StorageClass::StorageClass SC,
const SPIRVType *SpirvResType, SPIRV::StorageClass::StorageClass SC,
uint32_t Set, uint32_t Binding, uint32_t ArraySize, Register IndexReg,
bool IsNonUniform, MachineIRBuilder MIRBuilder) const {
const Type *ResType = GR.getTypeForSPIRVType(SpirvResType);
if (ArraySize == 1) {
SPIRVType *PtrType =
GR.getOrCreateSPIRVPointerType(ResType, MIRBuilder, SC);
assert(GR.getPointeeType(PtrType) == SpirvResType &&
"SpirvResType did not have an explicit layout.");
return GR.getOrCreateGlobalVariableWithBinding(PtrType, Set, Binding,
MIRBuilder);
}

const SPIRVType *VarType = GR.getOrCreateSPIRVArrayType(
ResType, ArraySize, *MIRBuilder.getInsertPt(), TII);
const Type *VarType = ArrayType::get(const_cast<Type *>(ResType), ArraySize);
SPIRVType *VarPointerType =
GR.getOrCreateSPIRVPointerType(VarType, MIRBuilder, SC);
Register VarReg = GR.getOrCreateGlobalVariableWithBinding(
Expand Down Expand Up @@ -3807,17 +3813,6 @@ bool SPIRVInstructionSelector::selectGlobalValue(
MachineIRBuilder MIRBuilder(I);
const GlobalValue *GV = I.getOperand(1).getGlobal();
Type *GVType = toTypedPointer(GR.getDeducedGlobalValueType(GV));
SPIRVType *PointerBaseType;
if (GVType->isArrayTy()) {
SPIRVType *ArrayElementType =
GR.getOrCreateSPIRVType(GVType->getArrayElementType(), MIRBuilder,
SPIRV::AccessQualifier::ReadWrite, false);
PointerBaseType = GR.getOrCreateSPIRVArrayType(
ArrayElementType, GVType->getArrayNumElements(), I, TII);
} else {
PointerBaseType = GR.getOrCreateSPIRVType(
GVType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false);
}

std::string GlobalIdent;
if (!GV->hasName()) {
Expand Down Expand Up @@ -3850,7 +3845,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
? dyn_cast<Function>(GV)
: nullptr;
SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(
PointerBaseType, I, TII,
GVType, I,
GVFun ? SPIRV::StorageClass::CodeSectionINTEL
: addressSpaceToStorageClass(GV->getAddressSpace(), STI));
if (GVFun) {
Expand Down Expand Up @@ -3908,8 +3903,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
const unsigned AddrSpace = GV->getAddressSpace();
SPIRV::StorageClass::StorageClass StorageClass =
addressSpaceToStorageClass(AddrSpace, STI);
SPIRVType *ResType =
GR.getOrCreateSPIRVPointerType(PointerBaseType, I, TII, StorageClass);
SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(GVType, I, StorageClass);
Register Reg = GR.buildGlobalVariable(
ResVReg, ResType, GlobalIdent, GV, StorageClass, Init,
GlobalVar->isConstant(), HasLnkTy, LnkType, MIRBuilder, true);
Expand Down
13 changes: 4 additions & 9 deletions llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,8 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
Register Def = MI.getOperand(0).getReg();
Register Source = MI.getOperand(2).getReg();
Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
ElemTy, MIB, SPIRV::AccessQualifier::ReadWrite, true);
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
ElemTy, MI,
addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));

// If the ptrcast would be redundant, replace all uses with the source
Expand Down Expand Up @@ -366,9 +364,8 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
RegType.getAddressSpace()) {
const SPIRVSubtarget &ST =
MI->getParent()->getParent()->getSubtarget<SPIRVSubtarget>();
SpvType = GR->getOrCreateSPIRVPointerType(
GR->getPointeeType(SpvType), *MI, *ST.getInstrInfo(),
addressSpaceToStorageClass(RegType.getAddressSpace(), ST));
auto TSC = addressSpaceToStorageClass(RegType.getAddressSpace(), ST);
SpvType = GR->changePointerStorageClass(SpvType, TSC, *MI);
}
GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
}
Expand Down Expand Up @@ -518,10 +515,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
Register Reg = MI.getOperand(1).getReg();
MIB.setInsertPt(*MI.getParent(), MI.getIterator());
Type *ElementTy = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
ElementTy, MIB, SPIRV::AccessQualifier::ReadWrite, true);
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
ElementTy, MI,
addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
MachineInstr *Def = MRI.getVRegDef(Reg);
assert(Def && "Expecting an instruction that defines the register");
Expand Down
Loading