Skip to content

Commit 07cddbc

Browse files
committed
[RISCV] Generaize reduction tree matching to all integer reductions
This builds on the transform introduced in llvm#67821, and generalizes it for all integer reduction types. A couple of notes: * This will only form smax/smin/umax/umin reductions when zbb is enabled. Otherwise, we lower the min/max expressions early. I don't care about this case, and don't plan to address this further. * This excludes floating point. Floating point introduces concerns about associativity. I may or may not do a follow up patch for that case. * The explodevector test change is mildly undesirable from a clarity perspective. If anyone sees a good way to rewrite that to stablize the test, please suggest.
1 parent 0e8f924 commit 07cddbc

File tree

3 files changed

+867
-646
lines changed

3 files changed

+867
-646
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11108,6 +11108,31 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
1110811108
}
1110911109
}
1111011110

11111+
/// Given an integer binary operator, return the generic ISD::VECREDUCE_OP
11112+
/// which corresponds to it.
11113+
static unsigned getVecReduceOpcode(unsigned Opc) {
11114+
switch (Opc) {
11115+
default:
11116+
llvm_unreachable("Unhandled binary to transfrom reduction");
11117+
case ISD::ADD:
11118+
return ISD::VECREDUCE_ADD;
11119+
case ISD::UMAX:
11120+
return ISD::VECREDUCE_UMAX;
11121+
case ISD::SMAX:
11122+
return ISD::VECREDUCE_SMAX;
11123+
case ISD::UMIN:
11124+
return ISD::VECREDUCE_UMIN;
11125+
case ISD::SMIN:
11126+
return ISD::VECREDUCE_SMIN;
11127+
case ISD::AND:
11128+
return ISD::VECREDUCE_AND;
11129+
case ISD::OR:
11130+
return ISD::VECREDUCE_OR;
11131+
case ISD::XOR:
11132+
return ISD::VECREDUCE_XOR;
11133+
}
11134+
};
11135+
1111111136
/// Perform two related transforms whose purpose is to incrementally recognize
1111211137
/// an explode_vector followed by scalar reduction as a vector reduction node.
1111311138
/// This exists to recover from a deficiency in SLP which can't handle
@@ -11126,8 +11151,15 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
1112611151

1112711152
const SDLoc DL(N);
1112811153
const EVT VT = N->getValueType(0);
11129-
[[maybe_unused]] const unsigned Opc = N->getOpcode();
11130-
assert(Opc == ISD::ADD && "extend this to other reduction types");
11154+
11155+
// TODO: Handle floating point here.
11156+
if (!VT.isInteger())
11157+
return SDValue();
11158+
11159+
const unsigned Opc = N->getOpcode();
11160+
const unsigned ReduceOpc = getVecReduceOpcode(Opc);
11161+
assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
11162+
"Inconsistent mappings");
1113111163
const SDValue LHS = N->getOperand(0);
1113211164
const SDValue RHS = N->getOperand(1);
1113311165

@@ -11157,13 +11189,13 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
1115711189
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
1115811190
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
1115911191
DAG.getVectorIdxConstant(0, DL));
11160-
return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, Vec);
11192+
return DAG.getNode(ReduceOpc, DL, VT, Vec);
1116111193
}
1116211194

1116311195
// Match (binop (reduce (extract_subvector V, 0),
1116411196
// (extract_vector_elt V, sizeof(SubVec))))
1116511197
// into a reduction of one more element from the original vector V.
11166-
if (LHS.getOpcode() != ISD::VECREDUCE_ADD)
11198+
if (LHS.getOpcode() != ReduceOpc)
1116711199
return SDValue();
1116811200

1116911201
SDValue ReduceVec = LHS.getOperand(0);
@@ -11179,7 +11211,7 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
1117911211
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
1118011212
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
1118111213
DAG.getVectorIdxConstant(0, DL));
11182-
return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, Vec);
11214+
return DAG.getNode(ReduceOpc, DL, VT, Vec);
1118311215
}
1118411216
}
1118511217

@@ -11687,6 +11719,8 @@ static SDValue performANDCombine(SDNode *N,
1168711719

1168811720
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
1168911721
return V;
11722+
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
11723+
return V;
1169011724

1169111725
if (DCI.isAfterLegalizeDAG())
1169211726
if (SDValue V = combineDeMorganOfBoolean(N, DAG))
@@ -11739,6 +11773,8 @@ static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
1173911773

1174011774
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
1174111775
return V;
11776+
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
11777+
return V;
1174211778

1174311779
if (DCI.isAfterLegalizeDAG())
1174411780
if (SDValue V = combineDeMorganOfBoolean(N, DAG))
@@ -11790,6 +11826,9 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG,
1179011826

1179111827
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
1179211828
return V;
11829+
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
11830+
return V;
11831+
1179311832
// fold (xor (select cond, 0, y), x) ->
1179411833
// (select cond, x, (xor x, y))
1179511834
return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);
@@ -13995,8 +14034,13 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1399514034
case ISD::SMAX:
1399614035
case ISD::SMIN:
1399714036
case ISD::FMAXNUM:
13998-
case ISD::FMINNUM:
13999-
return combineBinOpToReduce(N, DAG, Subtarget);
14037+
case ISD::FMINNUM: {
14038+
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
14039+
return V;
14040+
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
14041+
return V;
14042+
return SDValue();
14043+
}
1400014044
case ISD::SETCC:
1400114045
return performSETCCCombine(N, DAG, Subtarget);
1400214046
case ISD::SIGN_EXTEND_INREG:

0 commit comments

Comments
 (0)