@@ -50,6 +50,83 @@ void RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
50
50
lower (MI, Mapping, WaterfallSgprs);
51
51
}
52
52
53
+ void RegBankLegalizeHelper::splitLoad (MachineInstr &MI,
54
+ ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
55
+ MachineFunction &MF = B.getMF ();
56
+ assert (MI.getNumMemOperands () == 1 );
57
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
58
+ Register Dst = MI.getOperand (0 ).getReg ();
59
+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
60
+ Register Base = MI.getOperand (1 ).getReg ();
61
+ LLT PtrTy = MRI.getType (Base);
62
+ const RegisterBank *PtrRB = MRI.getRegBankOrNull (Base);
63
+ LLT OffsetTy = LLT::scalar (PtrTy.getSizeInBits ());
64
+ SmallVector<Register, 4 > LoadPartRegs;
65
+
66
+ unsigned ByteOffset = 0 ;
67
+ for (LLT PartTy : LLTBreakdown) {
68
+ Register BasePlusOffset;
69
+ if (ByteOffset == 0 ) {
70
+ BasePlusOffset = Base;
71
+ } else {
72
+ auto Offset = B.buildConstant ({PtrRB, OffsetTy}, ByteOffset );
73
+ BasePlusOffset = B.buildPtrAdd ({PtrRB, PtrTy}, Base, Offset).getReg (0 );
74
+ }
75
+ auto *OffsetMMO = MF.getMachineMemOperand (&BaseMMO, ByteOffset , PartTy);
76
+ auto LoadPart = B.buildLoad ({DstRB, PartTy}, BasePlusOffset, *OffsetMMO);
77
+ LoadPartRegs.push_back (LoadPart.getReg (0 ));
78
+ ByteOffset += PartTy.getSizeInBytes ();
79
+ }
80
+
81
+ if (!MergeTy.isValid ()) {
82
+ // Loads are of same size, concat or merge them together.
83
+ B.buildMergeLikeInstr (Dst, LoadPartRegs);
84
+ } else {
85
+ // Loads are not all of same size, need to unmerge them to smaller pieces
86
+ // of MergeTy type, then merge pieces to Dst.
87
+ SmallVector<Register, 4 > MergeTyParts;
88
+ for (Register Reg : LoadPartRegs) {
89
+ if (MRI.getType (Reg) == MergeTy) {
90
+ MergeTyParts.push_back (Reg);
91
+ } else {
92
+ auto Unmerge = B.buildUnmerge ({DstRB, MergeTy}, Reg);
93
+ for (unsigned i = 0 ; i < Unmerge->getNumOperands () - 1 ; ++i)
94
+ MergeTyParts.push_back (Unmerge.getReg (i));
95
+ }
96
+ }
97
+ B.buildMergeLikeInstr (Dst, MergeTyParts);
98
+ }
99
+ MI.eraseFromParent ();
100
+ }
101
+
102
+ void RegBankLegalizeHelper::widenLoad (MachineInstr &MI, LLT WideTy,
103
+ LLT MergeTy) {
104
+ MachineFunction &MF = B.getMF ();
105
+ assert (MI.getNumMemOperands () == 1 );
106
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
107
+ Register Dst = MI.getOperand (0 ).getReg ();
108
+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
109
+ Register Base = MI.getOperand (1 ).getReg ();
110
+
111
+ MachineMemOperand *WideMMO = MF.getMachineMemOperand (&BaseMMO, 0 , WideTy);
112
+ auto WideLoad = B.buildLoad ({DstRB, WideTy}, Base, *WideMMO);
113
+
114
+ if (WideTy.isScalar ()) {
115
+ B.buildTrunc (Dst, WideLoad);
116
+ } else {
117
+ SmallVector<Register, 4 > MergeTyParts;
118
+ auto Unmerge = B.buildUnmerge ({DstRB, MergeTy}, WideLoad);
119
+
120
+ LLT DstTy = MRI.getType (Dst);
121
+ unsigned NumElts = DstTy.getSizeInBits () / MergeTy.getSizeInBits ();
122
+ for (unsigned i = 0 ; i < NumElts; ++i) {
123
+ MergeTyParts.push_back (Unmerge.getReg (i));
124
+ }
125
+ B.buildMergeLikeInstr (Dst, MergeTyParts);
126
+ }
127
+ MI.eraseFromParent ();
128
+ }
129
+
53
130
void RegBankLegalizeHelper::lower (MachineInstr &MI,
54
131
const RegBankLLTMapping &Mapping,
55
132
SmallSet<Register, 4 > &WaterfallSgprs) {
@@ -128,6 +205,54 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
128
205
MI.eraseFromParent ();
129
206
break ;
130
207
}
208
+ case SplitLoad: {
209
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
210
+ unsigned Size = DstTy.getSizeInBits ();
211
+ // Even split to 128-bit loads
212
+ if (Size > 128 ) {
213
+ LLT B128;
214
+ if (DstTy.isVector ()) {
215
+ LLT EltTy = DstTy.getElementType ();
216
+ B128 = LLT::fixed_vector (128 / EltTy.getSizeInBits (), EltTy);
217
+ } else {
218
+ B128 = LLT::scalar (128 );
219
+ }
220
+ if (Size / 128 == 2 )
221
+ splitLoad (MI, {B128, B128});
222
+ else if (Size / 128 == 4 )
223
+ splitLoad (MI, {B128, B128, B128, B128});
224
+ else {
225
+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
226
+ llvm_unreachable (" SplitLoad type not supported for MI" );
227
+ }
228
+ }
229
+ // 64 and 32 bit load
230
+ else if (DstTy == S96)
231
+ splitLoad (MI, {S64, S32}, S32);
232
+ else if (DstTy == V3S32)
233
+ splitLoad (MI, {V2S32, S32}, S32);
234
+ else if (DstTy == V6S16)
235
+ splitLoad (MI, {V4S16, V2S16}, V2S16);
236
+ else {
237
+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
238
+ llvm_unreachable (" SplitLoad type not supported for MI" );
239
+ }
240
+ break ;
241
+ }
242
+ case WidenLoad: {
243
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
244
+ if (DstTy == S96)
245
+ widenLoad (MI, S128);
246
+ else if (DstTy == V3S32)
247
+ widenLoad (MI, V4S32, S32);
248
+ else if (DstTy == V6S16)
249
+ widenLoad (MI, V8S16, V2S16);
250
+ else {
251
+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
252
+ llvm_unreachable (" WidenLoad type not supported for MI" );
253
+ }
254
+ break ;
255
+ }
131
256
}
132
257
133
258
// TODO: executeInWaterfallLoop(... WaterfallSgprs)
@@ -151,12 +276,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {
151
276
case Sgpr64:
152
277
case Vgpr64:
153
278
return LLT::scalar (64 );
279
+ case SgprP1:
280
+ case VgprP1:
281
+ return LLT::pointer (1 , 64 );
282
+ case SgprP3:
283
+ case VgprP3:
284
+ return LLT::pointer (3 , 32 );
285
+ case SgprP4:
286
+ case VgprP4:
287
+ return LLT::pointer (4 , 64 );
288
+ case SgprP5:
289
+ case VgprP5:
290
+ return LLT::pointer (5 , 32 );
154
291
case SgprV4S32:
155
292
case VgprV4S32:
156
293
case UniInVgprV4S32:
157
294
return LLT::fixed_vector (4 , 32 );
158
- case VgprP1:
159
- return LLT::pointer (1 , 64 );
295
+ default :
296
+ return LLT ();
297
+ }
298
+ }
299
+
300
+ LLT RegBankLegalizeHelper::getBTyFromID (RegBankLLTMappingApplyID ID, LLT Ty) {
301
+ switch (ID) {
302
+ case SgprB32:
303
+ case VgprB32:
304
+ case UniInVgprB32:
305
+ if (Ty == LLT::scalar (32 ) || Ty == LLT::fixed_vector (2 , 16 ) ||
306
+ Ty == LLT::pointer (3 , 32 ) || Ty == LLT::pointer (5 , 32 ) ||
307
+ Ty == LLT::pointer (6 , 32 ))
308
+ return Ty;
309
+ return LLT ();
310
+ case SgprB64:
311
+ case VgprB64:
312
+ case UniInVgprB64:
313
+ if (Ty == LLT::scalar (64 ) || Ty == LLT::fixed_vector (2 , 32 ) ||
314
+ Ty == LLT::fixed_vector (4 , 16 ) || Ty == LLT::pointer (0 , 64 ) ||
315
+ Ty == LLT::pointer (1 , 64 ) || Ty == LLT::pointer (4 , 64 ))
316
+ return Ty;
317
+ return LLT ();
318
+ case SgprB96:
319
+ case VgprB96:
320
+ case UniInVgprB96:
321
+ if (Ty == LLT::scalar (96 ) || Ty == LLT::fixed_vector (3 , 32 ) ||
322
+ Ty == LLT::fixed_vector (6 , 16 ))
323
+ return Ty;
324
+ return LLT ();
325
+ case SgprB128:
326
+ case VgprB128:
327
+ case UniInVgprB128:
328
+ if (Ty == LLT::scalar (128 ) || Ty == LLT::fixed_vector (4 , 32 ) ||
329
+ Ty == LLT::fixed_vector (2 , 64 ))
330
+ return Ty;
331
+ return LLT ();
332
+ case SgprB256:
333
+ case VgprB256:
334
+ case UniInVgprB256:
335
+ if (Ty == LLT::scalar (256 ) || Ty == LLT::fixed_vector (8 , 32 ) ||
336
+ Ty == LLT::fixed_vector (4 , 64 ) || Ty == LLT::fixed_vector (16 , 16 ))
337
+ return Ty;
338
+ return LLT ();
339
+ case SgprB512:
340
+ case VgprB512:
341
+ case UniInVgprB512:
342
+ if (Ty == LLT::scalar (512 ) || Ty == LLT::fixed_vector (16 , 32 ) ||
343
+ Ty == LLT::fixed_vector (8 , 64 ))
344
+ return Ty;
345
+ return LLT ();
160
346
default :
161
347
return LLT ();
162
348
}
@@ -170,10 +356,26 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
170
356
case Sgpr16:
171
357
case Sgpr32:
172
358
case Sgpr64:
359
+ case SgprP1:
360
+ case SgprP3:
361
+ case SgprP4:
362
+ case SgprP5:
173
363
case SgprV4S32:
364
+ case SgprB32:
365
+ case SgprB64:
366
+ case SgprB96:
367
+ case SgprB128:
368
+ case SgprB256:
369
+ case SgprB512:
174
370
case UniInVcc:
175
371
case UniInVgprS32:
176
372
case UniInVgprV4S32:
373
+ case UniInVgprB32:
374
+ case UniInVgprB64:
375
+ case UniInVgprB96:
376
+ case UniInVgprB128:
377
+ case UniInVgprB256:
378
+ case UniInVgprB512:
177
379
case Sgpr32Trunc:
178
380
case Sgpr32AExt:
179
381
case Sgpr32AExtBoolInReg:
@@ -182,7 +384,16 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
182
384
case Vgpr32:
183
385
case Vgpr64:
184
386
case VgprP1:
387
+ case VgprP3:
388
+ case VgprP4:
389
+ case VgprP5:
185
390
case VgprV4S32:
391
+ case VgprB32:
392
+ case VgprB64:
393
+ case VgprB96:
394
+ case VgprB128:
395
+ case VgprB256:
396
+ case VgprB512:
186
397
return VgprRB;
187
398
default :
188
399
return nullptr ;
@@ -207,16 +418,40 @@ void RegBankLegalizeHelper::applyMappingDst(
207
418
case Sgpr16:
208
419
case Sgpr32:
209
420
case Sgpr64:
421
+ case SgprP1:
422
+ case SgprP3:
423
+ case SgprP4:
424
+ case SgprP5:
210
425
case SgprV4S32:
211
426
case Vgpr32:
212
427
case Vgpr64:
213
428
case VgprP1:
429
+ case VgprP3:
430
+ case VgprP4:
431
+ case VgprP5:
214
432
case VgprV4S32: {
215
433
assert (Ty == getTyFromID (MethodIDs[OpIdx]));
216
434
assert (RB == getRegBankFromID (MethodIDs[OpIdx]));
217
435
break ;
218
436
}
219
- // uniform in vcc/vgpr: scalars and vectors
437
+ // sgpr and vgpr B-types
438
+ case SgprB32:
439
+ case SgprB64:
440
+ case SgprB96:
441
+ case SgprB128:
442
+ case SgprB256:
443
+ case SgprB512:
444
+ case VgprB32:
445
+ case VgprB64:
446
+ case VgprB96:
447
+ case VgprB128:
448
+ case VgprB256:
449
+ case VgprB512: {
450
+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
451
+ assert (RB == getRegBankFromID (MethodIDs[OpIdx]));
452
+ break ;
453
+ }
454
+ // uniform in vcc/vgpr: scalars, vectors and B-types
220
455
case UniInVcc: {
221
456
assert (Ty == S1);
222
457
assert (RB == SgprRB);
@@ -236,6 +471,19 @@ void RegBankLegalizeHelper::applyMappingDst(
236
471
buildReadAnyLane (B, Reg, NewVgprDst, RBI);
237
472
break ;
238
473
}
474
+ case UniInVgprB32:
475
+ case UniInVgprB64:
476
+ case UniInVgprB96:
477
+ case UniInVgprB128:
478
+ case UniInVgprB256:
479
+ case UniInVgprB512: {
480
+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
481
+ assert (RB == SgprRB);
482
+ Register NewVgprDst = MRI.createVirtualRegister ({VgprRB, Ty});
483
+ Op.setReg (NewVgprDst);
484
+ AMDGPU::buildReadAnyLane (B, Reg, NewVgprDst, RBI);
485
+ break ;
486
+ }
239
487
// sgpr trunc
240
488
case Sgpr32Trunc: {
241
489
assert (Ty.getSizeInBits () < 32 );
@@ -284,15 +532,33 @@ void RegBankLegalizeHelper::applyMappingSrc(
284
532
case Sgpr16:
285
533
case Sgpr32:
286
534
case Sgpr64:
535
+ case SgprP1:
536
+ case SgprP3:
537
+ case SgprP4:
538
+ case SgprP5:
287
539
case SgprV4S32: {
288
540
assert (Ty == getTyFromID (MethodIDs[i]));
289
541
assert (RB == getRegBankFromID (MethodIDs[i]));
290
542
break ;
291
543
}
544
+ // sgpr B-types
545
+ case SgprB32:
546
+ case SgprB64:
547
+ case SgprB96:
548
+ case SgprB128:
549
+ case SgprB256:
550
+ case SgprB512: {
551
+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
552
+ assert (RB == getRegBankFromID (MethodIDs[i]));
553
+ break ;
554
+ }
292
555
// vgpr scalars, pointers and vectors
293
556
case Vgpr32:
294
557
case Vgpr64:
295
558
case VgprP1:
559
+ case VgprP3:
560
+ case VgprP4:
561
+ case VgprP5:
296
562
case VgprV4S32: {
297
563
assert (Ty == getTyFromID (MethodIDs[i]));
298
564
if (RB != VgprRB) {
@@ -301,6 +567,20 @@ void RegBankLegalizeHelper::applyMappingSrc(
301
567
}
302
568
break ;
303
569
}
570
+ // vgpr B-types
571
+ case VgprB32:
572
+ case VgprB64:
573
+ case VgprB96:
574
+ case VgprB128:
575
+ case VgprB256:
576
+ case VgprB512: {
577
+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
578
+ if (RB != VgprRB) {
579
+ auto CopyToVgpr = B.buildCopy ({VgprRB, Ty}, Reg);
580
+ Op.setReg (CopyToVgpr.getReg (0 ));
581
+ }
582
+ break ;
583
+ }
304
584
// sgpr and vgpr scalars with extend
305
585
case Sgpr32AExt: {
306
586
// Note: this ext allows S1, and it is meant to be combined away.
@@ -373,7 +653,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
373
653
// We accept all types that can fit in some register class.
374
654
// Uniform G_PHIs have all sgpr registers.
375
655
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
376
- if (Ty == LLT::scalar (32 )) {
656
+ if (Ty == LLT::scalar (32 ) || Ty == LLT::pointer ( 4 , 64 ) ) {
377
657
return ;
378
658
}
379
659
0 commit comments