Skip to content

Commit 1a4acc7

Browse files
committed
Address review comments
1 parent 9520bf8 commit 1a4acc7

File tree

6 files changed

+319
-318
lines changed

6 files changed

+319
-318
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,10 +1681,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
16811681
MVT ScalarVT = SimpleVT.getScalarType();
16821682
unsigned toTypeWidth = ScalarVT.getSizeInBits();
16831683
if (SimpleVT.isVector()) {
1684-
assert((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16 ||
1685-
StoreVT == MVT::v2i16) &&
1686-
"Unexpected vector type");
1687-
// v2f16 is stored using st.b32
1684+
assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
1685+
// v2x16 is stored using st.b32
16881686
toTypeWidth = 32;
16891687
}
16901688

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,6 @@ static bool IsPTXVectorType(MVT VT) {
150150
}
151151
}
152152

153-
static bool Isv2f16Orv2bf16Orv2i16Type(EVT VT) {
154-
return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
155-
}
156-
157153
static bool Is16bitsType(MVT VT) {
158154
return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16 ||
159155
VT.SimpleTy == MVT::i16);
@@ -1372,7 +1368,7 @@ NVPTXTargetLowering::getPreferredVectorAction(MVT VT) const {
13721368
if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
13731369
VT.getScalarType() == MVT::i1)
13741370
return TypeSplitVector;
1375-
if (Isv2f16Orv2bf16Orv2i16Type(VT))
1371+
if (Isv2x16VT(VT))
13761372
return TypeLegal;
13771373
return TargetLoweringBase::getPreferredVectorAction(VT);
13781374
}
@@ -2153,7 +2149,7 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21532149
SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
21542150
SelectionDAG &DAG) const {
21552151
EVT VT = Op->getValueType(0);
2156-
if (!(Isv2f16Orv2bf16Orv2i16Type(VT)))
2152+
if (!(Isv2x16VT(VT)))
21572153
return Op;
21582154
APInt E0;
21592155
APInt E1;
@@ -2192,8 +2188,7 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
21922188
// Extract individual elements and select one of them.
21932189
SDValue Vector = Op->getOperand(0);
21942190
EVT VectorVT = Vector.getValueType();
2195-
assert((VectorVT == MVT::v2f16 || VectorVT == MVT::v2i16) &&
2196-
"Unexpected vector type.");
2191+
assert(Isv2x16VT(VectorVT) && "Unexpected vector type.");
21972192
EVT EltVT = VectorVT.getVectorElementType();
21982193

21992194
SDLoc dl(Op.getNode());
@@ -2571,9 +2566,9 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
25712566
if (Op.getValueType() == MVT::i1)
25722567
return LowerLOADi1(Op, DAG);
25732568

2574-
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
2575-
// loads and have to handle it here.
2576-
if (Isv2f16Orv2bf16Orv2i16Type(Op.getValueType())) {
2569+
// v2f16/v2bf16/v2i16 are legal, so we can't rely on legalizer to handle
2570+
// unaligned loads and have to handle it here.
2571+
if (Isv2x16VT(Op.getValueType())) {
25772572
LoadSDNode *Load = cast<LoadSDNode>(Op);
25782573
EVT MemVT = Load->getMemoryVT();
25792574
if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
@@ -2618,13 +2613,13 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
26182613

26192614
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
26202615
// stores and have to handle it here.
2621-
if (Isv2f16Orv2bf16Orv2i16Type(VT) &&
2616+
if (Isv2x16VT(VT) &&
26222617
!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
26232618
VT, *Store->getMemOperand()))
26242619
return expandUnalignedStore(Store, DAG);
26252620

26262621
// v2f16, v2bf16 and v2i16 don't need special handling.
2627-
if (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16)
2622+
if (Isv2x16VT(VT))
26282623
return SDValue();
26292624

26302625
if (VT.isVector())
@@ -2896,7 +2891,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
28962891
EVT LoadVT = EltVT;
28972892
if (EltVT == MVT::i1)
28982893
LoadVT = MVT::i8;
2899-
else if (Isv2f16Orv2bf16Orv2i16Type(EltVT))
2894+
else if (Isv2x16VT(EltVT))
29002895
// getLoad needs a vector type, but it can't handle
29012896
// vectors which contain v2f16 or v2bf16 elements. So we must load
29022897
// using i32 here and then bitcast back.
@@ -2922,7 +2917,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
29222917
if (EltVT == MVT::i1)
29232918
Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
29242919
// v2f16 was loaded as an i32. Now we must bitcast it back.
2925-
else if (Isv2f16Orv2bf16Orv2i16Type(EltVT))
2920+
else if (Isv2x16VT(EltVT))
29262921
Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
29272922

29282923
// If a promoted integer type is used, truncate down to the original
@@ -5335,7 +5330,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
53355330

53365331
unsigned Opcode = 0;
53375332
SDVTList LdResVTs;
5338-
bool LoadF16x2 = false;
5333+
bool Load16x2 = false;
53395334

53405335
switch (NumElts) {
53415336
default:
@@ -5355,7 +5350,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
53555350
// instruction. Instead, we split the vector into v2f16 chunks and
53565351
// load them with ld.v4.b32.
53575352
assert(Is16bitsType(EltVT.getSimpleVT()) && "Unsupported v8 vector type.");
5358-
LoadF16x2 = true;
5353+
Load16x2 = true;
53595354
Opcode = NVPTXISD::LoadV4;
53605355
EVT VVT;
53615356
if (EltVT == MVT::f16)
@@ -5382,7 +5377,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
53825377
LD->getMemOperand());
53835378

53845379
SmallVector<SDValue, 8> ScalarRes;
5385-
if (LoadF16x2) {
5380+
if (Load16x2) {
53865381
// Split v2f16 subvectors back into individual elements.
53875382
NumElts /= 2;
53885383
for (unsigned i = 0; i < NumElts; ++i) {

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,7 @@ class NVPTXTargetLowering : public TargetLowering {
617617
Align getArgumentAlignment(SDValue Callee, const CallBase *CB, Type *Ty,
618618
unsigned Idx, const DataLayout &DL) const;
619619
};
620+
620621
} // namespace llvm
621622

622623
#endif

llvm/lib/Target/NVPTX/NVPTXUtilities.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,4 +348,8 @@ bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM) {
348348
!isKernelFunction(*F);
349349
}
350350

351+
bool Isv2x16VT(EVT VT) {
352+
return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
353+
}
354+
351355
} // namespace llvm

llvm/lib/Target/NVPTX/NVPTXUtilities.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef LLVM_LIB_TARGET_NVPTX_NVPTXUTILITIES_H
1414
#define LLVM_LIB_TARGET_NVPTX_NVPTXUTILITIES_H
1515

16+
#include "llvm/CodeGen/ValueTypes.h"
1617
#include "llvm/IR/Function.h"
1718
#include "llvm/IR/GlobalVariable.h"
1819
#include "llvm/IR/IntrinsicInst.h"
@@ -74,6 +75,8 @@ inline unsigned promoteScalarArgumentSize(unsigned size) {
7475
}
7576

7677
bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);
78+
79+
bool Isv2x16VT(EVT VT);
7780
}
7881

7982
#endif

0 commit comments

Comments
 (0)