Skip to content

[libc][math][c23] Add f16fmaf C23 math function #95483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libc/config/linux/aarch64/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
libc.src.math.canonicalizef16
libc.src.math.ceilf16
libc.src.math.copysignf16
libc.src.math.f16fmaf
libc.src.math.f16sqrtf
libc.src.math.fabsf16
libc.src.math.fdimf16
Expand Down
1 change: 1 addition & 0 deletions libc/config/linux/x86_64/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
libc.src.math.canonicalizef16
libc.src.math.ceilf16
libc.src.math.copysignf16
libc.src.math.f16fmaf
libc.src.math.f16sqrtf
libc.src.math.fabsf16
libc.src.math.fdimf16
Expand Down
2 changes: 2 additions & 0 deletions libc/docs/math/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ Basic Operations
+------------------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| dsub | N/A | N/A | | N/A | | 7.12.14.2 | F.10.11 |
+------------------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| f16fma | |check| | | | N/A | | 7.12.14.5 | F.10.11 |
+------------------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| fabs | |check| | |check| | |check| | |check| | |check| | 7.12.7.3 | F.10.4.3 |
+------------------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| fadd | N/A | | | N/A | | 7.12.14.1 | F.10.11 |
Expand Down
2 changes: 2 additions & 0 deletions libc/spec/stdc.td
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,8 @@ def StdC : StandardSpec<"stdc"> {

GuardedFunctionSpec<"totalordermagf16", RetValSpec<IntType>, [ArgSpec<Float16Ptr>, ArgSpec<Float16Ptr>], "LIBC_TYPES_HAS_FLOAT16">,

GuardedFunctionSpec<"f16fmaf", RetValSpec<Float16Type>, [ArgSpec<FloatType>, ArgSpec<FloatType>, ArgSpec<FloatType>], "LIBC_TYPES_HAS_FLOAT16">,

GuardedFunctionSpec<"f16sqrtf", RetValSpec<Float16Type>, [ArgSpec<FloatType>], "LIBC_TYPES_HAS_FLOAT16">,
]
>;
Expand Down
32 changes: 10 additions & 22 deletions libc/src/__support/FPUtil/FMA.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,29 @@
#define LLVM_LIBC_SRC___SUPPORT_FPUTIL_FMA_H

#include "src/__support/CPP/type_traits.h"
#include "src/__support/FPUtil/generic/FMA.h"
#include "src/__support/macros/properties/architectures.h"
#include "src/__support/macros/properties/cpu_features.h" // LIBC_TARGET_CPU_HAS_FMA

#if defined(LIBC_TARGET_CPU_HAS_FMA)

