@@ -6971,7 +6971,7 @@ static bool hasPassthruOp(unsigned Opcode) {
6971
6971
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
6972
6972
"not a RISC-V target specific op");
6973
6973
static_assert(
6974
- RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 &&
6974
+ RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 &&
6975
6975
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
6976
6976
"adding target specific op should update this function");
6977
6977
if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::VFMAX_VL)
@@ -6995,7 +6995,7 @@ static bool hasMaskOp(unsigned Opcode) {
6995
6995
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
6996
6996
"not a RISC-V target specific op");
6997
6997
static_assert(
6998
- RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 &&
6998
+ RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 &&
6999
6999
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
7000
7000
"adding target specific op should update this function");
7001
7001
if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL)
@@ -18101,6 +18101,118 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
18101
18101
DAG.getBuildVector(VT, DL, RHSOps));
18102
18102
}
18103
18103
18104
+ static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
18105
+ const SDLoc &DL, SelectionDAG &DAG,
18106
+ const RISCVSubtarget &Subtarget) {
18107
+ assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
18108
+ RISCVISD::VQDOTSU_VL == Opc);
18109
+ MVT VT = Op0.getSimpleValueType();
18110
+ assert(VT == Op1.getSimpleValueType() &&
18111
+ VT.getVectorElementType() == MVT::i32);
18112
+
18113
+ assert(VT.isFixedLengthVector());
18114
+ MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
18115
+ SDValue Passthru = convertToScalableVector(
18116
+ ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
18117
+ Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
18118
+ Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
18119
+
18120
+ auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
18121
+ const unsigned Policy = RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC;
18122
+ SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
18123
+ SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
18124
+ {Op0, Op1, Passthru, Mask, VL, PolicyOp});
18125
+ return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
18126
+ }
18127
+
18128
+ static MVT getQDOTXResultType(MVT OpVT) {
18129
+ ElementCount OpEC = OpVT.getVectorElementCount();
18130
+ assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
18131
+ return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
18132
+ }
18133
+
18134
+ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
18135
+ SelectionDAG &DAG,
18136
+ const RISCVSubtarget &Subtarget,
18137
+ const RISCVTargetLowering &TLI) {
18138
+ // Note: We intentionally do not check the legality of the reduction type.
18139
+ // We want to handle the m4/m8 *src* types, and thus need to let illegal
18140
+ // intermediate types flow through here.
18141
+ if (InVec.getValueType().getVectorElementType() != MVT::i32 ||
18142
+ !InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
18143
+ return SDValue();
18144
+
18145
+ // reduce (zext a) <--> reduce (mul zext a. zext 1)
18146
+ // reduce (sext a) <--> reduce (mul sext a. sext 1)
18147
+ if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
18148
+ InVec.getOpcode() == ISD::SIGN_EXTEND) {
18149
+ SDValue A = InVec.getOperand(0);
18150
+ if (A.getValueType().getVectorElementType() != MVT::i8 ||
18151
+ !TLI.isTypeLegal(A.getValueType()))
18152
+ return SDValue();
18153
+
18154
+ MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
18155
+ A = DAG.getBitcast(ResVT, A);
18156
+ SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
18157
+
18158
+ bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
18159
+ unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
18160
+ return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
18161
+ }
18162
+
18163
+ // mul (sext, sext) -> vqdot
18164
+ // mul (zext, zext) -> vqdotu
18165
+ // mul (sext, zext) -> vqdotsu
18166
+ // mul (zext, sext) -> vqdotsu (swapped)
18167
+ // TODO: Improve .vx handling - we end up with a sub-vector insert
18168
+ // which confuses the splat pattern matching. Also, match vqdotus.vx
18169
+ if (InVec.getOpcode() != ISD::MUL)
18170
+ return SDValue();
18171
+
18172
+ SDValue A = InVec.getOperand(0);
18173
+ SDValue B = InVec.getOperand(1);
18174
+ unsigned Opc = 0;
18175
+ if (A.getOpcode() == B.getOpcode()) {
18176
+ if (A.getOpcode() == ISD::SIGN_EXTEND)
18177
+ Opc = RISCVISD::VQDOT_VL;
18178
+ else if (A.getOpcode() == ISD::ZERO_EXTEND)
18179
+ Opc = RISCVISD::VQDOTU_VL;
18180
+ else
18181
+ return SDValue();
18182
+ } else {
18183
+ if (B.getOpcode() != ISD::ZERO_EXTEND)
18184
+ std::swap(A, B);
18185
+ if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
18186
+ return SDValue();
18187
+ Opc = RISCVISD::VQDOTSU_VL;
18188
+ }
18189
+ assert(Opc);
18190
+
18191
+ if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
18192
+ A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
18193
+ !TLI.isTypeLegal(A.getValueType()))
18194
+ return SDValue();
18195
+
18196
+ MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
18197
+ A = DAG.getBitcast(ResVT, A.getOperand(0));
18198
+ B = DAG.getBitcast(ResVT, B.getOperand(0));
18199
+ return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
18200
+ }
18201
+
18202
+ static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
18203
+ const RISCVSubtarget &Subtarget,
18204
+ const RISCVTargetLowering &TLI) {
18205
+ if (!Subtarget.hasStdExtZvqdotq())
18206
+ return SDValue();
18207
+
18208
+ SDLoc DL(N);
18209
+ EVT VT = N->getValueType(0);
18210
+ SDValue InVec = N->getOperand(0);
18211
+ if (SDValue V = foldReduceOperandViaVQDOT(InVec, DL, DAG, Subtarget, TLI))
18212
+ return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, V);
18213
+ return SDValue();
18214
+ }
18215
+
18104
18216
static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
18105
18217
const RISCVSubtarget &Subtarget,
18106
18218
const RISCVTargetLowering &TLI) {
@@ -19878,8 +19990,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
19878
19990
19879
19991
return SDValue();
19880
19992
}
19881
- case ISD::CTPOP:
19882
19993
case ISD::VECREDUCE_ADD:
19994
+ if (SDValue V = performVECREDUCECombine(N, DAG, Subtarget, *this))
19995
+ return V;
19996
+ [[fallthrough]];
19997
+ case ISD::CTPOP:
19883
19998
if (SDValue V = combineToVCPOP(N, DAG, Subtarget))
19884
19999
return V;
19885
20000
break;
@@ -22401,6 +22516,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
22401
22516
NODE_NAME_CASE(RI_VUNZIP2A_VL)
22402
22517
NODE_NAME_CASE(RI_VUNZIP2B_VL)
22403
22518
NODE_NAME_CASE(RI_VEXTRACT)
22519
+ NODE_NAME_CASE(VQDOT_VL)
22520
+ NODE_NAME_CASE(VQDOTU_VL)
22521
+ NODE_NAME_CASE(VQDOTSU_VL)
22404
22522
NODE_NAME_CASE(READ_CSR)
22405
22523
NODE_NAME_CASE(WRITE_CSR)
22406
22524
NODE_NAME_CASE(SWAP_CSR)
0 commit comments