Skip to content

Commit ac430b4

Browse files
committed
[NVPTX] Make i16x2 a native type and add support for instructions supporting it
On sm_90 some instructions now support i16x2 which allows hardware to execute more efficiently add, min and max instructions. In order to support that we need to make i16x2 a native type in the backend. This does the necessary changes to make i16x2 a native type and adds support for the instructions natively supporting i16x2. This caused a negative test in nvptx slp to start passing. Changed the test to a positive one as the IR is correctly vectorized.
1 parent 24c5f18 commit ac430b4

13 files changed

+751
-179
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
615615
// We only care about f16x2 as it's the only real vector type we
616616
// need to deal with.
617617
MVT VT = Vector.getSimpleValueType();
618-
if (!(VT == MVT::v2f16 || VT == MVT::v2bf16))
618+
if (!(VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16))
619619
return false;
620620
// Find and record all uses of this vector that extract element 0 or 1.
621621
SmallVector<SDNode *, 4> E0, E1;
@@ -828,6 +828,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
828828
return Opcode_i16;
829829
case MVT::v2f16:
830830
case MVT::v2bf16:
831+
case MVT::v2i16:
831832
return Opcode_i32;
832833
case MVT::f32:
833834
return Opcode_f32;
@@ -909,9 +910,10 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
909910
// Vector Setting
910911
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
911912
if (SimpleVT.isVector()) {
912-
assert((LoadedVT == MVT::v2f16 || LoadedVT == MVT::v2bf16) &&
913+
assert((LoadedVT == MVT::v2f16 || LoadedVT == MVT::v2bf16 ||
914+
LoadedVT == MVT::v2i16) &&
913915
"Unexpected vector type");
914-
// v2f16/v2bf16 is loaded using ld.b32
916+
// v2f16/v2bf16/v2i16 is loaded using ld.b32
915917
fromTypeWidth = 32;
916918
}
917919

@@ -1064,7 +1066,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
10641066
// v8f16 is a special case. PTX doesn't have ld.v8.f16
10651067
// instruction. Instead, we split the vector into v2f16 chunks and
10661068
// load them with ld.v4.b32.
1067-
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
1069+
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16 || EltVT == MVT::v2i16) {
10681070
assert(N->getOpcode() == NVPTXISD::LoadV4 && "Unexpected load opcode.");
10691071
EltVT = MVT::i32;
10701072
FromType = NVPTX::PTXLdStInstCode::Untyped;
@@ -1262,10 +1264,11 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12621264
EltVT = EltVT.getVectorElementType();
12631265
// vectors of f16 are loaded/stored as multiples of v2f16 elements.
12641266
if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
1265-
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16)) {
1266-
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
1267-
EltVT = N->getValueType(0);
1268-
NumElts /= 2;
1267+
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
1268+
(EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
1269+
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
1270+
EltVT = N->getValueType(0);
1271+
NumElts /= 2;
12691272
}
12701273
}
12711274

