Skip to content

Commit 1ac489c

Browse files
authored
[RISCV] Initial codegen support for zvqdotq extension (#137039)
This patch adds pattern matching for the basic usages of the dot product instructions introduced by the experimental zvqdotq extension. It specifically only handles the case where the pattern is feeding a i32 sum reduction as we need to reassociate the reduction tree to use these instructions. The vecreduce_add (sext) and vecreduce_add (zext) cases are included mostly to exercise the VX matchers. For the generic matching, we fail to match due to an order of combine issue which results in the bitcast being separated from the splat. I chose to do this lowering as an early combine so as to avoid having to integrate the entire logic into the reduction lowering flow. In particular, that would get a lot more complicated as we extend this to handle add-trees feeding the reductions.
1 parent b9d6cbd commit 1ac489c

File tree

4 files changed

+310
-68
lines changed

4 files changed

+310
-68
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 121 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6971,7 +6971,7 @@ static bool hasPassthruOp(unsigned Opcode) {
69716971
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
69726972
"not a RISC-V target specific op");
69736973
static_assert(
6974-
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 &&
6974+
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 &&
69756975
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
69766976
"adding target specific op should update this function");
69776977
if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::VFMAX_VL)
@@ -6995,7 +6995,7 @@ static bool hasMaskOp(unsigned Opcode) {
69956995
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
69966996
"not a RISC-V target specific op");
69976997
static_assert(
6998-
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 &&
6998+
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 &&
69996999
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
70007000
"adding target specific op should update this function");
70017001
if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL)
@@ -18101,6 +18101,118 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
1810118101
DAG.getBuildVector(VT, DL, RHSOps));
1810218102
}
1810318103

