@@ -38,6 +38,83 @@ void RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
38
38
lower (MI, Mapping, WaterfallSgprs);
39
39
}
40
40
41
+ void RegBankLegalizeHelper::splitLoad (MachineInstr &MI,
42
+ ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
43
+ MachineFunction &MF = B.getMF ();
44
+ assert (MI.getNumMemOperands () == 1 );
45
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
46
+ Register Dst = MI.getOperand (0 ).getReg ();
47
+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
48
+ Register Base = MI.getOperand (1 ).getReg ();
49
+ LLT PtrTy = MRI.getType (Base);
50
+ const RegisterBank *PtrRB = MRI.getRegBankOrNull (Base);
51
+ LLT OffsetTy = LLT::scalar (PtrTy.getSizeInBits ());
52
+ SmallVector<Register, 4 > LoadPartRegs;
53
+
54
+ unsigned ByteOffset = 0 ;
55
+ for (LLT PartTy : LLTBreakdown) {
56
+ Register BasePlusOffset;
57
+ if (ByteOffset == 0 ) {
58
+ BasePlusOffset = Base;
59
+ } else {
60
+ auto Offset = B.buildConstant ({PtrRB, OffsetTy}, ByteOffset);
61
+ BasePlusOffset = B.buildPtrAdd ({PtrRB, PtrTy}, Base, Offset).getReg (0 );
62
+ }
63
+ auto *OffsetMMO = MF.getMachineMemOperand (&BaseMMO, ByteOffset, PartTy);
64
+ auto LoadPart = B.buildLoad ({DstRB, PartTy}, BasePlusOffset, *OffsetMMO);
65
+ LoadPartRegs.push_back (LoadPart.getReg (0 ));
66
+ ByteOffset += PartTy.getSizeInBytes ();
67
+ }
68
+
69
+ if (!MergeTy.isValid ()) {
70
+ // Loads are of same size, concat or merge them together.
71
+ B.buildMergeLikeInstr (Dst, LoadPartRegs);
72
+ } else {
73
+ // Loads are not all of same size, need to unmerge them to smaller pieces
74
+ // of MergeTy type, then merge pieces to Dst.
75
+ SmallVector<Register, 4 > MergeTyParts;
76
+ for (Register Reg : LoadPartRegs) {
77
+ if (MRI.getType (Reg) == MergeTy) {
78
+ MergeTyParts.push_back (Reg);
79
+ } else {
80
+ auto Unmerge = B.buildUnmerge ({DstRB, MergeTy}, Reg);
81
+ for (unsigned i = 0 ; i < Unmerge->getNumOperands () - 1 ; ++i)
82
+ MergeTyParts.push_back (Unmerge.getReg (i));
83
+ }
84
+ }
85
+ B.buildMergeLikeInstr (Dst, MergeTyParts);
86
+ }
87
+ MI.eraseFromParent ();
88
+ }
89
+
90
+ void RegBankLegalizeHelper::widenLoad (MachineInstr &MI, LLT WideTy,
91
+ LLT MergeTy) {
92
+ MachineFunction &MF = B.getMF ();
93
+ assert (MI.getNumMemOperands () == 1 );
94
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
95
+ Register Dst = MI.getOperand (0 ).getReg ();
96
+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
97
+ Register Base = MI.getOperand (1 ).getReg ();
98
+
99
+ MachineMemOperand *WideMMO = MF.getMachineMemOperand (&BaseMMO, 0 , WideTy);
100
+ auto WideLoad = B.buildLoad ({DstRB, WideTy}, Base, *WideMMO);
101
+
102
+ if (WideTy.isScalar ()) {
103
+ B.buildTrunc (Dst, WideLoad);
104
+ } else {
105
+ SmallVector<Register, 4 > MergeTyParts;
106
+ auto Unmerge = B.buildUnmerge ({DstRB, MergeTy}, WideLoad);
107
+
108
+ LLT DstTy = MRI.getType (Dst);
109
+ unsigned NumElts = DstTy.getSizeInBits () / MergeTy.getSizeInBits ();
110
+ for (unsigned i = 0 ; i < NumElts; ++i) {
111
+ MergeTyParts.push_back (Unmerge.getReg (i));
112
+ }
113
+ B.buildMergeLikeInstr (Dst, MergeTyParts);
114
+ }
115
+ MI.eraseFromParent ();
116
+ }
117
+
41
118
void RegBankLegalizeHelper::lower (MachineInstr &MI,
42
119
const RegBankLLTMapping &Mapping,
43
120
SmallSet<Register, 4 > &WaterfallSgprs) {
@@ -116,6 +193,54 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
116
193
MI.eraseFromParent ();
117
194
break ;
118
195
}
196
+ case SplitLoad: {
197
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
198
+ unsigned Size = DstTy.getSizeInBits ();
199
+ // Even split to 128-bit loads
200
+ if (Size > 128 ) {
201
+ LLT B128;
202
+ if (DstTy.isVector ()) {
203
+ LLT EltTy = DstTy.getElementType ();
204
+ B128 = LLT::fixed_vector (128 / EltTy.getSizeInBits (), EltTy);
205
+ } else {
206
+ B128 = LLT::scalar (128 );
207
+ }
208
+ if (Size / 128 == 2 )
209
+ splitLoad (MI, {B128, B128});
210
+ else if (Size / 128 == 4 )
211
+ splitLoad (MI, {B128, B128, B128, B128});
212
+ else {
213
+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
214
+ llvm_unreachable (" SplitLoad type not supported for MI" );
215
+ }
216
+ }
217
+ // 64 and 32 bit load
218
+ else if (DstTy == S96)
219
+ splitLoad (MI, {S64, S32}, S32);
220
+ else if (DstTy == V3S32)
221
+ splitLoad (MI, {V2S32, S32}, S32);
222
+ else if (DstTy == V6S16)
223
+ splitLoad (MI, {V4S16, V2S16}, V2S16);
224
+ else {
225
+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
226
+ llvm_unreachable (" SplitLoad type not supported for MI" );
227
+ }
228
+ break ;
229
+ }
230
+ case WidenLoad: {
231
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
232
+ if (DstTy == S96)
233
+ widenLoad (MI, S128);
234
+ else if (DstTy == V3S32)
235
+ widenLoad (MI, V4S32, S32);
236
+ else if (DstTy == V6S16)
237
+ widenLoad (MI, V8S16, V2S16);
238
+ else {
239
+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
240
+ llvm_unreachable (" WidenLoad type not supported for MI" );
241
+ }
242
+ break ;
243
+ }
119
244
}
120
245
121
246
// TODO: executeInWaterfallLoop(... WaterfallSgprs)
@@ -139,12 +264,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {
139
264
case Sgpr64:
140
265
case Vgpr64:
141
266
return LLT::scalar (64 );
267
+ case SgprP1:
268
+ case VgprP1:
269
+ return LLT::pointer (1 , 64 );
270
+ case SgprP3:
271
+ case VgprP3:
272
+ return LLT::pointer (3 , 32 );
273
+ case SgprP4:
274
+ case VgprP4:
275
+ return LLT::pointer (4 , 64 );
276
+ case SgprP5:
277
+ case VgprP5:
278
+ return LLT::pointer (5 , 32 );
142
279
case SgprV4S32:
143
280
case VgprV4S32:
144
281
case UniInVgprV4S32:
145
282
return LLT::fixed_vector (4 , 32 );
146
- case VgprP1:
147
- return LLT::pointer (1 , 64 );
283
+ default :
284
+ return LLT ();
285
+ }
286
+ }
287
+
288
+ LLT RegBankLegalizeHelper::getBTyFromID (RegBankLLTMappingApplyID ID, LLT Ty) {
289
+ switch (ID) {
290
+ case SgprB32:
291
+ case VgprB32:
292
+ case UniInVgprB32:
293
+ if (Ty == LLT::scalar (32 ) || Ty == LLT::fixed_vector (2 , 16 ) ||
294
+ Ty == LLT::pointer (3 , 32 ) || Ty == LLT::pointer (5 , 32 ) ||
295
+ Ty == LLT::pointer (6 , 32 ))
296
+ return Ty;
297
+ return LLT ();
298
+ case SgprB64:
299
+ case VgprB64:
300
+ case UniInVgprB64:
301
+ if (Ty == LLT::scalar (64 ) || Ty == LLT::fixed_vector (2 , 32 ) ||
302
+ Ty == LLT::fixed_vector (4 , 16 ) || Ty == LLT::pointer (0 , 64 ) ||
303
+ Ty == LLT::pointer (1 , 64 ) || Ty == LLT::pointer (4 , 64 ))
304
+ return Ty;
305
+ return LLT ();
306
+ case SgprB96:
307
+ case VgprB96:
308
+ case UniInVgprB96:
309
+ if (Ty == LLT::scalar (96 ) || Ty == LLT::fixed_vector (3 , 32 ) ||
310
+ Ty == LLT::fixed_vector (6 , 16 ))
311
+ return Ty;
312
+ return LLT ();
313
+ case SgprB128:
314
+ case VgprB128:
315
+ case UniInVgprB128:
316
+ if (Ty == LLT::scalar (128 ) || Ty == LLT::fixed_vector (4 , 32 ) ||
317
+ Ty == LLT::fixed_vector (2 , 64 ))
318
+ return Ty;
319
+ return LLT ();
320
+ case SgprB256:
321
+ case VgprB256:
322
+ case UniInVgprB256:
323
+ if (Ty == LLT::scalar (256 ) || Ty == LLT::fixed_vector (8 , 32 ) ||
324
+ Ty == LLT::fixed_vector (4 , 64 ) || Ty == LLT::fixed_vector (16 , 16 ))
325
+ return Ty;
326
+ return LLT ();
327
+ case SgprB512:
328
+ case VgprB512:
329
+ case UniInVgprB512:
330
+ if (Ty == LLT::scalar (512 ) || Ty == LLT::fixed_vector (16 , 32 ) ||
331
+ Ty == LLT::fixed_vector (8 , 64 ))
332
+ return Ty;
333
+ return LLT ();
148
334
default :
149
335
return LLT ();
150
336
}
@@ -158,10 +344,26 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
158
344
case Sgpr16:
159
345
case Sgpr32:
160
346
case Sgpr64:
347
+ case SgprP1:
348
+ case SgprP3:
349
+ case SgprP4:
350
+ case SgprP5:
161
351
case SgprV4S32:
352
+ case SgprB32:
353
+ case SgprB64:
354
+ case SgprB96:
355
+ case SgprB128:
356
+ case SgprB256:
357
+ case SgprB512:
162
358
case UniInVcc:
163
359
case UniInVgprS32:
164
360
case UniInVgprV4S32:
361
+ case UniInVgprB32:
362
+ case UniInVgprB64:
363
+ case UniInVgprB96:
364
+ case UniInVgprB128:
365
+ case UniInVgprB256:
366
+ case UniInVgprB512:
165
367
case Sgpr32Trunc:
166
368
case Sgpr32AExt:
167
369
case Sgpr32AExtBoolInReg:
@@ -170,7 +372,16 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
170
372
case Vgpr32:
171
373
case Vgpr64:
172
374
case VgprP1:
375
+ case VgprP3:
376
+ case VgprP4:
377
+ case VgprP5:
173
378
case VgprV4S32:
379
+ case VgprB32:
380
+ case VgprB64:
381
+ case VgprB96:
382
+ case VgprB128:
383
+ case VgprB256:
384
+ case VgprB512:
174
385
return VgprRB;
175
386
default :
176
387
return nullptr ;
@@ -195,16 +406,40 @@ void RegBankLegalizeHelper::applyMappingDst(
195
406
case Sgpr16:
196
407
case Sgpr32:
197
408
case Sgpr64:
409
+ case SgprP1:
410
+ case SgprP3:
411
+ case SgprP4:
412
+ case SgprP5:
198
413
case SgprV4S32:
199
414
case Vgpr32:
200
415
case Vgpr64:
201
416
case VgprP1:
417
+ case VgprP3:
418
+ case VgprP4:
419
+ case VgprP5:
202
420
case VgprV4S32: {
203
421
assert (Ty == getTyFromID (MethodIDs[OpIdx]));
204
422
assert (RB == getRegBankFromID (MethodIDs[OpIdx]));
205
423
break ;
206
424
}
207
- // uniform in vcc/vgpr: scalars and vectors
425
+ // sgpr and vgpr B-types
426
+ case SgprB32:
427
+ case SgprB64:
428
+ case SgprB96:
429
+ case SgprB128:
430
+ case SgprB256:
431
+ case SgprB512:
432
+ case VgprB32:
433
+ case VgprB64:
434
+ case VgprB96:
435
+ case VgprB128:
436
+ case VgprB256:
437
+ case VgprB512: {
438
+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
439
+ assert (RB == getRegBankFromID (MethodIDs[OpIdx]));
440
+ break ;
441
+ }
442
+ // uniform in vcc/vgpr: scalars, vectors and B-types
208
443
case UniInVcc: {
209
444
assert (Ty == S1);
210
445
assert (RB == SgprRB);
@@ -224,6 +459,19 @@ void RegBankLegalizeHelper::applyMappingDst(
224
459
buildReadAnyLane (B, Reg, NewVgprDst, RBI);
225
460
break ;
226
461
}
462
+ case UniInVgprB32:
463
+ case UniInVgprB64:
464
+ case UniInVgprB96:
465
+ case UniInVgprB128:
466
+ case UniInVgprB256:
467
+ case UniInVgprB512: {
468
+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
469
+ assert (RB == SgprRB);
470
+ Register NewVgprDst = MRI.createVirtualRegister ({VgprRB, Ty});
471
+ Op.setReg (NewVgprDst);
472
+ AMDGPU::buildReadAnyLane (B, Reg, NewVgprDst, RBI);
473
+ break ;
474
+ }
227
475
// sgpr trunc
228
476
case Sgpr32Trunc: {
229
477
assert (Ty.getSizeInBits () < 32 );
@@ -272,15 +520,33 @@ void RegBankLegalizeHelper::applyMappingSrc(
272
520
case Sgpr16:
273
521
case Sgpr32:
274
522
case Sgpr64:
523
+ case SgprP1:
524
+ case SgprP3:
525
+ case SgprP4:
526
+ case SgprP5:
275
527
case SgprV4S32: {
276
528
assert (Ty == getTyFromID (MethodIDs[i]));
277
529
assert (RB == getRegBankFromID (MethodIDs[i]));
278
530
break ;
279
531
}
532
+ // sgpr B-types
533
+ case SgprB32:
534
+ case SgprB64:
535
+ case SgprB96:
536
+ case SgprB128:
537
+ case SgprB256:
538
+ case SgprB512: {
539
+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
540
+ assert (RB == getRegBankFromID (MethodIDs[i]));
541
+ break ;
542
+ }
280
543
// vgpr scalars, pointers and vectors
281
544
case Vgpr32:
282
545
case Vgpr64:
283
546
case VgprP1:
547
+ case VgprP3:
548
+ case VgprP4:
549
+ case VgprP5:
284
550
case VgprV4S32: {
285
551
assert (Ty == getTyFromID (MethodIDs[i]));
286
552
if (RB != VgprRB) {
@@ -289,6 +555,20 @@ void RegBankLegalizeHelper::applyMappingSrc(
289
555
}
290
556
break ;
291
557
}
558
+ // vgpr B-types
559
+ case VgprB32:
560
+ case VgprB64:
561
+ case VgprB96:
562
+ case VgprB128:
563
+ case VgprB256:
564
+ case VgprB512: {
565
+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
566
+ if (RB != VgprRB) {
567
+ auto CopyToVgpr = B.buildCopy ({VgprRB, Ty}, Reg);
568
+ Op.setReg (CopyToVgpr.getReg (0 ));
569
+ }
570
+ break ;
571
+ }
292
572
// sgpr and vgpr scalars with extend
293
573
case Sgpr32AExt: {
294
574
// Note: this ext allows S1, and it is meant to be combined away.
@@ -361,7 +641,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
361
641
// We accept all types that can fit in some register class.
362
642
// Uniform G_PHIs have all sgpr registers.
363
643
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
364
- if (Ty == LLT::scalar (32 )) {
644
+ if (Ty == LLT::scalar (32 ) || Ty == LLT::pointer ( 4 , 64 ) ) {
365
645
return ;
366
646
}
367
647
0 commit comments