Skip to content

[NVPTX] Improve lowering of v4i8 #67866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 9, 2023
Merged
20 changes: 13 additions & 7 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
case MVT::v2f16:
case MVT::v2bf16:
case MVT::v2i16:
case MVT::v4i8:
return Opcode_i32;
case MVT::f32:
return Opcode_f32;
Expand Down Expand Up @@ -910,7 +911,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
// Vector Setting
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
if (SimpleVT.isVector()) {
assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
"Unexpected vector type");
// v2f16/v2bf16/v2i16 is loaded using ld.b32
fromTypeWidth = 32;
}
Expand Down Expand Up @@ -1254,19 +1256,23 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
SDLoc DL(N);
SDNode *LD;
SDValue Base, Offset, Addr;
EVT OrigType = N->getValueType(0);

EVT EltVT = Mem->getMemoryVT();
unsigned NumElts = 1;
if (EltVT.isVector()) {
NumElts = EltVT.getVectorNumElements();
EltVT = EltVT.getVectorElementType();
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
(EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
(EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
(EltVT == MVT::i16 && OrigType == MVT::v2i16)) {
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
EltVT = N->getValueType(0);
EltVT = OrigType;
NumElts /= 2;
} else if (OrigType == MVT::v4i8) {
EltVT = OrigType;
NumElts = 1;
}
}

Expand Down Expand Up @@ -1601,7 +1607,6 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
// concept of sign-/zero-extension, so emulate it here by adding an explicit
// CVT instruction. Ptxas should clean up any redundancies here.

EVT OrigType = N->getValueType(0);
LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);

if (OrigType != EltVT &&
Expand Down Expand Up @@ -1679,7 +1684,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
MVT ScalarVT = SimpleVT.getScalarType();
unsigned toTypeWidth = ScalarVT.getSizeInBits();
if (SimpleVT.isVector()) {
assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
"Unexpected vector type");
// v2x16 is stored using st.b32
toTypeWidth = 32;
}
Expand Down
20 changes: 15 additions & 5 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
llvm_unreachable("Unexpected type");
}
NumElts /= 2;
} else if (EltVT.getSimpleVT() == MVT::i8 &&
(NumElts % 4 == 0 || NumElts == 3)) {
// v*i8 are formally lowered as v4i8
EltVT = MVT::v4i8;
NumElts = (NumElts + 3) / 4;
}
for (unsigned j = 0; j != NumElts; ++j) {
ValueVTs.push_back(EltVT);
Expand Down Expand Up @@ -458,6 +463,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
Expand Down Expand Up @@ -2631,7 +2637,7 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
return expandUnalignedStore(Store, DAG);

// v2f16, v2bf16 and v2i16 don't need special handling.
if (Isv2x16VT(VT))
if (Isv2x16VT(VT) || VT == MVT::v4i8)
return SDValue();

if (VT.isVector())
Expand Down Expand Up @@ -2903,7 +2909,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
EVT LoadVT = EltVT;
if (EltVT == MVT::i1)
LoadVT = MVT::i8;
else if (Isv2x16VT(EltVT))
else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
// getLoad needs a vector type, but it can't handle
// vectors which contain v2f16 or v2bf16 elements. So we must load
// using i32 here and then bitcast back.
Expand All @@ -2929,7 +2935,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (EltVT == MVT::i1)
Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
// v2f16 was loaded as an i32. Now we must bitcast it back.
else if (Isv2x16VT(EltVT))
else if (EltVT != LoadVT)
Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);

// If a promoted integer type is used, truncate down to the original
Expand Down Expand Up @@ -5256,9 +5262,9 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
SDValue Vector = N->getOperand(0);
EVT VectorVT = Vector.getValueType();
if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
IsPTXVectorType(VectorVT.getSimpleVT()))
IsPTXVectorType(VectorVT.getSimpleVT()) && VectorVT != MVT::v4i8)
return SDValue(); // Native vector loads already combine nicely w/
// extract_vector_elt.
// extract_vector_elt, except for v4i8.
// Don't mess with singletons or v2*16 types, we already handle them OK.
if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT))
return SDValue();
Expand Down Expand Up @@ -5289,6 +5295,10 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
// If element has non-integer type, bitcast it back to the expected type.
if (EltVT != EltIVT)
Result = DCI.DAG.getNode(ISD::BITCAST, DL, EltVT, Result);
// Past legalizer, we may need to extent i8 -> i16 to match the register type.
if (EltVT != N->getValueType(0))
Result = DCI.DAG.getNode(ISD::ANY_EXTEND, DL, N->getValueType(0), Result);

