Skip to content

[SPIR-V] Fix generation of gMIR vs. SPIR-V code from utility methods #128159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 37 additions & 30 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,12 +445,12 @@ static std::tuple<Register, SPIRVType *>
buildBoolRegister(MachineIRBuilder &MIRBuilder, const SPIRVType *ResultType,
SPIRVGlobalRegistry *GR) {
LLT Type;
SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder);
SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder, true);

if (ResultType->getOpcode() == SPIRV::OpTypeVector) {
unsigned VectorElements = ResultType->getOperand(2).getImm();
BoolType =
GR->getOrCreateSPIRVVectorType(BoolType, VectorElements, MIRBuilder);
BoolType = GR->getOrCreateSPIRVVectorType(BoolType, VectorElements,
MIRBuilder, true);
const FixedVectorType *LLVMVectorType =
cast<FixedVectorType>(GR->getTypeForSPIRVType(BoolType));
Type = LLT::vector(LLVMVectorType->getElementCount(), 1);
Expand All @@ -476,11 +476,12 @@ static bool buildSelectInst(MachineIRBuilder &MIRBuilder,
if (ReturnType->getOpcode() == SPIRV::OpTypeVector) {
unsigned Bits = GR->getScalarOrVectorBitWidth(ReturnType);
uint64_t AllOnes = APInt::getAllOnes(Bits).getZExtValue();
TrueConst = GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType);
FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType);
TrueConst =
GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType, true);
FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType, true);
} else {
TrueConst = GR->buildConstantInt(1, MIRBuilder, ReturnType);
FalseConst = GR->buildConstantInt(0, MIRBuilder, ReturnType);
TrueConst = GR->buildConstantInt(1, MIRBuilder, ReturnType, true);
FalseConst = GR->buildConstantInt(0, MIRBuilder, ReturnType, true);
}

return MIRBuilder.buildSelect(ReturnRegister, SourceRegister, TrueConst,
Expand Down Expand Up @@ -580,8 +581,8 @@ static SPIRV::Scope::Scope getSPIRVScope(SPIRV::CLMemoryScope ClScope) {
static Register buildConstantIntReg32(uint64_t Val,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
return GR->buildConstantInt(Val, MIRBuilder,
GR->getOrCreateSPIRVIntegerType(32, MIRBuilder));
return GR->buildConstantInt(
Val, MIRBuilder, GR->getOrCreateSPIRVIntegerType(32, MIRBuilder), true);
}

static Register buildScopeReg(Register CLScopeRegister,
Expand Down Expand Up @@ -1152,7 +1153,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,

Register Arg0;
if (GroupBuiltin->HasBoolArg) {
SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder);
SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder, true);
Register BoolReg = Call->Arguments[0];
SPIRVType *BoolRegType = GR->getSPIRVTypeForVReg(BoolReg);
if (!BoolRegType)
Expand All @@ -1161,14 +1162,15 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
if (ArgInstruction->getOpcode() == TargetOpcode::G_CONSTANT) {
if (BoolRegType->getOpcode() != SPIRV::OpTypeBool)
Arg0 = GR->buildConstantInt(getIConstVal(BoolReg, MRI), MIRBuilder,
BoolType);
BoolType, true);
} else {
if (BoolRegType->getOpcode() == SPIRV::OpTypeInt) {
Arg0 = MRI->createGenericVirtualRegister(LLT::scalar(1));
MRI->setRegClass(Arg0, &SPIRV::iIDRegClass);
GR->assignSPIRVTypeToVReg(BoolType, Arg0, MIRBuilder.getMF());
MIRBuilder.buildICmp(CmpInst::ICMP_NE, Arg0, BoolReg,
GR->buildConstantInt(0, MIRBuilder, BoolRegType));
MIRBuilder.buildICmp(
CmpInst::ICMP_NE, Arg0, BoolReg,
GR->buildConstantInt(0, MIRBuilder, BoolRegType, true));
insertAssignInstr(Arg0, nullptr, BoolType, GR, MIRBuilder,
MIRBuilder.getMF().getRegInfo());
} else if (BoolRegType->getOpcode() != SPIRV::OpTypeBool) {
Expand Down Expand Up @@ -1213,7 +1215,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
LLT::fixed_vector(VecLen, MRI->getType(ElemReg)));
MRI->setRegClass(VecReg, &SPIRV::vIDRegClass);
SPIRVType *VecType =
GR->getOrCreateSPIRVVectorType(ElemType, VecLen, MIRBuilder);
GR->getOrCreateSPIRVVectorType(ElemType, VecLen, MIRBuilder, true);
GR->assignSPIRVTypeToVReg(VecType, VecReg, MIRBuilder.getMF());
auto MIB =
MIRBuilder.buildInstr(TargetOpcode::G_BUILD_VECTOR).addDef(VecReg);
Expand Down Expand Up @@ -1457,11 +1459,11 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
ToTruncate = DefaultReg;
}
auto NewRegister =
GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType, true);
MIRBuilder.buildCopy(DefaultReg, NewRegister);
} else { // If it could be in range, we need to load from the given builtin.
auto Vec3Ty =
GR->getOrCreateSPIRVVectorType(PointerSizeType, 3, MIRBuilder);
GR->getOrCreateSPIRVVectorType(PointerSizeType, 3, MIRBuilder, true);
Register LoadedVector =
buildBuiltinVariableLoad(MIRBuilder, Vec3Ty, GR, BuiltinValue,
LLT::fixed_vector(3, PointerSize));
Expand All @@ -1484,21 +1486,22 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
*MRI);

