@@ -5217,103 +5217,131 @@ bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const {
5217
5217
return F.getFnAttribute (" unsafe-fp-math" ).getValueAsBool ();
5218
5218
}
5219
5219
5220
+ static bool isConstZero (const SDValue &Operand) {
5221
+ const auto *Const = dyn_cast<ConstantSDNode>(Operand);
5222
+ return Const && Const->getZExtValue () == 0 ;
5223
+ }
5224
+
5220
5225
// / PerformADDCombineWithOperands - Try DAG combinations for an ADD with
5221
5226
// / operands N0 and N1. This is a helper for PerformADDCombine that is
5222
5227
// / called with the default operands, and if that fails, with commuted
5223
5228
// / operands.
5224
- static SDValue PerformADDCombineWithOperands (
5225
- SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI,
5226
- const NVPTXSubtarget &Subtarget, CodeGenOptLevel OptLevel) {
5227
- SelectionDAG &DAG = DCI.DAG ;
5228
- // Skip non-integer, non-scalar case
5229
- EVT VT=N0.getValueType ();
5230
- if (VT.isVector ())
5229
+ static SDValue
5230
+ PerformADDCombineWithOperands (SDNode *N, SDValue N0, SDValue N1,
5231
+ TargetLowering::DAGCombinerInfo &DCI) {
5232
+ EVT VT = N0.getValueType ();
5233
+
5234
+ // Since integer multiply-add costs the same as integer multiply
5235
+ // but is more costly than integer add, do the fusion only when
5236
+ // the mul is only used in the add.
5237
+ // TODO: this may not be true for later architectures, consider relaxing this
5238
+ if (!N0.getNode ()->hasOneUse ())
5231
5239
return SDValue ();
5232
5240
5233
5241
// fold (add (mul a, b), c) -> (mad a, b, c)
5234
5242
//
5235
- if (N0.getOpcode () == ISD::MUL) {
5236
- assert (VT.isInteger ());
5237
- // For integer:
5238
- // Since integer multiply-add costs the same as integer multiply
5239
- // but is more costly than integer add, do the fusion only when
5240
- // the mul is only used in the add.
5241
- if (OptLevel == CodeGenOptLevel::None || VT != MVT::i32 ||
5242
- !N0.getNode ()->hasOneUse ())
5243
+ if (N0.getOpcode () == ISD::MUL)
5244
+ return DCI.DAG .getNode (NVPTXISD::IMAD, SDLoc (N), VT, N0.getOperand (0 ),
5245
+ N0.getOperand (1 ), N1);
5246
+
5247
+ // fold (add (select cond, 0, (mul a, b)), c)
5248
+ // -> (select cond, c, (mad a, b, c))
5249
+ //
5250
+ if (N0.getOpcode () == ISD::SELECT) {
5251
+ unsigned ZeroOpNum;
5252
+ if (isConstZero (N0->getOperand (1 )))
5253
+ ZeroOpNum = 1 ;
5254
+ else if (isConstZero (N0->getOperand (2 )))
5255
+ ZeroOpNum = 2 ;
5256
+ else
5257
+ return SDValue ();
5258
+
5259
+ SDValue M = N0->getOperand ((ZeroOpNum == 1 ) ? 2 : 1 );
5260
+ if (M->getOpcode () != ISD::MUL || !M.getNode ()->hasOneUse ())
5243
5261
return SDValue ();
5244
5262
5245
- // Do the folding
5246
- return DAG.getNode (NVPTXISD::IMAD, SDLoc (N), VT,
5247
- N0.getOperand (0 ), N0.getOperand (1 ), N1);
5263
+ SDValue MAD = DCI.DAG .getNode (NVPTXISD::IMAD, SDLoc (N), VT,
5264
+ M->getOperand (0 ), M->getOperand (1 ), N1);
5265
+ return DCI.DAG .getSelect (SDLoc (N), VT, N0->getOperand (0 ),
5266
+ ((ZeroOpNum == 1 ) ? N1 : MAD),
5267
+ ((ZeroOpNum == 1 ) ? MAD : N1));
5248
5268
}
5249
- else if (N0.getOpcode () == ISD::FMUL) {
5250
- if (VT == MVT::f32 || VT == MVT::f64) {
5251
- const auto *TLI = static_cast <const NVPTXTargetLowering *>(
5252
- &DAG.getTargetLoweringInfo ());
5253
- if (!TLI->allowFMA (DAG.getMachineFunction (), OptLevel))
5254
- return SDValue ();
5255
5269
5256
- // For floating point:
5257
- // Do the fusion only when the mul has less than 5 uses and all
5258
- // are add.
5259
- // The heuristic is that if a use is not an add, then that use
5260
- // cannot be fused into fma, therefore mul is still needed anyway.
5261
- // If there are more than 4 uses, even if they are all add, fusing
5262
- // them will increase register pressue.
5263
- //
5264
- int numUses = 0 ;
5265
- int nonAddCount = 0 ;
5266
- for (const SDNode *User : N0.getNode ()->uses ()) {
5267
- numUses++;
5268
- if (User->getOpcode () != ISD::FADD)
5269
- ++nonAddCount;
5270
- }
5270
+ return SDValue ();
5271
+ }
5272
+
5273
+ static SDValue
5274
+ PerformFADDCombineWithOperands (SDNode *N, SDValue N0, SDValue N1,
5275
+ TargetLowering::DAGCombinerInfo &DCI,
5276
+ CodeGenOptLevel OptLevel) {
5277
+ EVT VT = N0.getValueType ();
5278
+ if (N0.getOpcode () == ISD::FMUL) {
5279
+ const auto *TLI = static_cast <const NVPTXTargetLowering *>(
5280
+ &DCI.DAG .getTargetLoweringInfo ());
5281
+ if (!TLI->allowFMA (DCI.DAG .getMachineFunction (), OptLevel))
5282
+ return SDValue ();
5283
+
5284
+ // For floating point:
5285
+ // Do the fusion only when the mul has less than 5 uses and all
5286
+ // are add.
5287
+ // The heuristic is that if a use is not an add, then that use
5288
+ // cannot be fused into fma, therefore mul is still needed anyway.
5289
+ // If there are more than 4 uses, even if they are all add, fusing
5290
+ // them will increase register pressue.
5291
+ //
5292
+ int numUses = 0 ;
5293
+ int nonAddCount = 0 ;
5294
+ for (const SDNode *User : N0.getNode ()->uses ()) {
5295
+ numUses++;
5296
+ if (User->getOpcode () != ISD::FADD)
5297
+ ++nonAddCount;
5271
5298
if (numUses >= 5 )
5272
5299
return SDValue ();
5273
- if (nonAddCount) {
5274
- int orderNo = N->getIROrder ();
5275
- int orderNo2 = N0.getNode ()->getIROrder ();
5276
- // simple heuristics here for considering potential register
5277
- // pressure, the logics here is that the differnce are used
5278
- // to measure the distance between def and use, the longer distance
5279
- // more likely cause register pressure.
5280
- if (orderNo - orderNo2 < 500 )
5281
- return SDValue ();
5282
-
5283
- // Now, check if at least one of the FMUL's operands is live beyond the node N,
5284
- // which guarantees that the FMA will not increase register pressure at node N.
5285
- bool opIsLive = false ;
5286
- const SDNode *left = N0.getOperand (0 ).getNode ();
5287
- const SDNode *right = N0.getOperand (1 ).getNode ();
5288
-
5289
- if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
5290
- opIsLive = true ;
5291
-
5292
- if (!opIsLive)
5293
- for (const SDNode *User : left->uses ()) {
5294
- int orderNo3 = User->getIROrder ();
5295
- if (orderNo3 > orderNo) {
5296
- opIsLive = true ;
5297
- break ;
5298
- }
5299
- }
5300
+ }
5301
+ if (nonAddCount) {
5302
+ int orderNo = N->getIROrder ();
5303
+ int orderNo2 = N0.getNode ()->getIROrder ();
5304
+ // simple heuristics here for considering potential register
5305
+ // pressure, the logics here is that the differnce are used
5306
+ // to measure the distance between def and use, the longer distance
5307
+ // more likely cause register pressure.
5308
+ if (orderNo - orderNo2 < 500 )
5309
+ return SDValue ();
5300
5310
5301
- if (!opIsLive)
5302
- for (const SDNode *User : right->uses ()) {
5303
- int orderNo3 = User->getIROrder ();
5304
- if (orderNo3 > orderNo) {
5305
- opIsLive = true ;
5306
- break ;
5307
- }
5311
+ // Now, check if at least one of the FMUL's operands is live beyond the
5312
+ // node N, which guarantees that the FMA will not increase register
5313
+ // pressure at node N.
5314
+ bool opIsLive = false ;
5315
+ const SDNode *left = N0.getOperand (0 ).getNode ();
5316
+ const SDNode *right = N0.getOperand (1 ).getNode ();
5317
+
5318
+ if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
5319
+ opIsLive = true ;
5320
+
5321
+ if (!opIsLive)
5322
+ for (const SDNode *User : left->uses ()) {
5323
+ int orderNo3 = User->getIROrder ();
5324
+ if (orderNo3 > orderNo) {
5325
+ opIsLive = true ;
5326
+ break ;
5308
5327
}
5328
+ }
5309
5329
5310
- if (!opIsLive)
5311
- return SDValue ();
5312
- }
5330
+ if (!opIsLive)
5331
+ for (const SDNode *User : right->uses ()) {
5332
+ int orderNo3 = User->getIROrder ();
5333
+ if (orderNo3 > orderNo) {
5334
+ opIsLive = true ;
5335
+ break ;
5336
+ }
5337
+ }
5313
5338
5314
- return DAG. getNode (ISD::FMA, SDLoc (N), VT,
5315
- N0. getOperand ( 0 ), N0. getOperand ( 1 ), N1 );
5339
+ if (!opIsLive)
5340
+ return SDValue ( );
5316
5341
}
5342
+
5343
+ return DCI.DAG .getNode (ISD::FMA, SDLoc (N), VT, N0.getOperand (0 ),
5344
+ N0.getOperand (1 ), N1);
5317
5345
}
5318
5346
5319
5347
return SDValue ();
@@ -5334,18 +5362,44 @@ static SDValue PerformStoreRetvalCombine(SDNode *N) {
5334
5362
// /
5335
5363
static SDValue PerformADDCombine (SDNode *N,
5336
5364
TargetLowering::DAGCombinerInfo &DCI,
5337
- const NVPTXSubtarget &Subtarget,
5365
+ CodeGenOptLevel OptLevel) {
5366
+ if (OptLevel == CodeGenOptLevel::None)
5367
+ return SDValue ();
5368
+
5369
+ SDValue N0 = N->getOperand (0 );
5370
+ SDValue N1 = N->getOperand (1 );
5371
+
5372
+ // Skip non-integer, non-scalar case
5373
+ EVT VT = N0.getValueType ();
5374
+ if (VT.isVector () || VT != MVT::i32)
5375
+ return SDValue ();
5376
+
5377
+ // First try with the default operand order.
5378
+ if (SDValue Result = PerformADDCombineWithOperands (N, N0, N1, DCI))
5379
+ return Result;
5380
+
5381
+ // If that didn't work, try again with the operands commuted.
5382
+ return PerformADDCombineWithOperands (N, N1, N0, DCI);
5383
+ }
5384
+
5385
+ // / PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
5386
+ // /
5387
+ static SDValue PerformFADDCombine (SDNode *N,
5388
+ TargetLowering::DAGCombinerInfo &DCI,
5338
5389
CodeGenOptLevel OptLevel) {
5339
5390
SDValue N0 = N->getOperand (0 );
5340
5391
SDValue N1 = N->getOperand (1 );
5341
5392
5393
+ EVT VT = N0.getValueType ();
5394
+ if (VT.isVector () || !(VT == MVT::f32 || VT == MVT::f64))
5395
+ return SDValue ();
5396
+
5342
5397
// First try with the default operand order.
5343
- if (SDValue Result =
5344
- PerformADDCombineWithOperands (N, N0, N1, DCI, Subtarget, OptLevel))
5398
+ if (SDValue Result = PerformFADDCombineWithOperands (N, N0, N1, DCI, OptLevel))
5345
5399
return Result;
5346
5400
5347
5401
// If that didn't work, try again with the operands commuted.
5348
- return PerformADDCombineWithOperands (N, N1, N0, DCI, Subtarget , OptLevel);
5402
+ return PerformFADDCombineWithOperands (N, N1, N0, DCI, OptLevel);
5349
5403
}
5350
5404
5351
5405
static SDValue PerformANDCombine (SDNode *N,
@@ -5878,8 +5932,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5878
5932
switch (N->getOpcode ()) {
5879
5933
default : break ;
5880
5934
case ISD::ADD:
5935
+ return PerformADDCombine (N, DCI, OptLevel);
5881
5936
case ISD::FADD:
5882
- return PerformADDCombine (N, DCI, STI , OptLevel);
5937
+ return PerformFADDCombine (N, DCI, OptLevel);
5883
5938
case ISD::MUL:
5884
5939
return PerformMULCombine (N, DCI, OptLevel);
5885
5940
case ISD::SHL:
0 commit comments