@@ -39,6 +39,83 @@ void RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
39
39
lower (MI, Mapping, WaterfallSgprs);
40
40
}
41
41
42
+ void RegBankLegalizeHelper::splitLoad (MachineInstr &MI,
43
+ ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
44
+ MachineFunction &MF = B.getMF ();
45
+ assert (MI.getNumMemOperands () == 1 );
46
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
47
+ Register Dst = MI.getOperand (0 ).getReg ();
48
+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
49
+ Register Base = MI.getOperand (1 ).getReg ();
50
+ LLT PtrTy = MRI.getType (Base);
51
+ const RegisterBank *PtrRB = MRI.getRegBankOrNull (Base);
52
+ LLT OffsetTy = LLT::scalar (PtrTy.getSizeInBits ());
53
+ SmallVector<Register, 4 > LoadPartRegs;
54
+
55
+ unsigned ByteOffset = 0 ;
56
+ for (LLT PartTy : LLTBreakdown) {
57
+ Register BasePlusOffset;
58
+ if (ByteOffset == 0 ) {
59
+ BasePlusOffset = Base;
60
+ } else {
61
+ auto Offset = B.buildConstant ({PtrRB, OffsetTy}, ByteOffset);
62
+ BasePlusOffset = B.buildPtrAdd ({PtrRB, PtrTy}, Base, Offset).getReg (0 );
63
+ }
64
+ auto *OffsetMMO = MF.getMachineMemOperand (&BaseMMO, ByteOffset, PartTy);
65
+ auto LoadPart = B.buildLoad ({DstRB, PartTy}, BasePlusOffset, *OffsetMMO);
66
+ LoadPartRegs.push_back (LoadPart.getReg (0 ));
67
+ ByteOffset += PartTy.getSizeInBytes ();
68
+ }
69
+
70
+ if (!MergeTy.isValid ()) {
71
+ // Loads are of same size, concat or merge them together.
72
+ B.buildMergeLikeInstr (Dst, LoadPartRegs);
73
+ } else {
74
+ // Loads are not all of same size, need to unmerge them to smaller pieces
75
+ // of MergeTy type, then merge pieces to Dst.
76
+ SmallVector<Register, 4 > MergeTyParts;
77
+ for (Register Reg : LoadPartRegs) {
78
+ if (MRI.getType (Reg) == MergeTy) {
79
+ MergeTyParts.push_back (Reg);
80
+ } else {
81
+ auto Unmerge = B.buildUnmerge ({DstRB, MergeTy}, Reg);
82
+ for (unsigned i = 0 ; i < Unmerge->getNumOperands () - 1 ; ++i)
83
+ MergeTyParts.push_back (Unmerge.getReg (i));
84
+ }
85
+ }
86
+ B.buildMergeLikeInstr (Dst, MergeTyParts);
87
+ }
88
+ MI.eraseFromParent ();
89
+ }
90
+
91
+ void RegBankLegalizeHelper::widenLoad (MachineInstr &MI, LLT WideTy,
92
+ LLT MergeTy) {
93
+ MachineFunction &MF = B.getMF ();
94
+ assert (MI.getNumMemOperands () == 1 );
95
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
96
+ Register Dst = MI.getOperand (0 ).getReg ();
97
+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
98
+ Register Base = MI.getOperand (1 ).getReg ();
99
+
100
+ MachineMemOperand *WideMMO = MF.getMachineMemOperand (&BaseMMO, 0 , WideTy);
101
+ auto WideLoad = B.buildLoad ({DstRB, WideTy}, Base, *WideMMO);
102
+
103
+ if (WideTy.isScalar ()) {
104
+ B.buildTrunc (Dst, WideLoad);
105
+ } else {
106
+ SmallVector<Register, 4 > MergeTyParts;
107
+ auto Unmerge = B.buildUnmerge ({DstRB, MergeTy}, WideLoad);
108
+
109
+ LLT DstTy = MRI.getType (Dst);
110
+ unsigned NumElts = DstTy.getSizeInBits () / MergeTy.getSizeInBits ();
111
+ for (unsigned i = 0 ; i < NumElts; ++i) {
112
+ MergeTyParts.push_back (Unmerge.getReg (i));
113
+ }
114
+ B.buildMergeLikeInstr (Dst, MergeTyParts);
115
+ }
116
+ MI.eraseFromParent ();
117
+ }
118
+
42
119
void RegBankLegalizeHelper::lower (MachineInstr &MI,
43
120
const RegBankLLTMapping &Mapping,
44
121
SmallSet<Register, 4 > &WaterfallSgprs) {
@@ -117,6 +194,54 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
117
194
MI.eraseFromParent ();
118
195
break ;
119
196
}
197
+ case SplitLoad: {
198
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
199
+ unsigned Size = DstTy.getSizeInBits ();
200
+ // Even split to 128-bit loads
201
+ if (Size > 128 ) {
202
+ LLT B128;
203
+ if (DstTy.isVector ()) {
204
+ LLT EltTy = DstTy.getElementType ();
205
+ B128 = LLT::fixed_vector (128 / EltTy.getSizeInBits (), EltTy);
206
+ } else {
207
+ B128 = LLT::scalar (128 );
208
+ }
209
+ if (Size / 128 == 2 )
210
+ splitLoad (MI, {B128, B128});
211
+ else if (Size / 128 == 4 )
212
+ splitLoad (MI, {B128, B128, B128, B128});
213
+ else {
214
+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
215
+ llvm_unreachable (" SplitLoad type not supported for MI" );
216
+ }
217
+ }
218
+ // 64 and 32 bit load
219
+ else if (DstTy == S96)
220
+ splitLoad (MI, {S64, S32}, S32);
221
+ else if (DstTy == V3S32)
222
+ splitLoad (MI, {V2S32, S32}, S32);
223
+ else if (DstTy == V6S16)
224
+ splitLoad (MI, {V4S16, V2S16}, V2S16);
225
+ else {
226
+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
227
+ llvm_unreachable (" SplitLoad type not supported for MI" );
228
+ }
229
+ break ;
230
+ }
231
+ case WidenLoad: {
232
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
233
+ if (DstTy == S96)
234
+ widenLoad (MI, S128);
235
+ else if (DstTy == V3S32)
236
+ widenLoad (MI, V4S32, S32);
237
+ else if (DstTy == V6S16)
238
+ widenLoad (MI, V8S16, V2S16);
239
+ else {
240
+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
241
+ llvm_unreachable (" WidenLoad type not supported for MI" );
242
+ }
243
+ break ;
244
+ }
120
245
}
121
246
122
247
// TODO: executeInWaterfallLoop(... WaterfallSgprs)
@@ -140,12 +265,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {
140
265
case Sgpr64:
141
266
case Vgpr64:
142
267
return LLT::scalar (64 );
268
+ case SgprP1:
269
+ case VgprP1:
270
+ return LLT::pointer (1 , 64 );
271
+ case SgprP3:
272
+ case VgprP3:
273
+ return LLT::pointer (3 , 32 );
274
+ case SgprP4:
275
+ case VgprP4:
276
+ return LLT::pointer (4 , 64 );
277
+ case SgprP5:
278
+ case VgprP5:
279
+ return LLT::pointer (5 , 32 );
143
280
case SgprV4S32:
144
281
case VgprV4S32:
145
282
case UniInVgprV4S32:
146
283
return LLT::fixed_vector (4 , 32 );
147
- case VgprP1:
148
- return LLT::pointer (1 , 64 );
284
+ default :
285
+ return LLT ();
286
+ }
287
+ }
288
+
289
+ LLT RegBankLegalizeHelper::getBTyFromID (RegBankLLTMappingApplyID ID, LLT Ty) {
290
+ switch (ID) {
291
+ case SgprB32:
292
+ case VgprB32:
293
+ case UniInVgprB32:
294
+ if (Ty == LLT::scalar (32 ) || Ty == LLT::fixed_vector (2 , 16 ) ||
295
+ Ty == LLT::pointer (3 , 32 ) || Ty == LLT::pointer (5 , 32 ) ||
296
+ Ty == LLT::pointer (6 , 32 ))
297
+ return Ty;
298
+ return LLT ();
299
+ case SgprB64:
300
+ case VgprB64:
301
+ case UniInVgprB64:
302
+ if (Ty == LLT::scalar (64 ) || Ty == LLT::fixed_vector (2 , 32 ) ||
303
+ Ty == LLT::fixed_vector (4 , 16 ) || Ty == LLT::pointer (0 , 64 ) ||
304
+ Ty == LLT::pointer (1 , 64 ) || Ty == LLT::pointer (4 , 64 ))
305
+ return Ty;
306
+ return LLT ();
307
+ case SgprB96:
308
+ case VgprB96:
309
+ case UniInVgprB96:
310
+ if (Ty == LLT::scalar (96 ) || Ty == LLT::fixed_vector (3 , 32 ) ||
311
+ Ty == LLT::fixed_vector (6 , 16 ))
312
+ return Ty;
313
+ return LLT ();
314
+ case SgprB128:
315
+ case VgprB128:
316
+ case UniInVgprB128:
317
+ if (Ty == LLT::scalar (128 ) || Ty == LLT::fixed_vector (4 , 32 ) ||
318
+ Ty == LLT::fixed_vector (2 , 64 ))
319
+ return Ty;
320
+ return LLT ();
321
+ case SgprB256:
322
+ case VgprB256:
323
+ case UniInVgprB256:
324
+ if (Ty == LLT::scalar (256 ) || Ty == LLT::fixed_vector (8 , 32 ) ||
325
+ Ty == LLT::fixed_vector (4 , 64 ) || Ty == LLT::fixed_vector (16 , 16 ))
326
+ return Ty;
327
+ return LLT ();
328
+ case SgprB512:
329
+ case VgprB512:
330
+ case UniInVgprB512:
331
+ if (Ty == LLT::scalar (512 ) || Ty == LLT::fixed_vector (16 , 32 ) ||
332
+ Ty == LLT::fixed_vector (8 , 64 ))
333
+ return Ty;
334
+ return LLT ();
149
335
default :
150
336
return LLT ();
151
337
}
@@ -159,10 +345,26 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
159
345
case Sgpr16:
160
346
case Sgpr32:
161
347
case Sgpr64:
348
+ case SgprP1:
349
+ case SgprP3:
350
+ case SgprP4:
351
+ case SgprP5:
162
352
case SgprV4S32:
353
+ case SgprB32:
354
+ case SgprB64:
355
+ case SgprB96:
356
+ case SgprB128:
357
+ case SgprB256:
358
+ case SgprB512:
163
359
case UniInVcc:
164
360
case UniInVgprS32:
165
361
case UniInVgprV4S32:
362
+ case UniInVgprB32:
363
+ case UniInVgprB64:
364
+ case UniInVgprB96:
365
+ case UniInVgprB128:
366
+ case UniInVgprB256:
367
+ case UniInVgprB512:
166
368
case Sgpr32Trunc:
167
369
case Sgpr32AExt:
168
370
case Sgpr32AExtBoolInReg:
@@ -171,7 +373,16 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
171
373
case Vgpr32:
172
374
case Vgpr64:
173
375
case VgprP1:
376
+ case VgprP3:
377
+ case VgprP4:
378
+ case VgprP5:
174
379
case VgprV4S32:
380
+ case VgprB32:
381
+ case VgprB64:
382
+ case VgprB96:
383
+ case VgprB128:
384
+ case VgprB256:
385
+ case VgprB512:
175
386
return VgprRB;
176
387
default :
177
388
return nullptr ;
@@ -196,16 +407,40 @@ void RegBankLegalizeHelper::applyMappingDst(
196
407
case Sgpr16:
197
408
case Sgpr32:
198
409
case Sgpr64:
410
+ case SgprP1:
411
+ case SgprP3:
412
+ case SgprP4:
413
+ case SgprP5:
199
414
case SgprV4S32:
200
415
case Vgpr32:
201
416
case Vgpr64:
202
417
case VgprP1:
418
+ case VgprP3:
419
+ case VgprP4:
420
+ case VgprP5:
203
421
case VgprV4S32: {
204
422
assert (Ty == getTyFromID (MethodIDs[OpIdx]));
205
423
assert (RB == getRegBankFromID (MethodIDs[OpIdx]));
206
424
break ;
207
425
}
208
- // uniform in vcc/vgpr: scalars and vectors
426
+ // sgpr and vgpr B-types
427
+ case SgprB32:
428
+ case SgprB64:
429
+ case SgprB96:
430
+ case SgprB128:
431
+ case SgprB256:
432
+ case SgprB512:
433
+ case VgprB32:
434
+ case VgprB64:
435
+ case VgprB96:
436
+ case VgprB128:
437
+ case VgprB256:
438
+ case VgprB512: {
439
+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
440
+ assert (RB == getRegBankFromID (MethodIDs[OpIdx]));
441
+ break ;
442
+ }
443
+ // uniform in vcc/vgpr: scalars, vectors and B-types
209
444
case UniInVcc: {
210
445
assert (Ty == S1);
211
446
assert (RB == SgprRB);
@@ -225,6 +460,19 @@ void RegBankLegalizeHelper::applyMappingDst(
225
460
buildReadAnyLane (B, Reg, NewVgprDst, RBI);
226
461
break ;
227
462
}
463
+ case UniInVgprB32:
464
+ case UniInVgprB64:
465
+ case UniInVgprB96:
466
+ case UniInVgprB128:
467
+ case UniInVgprB256:
468
+ case UniInVgprB512: {
469
+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
470
+ assert (RB == SgprRB);
471
+ Register NewVgprDst = MRI.createVirtualRegister ({VgprRB, Ty});
472
+ Op.setReg (NewVgprDst);
473
+ AMDGPU::buildReadAnyLane (B, Reg, NewVgprDst, RBI);
474
+ break ;
475
+ }
228
476
// sgpr trunc
229
477
case Sgpr32Trunc: {
230
478
assert (Ty.getSizeInBits () < 32 );
@@ -273,15 +521,33 @@ void RegBankLegalizeHelper::applyMappingSrc(
273
521
case Sgpr16:
274
522
case Sgpr32:
275
523
case Sgpr64:
524
+ case SgprP1:
525
+ case SgprP3:
526
+ case SgprP4:
527
+ case SgprP5:
276
528
case SgprV4S32: {
277
529
assert (Ty == getTyFromID (MethodIDs[i]));
278
530
assert (RB == getRegBankFromID (MethodIDs[i]));
279
531
break ;
280
532
}
533
+ // sgpr B-types
534
+ case SgprB32:
535
+ case SgprB64:
536
+ case SgprB96:
537
+ case SgprB128:
538
+ case SgprB256:
539
+ case SgprB512: {
540
+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
541
+ assert (RB == getRegBankFromID (MethodIDs[i]));
542
+ break ;
543
+ }
281
544
// vgpr scalars, pointers and vectors
282
545
case Vgpr32:
283
546
case Vgpr64:
284
547
case VgprP1:
548
+ case VgprP3:
549
+ case VgprP4:
550
+ case VgprP5:
285
551
case VgprV4S32: {
286
552
assert (Ty == getTyFromID (MethodIDs[i]));
287
553
if (RB != VgprRB) {
@@ -290,6 +556,20 @@ void RegBankLegalizeHelper::applyMappingSrc(
290
556
}
291
557
break ;
292
558
}
559
+ // vgpr B-types
560
+ case VgprB32:
561
+ case VgprB64:
562
+ case VgprB96:
563
+ case VgprB128:
564
+ case VgprB256:
565
+ case VgprB512: {
566
+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
567
+ if (RB != VgprRB) {
568
+ auto CopyToVgpr = B.buildCopy ({VgprRB, Ty}, Reg);
569
+ Op.setReg (CopyToVgpr.getReg (0 ));
570
+ }
571
+ break ;
572
+ }
293
573
// sgpr and vgpr scalars with extend
294
574
case Sgpr32AExt: {
295
575
// Note: this ext allows S1, and it is meant to be combined away.
@@ -362,7 +642,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
362
642
// We accept all types that can fit in some register class.
363
643
// Uniform G_PHIs have all sgpr registers.
364
644
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
365
- if (Ty == LLT::scalar (32 )) {
645
+ if (Ty == LLT::scalar (32 ) || Ty == LLT::pointer ( 4 , 64 ) ) {
366
646
return ;
367
647
}
368
648
0 commit comments