Skip to content

Commit a952b64

Browse files
committed
Add constant-folding for unary NVVM intrinsics
Add support for constant-folding numerous NVVM unary arithmetic intrinsics (including f, d, and ftz_f variants): - nvvm.ceil.* - nvvm.cos.approx.* - nvvm.ex2.approx.* - nvvm.fabs.* - nvvm.floor.* - nvvm.lg2.approx.* - nvvm.rcp.* - nvvm.round.* - nvvm.rsqrt.approx.* - nvvm.saturate.* - nvvm.sin.approx.* - nvvm.sqrt.f - nvvm.sqrt.rn.* - nvvm.sqrt.approx.*
1 parent 23a341e commit a952b64

File tree

3 files changed

+1364
-4
lines changed

3 files changed

+1364
-4
lines changed

llvm/include/llvm/IR/NVVMIntrinsicUtils.h

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,128 @@ inline bool FMinFMaxIsXorSignAbs(Intrinsic::ID IntrinsicID) {
334334
return false;
335335
}
336336

337+
inline bool UnaryMathIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
338+
switch (IntrinsicID) {
339+
case Intrinsic::nvvm_ceil_ftz_f:
340+
case Intrinsic::nvvm_cos_approx_ftz_f:
341+
case Intrinsic::nvvm_ex2_approx_ftz_f:
342+
case Intrinsic::nvvm_fabs_ftz_f:
343+
case Intrinsic::nvvm_floor_ftz_f:
344+
case Intrinsic::nvvm_lg2_approx_ftz_f:
345+
case Intrinsic::nvvm_round_ftz_f:
346+
case Intrinsic::nvvm_rsqrt_approx_ftz_d:
347+
case Intrinsic::nvvm_rsqrt_approx_ftz_f:
348+
case Intrinsic::nvvm_saturate_ftz_f:
349+
case Intrinsic::nvvm_sin_approx_ftz_f:
350+
case Intrinsic::nvvm_sqrt_rn_ftz_f:
351+
case Intrinsic::nvvm_sqrt_approx_ftz_f:
352+
return true;
353+
case Intrinsic::nvvm_ceil_f:
354+
case Intrinsic::nvvm_ceil_d:
355+
case Intrinsic::nvvm_cos_approx_f:
356+
case Intrinsic::nvvm_ex2_approx_d:
357+
case Intrinsic::nvvm_ex2_approx_f:
358+
case Intrinsic::nvvm_fabs_d:
359+
case Intrinsic::nvvm_fabs_f:
360+
case Intrinsic::nvvm_floor_f:
361+
case Intrinsic::nvvm_floor_d:
362+
case Intrinsic::nvvm_lg2_approx_d:
363+
case Intrinsic::nvvm_lg2_approx_f:
364+
case Intrinsic::nvvm_round_f:
365+
case Intrinsic::nvvm_round_d:
366+
case Intrinsic::nvvm_rsqrt_approx_d:
367+
case Intrinsic::nvvm_rsqrt_approx_f:
368+
case Intrinsic::nvvm_saturate_d:
369+
case Intrinsic::nvvm_saturate_f:
370+
case Intrinsic::nvvm_sin_approx_f:
371+
case Intrinsic::nvvm_sqrt_f:
372+
case Intrinsic::nvvm_sqrt_rn_d:
373+
case Intrinsic::nvvm_sqrt_rn_f:
374+
case Intrinsic::nvvm_sqrt_approx_f:
375+
return false;
376+
}
377+
llvm_unreachable("Checking FTZ flag for invalid unary intrinsic");
378+
return false;
379+
}
380+
381+
inline bool RCPShouldFTZ(Intrinsic::ID IntrinsicID) {
382+
switch (IntrinsicID) {
383+
case Intrinsic::nvvm_rcp_rm_ftz_f:
384+
case Intrinsic::nvvm_rcp_rn_ftz_f:
385+
case Intrinsic::nvvm_rcp_rp_ftz_f:
386+
case Intrinsic::nvvm_rcp_rz_ftz_f:
387+
case Intrinsic::nvvm_rcp_approx_ftz_f:
388+
case Intrinsic::nvvm_rcp_approx_ftz_d:
389+
return true;
390+
case Intrinsic::nvvm_rcp_rm_d:
391+
case Intrinsic::nvvm_rcp_rm_f:
392+
case Intrinsic::nvvm_rcp_rn_d:
393+
case Intrinsic::nvvm_rcp_rn_f:
394+
case Intrinsic::nvvm_rcp_rp_d:
395+
case Intrinsic::nvvm_rcp_rp_f:
396+
case Intrinsic::nvvm_rcp_rz_d:
397+
case Intrinsic::nvvm_rcp_rz_f:
398+
return false;
399+
}
400+
llvm_unreachable("Checking FTZ flag for invalid rcp intrinsic");
401+
return false;
402+
}
403+
404+
inline APFloat::roundingMode GetRCPRoundingMode(Intrinsic::ID IntrinsicID) {
405+
switch (IntrinsicID) {
406+
case Intrinsic::nvvm_rcp_rm_f:
407+
case Intrinsic::nvvm_rcp_rm_d:
408+
case Intrinsic::nvvm_rcp_rm_ftz_f:
409+
return APFloat::rmTowardNegative;
410+
411+
case Intrinsic::nvvm_rcp_approx_ftz_f:
412+
case Intrinsic::nvvm_rcp_approx_ftz_d:
413+
case Intrinsic::nvvm_rcp_rn_f:
414+
case Intrinsic::nvvm_rcp_rn_d:
415+
case Intrinsic::nvvm_rcp_rn_ftz_f:
416+
return APFloat::rmNearestTiesToEven;
417+
418+
case Intrinsic::nvvm_rcp_rp_f:
419+
case Intrinsic::nvvm_rcp_rp_d:
420+
case Intrinsic::nvvm_rcp_rp_ftz_f:
421+
return APFloat::rmNearestTiesToEven;
422+
423+
case Intrinsic::nvvm_rcp_rz_f:
424+
case Intrinsic::nvvm_rcp_rz_d:
425+
case Intrinsic::nvvm_rcp_rz_ftz_f:
426+
return APFloat::rmTowardZero;
427+
}
428+
llvm_unreachable("Checking rounding mode for invalid rcp intrinsic");
429+
return APFloat::roundingMode::Invalid;
430+
}
431+
432+
inline bool RCPIsApprox(Intrinsic::ID IntrinsicID) {
433+
switch (IntrinsicID) {
434+
case Intrinsic::nvvm_rcp_approx_ftz_f:
435+
case Intrinsic::nvvm_rcp_approx_ftz_d:
436+
return true;
437+
438+
case Intrinsic::nvvm_rcp_rm_f:
439+
case Intrinsic::nvvm_rcp_rm_d:
440+
case Intrinsic::nvvm_rcp_rm_ftz_f:
441+
442+
case Intrinsic::nvvm_rcp_rn_f:
443+
case Intrinsic::nvvm_rcp_rn_d:
444+
case Intrinsic::nvvm_rcp_rn_ftz_f:
445+
446+
case Intrinsic::nvvm_rcp_rp_f:
447+
case Intrinsic::nvvm_rcp_rp_d:
448+
case Intrinsic::nvvm_rcp_rp_ftz_f:
449+
450+
case Intrinsic::nvvm_rcp_rz_f:
451+
case Intrinsic::nvvm_rcp_rz_d:
452+
case Intrinsic::nvvm_rcp_rz_ftz_f:
453+
return false;
454+
}
455+
llvm_unreachable("Checking approx flag for invalid rcp intrinsic");
456+
return false;
457+
}
458+
337459
} // namespace nvvm
338460
} // namespace llvm
339461
#endif // LLVM_IR_NVVMINTRINSICUTILS_H

