@@ -150,10 +150,6 @@ static bool IsPTXVectorType(MVT VT) {
150
150
}
151
151
}
152
152
153
- static bool Isv2f16Orv2bf16Orv2i16Type (EVT VT) {
154
- return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
155
- }
156
-
157
153
static bool Is16bitsType (MVT VT) {
158
154
return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16 ||
159
155
VT.SimpleTy == MVT::i16);
@@ -1372,7 +1368,7 @@ NVPTXTargetLowering::getPreferredVectorAction(MVT VT) const {
1372
1368
if (!VT.isScalableVector () && VT.getVectorNumElements () != 1 &&
1373
1369
VT.getScalarType () == MVT::i1)
1374
1370
return TypeSplitVector;
1375
- if (Isv2f16Orv2bf16Orv2i16Type (VT))
1371
+ if (Isv2x16VT (VT))
1376
1372
return TypeLegal;
1377
1373
return TargetLoweringBase::getPreferredVectorAction (VT);
1378
1374
}
@@ -2153,7 +2149,7 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2153
2149
SDValue NVPTXTargetLowering::LowerBUILD_VECTOR (SDValue Op,
2154
2150
SelectionDAG &DAG) const {
2155
2151
EVT VT = Op->getValueType (0 );
2156
- if (!(Isv2f16Orv2bf16Orv2i16Type (VT)))
2152
+ if (!(Isv2x16VT (VT)))
2157
2153
return Op;
2158
2154
APInt E0 ;
2159
2155
APInt E1 ;
@@ -2192,8 +2188,7 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2192
2188
// Extract individual elements and select one of them.
2193
2189
SDValue Vector = Op->getOperand (0 );
2194
2190
EVT VectorVT = Vector.getValueType ();
2195
- assert ((VectorVT == MVT::v2f16 || VectorVT == MVT::v2i16) &&
2196
- " Unexpected vector type." );
2191
+ assert (Isv2x16VT (VectorVT) && " Unexpected vector type." );
2197
2192
EVT EltVT = VectorVT.getVectorElementType ();
2198
2193
2199
2194
SDLoc dl (Op.getNode ());
@@ -2571,9 +2566,9 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
2571
2566
if (Op.getValueType () == MVT::i1)
2572
2567
return LowerLOADi1 (Op, DAG);
2573
2568
2574
- // v2f16 is legal, so we can't rely on legalizer to handle unaligned
2575
- // loads and have to handle it here.
2576
- if (Isv2f16Orv2bf16Orv2i16Type (Op.getValueType ())) {
2569
+ // v2f16/v2bf16/v2i16 are legal, so we can't rely on legalizer to handle
2570
+ // unaligned loads and have to handle it here.
2571
+ if (Isv2x16VT (Op.getValueType ())) {
2577
2572
LoadSDNode *Load = cast<LoadSDNode>(Op);
2578
2573
EVT MemVT = Load->getMemoryVT ();
2579
2574
if (!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
@@ -2618,13 +2613,13 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
2618
2613
2619
2614
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
2620
2615
// stores and have to handle it here.
2621
- if (Isv2f16Orv2bf16Orv2i16Type (VT) &&
2616
+ if (Isv2x16VT (VT) &&
2622
2617
!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
2623
2618
VT, *Store->getMemOperand ()))
2624
2619
return expandUnalignedStore (Store, DAG);
2625
2620
2626
2621
// v2f16, v2bf16 and v2i16 don't need special handling.
2627
- if (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16 )
2622
+ if (Isv2x16VT (VT) )
2628
2623
return SDValue ();
2629
2624
2630
2625
if (VT.isVector ())
@@ -2896,7 +2891,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
2896
2891
EVT LoadVT = EltVT;
2897
2892
if (EltVT == MVT::i1)
2898
2893
LoadVT = MVT::i8;
2899
- else if (Isv2f16Orv2bf16Orv2i16Type (EltVT))
2894
+ else if (Isv2x16VT (EltVT))
2900
2895
// getLoad needs a vector type, but it can't handle
2901
2896
// vectors which contain v2f16 or v2bf16 elements. So we must load
2902
2897
// using i32 here and then bitcast back.
@@ -2922,7 +2917,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
2922
2917
if (EltVT == MVT::i1)
2923
2918
Elt = DAG.getNode (ISD::TRUNCATE, dl, MVT::i1, Elt);
2924
2919
// v2f16 was loaded as an i32. Now we must bitcast it back.
2925
- else if (Isv2f16Orv2bf16Orv2i16Type (EltVT))
2920
+ else if (Isv2x16VT (EltVT))
2926
2921
Elt = DAG.getNode (ISD::BITCAST, dl, EltVT, Elt);
2927
2922
2928
2923
// If a promoted integer type is used, truncate down to the original
@@ -5335,7 +5330,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
5335
5330
5336
5331
unsigned Opcode = 0 ;
5337
5332
SDVTList LdResVTs;
5338
- bool LoadF16x2 = false ;
5333
+ bool Load16x2 = false ;
5339
5334
5340
5335
switch (NumElts) {
5341
5336
default :
@@ -5355,7 +5350,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
5355
5350
// instruction. Instead, we split the vector into v2f16 chunks and
5356
5351
// load them with ld.v4.b32.
5357
5352
assert (Is16bitsType (EltVT.getSimpleVT ()) && " Unsupported v8 vector type." );
5358
- LoadF16x2 = true ;
5353
+ Load16x2 = true ;
5359
5354
Opcode = NVPTXISD::LoadV4;
5360
5355
EVT VVT;
5361
5356
if (EltVT == MVT::f16)
@@ -5382,7 +5377,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
5382
5377
LD->getMemOperand ());
5383
5378
5384
5379
SmallVector<SDValue, 8 > ScalarRes;
5385
- if (LoadF16x2 ) {
5380
+ if (Load16x2 ) {
5386
5381
// Split v2f16 subvectors back into individual elements.
5387
5382
NumElts /= 2 ;
5388
5383
for (unsigned i = 0 ; i < NumElts; ++i) {
0 commit comments