@@ -1678,7 +1681,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
16781681
MVT ScalarVT = SimpleVT.getScalarType();
16791682
unsigned toTypeWidth = ScalarVT.getSizeInBits();
16801683
if (SimpleVT.isVector()) {
1681-
assert((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16) &&
1684+
assert((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16 ||
1685+
StoreVT == MVT::v2i16) &&
16821686
"Unexpected vector type");
16831687
// v2f16 is stored using st.b32
16841688
toTypeWidth = 32;
@@ -1847,7 +1851,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
18471851
// v8f16 is a special case. PTX doesn't have st.v8.f16
18481852
// instruction. Instead, we split the vector into v2f16 chunks and
18491853
// store them with st.v4.b32.
1850-
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
1854+
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16 || EltVT == MVT::v2i16) {
18511855
assert(N->getOpcode() == NVPTXISD::StoreV4 && "Unexpected load opcode.");
18521856
EltVT = MVT::i32;
18531857
ToType = NVPTX::PTXLdStInstCode::Untyped;

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 105 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ static bool IsPTXVectorType(MVT VT) {
133133
case MVT::v4i8:
134134
case MVT::v2i16:
135135
case MVT::v4i16:
136+
case MVT::v8i16: // <4 x i16x2>
136137
case MVT::v2i32:
137138
case MVT::v4i32:
138139
case MVT::v2i64:
@@ -149,12 +150,13 @@ static bool IsPTXVectorType(MVT VT) {
149150
}
150151
}
151152

152-
static bool Isv2f16Orv2bf16Type(EVT VT) {
153-
return (VT == MVT::v2f16 || VT == MVT::v2bf16);
153+
static bool Isv2f16Orv2bf16Orv2i16Type(EVT VT) {
154+
return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
154155
}
155156

156-
static bool Isf16Orbf16Type(MVT VT) {
157-
return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16);
157+
static bool Is16bitsType(MVT VT) {
158+
return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16 ||
159+
VT.SimpleTy == MVT::i16);
158160
}
159161

160162
/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
@@ -207,8 +209,13 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
207209
// Vectors with an even number of f16 elements will be passed to
208210
// us as an array of v2f16/v2bf16 elements. We must match this so we
209211
// stay in sync with Ins/Outs.
210-
if ((Isf16Orbf16Type(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
211-
EltVT = EltVT == MVT::f16 ? MVT::v2f16 : MVT::v2bf16;
212+
if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
213+
if (EltVT == MVT::f16)
214+
EltVT = MVT::v2f16;
215+
else if (EltVT == MVT::bf16)
216+
EltVT = MVT::v2bf16;
217+
else if (EltVT == MVT::i16)
218+
EltVT = MVT::v2i16;
212219
NumElts /= 2;
213220
}
214221
for (unsigned j = 0; j != NumElts; ++j) {
@@ -427,8 +434,26 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
427434
Op, VT, IsOpSupported ? Action : NoBF16Action);
428435
};
429436

437+
auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
438+
LegalizeAction NoI16x2Action) {
439+
bool IsOpSupported = false;
440+
// instructions are available on sm_90 only
441+
switch (Op) {
442+
case ISD::ADD:
443+
case ISD::SMAX:
444+
case ISD::SMIN:
445+
case ISD::UMIN:
446+
case ISD::UMAX:
447+
case ISD::SUB:
448+
IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80;
449+
break;
450+
}
451+
setOperationAction(Op, VT, IsOpSupported ? Action : NoI16x2Action);
452+
};
453+
430454
addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
431455
addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
456+
addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
432457
addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
433458
addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
434459
addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
@@ -459,9 +484,17 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
459484

460485
setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
461486
setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
487+
488+
// Conversion to/from i16/i16x2 is always legal.
489+
setOperationAction(ISD::BUILD_VECTOR, MVT::v2i16, Custom);
490+
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2i16, Custom);
491+
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2i16, Expand);
492+
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2i16, Expand);
493+
462494
// Operations not directly supported by NVPTX.
463-
for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
464-
MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::i32, MVT::i64}) {
495+
for (MVT VT :
496+
{MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32, MVT::f64,
497+
MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}) {
465498
setOperationAction(ISD::SELECT_CC, VT, Expand);
466499
setOperationAction(ISD::BR_CC, VT, Expand);
467500
}
@@ -473,6 +506,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
473506
setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Legal);
474507
setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8 , Legal);
475508
setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
509+
setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::v2i16, Expand);
476510

477511
setOperationAction(ISD::SHL_PARTS, MVT::i32 , Custom);
478512
setOperationAction(ISD::SRA_PARTS, MVT::i32 , Custom);
@@ -493,10 +527,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
493527
setOperationAction(ISD::ROTR, MVT::i32, Legal);
494528

495529
setOperationAction(ISD::ROTL, MVT::i16, Expand);
530+
setOperationAction(ISD::ROTL, MVT::v2i16, Expand);
496531
setOperationAction(ISD::ROTR, MVT::i16, Expand);
532+
setOperationAction(ISD::ROTR, MVT::v2i16, Expand);
497533
setOperationAction(ISD::ROTL, MVT::i8, Expand);
498534
setOperationAction(ISD::ROTR, MVT::i8, Expand);
499535
setOperationAction(ISD::BSWAP, MVT::i16, Expand);
536+
setOperationAction(ISD::BSWAP, MVT::v2i16, Expand);
500537
setOperationAction(ISD::BSWAP, MVT::i32, Expand);
501538
setOperationAction(ISD::BSWAP, MVT::i64, Expand);
502539

@@ -584,6 +621,22 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
584621
setOperationAction(ISD::CTLZ, Ty, Legal);
585622
}
586623