llvm/lib/Analysis/ConstantFolding.cpp

Lines changed: 207 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,6 +1791,69 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
17911791
case Intrinsic::nearbyint:
17921792
case Intrinsic::rint:
17931793
case Intrinsic::canonicalize:
1794+
1795+
// NVVM math intrinsics:
1796+
case Intrinsic::nvvm_ceil_d:
1797+
case Intrinsic::nvvm_ceil_f:
1798+
case Intrinsic::nvvm_ceil_ftz_f:
1799+
1800+
case Intrinsic::nvvm_cos_approx_f:
1801+
case Intrinsic::nvvm_cos_approx_ftz_f:
1802+
1803+
case Intrinsic::nvvm_ex2_approx_d:
1804+
case Intrinsic::nvvm_ex2_approx_f:
1805+
case Intrinsic::nvvm_ex2_approx_ftz_f:
1806+
1807+
case Intrinsic::nvvm_fabs_d:
1808+
case Intrinsic::nvvm_fabs_f:
1809+
case Intrinsic::nvvm_fabs_ftz_f:
1810+
1811+
case Intrinsic::nvvm_floor_d:
1812+
case Intrinsic::nvvm_floor_f:
1813+
case Intrinsic::nvvm_floor_ftz_f:
1814+
1815+
case Intrinsic::nvvm_lg2_approx_d:
1816+
case Intrinsic::nvvm_lg2_approx_f:
1817+
case Intrinsic::nvvm_lg2_approx_ftz_f:
1818+
1819+
case Intrinsic::nvvm_rcp_rm_d:
1820+
case Intrinsic::nvvm_rcp_rm_f:
1821+
case Intrinsic::nvvm_rcp_rm_ftz_f:
1822+
case Intrinsic::nvvm_rcp_rn_d:
1823+
case Intrinsic::nvvm_rcp_rn_f:
1824+
case Intrinsic::nvvm_rcp_rn_ftz_f:
1825+
case Intrinsic::nvvm_rcp_rp_d:
1826+
case Intrinsic::nvvm_rcp_rp_f:
1827+
case Intrinsic::nvvm_rcp_rp_ftz_f:
1828+
case Intrinsic::nvvm_rcp_rz_d:
1829+
case Intrinsic::nvvm_rcp_rz_f:
1830+
case Intrinsic::nvvm_rcp_rz_ftz_f:
1831+
case Intrinsic::nvvm_rcp_approx_ftz_d:
1832+
case Intrinsic::nvvm_rcp_approx_ftz_f:
1833+
1834+
case Intrinsic::nvvm_round_d:
1835+
case Intrinsic::nvvm_round_f:
1836+
case Intrinsic::nvvm_round_ftz_f:
1837+
1838+
case Intrinsic::nvvm_rsqrt_approx_d:
1839+
case Intrinsic::nvvm_rsqrt_approx_f:
1840+
case Intrinsic::nvvm_rsqrt_approx_ftz_d:
1841+
case Intrinsic::nvvm_rsqrt_approx_ftz_f:
1842+
1843+
case Intrinsic::nvvm_saturate_d:
1844+
case Intrinsic::nvvm_saturate_f:
1845+
case Intrinsic::nvvm_saturate_ftz_f:
1846+
1847+
case Intrinsic::nvvm_sin_approx_f:
1848+
case Intrinsic::nvvm_sin_approx_ftz_f:
1849+
1850+
case Intrinsic::nvvm_sqrt_f:
1851+
case Intrinsic::nvvm_sqrt_rn_d:
1852+
case Intrinsic::nvvm_sqrt_rn_f:
1853+
case Intrinsic::nvvm_sqrt_rn_ftz_f:
1854+
case Intrinsic::nvvm_sqrt_approx_f:
1855+
case Intrinsic::nvvm_sqrt_approx_ftz_f:
1856+
17941857
// Constrained intrinsics can be folded if FP environment is known
17951858
// to compiler.
17961859
case Intrinsic::experimental_constrained_fma:
@@ -1944,16 +2007,23 @@ static const APFloat FTZPreserveSign(const APFloat &V) {
19442007
return V;
19452008
}
19462009