return Result;
}

Expand Down
52 changes: 32 additions & 20 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1486,23 +1486,24 @@ defm OR : BITWISE<"or", or>;
defm AND : BITWISE<"and", and>;
defm XOR : BITWISE<"xor", xor>;

// Lower logical v2i16 ops as bitwise ops on b32.
def: Pat<(or (v2i16 Int32Regs:$a), (v2i16 Int32Regs:$b)),
(ORb32rr Int32Regs:$a, Int32Regs:$b)>;
def: Pat<(xor (v2i16 Int32Regs:$a), (v2i16 Int32Regs:$b)),
(XORb32rr Int32Regs:$a, Int32Regs:$b)>;
def: Pat<(and (v2i16 Int32Regs:$a), (v2i16 Int32Regs:$b)),
(ANDb32rr Int32Regs:$a, Int32Regs:$b)>;

// The constants get legalized into a bitcast from i32, so that's what we need
// to match here.
def: Pat<(or Int32Regs:$a, (v2i16 (bitconvert (i32 imm:$b)))),
(ORb32ri Int32Regs:$a, imm:$b)>;
def: Pat<(xor Int32Regs:$a, (v2i16 (bitconvert (i32 imm:$b)))),
(XORb32ri Int32Regs:$a, imm:$b)>;
def: Pat<(and Int32Regs:$a, (v2i16 (bitconvert (i32 imm:$b)))),
(ANDb32ri Int32Regs:$a, imm:$b)>;

// Lower logical v2i16/v4i8 ops as bitwise ops on b32.
foreach vt = [v2i16, v4i8] in {
def: Pat<(or (vt Int32Regs:$a), (vt Int32Regs:$b)),
(ORb32rr Int32Regs:$a, Int32Regs:$b)>;
def: Pat<(xor (vt Int32Regs:$a), (vt Int32Regs:$b)),
(XORb32rr Int32Regs:$a, Int32Regs:$b)>;
def: Pat<(and (vt Int32Regs:$a), (vt Int32Regs:$b)),
(ANDb32rr Int32Regs:$a, Int32Regs:$b)>;

// The constants get legalized into a bitcast from i32, so that's what we need
// to match here.
def: Pat<(or Int32Regs:$a, (vt (bitconvert (i32 imm:$b)))),
(ORb32ri Int32Regs:$a, imm:$b)>;
def: Pat<(xor Int32Regs:$a, (vt (bitconvert (i32 imm:$b)))),
(XORb32ri Int32Regs:$a, imm:$b)>;
def: Pat<(and Int32Regs:$a, (vt (bitconvert (i32 imm:$b)))),
(ANDb32ri Int32Regs:$a, imm:$b)>;
}

def NOT1 : NVPTXInst<(outs Int1Regs:$dst), (ins Int1Regs:$src),
"not.pred \t$dst, $src;",
Expand Down Expand Up @@ -2682,7 +2683,7 @@ foreach vt = [f16, bf16] in {
def: Pat<(vt (ProxyReg vt:$src)), (ProxyRegI16 Int16Regs:$src)>;
}

foreach vt = [v2f16, v2bf16, v2i16] in {
foreach vt = [v2f16, v2bf16, v2i16, v4i8] in {
def: Pat<(vt (ProxyReg vt:$src)), (ProxyRegI32 Int32Regs:$src)>;
}

Expand Down Expand Up @@ -2995,8 +2996,8 @@ def: Pat<(i16 (bitconvert (vt Int16Regs:$a))),
(ProxyRegI16 Int16Regs:$a)>;
}

