-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select c, (mad a, b, c)) #96352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Alex MacLean (AlexMaclean) ChangesAdd folding for Full diff: https://github.com/llvm/llvm-project/pull/96352.diff 2 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index f4ef7c9914f13..0c609554370a3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5215,103 +5215,129 @@ bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const {
return F.getFnAttribute("unsafe-fp-math").getValueAsBool();
}
+static bool isConstZero(const SDValue &Operand) {
+ const auto *Const = dyn_cast<ConstantSDNode>(Operand);
+ return Const && Const->getZExtValue() == 0;
+}
+
/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
/// operands N0 and N1. This is a helper for PerformADDCombine that is
/// called with the default operands, and if that fails, with commuted
/// operands.
-static SDValue PerformADDCombineWithOperands(
- SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI,
- const NVPTXSubtarget &Subtarget, CodeGenOptLevel OptLevel) {
- SelectionDAG &DAG = DCI.DAG;
- // Skip non-integer, non-scalar case
- EVT VT=N0.getValueType();
- if (VT.isVector())
+static SDValue
+PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ EVT VT = N0.getValueType();
+
+ // Since integer multiply-add costs the same as integer multiply
+ // but is more costly than integer add, do the fusion only when
+ // the mul is only used in the add.
+ if (!N0.getNode()->hasOneUse())
return SDValue();
// fold (add (mul a, b), c) -> (mad a, b, c)
//
- if (N0.getOpcode() == ISD::MUL) {
- assert (VT.isInteger());
- // For integer:
- // Since integer multiply-add costs the same as integer multiply
- // but is more costly than integer add, do the fusion only when
- // the mul is only used in the add.
- if (OptLevel == CodeGenOptLevel::None || VT != MVT::i32 ||
- !N0.getNode()->hasOneUse())
+ if (N0.getOpcode() == ISD::MUL)
+ return DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT, N0.getOperand(0),
+ N0.getOperand(1), N1);
+
+ // fold (add (select cond, 0, (mul a, b)), c)
+ // -> (select cond, (mad a, b, c), c)
+ //
+ if (N0.getOpcode() == ISD::SELECT) {
+ bool ZeroCond;
+ if (isConstZero(N0->getOperand(1)))
+ ZeroCond = true;
+ else if (isConstZero(N0->getOperand(2)))
+ ZeroCond = false;
+ else
+ return SDValue();
+
+ SDValue M = N0->getOperand(ZeroCond ? 2 : 1);
+ if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
return SDValue();
- // Do the folding
- return DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
- N0.getOperand(0), N0.getOperand(1), N1);
+ SDValue MAD = DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
+ M->getOperand(0), M->getOperand(1), N1);
+ return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
+ (ZeroCond ? N1 : MAD), (ZeroCond ? MAD : N1));
}
- else if (N0.getOpcode() == ISD::FMUL) {
- if (VT == MVT::f32 || VT == MVT::f64) {
- const auto *TLI = static_cast<const NVPTXTargetLowering *>(
- &DAG.getTargetLoweringInfo());
- if (!TLI->allowFMA(DAG.getMachineFunction(), OptLevel))
- return SDValue();
- // For floating point:
- // Do the fusion only when the mul has less than 5 uses and all
- // are add.
- // The heuristic is that if a use is not an add, then that use
- // cannot be fused into fma, therefore mul is still needed anyway.
- // If there are more than 4 uses, even if they are all add, fusing
- // them will increase register pressue.
- //
- int numUses = 0;
- int nonAddCount = 0;
- for (const SDNode *User : N0.getNode()->uses()) {
- numUses++;
- if (User->getOpcode() != ISD::FADD)
- ++nonAddCount;
- }
+ return SDValue();
+}
+
+static SDValue
+PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
+ TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
+ EVT VT = N0.getValueType();
+ if (N0.getOpcode() == ISD::FMUL) {
+ const auto *TLI = static_cast<const NVPTXTargetLowering *>(
+ &DCI.DAG.getTargetLoweringInfo());
+ if (!TLI->allowFMA(DCI.DAG.getMachineFunction(), OptLevel))
+ return SDValue();
+
+ // For floating point:
+ // Do the fusion only when the mul has less than 5 uses and all
+ // are add.
+ // The heuristic is that if a use is not an add, then that use
+ // cannot be fused into fma, therefore mul is still needed anyway.
+ // If there are more than 4 uses, even if they are all add, fusing
+ // them will increase register pressue.
+ //
+ int numUses = 0;
+ int nonAddCount = 0;
+ for (const SDNode *User : N0.getNode()->uses()) {
+ numUses++;
+ if (User->getOpcode() != ISD::FADD)
+ ++nonAddCount;
if (numUses >= 5)
return SDValue();
- if (nonAddCount) {
- int orderNo = N->getIROrder();
- int orderNo2 = N0.getNode()->getIROrder();
- // simple heuristics here for considering potential register
- // pressure, the logics here is that the differnce are used
- // to measure the distance between def and use, the longer distance
- // more likely cause register pressure.
- if (orderNo - orderNo2 < 500)
- return SDValue();
-
- // Now, check if at least one of the FMUL's operands is live beyond the node N,
- // which guarantees that the FMA will not increase register pressure at node N.
- bool opIsLive = false;
- const SDNode *left = N0.getOperand(0).getNode();
- const SDNode *right = N0.getOperand(1).getNode();
-
- if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
- opIsLive = true;
-
- if (!opIsLive)
- for (const SDNode *User : left->uses()) {
- int orderNo3 = User->getIROrder();
- if (orderNo3 > orderNo) {
- opIsLive = true;
- break;
- }
- }
+ }
+ if (nonAddCount) {
+ int orderNo = N->getIROrder();
+ int orderNo2 = N0.getNode()->getIROrder();
+ // simple heuristics here for considering potential register
+ // pressure, the logics here is that the differnce are used
+ // to measure the distance between def and use, the longer distance
+ // more likely cause register pressure.
+ if (orderNo - orderNo2 < 500)
+ return SDValue();
- if (!opIsLive)
- for (const SDNode *User : right->uses()) {
- int orderNo3 = User->getIROrder();
- if (orderNo3 > orderNo) {
- opIsLive = true;
- break;
- }
+ // Now, check if at least one of the FMUL's operands is live beyond the
+ // node N, which guarantees that the FMA will not increase register
+ // pressure at node N.
+ bool opIsLive = false;
+ const SDNode *left = N0.getOperand(0).getNode();
+ const SDNode *right = N0.getOperand(1).getNode();
+
+ if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
+ opIsLive = true;
+
+ if (!opIsLive)
+ for (const SDNode *User : left->uses()) {
+ int orderNo3 = User->getIROrder();
+ if (orderNo3 > orderNo) {
+ opIsLive = true;
+ break;
}
+ }
- if (!opIsLive)
- return SDValue();
- }
+ if (!opIsLive)
+ for (const SDNode *User : right->uses()) {
+ int orderNo3 = User->getIROrder();
+ if (orderNo3 > orderNo) {
+ opIsLive = true;
+ break;
+ }
+ }
- return DAG.getNode(ISD::FMA, SDLoc(N), VT,
- N0.getOperand(0), N0.getOperand(1), N1);
+ if (!opIsLive)
+ return SDValue();
}
+
+ return DCI.DAG.getNode(ISD::FMA, SDLoc(N), VT, N0.getOperand(0),
+ N0.getOperand(1), N1);
}
return SDValue();
@@ -5332,18 +5358,44 @@ static SDValue PerformStoreRetvalCombine(SDNode *N) {
///
static SDValue PerformADDCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
- const NVPTXSubtarget &Subtarget,
+ CodeGenOptLevel OptLevel) {
+ if (OptLevel == CodeGenOptLevel::None)
+ return SDValue();
+
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+
+ // Skip non-integer, non-scalar case
+ EVT VT = N0.getValueType();
+ if (VT.isVector() || VT != MVT::i32)
+ return SDValue();
+
+ // First try with the default operand order.
+ if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
+ return Result;
+
+ // If that didn't work, try again with the operands commuted.
+ return PerformADDCombineWithOperands(N, N1, N0, DCI);
+}
+
+/// PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
+///
+static SDValue PerformFADDCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
+ EVT VT = N0.getValueType();
+ if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
+ return SDValue();
+
// First try with the default operand order.
- if (SDValue Result =
- PerformADDCombineWithOperands(N, N0, N1, DCI, Subtarget, OptLevel))
+ if (SDValue Result = PerformFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
return Result;
// If that didn't work, try again with the operands commuted.
- return PerformADDCombineWithOperands(N, N1, N0, DCI, Subtarget, OptLevel);
+ return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
}
static SDValue PerformANDCombine(SDNode *N,
@@ -5876,8 +5928,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
switch (N->getOpcode()) {
default: break;
case ISD::ADD:
+ return PerformADDCombine(N, DCI, OptLevel);
case ISD::FADD:
- return PerformADDCombine(N, DCI, STI, OptLevel);
+ return PerformFADDCombine(N, DCI, OptLevel);
case ISD::MUL:
return PerformMULCombine(N, DCI, OptLevel);
case ISD::SHL:
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad.ll b/llvm/test/CodeGen/NVPTX/combine-mad.ll
index 0637bc916ea49..56bfaa14c5877 100644
--- a/llvm/test/CodeGen/NVPTX/combine-mad.ll
+++ b/llvm/test/CodeGen/NVPTX/combine-mad.ll
@@ -134,3 +134,52 @@ define i32 @test3(i32 %n, i32 %m, i32 %s) {
%mul = mul i32 %sel, %m
ret i32 %mul
}
+
+;; (add (select 0, (mul a, b)), c) -> (select (mad a, b, c), c)
+define i32 @test4(i32 %a, i32 %b, i32 %c, i1 %p) {
+; CHECK-LABEL: test4(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<2>;
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u8 %rs1, [test4_param_3];
+; CHECK-NEXT: and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT: setp.eq.b16 %p1, %rs2, 1;
+; CHECK-NEXT: ld.param.u32 %r1, [test4_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test4_param_1];
+; CHECK-NEXT: ld.param.u32 %r3, [test4_param_2];
+; CHECK-NEXT: mad.lo.s32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: selp.b32 %r5, %r4, %r3, %p1;
+; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %a, %b
+ %sel = select i1 %p, i32 %mul, i32 0
+ %add = add i32 %c, %sel
+ ret i32 %add
+}
+
+define i32 @test4_rev(i32 %a, i32 %b, i32 %c, i1 %p) {
+; CHECK-LABEL: test4_rev(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<2>;
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u8 %rs1, [test4_rev_param_3];
+; CHECK-NEXT: and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT: setp.eq.b16 %p1, %rs2, 1;
+; CHECK-NEXT: ld.param.u32 %r1, [test4_rev_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test4_rev_param_1];
+; CHECK-NEXT: ld.param.u32 %r3, [test4_rev_param_2];
+; CHECK-NEXT: mad.lo.s32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: selp.b32 %r5, %r3, %r4, %p1;
+; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %a, %b
+ %sel = select i1 %p, i32 0, i32 %mul
+ %add = add i32 %c, %sel
+ ret i32 %add
+}
|
|
||
// Since integer multiply-add costs the same as integer multiply | ||
// but is more costly than integer add, do the fusion only when | ||
// the mul is only used in the add. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
Is this statement true? According to https://docs.nvidia.com/cuda/cuda-c-programming-guide/#arithmetic-instructions, on sm70+ integer multiply, integer add, and integer madd all have the same throughput.
-
If we really do want to add the hasOneUse() constraint for the simple madd transformation, we should definitely call it out in the commit message (or even do it as a separate commit?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both this statement and the hasOneUse()
constraint were here before this change. I've just refactored them around a little to avoid the need for code duplication. I agree the hasOneUse()
, may be a little too conservative, but relaxing this constrain poses some risks and I think it makes sense to maintain it as is for the purposes of this commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see it. Sorry, I looked but didn't see the logic before.
Maybe add a TODO that this statement may not be true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, added a TODO
bool ZeroCond; | ||
if (isConstZero(N0->getOperand(1))) | ||
ZeroCond = true; | ||
else if (isConstZero(N0->getOperand(2))) | ||
ZeroCond = false; | ||
else | ||
return SDValue(); | ||
|
||
SDValue M = N0->getOperand(ZeroCond ? 2 : 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like you're using a boolean here to represent "operand 1 is zero or operand 2 is zero". But maybe using an integer with values 1 or 2 would make more sense? Like, "ZeroCond is false therefore N0->getOperand(2) is zero" doesn't make sense to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I've updated this to use an index.
… c)) (llvm#96352) Add folding for `(add (select cond, 0, (mul a, b)), c)` to `(select cond, c, (mad a, b, c))`. Also, refactor the DAG folding implementation to separate out the `ADD` and `FADD` folding cases.
… c)) (llvm#96352) Add folding for `(add (select cond, 0, (mul a, b)), c)` to `(select cond, c, (mad a, b, c))`. Also, refactor the DAG folding implementation to separate out the `ADD` and `FADD` folding cases.
Add folding for
(add (select cond, 0, (mul a, b)), c)
to(select cond, c, (mad a, b, c))
. Also, refactor the DAG folding implementation to separate out theADD
andFADD
folding cases.