Skip to content

Commit 109e4a1

Browse files
authored
[RISCV] Handle zeroinitializer of vector tuple Type (#113995)
It doesn't make sense to add a new generic ISD to handle riscv tuple type. Instead we use `SPLAT_VECTOR` for ISD and further lower to `VMV_V_X`. Note: If there's `visitSPLAT_VECTOR` in generic DAG combiner, it needs to skip riscv vector tuple type. Stack on #114329
1 parent 4f41862 commit 109e4a1

File tree

4 files changed

+79
-1
lines changed

4 files changed

+79
-1
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1896,6 +1896,18 @@ SDValue SelectionDAGBuilder::getValueImpl(const Value *V) {
18961896
DAG.getConstant(0, getCurSDLoc(), MVT::nxv16i1));
18971897
}
18981898

1899+
if (VT.isRISCVVectorTuple()) {
1900+
assert(C->isNullValue() && "Can only zero this target type!");
1901+
return NodeMap[V] = DAG.getNode(
1902+
ISD::BITCAST, getCurSDLoc(), VT,
1903+
DAG.getNode(
1904+
ISD::SPLAT_VECTOR, getCurSDLoc(),
1905+
EVT::getVectorVT(*DAG.getContext(), MVT::i8,
1906+
VT.getSizeInBits().getKnownMinValue() / 8,
1907+
true),
1908+
DAG.getConstant(0, getCurSDLoc(), MVT::getIntegerVT(8))));
1909+
}
1910+
18991911
VectorType *VecTy = cast<VectorType>(V->getType());
19001912

19011913
// Now that we know the number and type of the elements, get that number of

llvm/lib/IR/Type.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,7 @@ static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) {
990990
Ty->getIntParameter(0);
991991
return TargetTypeInfo(
992992
ScalableVectorType::get(Type::getInt8Ty(C), TotalNumElts),
993-
TargetExtType::CanBeLocal);
993+
TargetExtType::CanBeLocal, TargetExtType::HasZeroInit);
994994
}
995995

996996
// DirectX resources

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18060,6 +18060,20 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1806018060
SDValue N0 = N->getOperand(0);
1806118061
EVT VT = N->getValueType(0);
1806218062
EVT SrcVT = N0.getValueType();
18063+
if (VT.isRISCVVectorTuple() && N0->getOpcode() == ISD::SPLAT_VECTOR) {
18064+
unsigned NF = VT.getRISCVVectorTupleNumFields();
18065+
unsigned NumScalElts = VT.getSizeInBits().getKnownMinValue() / (NF * 8);
18066+
SDValue EltVal = DAG.getConstant(0, DL, Subtarget.getXLenVT());
18067+
MVT ScalTy = MVT::getScalableVectorVT(MVT::getIntegerVT(8), NumScalElts);
18068+
18069+
SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, DL, ScalTy, EltVal);
18070+
18071+
SDValue Result = DAG.getUNDEF(VT);
18072+
for (unsigned i = 0; i < NF; ++i)
18073+
Result = DAG.getNode(RISCVISD::TUPLE_INSERT, DL, VT, Result, Splat,
18074+
DAG.getVectorIdxConstant(i, DL));
18075+
return Result;
18076+
}
1806318077
// If this is a bitcast between a MVT::v4i1/v2i1/v1i1 and an illegal integer
1806418078
// type, widen both sides to avoid a trip through memory.
1806518079
if ((SrcVT == MVT::v1i1 || SrcVT == MVT::v2i1 || SrcVT == MVT::v4i1) &&
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=riscv32 -mattr=+v \
3+
; RUN: -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=CHECK
4+
; RUN: llc -mtriple=riscv64 -mattr=+v \
5+
; RUN: -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=CHECK
6+
7+
define target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @test_tuple_zero_power_of_2() {
8+
; CHECK-LABEL: test_tuple_zero_power_of_2:
9+
; CHECK: # %bb.0: # %entry
10+
; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma
11+
; CHECK-NEXT: vmv.v.i v8, 0
12+
; CHECK-NEXT: vmv.v.i v10, 0
13+
; CHECK-NEXT: ret
14+
entry:
15+
ret target("riscv.vector.tuple", <vscale x 16 x i8>, 2) zeroinitializer
16+
}
17+
18+
define target("riscv.vector.tuple", <vscale x 16 x i8>, 3) @test_tuple_zero_non_power_of_2() {
19+
; CHECK-LABEL: test_tuple_zero_non_power_of_2:
20+
; CHECK: # %bb.0: # %entry
21+
; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma
22+
; CHECK-NEXT: vmv.v.i v8, 0
23+
; CHECK-NEXT: vmv.v.i v10, 0
24+
; CHECK-NEXT: vmv.v.i v12, 0
25+
; CHECK-NEXT: ret
26+
entry:
27+
ret target("riscv.vector.tuple", <vscale x 16 x i8>, 3) zeroinitializer
28+
}
29+
30+
define target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @test_tuple_zero_insert1(<vscale x 4 x i32> %a) {
31+
; CHECK-LABEL: test_tuple_zero_insert1:
32+
; CHECK: # %bb.0: # %entry
33+
; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma
34+
; CHECK-NEXT: vmv.v.i v10, 0
35+
; CHECK-NEXT: ret
36+
entry:
37+
%1 = call target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @llvm.riscv.tuple.insert.triscv.vector.tuple_nxv16i8_2t.nxv4i32(target("riscv.vector.tuple", <vscale x 16 x i8>, 2) zeroinitializer, <vscale x 4 x i32> %a, i32 0)
38+
ret target("riscv.vector.tuple", <vscale x 16 x i8>, 2) %1
39+
}
40+
41+
define target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @test_tuple_zero_insert2(<vscale x 4 x i32> %a) {
42+
; CHECK-LABEL: test_tuple_zero_insert2:
43+
; CHECK: # %bb.0: # %entry
44+
; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma
45+
; CHECK-NEXT: vmv.v.i v6, 0
46+
; CHECK-NEXT: vmv2r.v v10, v8
47+
; CHECK-NEXT: vmv2r.v v8, v6
48+
; CHECK-NEXT: ret
49+
entry:
50+
%1 = call target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @llvm.riscv.tuple.insert.triscv.vector.tuple_nxv16i8_2t.nxv4i32(target("riscv.vector.tuple", <vscale x 16 x i8>, 2) zeroinitializer, <vscale x 4 x i32> %a, i32 1)
51+
ret target("riscv.vector.tuple", <vscale x 16 x i8>, 2) %1
52+
}

0 commit comments

Comments
 (0)