auto IndexType = GR->getSPIRVTypeForVReg(IndexRegister);
auto BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder);
auto BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder, true);

Register CompareRegister =
MRI->createGenericVirtualRegister(LLT::scalar(1));
MRI->setRegClass(CompareRegister, &SPIRV::iIDRegClass);
GR->assignSPIRVTypeToVReg(BoolType, CompareRegister, MIRBuilder.getMF());

// Use G_ICMP to check if idxVReg < 3.
MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CompareRegister, IndexRegister,
GR->buildConstantInt(3, MIRBuilder, IndexType));
MIRBuilder.buildICmp(
CmpInst::ICMP_ULT, CompareRegister, IndexRegister,
GR->buildConstantInt(3, MIRBuilder, IndexType, true));

// Get constant for the default value (0 or 1 depending on which
// function).
Register DefaultRegister =
GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType, true);

// Get a register for the selection result (possibly a new temporary one).
Register SelectionResult = Call->ReturnRegister;
Expand Down Expand Up @@ -1812,7 +1815,7 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
MIRBuilder.getMRI()->setRegClass(QueryResult, &SPIRV::vIDRegClass);
SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
QueryResultType = GR->getOrCreateSPIRVVectorType(
IntTy, NumActualRetComponents, MIRBuilder);
IntTy, NumActualRetComponents, MIRBuilder, true);
GR->assignSPIRVTypeToVReg(QueryResultType, QueryResult, MIRBuilder.getMF());
}
bool IsDimBuf = ImgType->getOperand(2).getImm() == SPIRV::Dim::DIM_Buffer;
Expand Down Expand Up @@ -1969,7 +1972,7 @@ static bool generateReadImageInst(const StringRef DemangledCall,

if (Call->ReturnType->getOpcode() != SPIRV::OpTypeVector) {
SPIRVType *TempType =
GR->getOrCreateSPIRVVectorType(Call->ReturnType, 4, MIRBuilder);
GR->getOrCreateSPIRVVectorType(Call->ReturnType, 4, MIRBuilder, true);
Register TempRegister =
MRI->createGenericVirtualRegister(GR->getRegType(TempType));
MRI->setRegClass(TempRegister, GR->getRegClass(TempType));
Expand Down Expand Up @@ -2067,7 +2070,7 @@ static bool generateSampleImageInst(const StringRef DemangledCall,
SPIRVType *Type =
Call->ReturnType
? Call->ReturnType
: GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder);
: GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder, true);
if (!Type) {
std::string DiagMsg =
"Unable to recognize SPIRV type name: " + ReturnType;
Expand Down Expand Up @@ -2265,7 +2268,8 @@ static bool buildNDRange(const SPIRV::IncomingCall *Call,
unsigned BitWidth = GR->getPointerSize() == 64 ? 64 : 32;
Type *BaseTy = IntegerType::get(MF.getFunction().getContext(), BitWidth);
Type *FieldTy = ArrayType::get(BaseTy, Size);
SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(FieldTy, MIRBuilder);
SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(
FieldTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
GlobalWorkSize = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
GR->assignSPIRVTypeToVReg(SpvFieldTy, GlobalWorkSize, MF);
MIRBuilder.buildInstr(SPIRV::OpLoad)
Expand All @@ -2277,7 +2281,7 @@ static bool buildNDRange(const SPIRV::IncomingCall *Call,
Const = GR->getOrCreateConstIntArray(0, Size, *MIRBuilder.getInsertPt(),
SpvFieldTy, *ST.getInstrInfo());
} else {
Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
Const = GR->buildConstantInt(0, MIRBuilder, SpvTy, true);
}
if (!LocalWorkSize.isValid())
LocalWorkSize = Const;
Expand All @@ -2303,7 +2307,8 @@ getOrCreateSPIRVDeviceEventPointer(MachineIRBuilder &MIRBuilder,
LLVMContext &Context = MIRBuilder.getMF().getFunction().getContext();
unsigned SC1 = storageClassToAddressSpace(SPIRV::StorageClass::Generic);
Type *PtrType = PointerType::get(Context, SC1);
return GR->getOrCreateSPIRVType(PtrType, MIRBuilder);
return GR->getOrCreateSPIRVType(PtrType, MIRBuilder,
SPIRV::AccessQualifier::ReadWrite, true);
}

static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
Expand Down Expand Up @@ -2452,7 +2457,7 @@ static bool generateAsyncCopy(const SPIRV::IncomingCall *Call,
SPIRVType *NewType =
Call->ReturnType->getOpcode() == SPIRV::OpTypeEvent
? nullptr
: GR->getOrCreateSPIRVTypeByName("spirv.Event", MIRBuilder);
: GR->getOrCreateSPIRVTypeByName("spirv.Event", MIRBuilder, true);
Register TypeReg = GR->getSPIRVTypeID(NewType ? NewType : Call->ReturnType);
unsigned NumArgs = Call->Arguments.size();
Register EventReg = Call->Arguments[NumArgs - 1];
Expand Down Expand Up @@ -2953,12 +2958,13 @@ static SPIRVType *getCoopMatrType(const TargetExtType *ExtensionType,
assert(ExtensionType->getNumTypeParameters() == 1 &&
"SPIR-V coop matrices builtin type must have a type parameter!");
const SPIRVType *ElemType =
GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder);
GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder,
SPIRV::AccessQualifier::ReadWrite, true);
// Create or get an existing type from GlobalRegistry.
return GR->getOrCreateOpTypeCoopMatr(
MIRBuilder, ExtensionType, ElemType, ExtensionType->getIntParameter(0),
ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
ExtensionType->getIntParameter(3));
ExtensionType->getIntParameter(3), true);
}

static SPIRVType *
Expand All @@ -2968,7 +2974,8 @@ getImageType(const TargetExtType *ExtensionType,
assert(ExtensionType->getNumTypeParameters() == 1 &&
"SPIR-V image builtin type must have sampled type parameter!");
const SPIRVType *SampledType =
GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder);
GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder,
SPIRV::AccessQualifier::ReadWrite, true);
assert((ExtensionType->getNumIntParameters() == 7 ||
ExtensionType->getNumIntParameters() == 6) &&
"Invalid number of parameters for SPIR-V image builtin!");
Expand Down
38 changes: 25 additions & 13 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,15 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
// If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
// be legally reassigned later).
if (!isPointerTy(OriginalArgType))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual,
true);