624+
setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Expand);
625+
setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Expand);
626+
setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Expand);
627+
setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Expand);
628+
setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Expand);
629+
setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand);
630+
setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand);
631+
632+
setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Expand);
633+
setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Expand);
634+
setI16x2OperationAction(ISD::AND, MVT::v2i16, Legal, Expand);
635+
setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Expand);
636+
setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Expand);
637+
setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Expand);
638+
setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Expand);
639+
587640
setOperationAction(ISD::ADDC, MVT::i32, Legal);
588641
setOperationAction(ISD::ADDE, MVT::i32, Legal);
589642
setOperationAction(ISD::SUBC, MVT::i32, Legal);
@@ -596,6 +649,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
596649
}
597650

598651
setOperationAction(ISD::CTTZ, MVT::i16, Expand);
652+
setOperationAction(ISD::CTTZ, MVT::v2i16, Expand);
599653
setOperationAction(ISD::CTTZ, MVT::i32, Expand);
600654
setOperationAction(ISD::CTTZ, MVT::i64, Expand);
601655

@@ -1318,7 +1372,7 @@ NVPTXTargetLowering::getPreferredVectorAction(MVT VT) const {
13181372
if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
13191373
VT.getScalarType() == MVT::i1)
13201374
return TypeSplitVector;
1321-
if (Isv2f16Orv2bf16Type(VT))
1375+
if (Isv2f16Orv2bf16Orv2i16Type(VT))
13221376
return TypeLegal;
13231377
return TargetLoweringBase::getPreferredVectorAction(VT);
13241378
}
@@ -2098,15 +2152,31 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
20982152
// generates good SASS in both cases.
20992153
SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
21002154
SelectionDAG &DAG) const {
2101-
if (!(Isv2f16Orv2bf16Type(Op->getValueType(0)) &&
2102-
isa<ConstantFPSDNode>(Op->getOperand(0)) &&
2103-
isa<ConstantFPSDNode>(Op->getOperand(1))))
2155+
EVT VT = Op->getValueType(0);
2156+
if (!(Isv2f16Orv2bf16Orv2i16Type(VT)))
21042157
return Op;
2158+
APInt E0;
2159+
APInt E1;
2160+
if (VT == MVT::v2f16 || VT == MVT::v2bf16) {
2161+
if (!(isa<ConstantFPSDNode>(Op->getOperand(0)) &&
2162+
isa<ConstantFPSDNode>(Op->getOperand(1))))
2163+
return Op;
2164+
2165+
E0 = cast<ConstantFPSDNode>(Op->getOperand(0))
2166+
->getValueAPF()
2167+
.bitcastToAPInt();
2168+
E1 = cast<ConstantFPSDNode>(Op->getOperand(1))
2169+
->getValueAPF()
2170+
.bitcastToAPInt();
2171+
} else {
2172+
assert(VT == MVT::v2i16);
2173+
if (!(isa<ConstantSDNode>(Op->getOperand(0)) &&
2174+
isa<ConstantSDNode>(Op->getOperand(1))))
2175+
return Op;
21052176

2106-
APInt E0 =
2107-
cast<ConstantFPSDNode>(Op->getOperand(0))->getValueAPF().bitcastToAPInt();
2108-
APInt E1 =
2109-
cast<ConstantFPSDNode>(Op->getOperand(1))->getValueAPF().bitcastToAPInt();
2177+
E0 = cast<ConstantSDNode>(Op->getOperand(0))->getAPIntValue();
2178+
E1 = cast<ConstantSDNode>(Op->getOperand(1))->getAPIntValue();
2179+
}
21102180
SDValue Const =
21112181
DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32);
21122182
return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
@@ -2122,7 +2192,8 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
21222192
// Extract individual elements and select one of them.
21232193
SDValue Vector = Op->getOperand(0);
21242194
EVT VectorVT = Vector.getValueType();
2125-
assert(VectorVT == MVT::v2f16 && "Unexpected vector type.");
2195+
assert((VectorVT == MVT::v2f16 || VectorVT == MVT::v2i16) &&
2196+
"Unexpected vector type.");
21262197
EVT EltVT = VectorVT.getVectorElementType();
21272198

21282199
SDLoc dl(Op.getNode());
@@ -2470,7 +2541,7 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
24702541

