Skip to content

[NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select c, (mad a, b, c)) #96352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 137 additions & 83 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5215,103 +5215,130 @@ bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const {
return F.getFnAttribute("unsafe-fp-math").getValueAsBool();
}

static bool isConstZero(const SDValue &Operand) {
const auto *Const = dyn_cast<ConstantSDNode>(Operand);
return Const && Const->getZExtValue() == 0;
}

/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
/// operands N0 and N1. This is a helper for PerformADDCombine that is
/// called with the default operands, and if that fails, with commuted
/// operands.
static SDValue PerformADDCombineWithOperands(
SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI,
const NVPTXSubtarget &Subtarget, CodeGenOptLevel OptLevel) {
SelectionDAG &DAG = DCI.DAG;
// Skip non-integer, non-scalar case
EVT VT=N0.getValueType();
if (VT.isVector())
static SDValue
PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
TargetLowering::DAGCombinerInfo &DCI) {
EVT VT = N0.getValueType();

// Since integer multiply-add costs the same as integer multiply
// but is more costly than integer add, do the fusion only when
// the mul is only used in the add.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Is this statement true? According to https://docs.nvidia.com/cuda/cuda-c-programming-guide/#arithmetic-instructions, on sm70+ integer multiply, integer add, and integer madd all have the same throughput.

  2. If we really do want to add the hasOneUse() constraint for the simple madd transformation, we should definitely call it out in the commit message (or even do it as a separate commit?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both this statement and the hasOneUse() constraint were here before this change. I've just refactored them around a little to avoid the need for code duplication. I agree the hasOneUse(), may be a little too conservative, but relaxing this constrain poses some risks and I think it makes sense to maintain it as is for the purposes of this commit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see it. Sorry, I looked but didn't see the logic before.

Maybe add a TODO that this statement may not be true?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, added a TODO

if (!N0.getNode()->hasOneUse())
return SDValue();

// fold (add (mul a, b), c) -> (mad a, b, c)
//
if (N0.getOpcode() == ISD::MUL) {
assert (VT.isInteger());
// For integer:
// Since integer multiply-add costs the same as integer multiply
// but is more costly than integer add, do the fusion only when
// the mul is only used in the add.
if (OptLevel == CodeGenOptLevel::None || VT != MVT::i32 ||
!N0.getNode()->hasOneUse())
if (N0.getOpcode() == ISD::MUL)
return DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT, N0.getOperand(0),
N0.getOperand(1), N1);

// fold (add (select cond, 0, (mul a, b)), c)
// -> (select cond, c, (mad a, b, c))
//
if (N0.getOpcode() == ISD::SELECT) {
unsigned ZeroOpNum;
if (isConstZero(N0->getOperand(1)))
ZeroOpNum = 1;
else if (isConstZero(N0->getOperand(2)))
ZeroOpNum = 2;
else
return SDValue();

SDValue M = N0->getOperand((ZeroOpNum == 1) ? 2 : 1);
if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
return SDValue();

// Do the folding
return DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
N0.getOperand(0), N0.getOperand(1), N1);
SDValue MAD = DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
M->getOperand(0), M->getOperand(1), N1);
return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
((ZeroOpNum == 1) ? N1 : MAD),
((ZeroOpNum == 1) ? MAD : N1));
}
else if (N0.getOpcode() == ISD::FMUL) {
if (VT == MVT::f32 || VT == MVT::f64) {
const auto *TLI = static_cast<const NVPTXTargetLowering *>(
&DAG.getTargetLoweringInfo());
if (!TLI->allowFMA(DAG.getMachineFunction(), OptLevel))
return SDValue();

// For floating point:
// Do the fusion only when the mul has less than 5 uses and all
// are add.
// The heuristic is that if a use is not an add, then that use
// cannot be fused into fma, therefore mul is still needed anyway.
// If there are more than 4 uses, even if they are all add, fusing
// them will increase register pressue.
//
int numUses = 0;
int nonAddCount = 0;
for (const SDNode *User : N0.getNode()->uses()) {
numUses++;
if (User->getOpcode() != ISD::FADD)
++nonAddCount;
}
return SDValue();
}

static SDValue
PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
EVT VT = N0.getValueType();
if (N0.getOpcode() == ISD::FMUL) {
const auto *TLI = static_cast<const NVPTXTargetLowering *>(
&DCI.DAG.getTargetLoweringInfo());
if (!TLI->allowFMA(DCI.DAG.getMachineFunction(), OptLevel))
return SDValue();

// For floating point:
// Do the fusion only when the mul has less than 5 uses and all
// are add.
// The heuristic is that if a use is not an add, then that use
// cannot be fused into fma, therefore mul is still needed anyway.
// If there are more than 4 uses, even if they are all add, fusing
// them will increase register pressue.
//
int numUses = 0;
int nonAddCount = 0;
for (const SDNode *User : N0.getNode()->uses()) {
numUses++;
if (User->getOpcode() != ISD::FADD)
++nonAddCount;
if (numUses >= 5)
return SDValue();
if (nonAddCount) {
int orderNo = N->getIROrder();
int orderNo2 = N0.getNode()->getIROrder();
// simple heuristics here for considering potential register
// pressure, the logics here is that the differnce are used
// to measure the distance between def and use, the longer distance
// more likely cause register pressure.
if (orderNo - orderNo2 < 500)
return SDValue();

// Now, check if at least one of the FMUL's operands is live beyond the node N,
// which guarantees that the FMA will not increase register pressure at node N.
bool opIsLive = false;
const SDNode *left = N0.getOperand(0).getNode();
const SDNode *right = N0.getOperand(1).getNode();

if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
opIsLive = true;

if (!opIsLive)
for (const SDNode *User : left->uses()) {
int orderNo3 = User->getIROrder();
if (orderNo3 > orderNo) {
opIsLive = true;
break;
}
}
}
if (nonAddCount) {
int orderNo = N->getIROrder();
int orderNo2 = N0.getNode()->getIROrder();
// simple heuristics here for considering potential register
// pressure, the logics here is that the differnce are used
// to measure the distance between def and use, the longer distance
// more likely cause register pressure.
if (orderNo - orderNo2 < 500)
return SDValue();

if (!opIsLive)
for (const SDNode *User : right->uses()) {
int orderNo3 = User->getIROrder();
if (orderNo3 > orderNo) {
opIsLive = true;
break;
}
// Now, check if at least one of the FMUL's operands is live beyond the
// node N, which guarantees that the FMA will not increase register
// pressure at node N.
bool opIsLive = false;
const SDNode *left = N0.getOperand(0).getNode();
const SDNode *right = N0.getOperand(1).getNode();

if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
opIsLive = true;

if (!opIsLive)
for (const SDNode *User : left->uses()) {
int orderNo3 = User->getIROrder();
if (orderNo3 > orderNo) {
opIsLive = true;
break;
}
}

if (!opIsLive)
return SDValue();
}
if (!opIsLive)
for (const SDNode *User : right->uses()) {
int orderNo3 = User->getIROrder();
if (orderNo3 > orderNo) {
opIsLive = true;
break;
}
}

return DAG.getNode(ISD::FMA, SDLoc(N), VT,
N0.getOperand(0), N0.getOperand(1), N1);
if (!opIsLive)
return SDValue();
}

return DCI.DAG.getNode(ISD::FMA, SDLoc(N), VT, N0.getOperand(0),
N0.getOperand(1), N1);
}

return SDValue();
Expand All @@ -5332,18 +5359,44 @@ static SDValue PerformStoreRetvalCombine(SDNode *N) {
///
static SDValue PerformADDCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const NVPTXSubtarget &Subtarget,
CodeGenOptLevel OptLevel) {
if (OptLevel == CodeGenOptLevel::None)
return SDValue();

SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);

// Skip non-integer, non-scalar case
EVT VT = N0.getValueType();
if (VT.isVector() || VT != MVT::i32)
return SDValue();

// First try with the default operand order.
if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
return Result;

// If that didn't work, try again with the operands commuted.
return PerformADDCombineWithOperands(N, N1, N0, DCI);
}

/// PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
///
static SDValue PerformFADDCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);