Argument *Arg = F.getArg(ArgIdx);
Type *ArgType = Arg->getType();
if (isTypedPointerTy(ArgType)) {
SPIRVType *ElementType = GR->getOrCreateSPIRVType(
cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder);
cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder,
SPIRV::AccessQualifier::ReadWrite, true);
return GR->getOrCreateSPIRVPointerType(
ElementType, MIRBuilder,
addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
Expand All @@ -231,7 +233,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
// type.
if (hasPointeeTypeAttr(Arg)) {
SPIRVType *ElementType =
GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder);
GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder,
SPIRV::AccessQualifier::ReadWrite, true);
return GR->getOrCreateSPIRVPointerType(
ElementType, MIRBuilder,
addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
Expand All @@ -245,7 +248,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
Type *BuiltinType =
cast<ConstantAsMetadata>(VMD->getMetadata())->getType();
assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType");
return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual);
return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual,
true);
}

// Check if this is spv_assign_ptr_type assigning pointer element type.
Expand All @@ -255,7 +259,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
Type *ElementTy =
toTypedPointer(cast<ConstantAsMetadata>(VMD->getMetadata())->getType());
SPIRVType *ElementType = GR->getOrCreateSPIRVType(ElementTy, MIRBuilder);
SPIRVType *ElementType = GR->getOrCreateSPIRVType(
ElementTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
return GR->getOrCreateSPIRVPointerType(
ElementType, MIRBuilder,
addressSpaceToStorageClass(
Expand All @@ -265,7 +270,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
// Replace PointerType with TypedPointerType to be able to map SPIR-V types to
// LLVM types in a consistent manner
return GR->getOrCreateSPIRVType(toTypedPointer(OriginalArgType), MIRBuilder,
ArgAccessQual);
ArgAccessQual, true);
}

static SPIRV::ExecutionModel::ExecutionModel
Expand Down Expand Up @@ -405,7 +410,8 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
FRetTy = DerivedTy;
}
}
SPIRVType *RetTy = GR->getOrCreateSPIRVType(FRetTy, MIRBuilder);
SPIRVType *RetTy = GR->getOrCreateSPIRVType(
FRetTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
FTy, RetTy, ArgTypeVRegs, MIRBuilder);
Expand Down Expand Up @@ -486,10 +492,12 @@ void SPIRVCallLowering::produceIndirectPtrTypes(
// Create indirect call data types if any
MachineFunction &MF = MIRBuilder.getMF();
for (auto const &IC : IndirectCalls) {
SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(IC.RetTy, MIRBuilder);
SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(
IC.RetTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
SmallVector<SPIRVType *, 4> SpirvArgTypes;
for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(IC.ArgTys[i], MIRBuilder);
SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(
IC.ArgTys[i], MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
SpirvArgTypes.push_back(SPIRVTy);
if (!GR->getSPIRVTypeForVReg(IC.ArgRegs[i]))
GR->assignSPIRVTypeToVReg(SPIRVTy, IC.ArgRegs[i], MF);
Expand Down Expand Up @@ -557,10 +565,12 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
RetTy =
TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
}
setRegClassType(ResVReg, RetTy, GR, MIRBuilder);
setRegClassType(ResVReg, RetTy, GR, MIRBuilder,
SPIRV::AccessQualifier::ReadWrite, true);
}
} else {
ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder);
ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder,
SPIRV::AccessQualifier::ReadWrite, true);
}
SmallVector<Register, 8> ArgVRegs;
for (auto Arg : Info.OrigArgs) {
Expand All @@ -584,7 +594,8 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
ArgTy = Arg.Ty;
}
if (ArgTy) {
SpvType = GR->getOrCreateSPIRVType(ArgTy, MIRBuilder);
SpvType = GR->getOrCreateSPIRVType(
ArgTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
GR->assignSPIRVTypeToVReg(SpvType, ArgReg, MF);
}
}
Expand Down Expand Up @@ -669,7 +680,8 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
// Make sure there's a valid return reg, even for functions returning void.
if (!ResVReg.isValid())
ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder);
SPIRVType *RetType = GR->assignTypeToVReg(
OrigRetTy, ResVReg, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);

// Emit the call instruction and its args.
auto MIB = MIRBuilder.buildInstr(CallOp)
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ bool SPIRVEmitNonSemanticDI::emitGlobalDI(MachineFunction &MF) {
};

const SPIRVType *VoidTy =
GR->getOrCreateSPIRVType(Type::getVoidTy(*Context), MIRBuilder);
GR->getOrCreateSPIRVType(Type::getVoidTy(*Context), MIRBuilder,
SPIRV::AccessQualifier::ReadWrite, false);

const auto EmitDIInstruction =
[&](SPIRV::NonSemanticExtInst::NonSemanticExtInst Inst,
Expand All @@ -217,7 +218,8 @@ bool SPIRVEmitNonSemanticDI::emitGlobalDI(MachineFunction &MF) {
};

const SPIRVType *I32Ty =
GR->getOrCreateSPIRVType(Type::getInt32Ty(*Context), MIRBuilder);
GR->getOrCreateSPIRVType(Type::getInt32Ty(*Context), MIRBuilder,
SPIRV::AccessQualifier::ReadWrite, false);

const Register DwarfVersionReg =
GR->buildConstantInt(DwarfVersion, MIRBuilder, I32Ty, false);
Expand Down
Loading
Loading