18104+
static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
18105+
const SDLoc &DL, SelectionDAG &DAG,
18106+
const RISCVSubtarget &Subtarget) {
18107+
assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
18108+
RISCVISD::VQDOTSU_VL == Opc);
18109+
MVT VT = Op0.getSimpleValueType();
18110+
assert(VT == Op1.getSimpleValueType() &&
18111+
VT.getVectorElementType() == MVT::i32);
18112+
18113+
assert(VT.isFixedLengthVector());
18114+
MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
18115+
SDValue Passthru = convertToScalableVector(
18116+
ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
18117+
Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
18118+
Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
18119+
18120+
auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
18121+
const unsigned Policy = RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC;
18122+
SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
18123+
SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
18124+
{Op0, Op1, Passthru, Mask, VL, PolicyOp});
18125+
return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
18126+
}
18127+
18128+
static MVT getQDOTXResultType(MVT OpVT) {
18129+
ElementCount OpEC = OpVT.getVectorElementCount();
18130+
assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
18131+
return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
18132+
}
18133+
18134+
static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
18135+
SelectionDAG &DAG,
18136+
const RISCVSubtarget &Subtarget,
18137+
const RISCVTargetLowering &TLI) {
18138+
// Note: We intentionally do not check the legality of the reduction type.
18139+
// We want to handle the m4/m8 *src* types, and thus need to let illegal
18140+
// intermediate types flow through here.
18141+
if (InVec.getValueType().getVectorElementType() != MVT::i32 ||
18142+
!InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
18143+
return SDValue();
18144+
18145+
// reduce (zext a) <--> reduce (mul zext a. zext 1)
18146+
// reduce (sext a) <--> reduce (mul sext a. sext 1)
18147+
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
18148+
InVec.getOpcode() == ISD::SIGN_EXTEND) {
18149+
SDValue A = InVec.getOperand(0);
18150+
if (A.getValueType().getVectorElementType() != MVT::i8 ||
18151+
!TLI.isTypeLegal(A.getValueType()))
18152+
return SDValue();
18153+
18154+
MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
18155+
A = DAG.getBitcast(ResVT, A);
18156+
SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
18157+
18158+
bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
18159+
unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
18160+
return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
18161+
}
18162+
18163+
// mul (sext, sext) -> vqdot
18164+
// mul (zext, zext) -> vqdotu
18165+
// mul (sext, zext) -> vqdotsu
18166+
// mul (zext, sext) -> vqdotsu (swapped)
18167+
// TODO: Improve .vx handling - we end up with a sub-vector insert
18168+
// which confuses the splat pattern matching. Also, match vqdotus.vx
18169+
if (InVec.getOpcode() != ISD::MUL)
18170+
return SDValue();
18171+
18172+
SDValue A = InVec.getOperand(0);
18173+
SDValue B = InVec.getOperand(1);
18174+
unsigned Opc = 0;
18175+
if (A.getOpcode() == B.getOpcode()) {
18176+
if (A.getOpcode() == ISD::SIGN_EXTEND)
18177+
Opc = RISCVISD::VQDOT_VL;
18178+
else if (A.getOpcode() == ISD::ZERO_EXTEND)
18179+
Opc = RISCVISD::VQDOTU_VL;
18180+
else
18181+
return SDValue();
18182+
} else {
18183+
if (B.getOpcode() != ISD::ZERO_EXTEND)
18184+
std::swap(A, B);
18185+
if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
18186+
return SDValue();
18187+
Opc = RISCVISD::VQDOTSU_VL;
18188+
}
18189+
assert(Opc);
18190+
18191+
if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
18192+
A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
18193+
!TLI.isTypeLegal(A.getValueType()))
18194+
return SDValue();
18195+
18196+
MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
18197+
A = DAG.getBitcast(ResVT, A.getOperand(0));
18198+
B = DAG.getBitcast(ResVT, B.getOperand(0));
18199+
return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
18200+
}
18201+
18202+
static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
18203+
const RISCVSubtarget &Subtarget,
18204+
const RISCVTargetLowering &TLI) {
18205+
if (!Subtarget.hasStdExtZvqdotq())
18206+
return SDValue();
18207+
18208+
SDLoc DL(N);
18209+
EVT VT = N->getValueType(0);
18210+
SDValue InVec = N->getOperand(0);
18211+
if (SDValue V = foldReduceOperandViaVQDOT(InVec, DL, DAG, Subtarget, TLI))
18212+
return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, V);
18213+
return SDValue();
18214+
}
18215+
1810418216
static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
1810518217
const RISCVSubtarget &Subtarget,
1810618218
const RISCVTargetLowering &TLI) {
@@ -19878,8 +19990,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1987819990

1987919991
return SDValue();
1988019992
}
19881-
case ISD::CTPOP:
1988219993
case ISD::VECREDUCE_ADD:
19994+
if (SDValue V = performVECREDUCECombine(N, DAG, Subtarget, *this))
19995+
return V;
19996+
[[fallthrough]];
19997+
case ISD::CTPOP:
1988319998
if (SDValue V = combineToVCPOP(N, DAG, Subtarget))
1988419999
return V;
1988520000
break;
@@ -22401,6 +22516,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
2240122516
NODE_NAME_CASE(RI_VUNZIP2A_VL)
2240222517
NODE_NAME_CASE(RI_VUNZIP2B_VL)
2240322518
NODE_NAME_CASE(RI_VEXTRACT)
22519+
NODE_NAME_CASE(VQDOT_VL)
22520+
NODE_NAME_CASE(VQDOTU_VL)
22521+
NODE_NAME_CASE(VQDOTSU_VL)
2240422522
NODE_NAME_CASE(READ_CSR)
2240522523
NODE_NAME_CASE(WRITE_CSR)
2240622524
NODE_NAME_CASE(SWAP_CSR)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,12 @@ enum NodeType : unsigned {
416416
RI_VUNZIP2A_VL,
417417
RI_VUNZIP2B_VL,
418418

419-
LAST_VL_VECTOR_OP = RI_VUNZIP2B_VL,
419+
// zvqdot instructions with additional passthru, mask and VL operands
420+
VQDOT_VL,
421+
VQDOTU_VL,
422+
VQDOTSU_VL,
423+
424+
LAST_VL_VECTOR_OP = VQDOTSU_VL,
420425

421426
// XRivosVisni
422427
// VEXTRACT matches the semantics of ri.vextract.x.v. The result is always

llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,34 @@ let Predicates = [HasStdExtZvqdotq] in {
2626
def VQDOTSU_VX : VALUVX<0b101010, OPMVX, "vqdotsu.vx">;
2727
def VQDOTUS_VX : VALUVX<0b101110, OPMVX, "vqdotus.vx">;
2828
} // Predicates = [HasStdExtZvqdotq]
29+
30+
31+
def riscv_vqdot_vl : SDNode<"RISCVISD::VQDOT_VL", SDT_RISCVIntBinOp_VL>;
32+
def riscv_vqdotu_vl : SDNode<"RISCVISD::VQDOTU_VL", SDT_RISCVIntBinOp_VL>;
33+
def riscv_vqdotsu_vl : SDNode<"RISCVISD::VQDOTSU_VL", SDT_RISCVIntBinOp_VL>;
34+
35+
multiclass VPseudoVQDOT_VV_VX {
36+
foreach m = MxSet<32>.m in {
37+
defm "" : VPseudoBinaryV_VV<m>,
38+
SchedBinary<"WriteVIALUV", "ReadVIALUV", "ReadVIALUV", m.MX,
39+
forcePassthruRead=true>;
40+
defm "" : VPseudoBinaryV_VX<m>,
41+
SchedBinary<"WriteVIALUX", "ReadVIALUV", "ReadVIALUX", m.MX,
42+
forcePassthruRead=true>;
43+
}
44+
}
45+
46+
// TODO: Add pseudo and patterns for vqdotus.vx
47+
// TODO: Add isCommutable for VQDOT and VQDOTU
48+
let Predicates = [HasStdExtZvqdotq], mayLoad = 0, mayStore = 0,
49+
hasSideEffects = 0 in {
50+
defm PseudoVQDOT : VPseudoVQDOT_VV_VX;
51+
defm PseudoVQDOTU : VPseudoVQDOT_VV_VX;
52+
defm PseudoVQDOTSU : VPseudoVQDOT_VV_VX;
53+
}
54+
55+
defvar AllE32Vectors = [VI32MF2, VI32M1, VI32M2, VI32M4, VI32M8];
56+
defm : VPatBinaryVL_VV_VX<riscv_vqdot_vl, "PseudoVQDOT", AllE32Vectors>;
57+
defm : VPatBinaryVL_VV_VX<riscv_vqdotu_vl, "PseudoVQDOTU", AllE32Vectors>;
58+
defm : VPatBinaryVL_VV_VX<riscv_vqdotsu_vl, "PseudoVQDOTSU", AllE32Vectors>;
59+

0 commit comments

Comments
 (0)