@@ -133,6 +133,7 @@ static bool IsPTXVectorType(MVT VT) {
133
133
case MVT::v4i8:
134
134
case MVT::v2i16:
135
135
case MVT::v4i16:
136
+ case MVT::v8i16: // <4 x i16x2>
136
137
case MVT::v2i32:
137
138
case MVT::v4i32:
138
139
case MVT::v2i64:
@@ -149,12 +150,13 @@ static bool IsPTXVectorType(MVT VT) {
149
150
}
150
151
}
151
152
152
- static bool Isv2f16Orv2bf16Type (EVT VT) {
153
- return (VT == MVT::v2f16 || VT == MVT::v2bf16);
153
+ static bool Isv2f16Orv2bf16Orv2i16Type (EVT VT) {
154
+ return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16 );
154
155
}
155
156
156
- static bool Isf16Orbf16Type (MVT VT) {
157
- return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16);
157
+ static bool Is16bitsType (MVT VT) {
158
+ return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16 ||
159
+ VT.SimpleTy == MVT::i16);
158
160
}
159
161
160
162
// / ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
@@ -207,8 +209,13 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
207
209
// Vectors with an even number of f16 elements will be passed to
208
210
// us as an array of v2f16/v2bf16 elements. We must match this so we
209
211
// stay in sync with Ins/Outs.
210
- if ((Isf16Orbf16Type (EltVT.getSimpleVT ())) && NumElts % 2 == 0 ) {
211
- EltVT = EltVT == MVT::f16 ? MVT::v2f16 : MVT::v2bf16;
212
+ if ((Is16bitsType (EltVT.getSimpleVT ())) && NumElts % 2 == 0 ) {
213
+ if (EltVT == MVT::f16)
214
+ EltVT = MVT::v2f16;
215
+ else if (EltVT == MVT::bf16)
216
+ EltVT = MVT::v2bf16;
217
+ else if (EltVT == MVT::i16)
218
+ EltVT = MVT::v2i16;
212
219
NumElts /= 2 ;
213
220
}
214
221
for (unsigned j = 0 ; j != NumElts; ++j) {
@@ -427,8 +434,26 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
427
434
Op, VT, IsOpSupported ? Action : NoBF16Action);
428
435
};
429
436
437
+ auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
438
+ LegalizeAction NoI16x2Action) {
439
+ bool IsOpSupported = false ;
440
+ // instructions are available on sm_90 only
441
+ switch (Op) {
442
+ case ISD::ADD:
443
+ case ISD::SMAX:
444
+ case ISD::SMIN:
445
+ case ISD::UMIN:
446
+ case ISD::UMAX:
447
+ case ISD::SUB:
448
+ IsOpSupported = STI.getSmVersion () >= 90 && STI.getPTXVersion () >= 80 ;
449
+ break ;
450
+ }
451
+ setOperationAction (Op, VT, IsOpSupported ? Action : NoI16x2Action);
452
+ };
453
+
430
454
addRegisterClass (MVT::i1, &NVPTX::Int1RegsRegClass);
431
455
addRegisterClass (MVT::i16, &NVPTX::Int16RegsRegClass);
456
+ addRegisterClass (MVT::v2i16, &NVPTX::Int32RegsRegClass);
432
457
addRegisterClass (MVT::i32, &NVPTX::Int32RegsRegClass);
433
458
addRegisterClass (MVT::i64, &NVPTX::Int64RegsRegClass);
434
459
addRegisterClass (MVT::f32, &NVPTX::Float32RegsRegClass);
@@ -459,9 +484,17 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
459
484
460
485
setBF16OperationAction (ISD::SETCC, MVT::bf16, Legal, Promote);
461
486
setBF16OperationAction (ISD::SETCC, MVT::v2bf16, Legal, Expand);
487
+
488
+ // Conversion to/from i16/i16x2 is always legal.
489
+ setOperationAction (ISD::BUILD_VECTOR, MVT::v2i16, Custom);
490
+ setOperationAction (ISD::EXTRACT_VECTOR_ELT, MVT::v2i16, Custom);
491
+ setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::v2i16, Expand);
492
+ setOperationAction (ISD::VECTOR_SHUFFLE, MVT::v2i16, Expand);
493
+
462
494
// Operations not directly supported by NVPTX.
463
- for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
464
- MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::i32, MVT::i64}) {
495
+ for (MVT VT :
496
+ {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32, MVT::f64,
497
+ MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}) {
465
498
setOperationAction (ISD::SELECT_CC, VT, Expand);
466
499
setOperationAction (ISD::BR_CC, VT, Expand);
467
500
}
@@ -473,6 +506,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
473
506
setOperationAction (ISD::SIGN_EXTEND_INREG, MVT::i16, Legal);
474
507
setOperationAction (ISD::SIGN_EXTEND_INREG, MVT::i8 , Legal);
475
508
setOperationAction (ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
509
+ setOperationAction (ISD::SIGN_EXTEND_INREG, MVT::v2i16, Expand);
476
510
477
511
setOperationAction (ISD::SHL_PARTS, MVT::i32 , Custom);
478
512
setOperationAction (ISD::SRA_PARTS, MVT::i32 , Custom);
@@ -493,10 +527,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
493
527
setOperationAction (ISD::ROTR, MVT::i32, Legal);
494
528
495
529
setOperationAction (ISD::ROTL, MVT::i16, Expand);
530
+ setOperationAction (ISD::ROTL, MVT::v2i16, Expand);
496
531
setOperationAction (ISD::ROTR, MVT::i16, Expand);
532
+ setOperationAction (ISD::ROTR, MVT::v2i16, Expand);
497
533
setOperationAction (ISD::ROTL, MVT::i8, Expand);
498
534
setOperationAction (ISD::ROTR, MVT::i8, Expand);
499
535
setOperationAction (ISD::BSWAP, MVT::i16, Expand);
536
+ setOperationAction (ISD::BSWAP, MVT::v2i16, Expand);
500
537
setOperationAction (ISD::BSWAP, MVT::i32, Expand);
501
538
setOperationAction (ISD::BSWAP, MVT::i64, Expand);
502
539
@@ -584,6 +621,22 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
584
621
setOperationAction (ISD::CTLZ, Ty, Legal);
585
622
}
586
623
624
+ setI16x2OperationAction (ISD::ABS, MVT::v2i16, Legal, Expand);
625
+ setI16x2OperationAction (ISD::SMIN, MVT::v2i16, Legal, Expand);
626
+ setI16x2OperationAction (ISD::SMAX, MVT::v2i16, Legal, Expand);
627
+ setI16x2OperationAction (ISD::UMIN, MVT::v2i16, Legal, Expand);
628
+ setI16x2OperationAction (ISD::UMAX, MVT::v2i16, Legal, Expand);
629
+ setI16x2OperationAction (ISD::CTPOP, MVT::v2i16, Legal, Expand);
630
+ setI16x2OperationAction (ISD::CTLZ, MVT::v2i16, Legal, Expand);
631
+
632
+ setI16x2OperationAction (ISD::ADD, MVT::v2i16, Legal, Expand);
633
+ setI16x2OperationAction (ISD::SUB, MVT::v2i16, Legal, Expand);
634
+ setI16x2OperationAction (ISD::AND, MVT::v2i16, Legal, Expand);
635
+ setI16x2OperationAction (ISD::MUL, MVT::v2i16, Legal, Expand);
636
+ setI16x2OperationAction (ISD::SHL, MVT::v2i16, Legal, Expand);
637
+ setI16x2OperationAction (ISD::SREM, MVT::v2i16, Legal, Expand);
638
+ setI16x2OperationAction (ISD::UREM, MVT::v2i16, Legal, Expand);
639
+
587
640
setOperationAction (ISD::ADDC, MVT::i32, Legal);
588
641
setOperationAction (ISD::ADDE, MVT::i32, Legal);
589
642
setOperationAction (ISD::SUBC, MVT::i32, Legal);
@@ -596,6 +649,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
596
649
}
597
650
598
651
setOperationAction (ISD::CTTZ, MVT::i16, Expand);
652
+ setOperationAction (ISD::CTTZ, MVT::v2i16, Expand);
599
653
setOperationAction (ISD::CTTZ, MVT::i32, Expand);
600
654
setOperationAction (ISD::CTTZ, MVT::i64, Expand);
601
655
@@ -1318,7 +1372,7 @@ NVPTXTargetLowering::getPreferredVectorAction(MVT VT) const {
1318
1372
if (!VT.isScalableVector () && VT.getVectorNumElements () != 1 &&
1319
1373
VT.getScalarType () == MVT::i1)
1320
1374
return TypeSplitVector;
1321
- if (Isv2f16Orv2bf16Type (VT))
1375
+ if (Isv2f16Orv2bf16Orv2i16Type (VT))
1322
1376
return TypeLegal;
1323
1377
return TargetLoweringBase::getPreferredVectorAction (VT);
1324
1378
}
@@ -2098,15 +2152,31 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2098
2152
// generates good SASS in both cases.
2099
2153
SDValue NVPTXTargetLowering::LowerBUILD_VECTOR (SDValue Op,
2100
2154
SelectionDAG &DAG) const {
2101
- if (!(Isv2f16Orv2bf16Type (Op->getValueType (0 )) &&
2102
- isa<ConstantFPSDNode>(Op->getOperand (0 )) &&
2103
- isa<ConstantFPSDNode>(Op->getOperand (1 ))))
2155
+ EVT VT = Op->getValueType (0 );
2156
+ if (!(Isv2f16Orv2bf16Orv2i16Type (VT)))
2104
2157
return Op;
2158
+ APInt E0 ;
2159
+ APInt E1 ;
2160
+ if (VT == MVT::v2f16 || VT == MVT::v2bf16) {
2161
+ if (!(isa<ConstantFPSDNode>(Op->getOperand (0 )) &&
2162
+ isa<ConstantFPSDNode>(Op->getOperand (1 ))))
2163
+ return Op;
2164
+
2165
+ E0 = cast<ConstantFPSDNode>(Op->getOperand (0 ))
2166
+ ->getValueAPF ()
2167
+ .bitcastToAPInt ();
2168
+ E1 = cast<ConstantFPSDNode>(Op->getOperand (1 ))
2169
+ ->getValueAPF ()
2170
+ .bitcastToAPInt ();
2171
+ } else {
2172
+ assert (VT == MVT::v2i16);
2173
+ if (!(isa<ConstantSDNode>(Op->getOperand (0 )) &&
2174
+ isa<ConstantSDNode>(Op->getOperand (1 ))))
2175
+ return Op;
2105
2176
2106
- APInt E0 =
2107
- cast<ConstantFPSDNode>(Op->getOperand (0 ))->getValueAPF ().bitcastToAPInt ();
2108
- APInt E1 =
2109
- cast<ConstantFPSDNode>(Op->getOperand (1 ))->getValueAPF ().bitcastToAPInt ();
2177
+ E0 = cast<ConstantSDNode>(Op->getOperand (0 ))->getAPIntValue ();
2178
+ E1 = cast<ConstantSDNode>(Op->getOperand (1 ))->getAPIntValue ();
2179
+ }
2110
2180
SDValue Const =
2111
2181
DAG.getConstant (E1 .zext (32 ).shl (16 ) | E0 .zext (32 ), SDLoc (Op), MVT::i32);
2112
2182
return DAG.getNode (ISD::BITCAST, SDLoc (Op), Op->getValueType (0 ), Const);
@@ -2122,7 +2192,8 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2122
2192
// Extract individual elements and select one of them.
2123
2193
SDValue Vector = Op->getOperand (0 );
2124
2194
EVT VectorVT = Vector.getValueType ();
2125
- assert (VectorVT == MVT::v2f16 && " Unexpected vector type." );
2195
+ assert ((VectorVT == MVT::v2f16 || VectorVT == MVT::v2i16) &&
2196
+ " Unexpected vector type." );
2126
2197
EVT EltVT = VectorVT.getVectorElementType ();
2127
2198
2128
2199
SDLoc dl (Op.getNode ());
@@ -2470,7 +2541,7 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
2470
2541
2471
2542
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
2472
2543
// loads and have to handle it here.
2473
- if (Isv2f16Orv2bf16Type (Op.getValueType ())) {
2544
+ if (Isv2f16Orv2bf16Orv2i16Type (Op.getValueType ())) {
2474
2545
LoadSDNode *Load = cast<LoadSDNode>(Op);
2475
2546
EVT MemVT = Load->getMemoryVT ();
2476
2547
if (!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
@@ -2515,13 +2586,13 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
2515
2586
2516
2587
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
2517
2588
// stores and have to handle it here.
2518
- if (Isv2f16Orv2bf16Type (VT) &&
2589
+ if (Isv2f16Orv2bf16Orv2i16Type (VT) &&
2519
2590
!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
2520
2591
VT, *Store->getMemOperand ()))
2521
2592
return expandUnalignedStore (Store, DAG);
2522
2593
2523
- // v2f16 and v2bf16 don't need special handling.
2524
- if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2594
+ // v2f16, v2bf16 and v2i16 don't need special handling.
2595
+ if (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16 )
2525
2596
return SDValue ();
2526
2597
2527
2598
if (VT.isVector ())
@@ -2562,6 +2633,7 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
2562
2633
case MVT::v4f32:
2563
2634
case MVT::v8f16: // <4 x f16x2>
2564
2635
case MVT::v8bf16: // <4 x bf16x2>
2636
+ case MVT::v8i16: // <4 x i16x2>
2565
2637
// This is a "native" vector type
2566
2638
break ;
2567
2639
}
@@ -2606,8 +2678,7 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
2606
2678
// v8f16 is a special case. PTX doesn't have st.v8.f16
2607
2679
// instruction. Instead, we split the vector into v2f16 chunks and
2608
2680
// store them with st.v4.b32.
2609
- assert (Isf16Orbf16Type (EltVT.getSimpleVT ()) &&
2610
- " Wrong type for the vector." );
2681
+ assert (Is16bitsType (EltVT.getSimpleVT ()) && " Wrong type for the vector." );
2611
2682
Opcode = NVPTXISD::StoreV4;
2612
2683
StoreF16x2 = true ;
2613
2684
break ;
@@ -2793,7 +2864,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
2793
2864
EVT LoadVT = EltVT;
2794
2865
if (EltVT == MVT::i1)
2795
2866
LoadVT = MVT::i8;
2796
- else if (Isv2f16Orv2bf16Type (EltVT))
2867
+ else if (Isv2f16Orv2bf16Orv2i16Type (EltVT))
2797
2868
// getLoad needs a vector type, but it can't handle
2798
2869
// vectors which contain v2f16 or v2bf16 elements. So we must load
2799
2870
// using i32 here and then bitcast back.
@@ -2819,7 +2890,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
2819
2890
if (EltVT == MVT::i1)
2820
2891
Elt = DAG.getNode (ISD::TRUNCATE, dl, MVT::i1, Elt);
2821
2892
// v2f16 was loaded as an i32. Now we must bitcast it back.
2822
- else if (Isv2f16Orv2bf16Type (EltVT))
2893
+ else if (Isv2f16Orv2bf16Orv2i16Type (EltVT))
2823
2894
Elt = DAG.getNode (ISD::BITCAST, dl, EltVT, Elt);
2824
2895
2825
2896
// If a promoted integer type is used, truncate down to the original
@@ -5198,6 +5269,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
5198
5269
case MVT::v4f16:
5199
5270
case MVT::v4f32:
5200
5271
case MVT::v8f16: // <4 x f16x2>
5272
+ case MVT::v8i16: // <4 x i16x2>
5201
5273
// This is a "native" vector type
5202
5274
break ;
5203
5275
}
@@ -5250,11 +5322,16 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
5250
5322
// v8f16 is a special case. PTX doesn't have ld.v8.f16
5251
5323
// instruction. Instead, we split the vector into v2f16 chunks and
5252
5324
// load them with ld.v4.b32.
5253
- assert (Isf16Orbf16Type (EltVT.getSimpleVT ()) &&
5254
- " Unsupported v8 vector type." );
5325
+ assert (Is16bitsType (EltVT.getSimpleVT ()) && " Unsupported v8 vector type." );
5255
5326
LoadF16x2 = true ;
5256
5327
Opcode = NVPTXISD::LoadV4;
5257
- EVT VVT = (EltVT == MVT::f16) ? MVT::v2f16 : MVT::v2bf16;
5328
+ EVT VVT;
5329
+ if (EltVT == MVT::f16)
5330
+ VVT = MVT::v2f16;
5331
+ else if (EltVT == MVT::bf16)
5332
+ VVT = MVT::v2bf16;
5333
+ else if (EltVT == MVT::i16)
5334
+ VVT = MVT::v2i16;
5258
5335
EVT ListVTs[] = {VVT, VVT, VVT, VVT, MVT::Other};
5259
5336
LdResVTs = DAG.getVTList (ListVTs);
5260
5337
break ;
0 commit comments