@@ -81,73 +81,6 @@ class AMDGPULateCodeGenPrepare
81
81
bool visitLoadInst (LoadInst &LI);
82
82
};
83
83
84
- using ValueToValueMap = DenseMap<const Value *, Value *>;
85
-
86
- class LiveRegOptimizer {
87
- private:
88
- Module *Mod = nullptr ;
89
- const DataLayout *DL = nullptr ;
90
- const GCNSubtarget *ST;
91
- // / The scalar type to convert to
92
- Type *ConvertToScalar;
93
- // / The set of visited Instructions
94
- SmallPtrSet<Instruction *, 4 > Visited;
95
- // / The set of Instructions to be deleted
96
- SmallPtrSet<Instruction *, 4 > DeadInstrs;
97
- // / Map of Value -> Converted Value
98
- ValueToValueMap ValMap;
99
- // / Map of containing conversions from Optimal Type -> Original Type per BB.
100
- DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap;
101
-
102
- public:
103
- // / Calculate the and \p return the type to convert to given a problematic \p
104
- // / OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32).
105
- Type *calculateConvertType (Type *OriginalType);
106
- // / Convert the virtual register defined by \p V to the compatible vector of
107
- // / legal type
108
- Value *convertToOptType (Instruction *V, BasicBlock::iterator &InstPt);
109
- // / Convert the virtual register defined by \p V back to the original type \p
110
- // / ConvertType, stripping away the MSBs in cases where there was an imperfect
111
- // / fit (e.g. v2i32 -> v7i8)
112
- Value *convertFromOptType (Type *ConvertType, Instruction *V,
113
- BasicBlock::iterator &InstPt,
114
- BasicBlock *InsertBlock);
115
- // / Check for problematic PHI nodes or cross-bb values based on the value
116
- // / defined by \p I, and coerce to legal types if necessary. For problematic
117
- // / PHI node, we coerce all incoming values in a single invocation.
118
- bool optimizeLiveType (Instruction *I);
119
-
120
- // / Remove all instructions that have become dead (i.e. all the re-typed PHIs)
121
- void removeDeadInstrs ();
122
-
123
- // Whether or not the type should be replaced to avoid inefficient
124
- // legalization code
125
- bool shouldReplace (Type *ITy) {
126
- FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy);
127
- if (!VTy)
128
- return false ;
129
-
130
- auto TLI = ST->getTargetLowering ();
131
-
132
- Type *EltTy = VTy->getElementType ();
133
- // If the element size is not less than the convert to scalar size, then we
134
- // can't do any bit packing
135
- if (!EltTy->isIntegerTy () ||
136
- EltTy->getScalarSizeInBits () > ConvertToScalar->getScalarSizeInBits ())
137
- return false ;
138
-
139
- // Only coerce illegal types
140
- TargetLoweringBase::LegalizeKind LK =
141
- TLI->getTypeConversion (EltTy->getContext (), EVT::getEVT (EltTy, false ));
142
- return LK.first != TargetLoweringBase::TypeLegal;
143
- }
144
-
145
- LiveRegOptimizer (Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) {
146
- DL = &Mod->getDataLayout ();
147
- ConvertToScalar = Type::getInt32Ty (Mod->getContext ());
148
- }
149
- };
150
-
151
84
} // end anonymous namespace
152
85
153
86
bool AMDGPULateCodeGenPrepare::doInitialization (Module &M) {
@@ -169,238 +102,14 @@ bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
169
102
AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache (F);
170
103
UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo ();
171
104
172
- // "Optimize" the virtual regs that cross basic block boundaries. When
173
- // building the SelectionDAG, vectors of illegal types that cross basic blocks
174
- // will be scalarized and widened, with each scalar living in its
175
- // own register. To work around this, this optimization converts the
176
- // vectors to equivalent vectors of legal type (which are converted back
177
- // before uses in subsequent blocks), to pack the bits into fewer physical
178
- // registers (used in CopyToReg/CopyFromReg pairs).
179
- LiveRegOptimizer LRO (Mod, &ST);
180
-
181
105
bool Changed = false ;
182
-
183
106
for (auto &BB : F)
184
- for (Instruction &I : make_early_inc_range (BB)) {
107
+ for (Instruction &I : llvm:: make_early_inc_range (BB))
185
108
Changed |= visit (I);
186
- Changed |= LRO.optimizeLiveType (&I);
187
- }
188
109
189
- LRO.removeDeadInstrs ();
190
110
return Changed;
191
111
}
192
112
193
- Type *LiveRegOptimizer::calculateConvertType (Type *OriginalType) {
194
- assert (OriginalType->getScalarSizeInBits () <=
195
- ConvertToScalar->getScalarSizeInBits ());
196
-
197
- FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
198
-
199
- TypeSize OriginalSize = DL->getTypeSizeInBits (VTy);
200
- TypeSize ConvertScalarSize = DL->getTypeSizeInBits (ConvertToScalar);
201
- unsigned ConvertEltCount =
202
- (OriginalSize + ConvertScalarSize - 1 ) / ConvertScalarSize;
203
-
204
- if (OriginalSize <= ConvertScalarSize)
205
- return IntegerType::get (Mod->getContext (), ConvertScalarSize);
206
-
207
- return VectorType::get (Type::getIntNTy (Mod->getContext (), ConvertScalarSize),
208
- ConvertEltCount, false );
209
- }
210
-
211
- Value *LiveRegOptimizer::convertToOptType (Instruction *V,
212
- BasicBlock::iterator &InsertPt) {
213
- FixedVectorType *VTy = cast<FixedVectorType>(V->getType ());
214
- Type *NewTy = calculateConvertType (V->getType ());
215
-
216
- TypeSize OriginalSize = DL->getTypeSizeInBits (VTy);
217
- TypeSize NewSize = DL->getTypeSizeInBits (NewTy);
218
-
219
- IRBuilder<> Builder (V->getParent (), InsertPt);
220
- // If there is a bitsize match, we can fit the old vector into a new vector of
221
- // desired type.
222
- if (OriginalSize == NewSize)
223
- return Builder.CreateBitCast (V, NewTy, V->getName () + " .bc" );
224
-
225
- // If there is a bitsize mismatch, we must use a wider vector.
226
- assert (NewSize > OriginalSize);
227
- uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits ();
228
-
229
- SmallVector<int , 8 > ShuffleMask;
230
- uint64_t OriginalElementCount = VTy->getElementCount ().getFixedValue ();
231
- for (unsigned I = 0 ; I < OriginalElementCount; I++)
232
- ShuffleMask.push_back (I);
233
-
234
- for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++)
235
- ShuffleMask.push_back (OriginalElementCount);
236
-
237
- Value *ExpandedVec = Builder.CreateShuffleVector (V, ShuffleMask);
238
- return Builder.CreateBitCast (ExpandedVec, NewTy, V->getName () + " .bc" );
239
- }
240
-
241
- Value *LiveRegOptimizer::convertFromOptType (Type *ConvertType, Instruction *V,
242
- BasicBlock::iterator &InsertPt,
243
- BasicBlock *InsertBB) {
244
- FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
245
-
246
- TypeSize OriginalSize = DL->getTypeSizeInBits (V->getType ());
247
- TypeSize NewSize = DL->getTypeSizeInBits (NewVTy);
248
-
249
- IRBuilder<> Builder (InsertBB, InsertPt);
250
- // If there is a bitsize match, we simply convert back to the original type.
251
- if (OriginalSize == NewSize)
252
- return Builder.CreateBitCast (V, NewVTy, V->getName () + " .bc" );
253
-
254
- // If there is a bitsize mismatch, then we must have used a wider value to
255
- // hold the bits.
256
- assert (OriginalSize > NewSize);
257
- // For wide scalars, we can just truncate the value.
258
- if (!V->getType ()->isVectorTy ()) {
259
- Instruction *Trunc = cast<Instruction>(
260
- Builder.CreateTrunc (V, IntegerType::get (Mod->getContext (), NewSize)));
261
- return cast<Instruction>(Builder.CreateBitCast (Trunc, NewVTy));
262
- }
263
-
264
- // For wider vectors, we must strip the MSBs to convert back to the original
265
- // type.
266
- VectorType *ExpandedVT = VectorType::get (
267
- Type::getIntNTy (Mod->getContext (), NewVTy->getScalarSizeInBits ()),
268
- (OriginalSize / NewVTy->getScalarSizeInBits ()), false );
269
- Instruction *Converted =
270
- cast<Instruction>(Builder.CreateBitCast (V, ExpandedVT));
271
-
272
- unsigned NarrowElementCount = NewVTy->getElementCount ().getFixedValue ();
273
- SmallVector<int , 8 > ShuffleMask (NarrowElementCount);
274
- std::iota (ShuffleMask.begin (), ShuffleMask.end (), 0 );
275
-
276
- return Builder.CreateShuffleVector (Converted, ShuffleMask);
277
- }
278
-
279
- bool LiveRegOptimizer::optimizeLiveType (Instruction *I) {
280
- SmallVector<Instruction *, 4 > Worklist;
281
- SmallPtrSet<PHINode *, 4 > PhiNodes;
282
- SmallPtrSet<Instruction *, 4 > Defs;
283
- SmallPtrSet<Instruction *, 4 > Uses;
284
-
285
- Worklist.push_back (cast<Instruction>(I));
286
- while (!Worklist.empty ()) {
287
- Instruction *II = Worklist.pop_back_val ();
288
-
289
- if (!Visited.insert (II).second )
290
- continue ;
291
-
292
- if (!shouldReplace (II->getType ()))
293
- continue ;
294
-
295
- if (PHINode *Phi = dyn_cast<PHINode>(II)) {
296
- PhiNodes.insert (Phi);
297
- // Collect all the incoming values of problematic PHI nodes.
298
- for (Value *V : Phi->incoming_values ()) {
299
- // Repeat the collection process for newly found PHI nodes.
300
- if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
301
- if (!PhiNodes.count (OpPhi) && !Visited.count (OpPhi))
302
- Worklist.push_back (OpPhi);
303
- continue ;
304
- }
305
-
306
- Instruction *IncInst = dyn_cast<Instruction>(V);
307
- // Other incoming value types (e.g. vector literals) are unhandled
308
- if (!IncInst && !isa<ConstantAggregateZero>(V))
309
- return false ;
310
-
311
- // Collect all other incoming values for coercion.
312
- if (IncInst)
313
- Defs.insert (IncInst);
314
- }
315
- }
316
-
317
- // Collect all relevant uses.
318
- for (User *V : II->users ()) {
319
- // Repeat the collection process for problematic PHI nodes.
320
- if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
321
- if (!PhiNodes.count (OpPhi) && !Visited.count (OpPhi))
322
- Worklist.push_back (OpPhi);
323
- continue ;
324
- }
325
-
326
- Instruction *UseInst = cast<Instruction>(V);
327
- // Collect all uses of PHINodes and any use the crosses BB boundaries.
328
- if (UseInst->getParent () != II->getParent () || isa<PHINode>(II)) {
329
- Uses.insert (UseInst);
330
- if (!Defs.count (II) && !isa<PHINode>(II)) {
331
- Defs.insert (II);
332
- }
333
- }
334
- }
335
- }
336
-
337
- // Coerce and track the defs.
338
- for (Instruction *D : Defs) {
339
- if (!ValMap.contains (D)) {
340
- BasicBlock::iterator InsertPt = std::next (D->getIterator ());
341
- Value *ConvertVal = convertToOptType (D, InsertPt);
342
- assert (ConvertVal);
343
- ValMap[D] = ConvertVal;
344
- }
345
- }
346
-
347
- // Construct new-typed PHI nodes.
348
- for (PHINode *Phi : PhiNodes) {
349
- ValMap[Phi] = PHINode::Create (calculateConvertType (Phi->getType ()),
350
- Phi->getNumIncomingValues (),
351
- Phi->getName () + " .tc" , Phi->getIterator ());
352
- }
353
-
354
- // Connect all the PHI nodes with their new incoming values.
355
- for (PHINode *Phi : PhiNodes) {
356
- PHINode *NewPhi = cast<PHINode>(ValMap[Phi]);
357
- bool MissingIncVal = false ;
358
- for (int I = 0 , E = Phi->getNumIncomingValues (); I < E; I++) {
359
- Value *IncVal = Phi->getIncomingValue (I);
360
- if (isa<ConstantAggregateZero>(IncVal)) {
361
- Type *NewType = calculateConvertType (Phi->getType ());
362
- NewPhi->addIncoming (ConstantInt::get (NewType, 0 , false ),
363
- Phi->getIncomingBlock (I));
364
- } else if (ValMap.contains (IncVal))
365
- NewPhi->addIncoming (ValMap[IncVal], Phi->getIncomingBlock (I));
366
- else
367
- MissingIncVal = true ;
368
- }
369
- DeadInstrs.insert (MissingIncVal ? cast<Instruction>(ValMap[Phi]) : Phi);
370
- }
371
- // Coerce back to the original type and replace the uses.
372
- for (Instruction *U : Uses) {
373
- // Replace all converted operands for a use.
374
- for (auto [OpIdx, Op] : enumerate(U->operands ())) {
375
- if (ValMap.contains (Op)) {
376
- Value *NewVal = nullptr ;
377
- if (BBUseValMap.contains (U->getParent ()) &&
378
- BBUseValMap[U->getParent ()].contains (ValMap[Op]))
379
- NewVal = BBUseValMap[U->getParent ()][ValMap[Op]];
380
- else {
381
- BasicBlock::iterator InsertPt = U->getParent ()->getFirstNonPHIIt ();
382
- NewVal =
383
- convertFromOptType (Op->getType (), cast<Instruction>(ValMap[Op]),
384
- InsertPt, U->getParent ());
385
- BBUseValMap[U->getParent ()][ValMap[Op]] = NewVal;
386
- }
387
- assert (NewVal);
388
- U->setOperand (OpIdx, NewVal);
389
- }
390
- }
391
- }
392
-
393
- return true ;
394
- }
395
-
396
- void LiveRegOptimizer::removeDeadInstrs () {
397
- // Remove instrs that have been marked dead after type-coercion.
398
- for (auto *I : DeadInstrs) {
399
- I->replaceAllUsesWith (PoisonValue::get (I->getType ()));
400
- I->eraseFromParent ();
401
- }
402
- }
403
-
404
113
bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad (LoadInst &LI) const {
405
114
unsigned AS = LI.getPointerAddressSpace ();
406
115
// Skip non-constant address space.
@@ -410,7 +119,7 @@ bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
410
119
// Skip non-simple loads.
411
120
if (!LI.isSimple ())
412
121
return false ;
413
- Type *Ty = LI.getType ();
122
+ auto *Ty = LI.getType ();
414
123
// Skip aggregate types.
415
124
if (Ty->isAggregateType ())
416
125
return false ;
0 commit comments