Skip to content

Commit 02a334d

Browse files
authored
[SPIR-V] Fix bad insertion for type/id MIR (#109686)
Those instructions were inserted either after the instruction using it, or in the middle of the module. The first directly causes an issue. The second causes a more subtle issue: the first type the type is inserted, the emission is fine, but the second times, the first instruction is reused, without checking its position in the function. This can lead to the second usage dominating the definition. In SPIR-V, types are usually in the header, above all code definition, but at this stage I don't think we can, so what I do instead is to emit it in the first basic block. This commit reduces the failed tests with expensive checks from 107 to 71. Signed-off-by: Nathan Gauër <[email protected]>
1 parent 57bee1e commit 02a334d

File tree

3 files changed

+98
-45
lines changed

3 files changed

+98
-45
lines changed

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 86 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "llvm/IR/Type.h"
2626
#include "llvm/Support/Casting.h"
2727
#include <cassert>
28+
#include <functional>
2829

2930
using namespace llvm;
3031
SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
@@ -83,8 +84,11 @@ inline Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
8384
}
8485

8586
SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
86-
return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
87-
.addDef(createTypeVReg(MIRBuilder));
87+
88+
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
89+
return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
90+
.addDef(createTypeVReg(MIRBuilder));
91+
});
8892
}
8993

9094
unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
@@ -118,24 +122,53 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
118122
MIRBuilder.buildInstr(SPIRV::OpCapability)
119123
.addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
120124
}
121-
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
122-
.addDef(createTypeVReg(MIRBuilder))
123-
.addImm(Width)
124-
.addImm(IsSigned ? 1 : 0);
125-
return MIB;
125+
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
126+
return MIRBuilder.buildInstr(SPIRV::OpTypeInt)
127+
.addDef(createTypeVReg(MIRBuilder))
128+
.addImm(Width)
129+
.addImm(IsSigned ? 1 : 0);
130+
});
126131
}
127132

128133
SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
129134
MachineIRBuilder &MIRBuilder) {
130-
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
131-
.addDef(createTypeVReg(MIRBuilder))
132-
.addImm(Width);
133-
return MIB;
135+
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
136+
return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
137+
.addDef(createTypeVReg(MIRBuilder))
138+
.addImm(Width);
139+
});
134140
}
135141

136142
SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
137-
return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
138-
.addDef(createTypeVReg(MIRBuilder));
143+
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
144+
return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
145+
.addDef(createTypeVReg(MIRBuilder));
146+
});
147+
}
148+
149+
SPIRVType *SPIRVGlobalRegistry::createOpType(
150+
MachineIRBuilder &MIRBuilder,
151+
std::function<MachineInstr *(MachineIRBuilder &)> Op) {
152+
auto oldInsertPoint = MIRBuilder.getInsertPt();
153+
MachineBasicBlock *OldMBB = &MIRBuilder.getMBB();
154+
155+
auto LastInsertedType = LastInsertedTypeMap.find(CurMF);
156+
if (LastInsertedType != LastInsertedTypeMap.end()) {
157+
MIRBuilder.setInsertPt(*MIRBuilder.getMF().begin(),
158+
LastInsertedType->second->getIterator());
159+
} else {
160+
MIRBuilder.setInsertPt(*MIRBuilder.getMF().begin(),
161+
MIRBuilder.getMF().begin()->begin());
162+
auto Result = LastInsertedTypeMap.try_emplace(CurMF, nullptr);
163+
assert(Result.second);
164+
LastInsertedType = Result.first;
165+
}
166+
167+
MachineInstr *Type = Op(MIRBuilder);
168+
LastInsertedType->second = Type;
169+
170+
MIRBuilder.setInsertPt(*OldMBB, oldInsertPoint);
171+
return Type;
139172
}
140173

141174
SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
@@ -147,11 +180,12 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
147180
EleOpc == SPIRV::OpTypeBool) &&
148181
"Invalid vector element type");
149182

150-
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
151-
.addDef(createTypeVReg(MIRBuilder))
152-
.addUse(getSPIRVTypeID(ElemType))
153-
.addImm(NumElems);
154-
return MIB;
183+
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
184+
return MIRBuilder.buildInstr(SPIRV::OpTypeVector)
185+
.addDef(createTypeVReg(MIRBuilder))
186+
.addUse(getSPIRVTypeID(ElemType))
187+
.addImm(NumElems);
188+
});
155189
}
156190

157191
std::tuple<Register, ConstantInt *, bool, unsigned>
@@ -688,22 +722,25 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
688722
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
689723
Register NumElementsVReg =
690724
buildConstantInt(NumElems, MIRBuilder, SpvTypeInt32, EmitIR);
691-
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
692-
.addDef(createTypeVReg(MIRBuilder))
693-
.addUse(getSPIRVTypeID(ElemType))
694-
.addUse(NumElementsVReg);
695-
return MIB;
725+
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
726+
return MIRBuilder.buildInstr(SPIRV::OpTypeArray)
727+
.addDef(createTypeVReg(MIRBuilder))
728+
.addUse(getSPIRVTypeID(ElemType))
729+
.addUse(NumElementsVReg);
730+
});
696731
}
697732