24712542
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
24722543
// loads and have to handle it here.
2473-
if (Isv2f16Orv2bf16Type(Op.getValueType())) {
2544+
if (Isv2f16Orv2bf16Orv2i16Type(Op.getValueType())) {
24742545
LoadSDNode *Load = cast<LoadSDNode>(Op);
24752546
EVT MemVT = Load->getMemoryVT();
24762547
if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
@@ -2515,13 +2586,13 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
25152586

25162587
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
25172588
// stores and have to handle it here.
2518-
if (Isv2f16Orv2bf16Type(VT) &&
2589+
if (Isv2f16Orv2bf16Orv2i16Type(VT) &&
25192590
!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
25202591
VT, *Store->getMemOperand()))
25212592
return expandUnalignedStore(Store, DAG);
25222593

2523-
// v2f16 and v2bf16 don't need special handling.
2524-
if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2594+
// v2f16, v2bf16 and v2i16 don't need special handling.
2595+
if (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16)
25252596
return SDValue();
25262597

25272598
if (VT.isVector())
@@ -2562,6 +2633,7 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
25622633
case MVT::v4f32:
25632634
case MVT::v8f16: // <4 x f16x2>
25642635
case MVT::v8bf16: // <4 x bf16x2>
2636+
case MVT::v8i16: // <4 x i16x2>
25652637
// This is a "native" vector type
25662638
break;
25672639
}
@@ -2606,8 +2678,7 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
26062678
// v8f16 is a special case. PTX doesn't have st.v8.f16
26072679
// instruction. Instead, we split the vector into v2f16 chunks and
26082680
// store them with st.v4.b32.
2609-
assert(Isf16Orbf16Type(EltVT.getSimpleVT()) &&
2610-
"Wrong type for the vector.");
2681+
assert(Is16bitsType(EltVT.getSimpleVT()) && "Wrong type for the vector.");
26112682
Opcode = NVPTXISD::StoreV4;
26122683
StoreF16x2 = true;
26132684
break;
@@ -2793,7 +2864,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
27932864
EVT LoadVT = EltVT;
27942865
if (EltVT == MVT::i1)
27952866
LoadVT = MVT::i8;
2796-
else if (Isv2f16Orv2bf16Type(EltVT))
2867+
else if (Isv2f16Orv2bf16Orv2i16Type(EltVT))
27972868
// getLoad needs a vector type, but it can't handle
27982869
// vectors which contain v2f16 or v2bf16 elements. So we must load
27992870
// using i32 here and then bitcast back.
@@ -2819,7 +2890,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
28192890
if (EltVT == MVT::i1)
28202891
Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
28212892
// v2f16 was loaded as an i32. Now we must bitcast it back.
2822-
else if (Isv2f16Orv2bf16Type(EltVT))
2893+
else if (Isv2f16Orv2bf16Orv2i16Type(EltVT))
28232894
Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
28242895

28252896
// If a promoted integer type is used, truncate down to the original
@@ -5198,6 +5269,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
51985269
case MVT::v4f16:
51995270
case MVT::v4f32:
52005271
case MVT::v8f16: // <4 x f16x2>
5272+
case MVT::v8i16: // <4 x i16x2>
52015273
// This is a "native" vector type
52025274
break;
52035275
}
@@ -5250,11 +5322,16 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
52505322
// v8f16 is a special case. PTX doesn't have ld.v8.f16
52515323
// instruction. Instead, we split the vector into v2f16 chunks and
52525324
// load them with ld.v4.b32.
5253-
assert(Isf16Orbf16Type(EltVT.getSimpleVT()) &&
5254-
"Unsupported v8 vector type.");
5325+
assert(Is16bitsType(EltVT.getSimpleVT()) && "Unsupported v8 vector type.");
52555326
LoadF16x2 = true;
52565327
Opcode = NVPTXISD::LoadV4;
5257-
EVT VVT = (EltVT == MVT::f16) ? MVT::v2f16 : MVT::v2bf16;
5328+
EVT VVT;
5329+
if (EltVT == MVT::f16)
5330+
VVT = MVT::v2f16;
5331+
else if (EltVT == MVT::bf16)
5332+
VVT = MVT::v2bf16;
5333+
else if (EltVT == MVT::i16)
5334+
VVT = MVT::v2i16;
52585335
EVT ListVTs[] = {VVT, VVT, VVT, VVT, MVT::Other};
52595336
LdResVTs = DAG.getVTList(ListVTs);
52605337
break;

0 commit comments

Comments
 (0)