Skip to content

Commit 2612765

Browse files
authored
[NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select c, (mad a, b, c)) (#96352)
Add folding for `(add (select cond, 0, (mul a, b)), c)` to `(select cond, c, (mad a, b, c))`. Also, refactor the DAG folding implementation to separate out the `ADD` and `FADD` folding cases.
1 parent b7762f2 commit 2612765

File tree

2 files changed

+187
-83
lines changed

2 files changed

+187
-83
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 138 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -5217,103 +5217,131 @@ bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const {
52175217
return F.getFnAttribute("unsafe-fp-math").getValueAsBool();
52185218
}
52195219

5220+
static bool isConstZero(const SDValue &Operand) {
5221+
const auto *Const = dyn_cast<ConstantSDNode>(Operand);
5222+
return Const && Const->getZExtValue() == 0;
5223+
}
5224+
52205225
/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
52215226
/// operands N0 and N1. This is a helper for PerformADDCombine that is
52225227
/// called with the default operands, and if that fails, with commuted
52235228
/// operands.
5224-
static SDValue PerformADDCombineWithOperands(
5225-
SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI,
5226-
const NVPTXSubtarget &Subtarget, CodeGenOptLevel OptLevel) {
5227-
SelectionDAG &DAG = DCI.DAG;
5228-
// Skip non-integer, non-scalar case
5229-
EVT VT=N0.getValueType();
5230-
if (VT.isVector())
5229+
static SDValue
5230+
PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
5231+
TargetLowering::DAGCombinerInfo &DCI) {
5232+
EVT VT = N0.getValueType();
5233+
5234+
// Since integer multiply-add costs the same as integer multiply
5235+
// but is more costly than integer add, do the fusion only when
5236+
// the mul is only used in the add.
5237+
// TODO: this may not be true for later architectures, consider relaxing this
5238+
if (!N0.getNode()->hasOneUse())
52315239
return SDValue();
52325240

52335241
// fold (add (mul a, b), c) -> (mad a, b, c)
52345242
//
5235-
if (N0.getOpcode() == ISD::MUL) {
5236-
assert (VT.isInteger());
5237-
// For integer:
5238-
// Since integer multiply-add costs the same as integer multiply
5239-
// but is more costly than integer add, do the fusion only when
5240-
// the mul is only used in the add.
5241-
if (OptLevel == CodeGenOptLevel::None || VT != MVT::i32 ||
5242-
!N0.getNode()->hasOneUse())
5243+
if (N0.getOpcode() == ISD::MUL)
5244+
return DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT, N0.getOperand(0),
5245+
N0.getOperand(1), N1);
5246+
5247+
// fold (add (select cond, 0, (mul a, b)), c)
5248+
// -> (select cond, c, (mad a, b, c))
5249+
//
5250+
if (N0.getOpcode() == ISD::SELECT) {
5251+
unsigned ZeroOpNum;
5252+
if (isConstZero(N0->getOperand(1)))
5253+
ZeroOpNum = 1;
5254+
else if (isConstZero(N0->getOperand(2)))
5255+
ZeroOpNum = 2;
5256+
else
5257+
return SDValue();
5258+
5259+
SDValue M = N0->getOperand((ZeroOpNum == 1) ? 2 : 1);
5260+
if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
52435261
return SDValue();
52445262

5245-
// Do the folding
5246-
return DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
5247-
N0.getOperand(0), N0.getOperand(1), N1);
5263+
SDValue MAD = DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
5264+
M->getOperand(0), M->getOperand(1), N1);
5265+
return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
5266+
((ZeroOpNum == 1) ? N1 : MAD),
5267+
((ZeroOpNum == 1) ? MAD : N1));
52485268
}
5249-
else if (N0.getOpcode() == ISD::FMUL) {
5250-
if (VT == MVT::f32 || VT == MVT::f64) {
5251-
const auto *TLI = static_cast<const NVPTXTargetLowering *>(
5252-
&DAG.getTargetLoweringInfo());
5253-
if (!TLI->allowFMA(DAG.getMachineFunction(), OptLevel))
5254-
return SDValue();
52555269

5256-
// For floating point:
5257-
// Do the fusion only when the mul has less than 5 uses and all
5258-
// are add.
5259-
// The heuristic is that if a use is not an add, then that use
5260-
// cannot be fused into fma, therefore mul is still needed anyway.
5261-
// If there are more than 4 uses, even if they are all add, fusing
5262-
// them will increase register pressue.
5263-
//
5264-
int numUses = 0;
5265-
int nonAddCount = 0;
5266-
for (const SDNode *User : N0.getNode()->uses()) {
5267-
numUses++;
5268-
if (User->getOpcode() != ISD::FADD)
5269-
++nonAddCount;
5270-
}
5270+
return SDValue();
5271+
}
5272+
5273+
static SDValue
5274+
PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
5275+
TargetLowering::DAGCombinerInfo &DCI,
5276+
CodeGenOptLevel OptLevel) {
5277+
EVT VT = N0.getValueType();
5278+
if (N0.getOpcode() == ISD::FMUL) {
5279+
const auto *TLI = static_cast<const NVPTXTargetLowering *>(
5280+
&DCI.DAG.getTargetLoweringInfo());
5281+
if (!TLI->allowFMA(DCI.DAG.getMachineFunction(), OptLevel))
5282+
return SDValue();
5283+
5284+
// For floating point:
5285+
// Do the fusion only when the mul has less than 5 uses and all
5286+
// are add.
5287+
// The heuristic is that if a use is not an add, then that use
5288+
// cannot be fused into fma, therefore mul is still needed anyway.
5289+
// If there are more than 4 uses, even if they are all add, fusing
5290+
// them will increase register pressue.
5291+
//
5292+
int numUses = 0;
5293+
int nonAddCount = 0;
5294+
for (const SDNode *User : N0.getNode()->uses()) {
5295+
numUses++;
5296+
if (User->getOpcode() != ISD::FADD)
5297+
++nonAddCount;
52715298
if (numUses >= 5)
52725299
return SDValue();
5273-
if (nonAddCount) {
5274-
int orderNo = N->getIROrder();
5275-
int orderNo2 = N0.getNode()->getIROrder();
5276-
// simple heuristics here for considering potential register
5277-
// pressure, the logics here is that the differnce are used
5278-
// to measure the distance between def and use, the longer distance
5279-
// more likely cause register pressure.
5280-
if (orderNo - orderNo2 < 500)
5281-
return SDValue();
5282-
5283-
// Now, check if at least one of the FMUL's operands is live beyond the node N,
5284-
// which guarantees that the FMA will not increase register pressure at node N.
5285-
bool opIsLive = false;
5286-
const SDNode *left = N0.getOperand(0).getNode();
5287-
const SDNode *right = N0.getOperand(1).getNode();
5288-
5289-
if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
5290-
opIsLive = true;
5291-
5292-
if (!opIsLive)
5293-
for (const SDNode *User : left->uses()) {
5294-
int orderNo3 = User->getIROrder();
5295-
if (orderNo3 > orderNo) {
5296-
opIsLive = true;
5297-
break;
5298-
}
5299-
}
5300+
}
5301+
if (nonAddCount) {
5302+
int orderNo = N->getIROrder();
5303+
int orderNo2 = N0.getNode()->getIROrder();
5304+
// simple heuristics here for considering potential register
5305+
// pressure, the logics here is that the differnce are used
5306+
// to measure the distance between def and use, the longer distance
5307+
// more likely cause register pressure.
5308+
if (orderNo - orderNo2 < 500)
5309+
return SDValue();
53005310

5301-
if (!opIsLive)
5302-
for (const SDNode *User : right->uses()) {
5303-
int orderNo3 = User->getIROrder();
5304-
if (orderNo3 > orderNo) {
5305-
opIsLive = true;
5306-
break;
5307-
}
5311+
// Now, check if at least one of the FMUL's operands is live beyond the
5312+
// node N, which guarantees that the FMA will not increase register
5313+
// pressure at node N.
5314+
bool opIsLive = false;
5315+
const SDNode *left = N0.getOperand(0).getNode();
5316+
const SDNode *right = N0.getOperand(1).getNode();
5317+
5318+
if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
5319+
opIsLive = true;
5320+
5321+
if (!opIsLive)
5322+
for (const SDNode *User : left->uses()) {
5323+
int orderNo3 = User->getIROrder();
5324+
if (orderNo3 > orderNo) {
5325+
opIsLive = true;
5326+
break;
53085327
}
5328+
}
53095329

5310-
if (!opIsLive)
5311-
return SDValue();
5312-
}
5330+
if (!opIsLive)
5331+
for (const SDNode *User : right->uses()) {
5332+
int orderNo3 = User->getIROrder();
5333+
if (orderNo3 > orderNo) {
5334+
opIsLive = true;
5335+
break;
5336+
}
5337+
}
53135338

5314-
return DAG.getNode(ISD::FMA, SDLoc(N), VT,
5315-
N0.getOperand(0), N0.getOperand(1), N1);
5339+
if (!opIsLive)
5340+
return SDValue();
53165341
}
5342+
5343+
return DCI.DAG.getNode(ISD::FMA, SDLoc(N), VT, N0.getOperand(0),
5344+
N0.getOperand(1), N1);
53175345
}
53185346

53195347
return SDValue();
@@ -5334,18 +5362,44 @@ static SDValue PerformStoreRetvalCombine(SDNode *N) {
53345362
///
53355363
static SDValue PerformADDCombine(SDNode *N,
53365364
TargetLowering::DAGCombinerInfo &DCI,
5337-
const NVPTXSubtarget &Subtarget,
5365+
CodeGenOptLevel OptLevel) {
5366+
if (OptLevel == CodeGenOptLevel::None)
5367+
return SDValue();
5368+
5369+
SDValue N0 = N->getOperand(0);
5370+
SDValue N1 = N->getOperand(1);
5371+
5372+
// Skip non-integer, non-scalar case
5373+
EVT VT = N0.getValueType();
5374+
if (VT.isVector() || VT != MVT::i32)
5375+
return SDValue();
5376+
5377+
// First try with the default operand order.
5378+
if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
5379+
return Result;
5380+
5381+
// If that didn't work, try again with the operands commuted.
5382+
return PerformADDCombineWithOperands(N, N1, N0, DCI);
5383+
}
5384+
5385+
/// PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
5386+
///
5387+
static SDValue PerformFADDCombine(SDNode *N,
5388+
TargetLowering::DAGCombinerInfo &DCI,
53385389
CodeGenOptLevel OptLevel) {
53395390
SDValue N0 = N->getOperand(0);
53405391
SDValue N1 = N->getOperand(1);
53415392

5393+
EVT VT = N0.getValueType();
5394+
if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
5395+
return SDValue();
5396+
53425397
// First try with the default operand order.
5343-
if (SDValue Result =
5344-
PerformADDCombineWithOperands(N, N0, N1, DCI, Subtarget, OptLevel))
5398+
if (SDValue Result = PerformFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
53455399
return Result;
53465400

53475401
// If that didn't work, try again with the operands commuted.
5348-
return PerformADDCombineWithOperands(N, N1, N0, DCI, Subtarget, OptLevel);
5402+
return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
53495403
}
53505404

53515405
static SDValue PerformANDCombine(SDNode *N,
@@ -5878,8 +5932,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
58785932
switch (N->getOpcode()) {
58795933
default: break;
58805934
case ISD::ADD:
5935+
return PerformADDCombine(N, DCI, OptLevel);
58815936
case ISD::FADD:
5882-
return PerformADDCombine(N, DCI, STI, OptLevel);
5937+
return PerformFADDCombine(N, DCI, OptLevel);
58835938
case ISD::MUL:
58845939
return PerformMULCombine(N, DCI, OptLevel);
58855940
case ISD::SHL:

llvm/test/CodeGen/NVPTX/combine-mad.ll

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,52 @@ define i32 @test3(i32 %n, i32 %m, i32 %s) {
134134
%mul = mul i32 %sel, %m
135135
ret i32 %mul
136136
}
137+
138+
;; (add (select 0, (mul a, b)), c) -> (select (mad a, b, c), c)
139+
define i32 @test4(i32 %a, i32 %b, i32 %c, i1 %p) {
140+
; CHECK-LABEL: test4(
141+
; CHECK: {
142+
; CHECK-NEXT: .reg .pred %p<2>;
143+
; CHECK-NEXT: .reg .b16 %rs<3>;
144+
; CHECK-NEXT: .reg .b32 %r<6>;
145+
; CHECK-EMPTY:
146+
; CHECK-NEXT: // %bb.0:
147+
; CHECK-NEXT: ld.param.u8 %rs1, [test4_param_3];
148+
; CHECK-NEXT: and.b16 %rs2, %rs1, 1;
149+
; CHECK-NEXT: setp.eq.b16 %p1, %rs2, 1;
150+
; CHECK-NEXT: ld.param.u32 %r1, [test4_param_0];
151+
; CHECK-NEXT: ld.param.u32 %r2, [test4_param_1];
152+
; CHECK-NEXT: ld.param.u32 %r3, [test4_param_2];
153+
; CHECK-NEXT: mad.lo.s32 %r4, %r1, %r2, %r3;
154+
; CHECK-NEXT: selp.b32 %r5, %r4, %r3, %p1;
155+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5;
156+
; CHECK-NEXT: ret;
157+
%mul = mul i32 %a, %b
158+
%sel = select i1 %p, i32 %mul, i32 0
159+
%add = add i32 %c, %sel
160+
ret i32 %add
161+
}
162+
163+
define i32 @test4_rev(i32 %a, i32 %b, i32 %c, i1 %p) {
164+
; CHECK-LABEL: test4_rev(
165+
; CHECK: {
166+
; CHECK-NEXT: .reg .pred %p<2>;
167+
; CHECK-NEXT: .reg .b16 %rs<3>;
168+
; CHECK-NEXT: .reg .b32 %r<6>;
169+
; CHECK-EMPTY:
170+
; CHECK-NEXT: // %bb.0:
171+
; CHECK-NEXT: ld.param.u8 %rs1, [test4_rev_param_3];
172+
; CHECK-NEXT: and.b16 %rs2, %rs1, 1;
173+
; CHECK-NEXT: setp.eq.b16 %p1, %rs2, 1;
174+
; CHECK-NEXT: ld.param.u32 %r1, [test4_rev_param_0];
175+
; CHECK-NEXT: ld.param.u32 %r2, [test4_rev_param_1];
176+
; CHECK-NEXT: ld.param.u32 %r3, [test4_rev_param_2];
177+
; CHECK-NEXT: mad.lo.s32 %r4, %r1, %r2, %r3;
178+
; CHECK-NEXT: selp.b32 %r5, %r3, %r4, %p1;
179+
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5;
180+
; CHECK-NEXT: ret;
181+
%mul = mul i32 %a, %b
182+
%sel = select i1 %p, i32 0, i32 %mul
183+
%add = add i32 %c, %sel
184+
ret i32 %add
185+
}

0 commit comments

Comments
 (0)