namespace LIBC_NAMESPACE {
namespace fputil {

template <typename T>
LIBC_INLINE cpp::enable_if_t<cpp::is_same_v<T, float>, T> fma(T x, T y, T z) {
return __builtin_fmaf(x, y, z);
template <typename OutType, typename InType>
LIBC_INLINE OutType fma(InType x, InType y, InType z) {
return generic::fma<OutType>(x, y, z);
}

template <typename T>
LIBC_INLINE cpp::enable_if_t<cpp::is_same_v<T, double>, T> fma(T x, T y, T z) {
return __builtin_fma(x, y, z);
#ifdef LIBC_TARGET_CPU_HAS_FMA
template <> LIBC_INLINE float fma(float x, float y, float z) {
return __builtin_fmaf(x, y, z);
}

} // namespace fputil
} // namespace LIBC_NAMESPACE

#else
// FMA instructions are not available
#include "generic/FMA.h"

namespace LIBC_NAMESPACE {
namespace fputil {

template <typename T> LIBC_INLINE T fma(T x, T y, T z) {
return generic::fma(x, y, z);
template <> LIBC_INLINE double fma(double x, double y, double z) {
return __builtin_fma(x, y, z);
}
#endif // LIBC_TARGET_CPU_HAS_FMA

} // namespace fputil
} // namespace LIBC_NAMESPACE

#endif

#endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_FMA_H
3 changes: 3 additions & 0 deletions libc/src/__support/FPUtil/generic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ add_header_library(
HDRS
FMA.h
DEPENDS
libc.hdr.fenv_macros
libc.src.__support.common
libc.src.__support.CPP.bit
libc.src.__support.CPP.limits
libc.src.__support.CPP.type_traits
libc.src.__support.FPUtil.fenv_impl
libc.src.__support.FPUtil.fp_bits
libc.src.__support.FPUtil.rounding_mode
libc.src.__support.big_int
libc.src.__support.macros.optimization
libc.src.__support.uint128
)
Expand Down
145 changes: 86 additions & 59 deletions libc/src/__support/FPUtil/generic/FMA.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,28 @@
#define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H

#include "src/__support/CPP/bit.h"
#include "src/__support/CPP/limits.h"
#include "src/__support/CPP/type_traits.h"
#include "src/__support/FPUtil/FEnvImpl.h"
#include "src/__support/FPUtil/FPBits.h"
#include "src/__support/FPUtil/rounding_mode.h"
#include "src/__support/big_int.h"
#include "src/__support/macros/attributes.h" // LIBC_INLINE
#include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
#include "src/__support/uint128.h"

#include "hdr/fenv_macros.h"

namespace LIBC_NAMESPACE {
namespace fputil {
namespace generic {

template <typename T> LIBC_INLINE T fma(T x, T y, T z);
template <typename OutType, typename InType>
LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
cpp::is_floating_point_v<InType> &&
sizeof(OutType) <= sizeof(InType),
OutType>
fma(InType x, InType y, InType z);

#ifndef LIBC_TARGET_CPU_HAS_FMA
// TODO(lntue): Implement fmaf that is correctly rounded to all rounding modes.
// The implementation below only is only correct for the default rounding mode,
// round-to-nearest tie-to-even.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use the type-generic fma from this PR to implement fmaf correctly rounded for all rounding modes, or is a specialized implementation that is both faster and correct for all rounding modes planned?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you double check if this implementation is correct for all rounding modes? I think the only thing that was blocking me to claim this one as correctly rounded to other rounding mode is that Dekker's 2Sum might not be exact in other rounding modes. But a recent paper by Paul Zimmermann shows that in other rounding modes, the errors from the 2Sum algorithm is still very close to double-double precision ULP. And in this case, it should mean exacts, since the precisions of the summands are 46 and 23 respectively.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It passes the fmaf MPFR unit test (libc.test.src.math.fmaf_test.__unit__).

Also, I just realized that the #ifndef LIBC_TARGET_CPU_HAS_FMA guard is useless, as this is fputil::generic::fma, not fputil::fma.

Expand Down Expand Up @@ -74,65 +82,86 @@ template <> LIBC_INLINE float fma<float>(float x, float y, float z) {

return static_cast<float>(bit_sum.get_val());
}
#endif // LIBC_TARGET_CPU_HAS_FMA

namespace internal {

// Extract the sticky bits and shift the `mantissa` to the right by
// `shift_length`.
LIBC_INLINE bool shift_mantissa(int shift_length, UInt128 &mant) {
if (shift_length >= 128) {
template <typename T>
LIBC_INLINE cpp::enable_if_t<is_unsigned_integral_or_big_int_v<T>, bool>
shift_mantissa(int shift_length, T &mant) {
if (shift_length >= cpp::numeric_limits<T>::digits) {
mant = 0;
return true; // prod_mant is non-zero.
}
UInt128 mask = (UInt128(1) << shift_length) - 1;
T mask = (T(1) << shift_length) - 1;
bool sticky_bits = (mant & mask) != 0;
mant >>= shift_length;
return sticky_bits;
}

} // namespace internal

template <> LIBC_INLINE double fma<double>(double x, double y, double z) {
using FPBits = fputil::FPBits<double>;
template <typename OutType, typename InType>
LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
cpp::is_floating_point_v<InType> &&
sizeof(OutType) <= sizeof(InType),
OutType>
fma(InType x, InType y, InType z) {
using OutFPBits = fputil::FPBits<OutType>;
using OutStorageType = typename OutFPBits::StorageType;
using InFPBits = fputil::FPBits<InType>;
using InStorageType = typename InFPBits::StorageType;

constexpr int IN_EXPLICIT_MANT_LEN = InFPBits::FRACTION_LEN + 1;
constexpr size_t PROD_LEN = 2 * (IN_EXPLICIT_MANT_LEN);
constexpr size_t TMP_RESULT_LEN = cpp::bit_ceil(PROD_LEN + 1);
using TmpResultType = UInt<TMP_RESULT_LEN>;

constexpr size_t EXTRA_FRACTION_LEN =
TMP_RESULT_LEN - 1 - OutFPBits::FRACTION_LEN;
constexpr TmpResultType EXTRA_FRACTION_STICKY_MASK =
(TmpResultType(1) << (EXTRA_FRACTION_LEN - 1)) - 1;

if (LIBC_UNLIKELY(x == 0 || y == 0 || z == 0)) {
return x * y + z;
return static_cast<OutType>(x * y + z);
}

int x_exp = 0;
int y_exp = 0;
int z_exp = 0;

// Normalize denormal inputs.
if (LIBC_UNLIKELY(FPBits(x).is_subnormal())) {
x_exp -= 52;
x *= 0x1.0p+52;
if (LIBC_UNLIKELY(InFPBits(x).is_subnormal())) {
x_exp -= InFPBits::FRACTION_LEN;
x *= InType(InStorageType(1) << InFPBits::FRACTION_LEN);
}
if (LIBC_UNLIKELY(FPBits(y).is_subnormal())) {
y_exp -= 52;
y *= 0x1.0p+52;
if (LIBC_UNLIKELY(InFPBits(y).is_subnormal())) {
y_exp -= InFPBits::FRACTION_LEN;
y *= InType(InStorageType(1) << InFPBits::FRACTION_LEN);
}
if (LIBC_UNLIKELY(FPBits(z).is_subnormal())) {
z_exp -= 52;
z *= 0x1.0p+52;
if (LIBC_UNLIKELY(InFPBits(z).is_subnormal())) {
z_exp -= InFPBits::FRACTION_LEN;
z *= InType(InStorageType(1) << InFPBits::FRACTION_LEN);
}

FPBits x_bits(x), y_bits(y), z_bits(z);
InFPBits x_bits(x), y_bits(y), z_bits(z);
const Sign z_sign = z_bits.sign();
Sign prod_sign = (x_bits.sign() == y_bits.sign()) ? Sign::POS : Sign::NEG;
x_exp += x_bits.get_biased_exponent();
y_exp += y_bits.get_biased_exponent();
z_exp += z_bits.get_biased_exponent();

if (LIBC_UNLIKELY(x_exp == FPBits::MAX_BIASED_EXPONENT ||
y_exp == FPBits::MAX_BIASED_EXPONENT ||
z_exp == FPBits::MAX_BIASED_EXPONENT))
return x * y + z;
if (LIBC_UNLIKELY(x_exp == InFPBits::MAX_BIASED_EXPONENT ||
y_exp == InFPBits::MAX_BIASED_EXPONENT ||
z_exp == InFPBits::MAX_BIASED_EXPONENT))
return static_cast<OutType>(x * y + z);

// Extract mantissa and append hidden leading bits.
UInt128 x_mant = x_bits.get_explicit_mantissa();
UInt128 y_mant = y_bits.get_explicit_mantissa();
UInt128 z_mant = z_bits.get_explicit_mantissa();
InStorageType x_mant = x_bits.get_explicit_mantissa();
InStorageType y_mant = y_bits.get_explicit_mantissa();
TmpResultType z_mant = z_bits.get_explicit_mantissa();

// If the exponent of the product x*y > the exponent of z, then no extra
// precision beside the entire product x*y is needed. On the other hand, when
Expand All @@ -144,21 +173,24 @@ template <> LIBC_INLINE double fma<double>(double x, double y, double z) {
// - prod : 1bb...bb....b
// In that case, in order to store the exact result, we need at least
// (Length of prod) - (MantissaLength of z) = 2*(52 + 1) - 52 = 54.
// TODO: 53? (Explicit mantissa.) ^
// Overall, before aligning the mantissas and exponents, we can simply left-
// shift the mantissa of z by at least 54, and left-shift the product of x*y
// by (that amount - 52). After that, it is enough to align the least
// TODO: ^ 54?
// significant bit, given that we keep track of the round and sticky bits
// after the least significant bit.
// We pick shifting z_mant by 64 bits so that technically we can simply use
// the original mantissa as high part when constructing 128-bit z_mant. So the
// mantissa of prod will be left-shifted by 64 - 54 = 10 initially.

UInt128 prod_mant = x_mant * y_mant << 10;
TmpResultType prod_mant = TmpResultType(x_mant) * y_mant;
int prod_lsb_exp =
x_exp + y_exp - (FPBits::EXP_BIAS + 2 * FPBits::FRACTION_LEN + 10);
x_exp + y_exp - (InFPBits::EXP_BIAS + 2 * InFPBits::FRACTION_LEN);

z_mant <<= 64;
int z_lsb_exp = z_exp - (FPBits::FRACTION_LEN + 64);
constexpr int RESULT_MIN_LEN = PROD_LEN - InFPBits::FRACTION_LEN;
z_mant <<= RESULT_MIN_LEN;
int z_lsb_exp = z_exp - (InFPBits::FRACTION_LEN + RESULT_MIN_LEN);
bool round_bit = false;
bool sticky_bits = false;
bool z_shifted = false;
Expand Down Expand Up @@ -198,46 +230,40 @@ template <> LIBC_INLINE double fma<double>(double x, double y, double z) {
}
}

uint64_t result = 0;
OutStorageType result = 0;
int r_exp = 0; // Unbiased exponent of the result

// Normalize the result.
if (prod_mant != 0) {
uint64_t prod_hi = static_cast<uint64_t>(prod_mant >> 64);
int lead_zeros =
prod_hi ? cpp::countl_zero(prod_hi)
: 64 + cpp::countl_zero(static_cast<uint64_t>(prod_mant));
int lead_zeros = cpp::countl_zero(prod_mant);
// Move the leading 1 to the most significant bit.
prod_mant <<= lead_zeros;
// The lower 64 bits are always sticky bits after moving the leading 1 to
// the most significant bit.
sticky_bits |= (static_cast<uint64_t>(prod_mant) != 0);
result = static_cast<uint64_t>(prod_mant >> 64);
// Change prod_lsb_exp the be the exponent of the least significant bit of
// the result.
prod_lsb_exp += 64 - lead_zeros;
r_exp = prod_lsb_exp + 63;
prod_lsb_exp -= lead_zeros;
r_exp = prod_lsb_exp + (cpp::numeric_limits<TmpResultType>::digits - 1) -
InFPBits::EXP_BIAS + OutFPBits::EXP_BIAS;

if (r_exp > 0) {
// The result is normal. We will shift the mantissa to the right by
// 63 - 52 = 11 bits (from the locations of the most significant bit).
// Then the rounding bit will correspond the 11th bit, and the lowest
// 10 bits are merged into sticky bits.
round_bit = (result & 0x0400ULL) != 0;
sticky_bits |= (result & 0x03ffULL) != 0;
result >>= 11;
round_bit =
(prod_mant & (TmpResultType(1) << (EXTRA_FRACTION_LEN - 1))) != 0;
sticky_bits |= (prod_mant & EXTRA_FRACTION_STICKY_MASK) != 0;
result = static_cast<OutStorageType>(prod_mant >> EXTRA_FRACTION_LEN);
} else {
if (r_exp < -52) {
if (r_exp < -OutFPBits::FRACTION_LEN) {
// The result is smaller than 1/2 of the smallest denormal number.
sticky_bits = true; // since the result is non-zero.
result = 0;
} else {
// The result is denormal.
uint64_t mask = 1ULL << (11 - r_exp);
round_bit = (result & mask) != 0;
sticky_bits |= (result & (mask - 1)) != 0;
if (r_exp > -52)
result >>= 12 - r_exp;
TmpResultType mask = TmpResultType(1) << (EXTRA_FRACTION_LEN - r_exp);
round_bit = (prod_mant & mask) != 0;
sticky_bits |= (prod_mant & (mask - 1)) != 0;
if (r_exp > -OutFPBits::FRACTION_LEN)
result = static_cast<OutStorageType>(
prod_mant >> (EXTRA_FRACTION_LEN + 1 - r_exp));
else
result = 0;
}
Expand All @@ -251,20 +277,21 @@ template <> LIBC_INLINE double fma<double>(double x, double y, double z) {

// Finalize the result.
int round_mode = fputil::quick_get_round();
if (LIBC_UNLIKELY(r_exp >= FPBits::MAX_BIASED_EXPONENT)) {
if (LIBC_UNLIKELY(r_exp >= OutFPBits::MAX_BIASED_EXPONENT)) {
if ((round_mode == FE_TOWARDZERO) ||
(round_mode == FE_UPWARD && prod_sign.is_neg()) ||
(round_mode == FE_DOWNWARD && prod_sign.is_pos())) {
return FPBits::max_normal(prod_sign).get_val();
return OutFPBits::max_normal(prod_sign).get_val();
}
return FPBits::inf(prod_sign).get_val();
return OutFPBits::inf(prod_sign).get_val();
}

// Remove hidden bit and append the exponent field and sign bit.
result = (result & FPBits::FRACTION_MASK) |
(static_cast<uint64_t>(r_exp) << FPBits::FRACTION_LEN);
result = static_cast<OutStorageType>(
(result & OutFPBits::FRACTION_MASK) |
(static_cast<OutStorageType>(r_exp) << OutFPBits::FRACTION_LEN));
if (prod_sign.is_neg()) {
result |= FPBits::SIGN_MASK;
result |= OutFPBits::SIGN_MASK;
}

// Rounding.
Expand All @@ -277,7 +304,7 @@ template <> LIBC_INLINE double fma<double>(double x, double y, double z) {
++result;
}

return cpp::bit_cast<double>(result);
return cpp::bit_cast<OutType>(result);
}

} // namespace generic
Expand Down
4 changes: 2 additions & 2 deletions libc/src/__support/FPUtil/multiply_add.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ namespace LIBC_NAMESPACE {
namespace fputil {

LIBC_INLINE float multiply_add(float x, float y, float z) {
return fma(x, y, z);
return fma<float>(x, y, z);
}

LIBC_INLINE double multiply_add(double x, double y, double z) {
return fma(x, y, z);
return fma<double>(x, y, z);
}

} // namespace fputil
Expand Down
Loading
Loading