1947-
Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V,
1948-
Type *Ty) {
2010+
Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V, Type *Ty,
2011+
bool ShouldFTZPreservingSign = false) {
19492012
llvm_fenv_clearexcept();
1950-
double Result = NativeFP(V.convertToDouble());
2013+
auto Input = ShouldFTZPreservingSign ? FTZPreserveSign(V) : V;
2014+
double Result = NativeFP(Input.convertToDouble());
19512015
if (llvm_fenv_testexcept()) {
19522016
llvm_fenv_clearexcept();
19532017
return nullptr;
19542018
}
19552019

1956-
return GetConstantFoldFPValue(Result, Ty);
2020+
Constant *Output = GetConstantFoldFPValue(Result, Ty);
2021+
if (ShouldFTZPreservingSign) {
2022+
const auto *CFP = static_cast<ConstantFP *>(Output);
2023+
return ConstantFP::get(Ty->getContext(),
2024+
FTZPreserveSign(CFP->getValueAPF()));
2025+
}
2026+
return Output;
19572027
}
19582028

19592029
#if defined(HAS_IEE754_FLOAT128) && defined(HAS_LOGF128)
@@ -2524,6 +2594,139 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
25242594
return ConstantFoldFP(cosh, APF, Ty);
25252595
case Intrinsic::sqrt:
25262596
return ConstantFoldFP(sqrt, APF, Ty);
2597+
2598+
// NVVM Intrinsics:
2599+
case Intrinsic::nvvm_ceil_ftz_f:
2600+
case Intrinsic::nvvm_ceil_f:
2601+
case Intrinsic::nvvm_ceil_d:
2602+
return ConstantFoldFP(ceil, APF, Ty,
2603+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2604+
2605+
case Intrinsic::nvvm_cos_approx_ftz_f:
2606+
case Intrinsic::nvvm_cos_approx_f:
2607+
return ConstantFoldFP(cos, APF, Ty,
2608+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2609+
2610+
case Intrinsic::nvvm_ex2_approx_ftz_f:
2611+
case Intrinsic::nvvm_ex2_approx_d:
2612+
case Intrinsic::nvvm_ex2_approx_f:
2613+
return ConstantFoldFP(exp2, APF, Ty,
2614+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2615+
2616+
case Intrinsic::nvvm_fabs_ftz_f:
2617+
case Intrinsic::nvvm_fabs_d:
2618+
case Intrinsic::nvvm_fabs_f:
2619+
return ConstantFoldFP(fabs, APF, Ty,
2620+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2621+
2622+
case Intrinsic::nvvm_floor_ftz_f:
2623+
case Intrinsic::nvvm_floor_f:
2624+
case Intrinsic::nvvm_floor_d:
2625+
return ConstantFoldFP(floor, APF, Ty,
2626+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2627+
2628+
case Intrinsic::nvvm_lg2_approx_ftz_f:
2629+
case Intrinsic::nvvm_lg2_approx_d:
2630+
case Intrinsic::nvvm_lg2_approx_f: {
2631+
if (APF.isNegative() || APF.isZero())
2632+
return nullptr;
2633+
return ConstantFoldFP(log2, APF, Ty,
2634+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2635+
}
2636+
2637+
case Intrinsic::nvvm_rcp_rm_ftz_f:
2638+
case Intrinsic::nvvm_rcp_rn_ftz_f:
2639+
case Intrinsic::nvvm_rcp_rp_ftz_f:
2640+
case Intrinsic::nvvm_rcp_rz_ftz_f:
2641+
case Intrinsic::nvvm_rcp_approx_ftz_f:
2642+
case Intrinsic::nvvm_rcp_approx_ftz_d:
2643+
case Intrinsic::nvvm_rcp_rm_d:
2644+
case Intrinsic::nvvm_rcp_rm_f:
2645+
case Intrinsic::nvvm_rcp_rn_d:
2646+
case Intrinsic::nvvm_rcp_rn_f:
2647+
case Intrinsic::nvvm_rcp_rp_d:
2648+
case Intrinsic::nvvm_rcp_rp_f:
2649+
case Intrinsic::nvvm_rcp_rz_d:
2650+
case Intrinsic::nvvm_rcp_rz_f: {
2651+
APFloat::roundingMode RoundMode = nvvm::GetRCPRoundingMode(IntrinsicID);
2652+
bool IsApprox = nvvm::RCPIsApprox(IntrinsicID);
2653+
bool IsFTZ = nvvm::RCPShouldFTZ(IntrinsicID);
2654+
2655+
auto Denominator = IsFTZ ? FTZPreserveSign(APF) : APF;
2656+
if (IsApprox && Denominator.isZero()) {
2657+
// According to the PTX spec, approximate rcp should return infinity
2658+
// with the same sign as the denominator when dividing by 0.
2659+
APFloat Inf = APFloat::getInf(APF.getSemantics(), APF.isNegative());
2660+
return ConstantFP::get(Ty->getContext(), Inf);
2661+
}
2662+
APFloat Res = APFloat::getOne(APF.getSemantics());
2663+
APFloat::opStatus Status = Res.divide(Denominator, RoundMode);
2664+
2665+
if (Status == APFloat::opOK || Status == APFloat::opInexact) {
2666+
if (IsFTZ)
2667+
Res = FTZPreserveSign(Res);
2668+
return ConstantFP::get(Ty->getContext(), Res);
2669+
}
2670+
return nullptr;
2671+
}
2672+
2673+
case Intrinsic::nvvm_round_ftz_f:
2674+
case Intrinsic::nvvm_round_f:
2675+
case Intrinsic::nvvm_round_d:
2676+
return ConstantFoldFP(round, APF, Ty,
2677+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2678+
2679+
case Intrinsic::nvvm_rsqrt_approx_ftz_d:
2680+
case Intrinsic::nvvm_rsqrt_approx_ftz_f:
2681+
case Intrinsic::nvvm_rsqrt_approx_d:
2682+
case Intrinsic::nvvm_rsqrt_approx_f: {
2683+
bool IsFTZ = nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID);
2684+
auto V = IsFTZ ? FTZPreserveSign(APF) : APF;
2685+
APFloat SqrtV(sqrt(V.convertToDouble()));
2686+
2687+
bool lost;
2688+
SqrtV.convert(APF.getSemantics(), APFloat::rmNearestTiesToEven, &lost);
2689+
2690+
APFloat Res = APFloat::getOne(APF.getSemantics());
2691+
Res.divide(SqrtV, APFloat::rmNearestTiesToEven);
2692+
2693+
// We do not need to flush the output for ftz because it is impossible
2694+
// for 1/sqrt(x) to be a denormal value. If x is the largest fp value,
2695+
// sqrt(x) will be a number with the exponent approximately halved and
2696+
// the reciprocal of that number can't be small enough to be denormal.
2697+
return ConstantFP::get(Ty->getContext(), Res);
2698+
}
2699+
2700+
case Intrinsic::nvvm_saturate_ftz_f:
2701+
case Intrinsic::nvvm_saturate_d:
2702+
case Intrinsic::nvvm_saturate_f: {
2703+
bool IsFTZ = nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID);
2704+
auto V = IsFTZ ? FTZPreserveSign(APF) : APF;
2705+
if (V.isNegative() || V.isZero() || V.isNaN())
2706+
return ConstantFP::getZero(Ty);
2707+
APFloat One = APFloat::getOne(APF.getSemantics());
2708+
if (V > One)
2709+
return ConstantFP::get(Ty->getContext(), One);
2710+
return ConstantFP::get(Ty->getContext(), APF);
2711+
}
2712+
2713+
case Intrinsic::nvvm_sin_approx_ftz_f:
2714+
case Intrinsic::nvvm_sin_approx_f:
2715+
return ConstantFoldFP(sin, APF, Ty,
2716+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2717+
2718+
case Intrinsic::nvvm_sqrt_rn_ftz_f:
2719+
case Intrinsic::nvvm_sqrt_approx_ftz_f:
2720+
case Intrinsic::nvvm_sqrt_f:
2721+
case Intrinsic::nvvm_sqrt_rn_d:
2722+
case Intrinsic::nvvm_sqrt_rn_f:
2723+
case Intrinsic::nvvm_sqrt_approx_f:
2724+
if (APF.isNegative())
2725+
return nullptr;
2726+
return ConstantFoldFP(sqrt, APF, Ty,
2727+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2728+
2729+
// AMDGCN Intrinsics:
25272730
case Intrinsic::amdgcn_cos:
25282731
case Intrinsic::amdgcn_sin: {
25292732
double V = getValueAsDouble(Op);

0 commit comments

Comments
 (0)