foreach ta = [v2f16, v2bf16, v2i16, i32] in {
foreach tb = [v2f16, v2bf16, v2i16, i32] in {
foreach ta = [v2f16, v2bf16, v2i16, v4i8, i32] in {
foreach tb = [v2f16, v2bf16, v2i16, v4i8, i32] in {
if !ne(ta, tb) then {
def: Pat<(ta (bitconvert (tb Int32Regs:$a))),
(ProxyRegI32 Int32Regs:$a)>;
Expand Down Expand Up @@ -3292,6 +3293,10 @@ let hasSideEffects = false in {
(ins Int16Regs:$s1, Int16Regs:$s2,
Int16Regs:$s3, Int16Regs:$s4),
"mov.b64 \t$d, {{$s1, $s2, $s3, $s4}};", []>;
def V4I8toI32 : NVPTXInst<(outs Int32Regs:$d),
(ins Int16Regs:$s1, Int16Regs:$s2,
Int16Regs:$s3, Int16Regs:$s4),
"mov.b32 \t$d, {{$s1, $s2, $s3, $s4}};", []>;
def V2I16toI32 : NVPTXInst<(outs Int32Regs:$d),
(ins Int16Regs:$s1, Int16Regs:$s2),
"mov.b32 \t$d, {{$s1, $s2}};", []>;
Expand All @@ -3307,6 +3312,10 @@ let hasSideEffects = false in {
Int16Regs:$d3, Int16Regs:$d4),
(ins Int64Regs:$s),
"mov.b64 \t{{$d1, $d2, $d3, $d4}}, $s;", []>;
def I32toV4I8 : NVPTXInst<(outs Int16Regs:$d1, Int16Regs:$d2,
Int16Regs:$d3, Int16Regs:$d4),
(ins Int32Regs:$s),
"mov.b32 \t{{$d1, $d2, $d3, $d4}}, $s;", []>;
def I32toV2I16 : NVPTXInst<(outs Int16Regs:$d1, Int16Regs:$d2),
(ins Int32Regs:$s),
"mov.b32 \t{{$d1, $d2}}, $s;", []>;
Expand Down Expand Up @@ -3354,6 +3363,9 @@ def : Pat<(v2bf16 (build_vector (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))),
(V2I16toI32 Int16Regs:$a, Int16Regs:$b)>;
def : Pat<(v2i16 (build_vector (i16 Int16Regs:$a), (i16 Int16Regs:$b))),
(V2I16toI32 Int16Regs:$a, Int16Regs:$b)>;
def : Pat<(v4i8 (build_vector (i16 Int16Regs:$a), (i16 Int16Regs:$b),
(i16 Int16Regs:$c), (i16 Int16Regs:$d))),
(V4I8toI32 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c, Int16Regs:$d)>;

// Count leading zeros
let hasSideEffects = false in {
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ foreach i = 0...31 in {
//===----------------------------------------------------------------------===//
def Int1Regs : NVPTXRegClass<[i1], 8, (add (sequence "P%u", 0, 4))>;
def Int16Regs : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4))>;
def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16], 32,
def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8], 32,
(add (sequence "R%u", 0, 4),
VRFrame32, VRFrameLocal32)>;
def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/NVPTX/load-with-non-coherent-cache.ll
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ define void @foo12(ptr noalias readonly %from, ptr %to) {
}

; SM20-LABEL: .visible .entry foo13(
; SM20: ld.global.v4.u8
; SM20: ld.global.u32
; SM35-LABEL: .visible .entry foo13(
; SM35: ld.global.nc.v4.u8
; SM35: ld.global.nc.u32
define void @foo13(ptr noalias readonly %from, ptr %to) {
%1 = load <4 x i8>, ptr %from
store <4 x i8> %1, ptr %to
Expand Down
26 changes: 12 additions & 14 deletions llvm/test/CodeGen/NVPTX/param-load-store.ll
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,16 @@ define signext i8 @test_i8s(i8 signext %a) {
; CHECK: .func (.param .align 4 .b8 func_retval0[4])
; CHECK-LABEL: test_v3i8(
; CHECK-NEXT: .param .align 4 .b8 test_v3i8_param_0[4]
; CHECK-DAG: ld.param.u8 [[E2:%rs[0-9]+]], [test_v3i8_param_0+2];
; CHECK-DAG: ld.param.v2.u8 {[[E0:%rs[0-9]+]], [[E1:%rs[0-9]+]]}, [test_v3i8_param_0];
; CHECK: ld.param.u32 [[R:%r[0-9]+]], [test_v3i8_param_0];
; CHECK: .param .align 4 .b8 param0[4];
; CHECK: st.param.v2.b8 [param0+0], {[[E0]], [[E1]]};
; CHECK: st.param.b8 [param0+2], [[E2]];
; CHECK: st.param.b32 [param0+0], [[R]]
; CHECK: .param .align 4 .b8 retval0[4];
; CHECK: call.uni (retval0),
; CHECK-NEXT: test_v3i8,
; CHECK-DAG: ld.param.v2.b8 {[[RE0:%rs[0-9]+]], [[RE1:%rs[0-9]+]]}, [retval0+0];
; CHECK-DAG: ld.param.b8 [[RE2:%rs[0-9]+]], [retval0+2];
; CHECK-DAG: st.param.v2.b8 [func_retval0+0], {[[RE0]], [[RE1]]};
; CHECK-DAG: st.param.b8 [func_retval0+2], [[RE2]];
; CHECK: ld.param.b32 [[RE:%r[0-9]+]], [retval0+0];
; v4i8/i32->{v3i8 elements}->v4i8/i32 conversion is messy and not very
; interesting here, so it's skipped.
; CHECK: st.param.b32 [func_retval0+0],
; CHECK-NEXT: ret;
define <3 x i8> @test_v3i8(<3 x i8> %a) {
%r = tail call <3 x i8> @test_v3i8(<3 x i8> %a);
Expand All @@ -233,14 +231,14 @@ define <3 x i8> @test_v3i8(<3 x i8> %a) {
; CHECK: .func (.param .align 4 .b8 func_retval0[4])
; CHECK-LABEL: test_v4i8(
; CHECK-NEXT: .param .align 4 .b8 test_v4i8_param_0[4]
; CHECK: ld.param.v4.u8 {[[E0:%rs[0-9]+]], [[E1:%rs[0-9]+]], [[E2:%rs[0-9]+]], [[E3:%rs[0-9]+]]}, [test_v4i8_param_0]
; CHECK: ld.param.u32 [[R:%r[0-9]+]], [test_v4i8_param_0]
; CHECK: .param .align 4 .b8 param0[4];
; CHECK: st.param.v4.b8 [param0+0], {[[E0]], [[E1]], [[E2]], [[E3]]};
; CHECK: st.param.b32 [param0+0], [[R]];
; CHECK: .param .align 4 .b8 retval0[4];
; CHECK: call.uni (retval0),
; CHECK-NEXT: test_v4i8,
; CHECK: ld.param.v4.b8 {[[RE0:%rs[0-9]+]], [[RE1:%rs[0-9]+]], [[RE2:%rs[0-9]+]], [[RE3:%rs[0-9]+]]}, [retval0+0];
; CHECK: st.param.v4.b8 [func_retval0+0], {[[RE0]], [[RE1]], [[RE2]], [[RE3]]}
; CHECK: ld.param.b32 [[RET:%r[0-9]+]], [retval0+0];
; CHECK: st.param.b32 [func_retval0+0], [[RET]];
; CHECK-NEXT: ret;
define <4 x i8> @test_v4i8(<4 x i8> %a) {
%r = tail call <4 x i8> @test_v4i8(<4 x i8> %a);
Expand All @@ -250,10 +248,10 @@ define <4 x i8> @test_v4i8(<4 x i8> %a) {
; CHECK: .func (.param .align 8 .b8 func_retval0[8])
; CHECK-LABEL: test_v5i8(
; CHECK-NEXT: .param .align 8 .b8 test_v5i8_param_0[8]
; CHECK-DAG: ld.param.u32 [[E0:%r[0-9]+]], [test_v5i8_param_0]
; CHECK-DAG: ld.param.u8 [[E4:%rs[0-9]+]], [test_v5i8_param_0+4];
; CHECK-DAG: ld.param.v4.u8 {[[E0:%rs[0-9]+]], [[E1:%rs[0-9]+]], [[E2:%rs[0-9]+]], [[E3:%rs[0-9]+]]}, [test_v5i8_param_0]
; CHECK: .param .align 8 .b8 param0[8];
; CHECK-DAG: st.param.v4.b8 [param0+0], {[[E0]], [[E1]], [[E2]], [[E3]]};
; CHECK-DAG: st.param.v4.b8 [param0+0],
; CHECK-DAG: st.param.b8 [param0+4], [[E4]];
; CHECK: .param .align 8 .b8 retval0[8];
; CHECK: call.uni (retval0),
Expand Down
Loading