698733
SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
699734
MachineIRBuilder &MIRBuilder) {
700735
assert(Ty->hasName());
701736
const StringRef Name = Ty->hasName() ? Ty->getName() : "";
702737
Register ResVReg = createTypeVReg(MIRBuilder);
703-
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
704-
addStringImm(Name, MIB);
705-
buildOpName(ResVReg, Name, MIRBuilder);
706-
return MIB;
738+
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
739+
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
740+
addStringImm(Name, MIB);
741+
buildOpName(ResVReg, Name, MIRBuilder);
742+
return MIB;
743+
});
707744
}
708745

709746
SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
@@ -717,14 +754,16 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
717754
FieldTypes.push_back(getSPIRVTypeID(ElemTy));
718755
}
719756
Register ResVReg = createTypeVReg(MIRBuilder);
720-
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
721-
for (const auto &Ty : FieldTypes)
722-
MIB.addUse(Ty);
723-
if (Ty->hasName())
724-
buildOpName(ResVReg, Ty->getName(), MIRBuilder);
725-
if (Ty->isPacked())
726-
buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
727-
return MIB;
757+
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
758+
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
759+
for (const auto &Ty : FieldTypes)
760+
MIB.addUse(Ty);
761+
if (Ty->hasName())
762+
buildOpName(ResVReg, Ty->getName(), MIRBuilder);
763+
if (Ty->isPacked())
764+
buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
765+
return MIB;
766+
});
728767
}
729768

730769
SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
@@ -739,17 +778,22 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
739778
MachineIRBuilder &MIRBuilder, Register Reg) {
740779
if (!Reg.isValid())
741780
Reg = createTypeVReg(MIRBuilder);
742-
return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
743-
.addDef(Reg)
744-
.addImm(static_cast<uint32_t>(SC))
745-
.addUse(getSPIRVTypeID(ElemType));
781+
782+
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
783+
return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
784+
.addDef(Reg)
785+
.addImm(static_cast<uint32_t>(SC))
786+
.addUse(getSPIRVTypeID(ElemType));
787+
});
746788
}
747789

748790
SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
749791
SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
750-
return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
751-
.addUse(createTypeVReg(MIRBuilder))
752-
.addImm(static_cast<uint32_t>(SC));
792+
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
793+
return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
794+
.addUse(createTypeVReg(MIRBuilder))
795+
.addImm(static_cast<uint32_t>(SC));
796+
});
753797
}
754798

755799
SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ class SPIRVGlobalRegistry {
6464
SmallPtrSet<const Type *, 4> TypesInProcessing;
6565
DenseMap<const Type *, SPIRVType *> ForwardPointerTypes;
6666

67+
// Stores for each function the last inserted SPIR-V Type.
68+
// See: SPIRVGlobalRegistry::createOpType.
69+
DenseMap<const MachineFunction *, MachineInstr *> LastInsertedTypeMap;
70+
6771
// if a function returns a pointer, this is to map it into TypedPointerType
6872
DenseMap<const Function *, TypedPointerType *> FunResPointerTypes;
6973

@@ -97,6 +101,13 @@ class SPIRVGlobalRegistry {
97101
SPIRV::AccessQualifier::AccessQualifier AccessQual,
98102
bool EmitIR);
99103

104+
// Internal function creating the an OpType at the correct position in the
105+
// function by tweaking the passed "MIRBuilder" insertion point and restoring
106+
// it to the correct position. "Op" should be the function creating the
107+
// specific OpType you need, and should return the newly created instruction.
108+
SPIRVType *createOpType(MachineIRBuilder &MIRBuilder,
109+
std::function<MachineInstr *(MachineIRBuilder &)> Op);
110+
100111
public:
101112
SPIRVGlobalRegistry(unsigned PointerSize);
102113

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,7 @@ void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
389389
createNewIdReg(nullptr, MI.getOperand(0).getReg(), MRI, *GR).first;
390390
AssignTypeInst.getOperand(1).setReg(NewReg);
391391
MI.getOperand(0).setReg(NewReg);
392-
MIB.setInsertPt(*MI.getParent(),
393-
(MI.getNextNode() ? MI.getNextNode()->getIterator()
394-
: MI.getParent()->end()));
392+
MIB.setInsertPt(*MI.getParent(), MI.getIterator());
395393
for (auto &Op : MI.operands()) {
396394
if (!Op.isReg() || Op.isDef())
397395
continue;

0 commit comments

Comments
 (0)