Skip to content

Commit 7a0b9da

Browse files
authored
[RISCV] Generaize reduction tree matching to all integer reductions (#68014)
This builds on the transform introduced in #67821, and generalizes it for all integer reduction types. A couple of notes: * This will only form smax/smin/umax/umin reductions when zbb is enabled. Otherwise, we lower the min/max expressions early. I don't care about this case, and don't plan to address this further. * This excludes floating point. Floating point introduces concerns about associativity. I may or may not do a follow up patch for that case. * The explodevector test change is mildly undesirable from a clarity perspective. If anyone sees a good way to rewrite that to stablize the test, please suggest.
1 parent 8092933 commit 7a0b9da

File tree

3 files changed

+867
-646
lines changed

3 files changed

+867
-646
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11106,6 +11106,31 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
1110611106
}
1110711107
}
1110811108

11109+
/// Given an integer binary operator, return the generic ISD::VECREDUCE_OP
11110+
/// which corresponds to it.
11111+
static unsigned getVecReduceOpcode(unsigned Opc) {
11112+
switch (Opc) {
11113+
default:
11114+
llvm_unreachable("Unhandled binary to transfrom reduction");
11115+
case ISD::ADD:
11116+
return ISD::VECREDUCE_ADD;
11117+
case ISD::UMAX:
11118+
return ISD::VECREDUCE_UMAX;
11119+
case ISD::SMAX:
11120+
return ISD::VECREDUCE_SMAX;
11121+
case ISD::UMIN:
11122+
return ISD::VECREDUCE_UMIN;
11123+
case ISD::SMIN:
11124+
return ISD::VECREDUCE_SMIN;
11125+
case ISD::AND:
11126+
return ISD::VECREDUCE_AND;
11127+
case ISD::OR:
11128+
return ISD::VECREDUCE_OR;
11129+
case ISD::XOR:
11130+
return ISD::VECREDUCE_XOR;
11131+
}
11132+
};
11133+
1110911134
/// Perform two related transforms whose purpose is to incrementally recognize
1111011135
/// an explode_vector followed by scalar reduction as a vector reduction node.
1111111136
/// This exists to recover from a deficiency in SLP which can't handle
@@ -11124,8 +11149,15 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
1112411149

1112511150
const SDLoc DL(N);
1112611151
const EVT VT = N->getValueType(0);
11127-
[[maybe_unused]] const unsigned Opc = N->getOpcode();
11128-
assert(Opc == ISD::ADD && "extend this to other reduction types");
11152+
11153+
// TODO: Handle floating point here.
11154+
if (!VT.isInteger())
11155+
return SDValue();
11156+
11157+
const unsigned Opc = N->getOpcode();
11158+
const unsigned ReduceOpc = getVecReduceOpcode(Opc);
11159+
assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
11160+
"Inconsistent mappings");
1112911161
const SDValue LHS = N->getOperand(0);
1113011162
const SDValue RHS = N->getOperand(1);
1113111163

@@ -11155,13 +11187,13 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
1115511187
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
1115611188
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
1115711189
DAG.getVectorIdxConstant(0, DL));
11158-
return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, Vec);
11190+
return DAG.getNode(ReduceOpc, DL, VT, Vec);
1115911191
}
1116011192

1116111193
// Match (binop (reduce (extract_subvector V, 0),
1116211194
// (extract_vector_elt V, sizeof(SubVec))))
1116311195
// into a reduction of one more element from the original vector V.
11164-
if (LHS.getOpcode() != ISD::VECREDUCE_ADD)
11196+
if (LHS.getOpcode() != ReduceOpc)
1116511197
return SDValue();
1116611198

1116711199
SDValue ReduceVec = LHS.getOperand(0);
@@ -11177,7 +11209,7 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
1117711209
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
1117811210
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
1117911211
DAG.getVectorIdxConstant(0, DL));
11180-
return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, Vec);
11212+
return DAG.getNode(ReduceOpc, DL, VT, Vec);
1118111213
}
1118211214
}
1118311215

@@ -11685,6 +11717,8 @@ static SDValue performANDCombine(SDNode *N,
1168511717

1168611718
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
1168711719
return V;
11720+
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
11721+
return V;
1168811722

1168911723
if (DCI.isAfterLegalizeDAG())
1169011724
if (SDValue V = combineDeMorganOfBoolean(N, DAG))
@@ -11737,6 +11771,8 @@ static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
1173711771

1173811772
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
1173911773
return V;
11774+
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
11775+
return V;
1174011776

1174111777
if (DCI.isAfterLegalizeDAG())
1174211778
if (SDValue V = combineDeMorganOfBoolean(N, DAG))
@@ -11788,6 +11824,9 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG,
1178811824

1178911825
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
1179011826
return V;
11827+
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
11828+
return V;
11829+
1179111830
// fold (xor (select cond, 0, y), x) ->
1179211831
// (select cond, x, (xor x, y))
1179311832
return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);
@@ -13993,8 +14032,13 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1399314032
case ISD::SMAX:
1399414033
case ISD::SMIN:
1399514034
case ISD::FMAXNUM:
13996-
case ISD::FMINNUM:
13997-
return combineBinOpToReduce(N, DAG, Subtarget);
14035+
case ISD::FMINNUM: {
14036+
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
14037+
return V;
14038+
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
14039+
return V;
14040+
return SDValue();
14041+
}
1399814042
case ISD::SETCC:
1399914043
return performSETCCCombine(N, DAG, Subtarget);
1400014044
case ISD::SIGN_EXTEND_INREG:

0 commit comments

Comments
 (0)