EVT VT = N0.getValueType();
if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
return SDValue();

// First try with the default operand order.
if (SDValue Result =
PerformADDCombineWithOperands(N, N0, N1, DCI, Subtarget, OptLevel))
if (SDValue Result = PerformFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
return Result;

// If that didn't work, try again with the operands commuted.
return PerformADDCombineWithOperands(N, N1, N0, DCI, Subtarget, OptLevel);
return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
}

static SDValue PerformANDCombine(SDNode *N,
Expand Down Expand Up @@ -5876,8 +5929,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
switch (N->getOpcode()) {
default: break;
case ISD::ADD:
return PerformADDCombine(N, DCI, OptLevel);
case ISD::FADD:
return PerformADDCombine(N, DCI, STI, OptLevel);
return PerformFADDCombine(N, DCI, OptLevel);
case ISD::MUL:
return PerformMULCombine(N, DCI, OptLevel);
case ISD::SHL:
Expand Down
49 changes: 49 additions & 0 deletions llvm/test/CodeGen/NVPTX/combine-mad.ll
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,52 @@ define i32 @test3(i32 %n, i32 %m, i32 %s) {
%mul = mul i32 %sel, %m
ret i32 %mul
}

;; (add (select 0, (mul a, b)), c) -> (select (mad a, b, c), c)
define i32 @test4(i32 %a, i32 %b, i32 %c, i1 %p) {
; CHECK-LABEL: test4(
; CHECK: {
; CHECK-NEXT: .reg .pred %p<2>;
; CHECK-NEXT: .reg .b16 %rs<3>;
; CHECK-NEXT: .reg .b32 %r<6>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u8 %rs1, [test4_param_3];
; CHECK-NEXT: and.b16 %rs2, %rs1, 1;
; CHECK-NEXT: setp.eq.b16 %p1, %rs2, 1;
; CHECK-NEXT: ld.param.u32 %r1, [test4_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test4_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test4_param_2];
; CHECK-NEXT: mad.lo.s32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: selp.b32 %r5, %r4, %r3, %p1;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5;
; CHECK-NEXT: ret;
%mul = mul i32 %a, %b
%sel = select i1 %p, i32 %mul, i32 0
%add = add i32 %c, %sel
ret i32 %add
}

define i32 @test4_rev(i32 %a, i32 %b, i32 %c, i1 %p) {
; CHECK-LABEL: test4_rev(
; CHECK: {
; CHECK-NEXT: .reg .pred %p<2>;
; CHECK-NEXT: .reg .b16 %rs<3>;
; CHECK-NEXT: .reg .b32 %r<6>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u8 %rs1, [test4_rev_param_3];
; CHECK-NEXT: and.b16 %rs2, %rs1, 1;
; CHECK-NEXT: setp.eq.b16 %p1, %rs2, 1;
; CHECK-NEXT: ld.param.u32 %r1, [test4_rev_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test4_rev_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test4_rev_param_2];
; CHECK-NEXT: mad.lo.s32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: selp.b32 %r5, %r3, %r4, %p1;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5;
; CHECK-NEXT: ret;
%mul = mul i32 %a, %b
%sel = select i1 %p, i32 0, i32 %mul
%add = add i32 %c, %sel
ret i32 %add
}
Loading