-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[libc][math][c23] Add f16fmaf C23 math function #95483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
0f9a842
06e598e
05bc63c
b112ac4
7e7c354
e409c82
bd971fe
5dae03b
9d4dc22
eb0ff6b
f3128ee
40bd58f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,20 +10,28 @@ | |
#define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H | ||
|
||
#include "src/__support/CPP/bit.h" | ||
#include "src/__support/CPP/limits.h" | ||
#include "src/__support/CPP/type_traits.h" | ||
#include "src/__support/FPUtil/FEnvImpl.h" | ||
#include "src/__support/FPUtil/FPBits.h" | ||
#include "src/__support/FPUtil/rounding_mode.h" | ||
#include "src/__support/big_int.h" | ||
#include "src/__support/macros/attributes.h" // LIBC_INLINE | ||
#include "src/__support/macros/optimization.h" // LIBC_UNLIKELY | ||
#include "src/__support/uint128.h" | ||
|
||
#include "hdr/fenv_macros.h" | ||
overmighty marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
namespace LIBC_NAMESPACE { | ||
namespace fputil { | ||
namespace generic { | ||
|
||
template <typename T> LIBC_INLINE T fma(T x, T y, T z); | ||
template <typename OutType, typename InType> | ||
LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> && | ||
cpp::is_floating_point_v<InType> && | ||
sizeof(OutType) <= sizeof(InType), | ||
OutType> | ||
fma(InType x, InType y, InType z); | ||
|
||
#ifndef LIBC_TARGET_CPU_HAS_FMA | ||
// TODO(lntue): Implement fmaf that is correctly rounded to all rounding modes. | ||
// The implementation below only is only correct for the default rounding mode, | ||
// round-to-nearest tie-to-even. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use the type-generic There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you double check if this implementation is correct for all rounding modes? I think the only thing that was blocking me to claim this one as correctly rounded to other rounding mode is that Dekker's 2Sum might not be exact in other rounding modes. But a recent paper by Paul Zimmermann shows that in other rounding modes, the errors from the 2Sum algorithm is still very close to double-double precision ULP. And in this case, it should mean exacts, since the precisions of the summands are 46 and 23 respectively. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It passes the Also, I just realized that the |
||
|
@@ -74,65 +82,86 @@ template <> LIBC_INLINE float fma<float>(float x, float y, float z) { | |
|
||
return static_cast<float>(bit_sum.get_val()); | ||
} | ||
#endif // LIBC_TARGET_CPU_HAS_FMA | ||
|
||
namespace internal { | ||
|
||
// Extract the sticky bits and shift the `mantissa` to the right by | ||
// `shift_length`. | ||
LIBC_INLINE bool shift_mantissa(int shift_length, UInt128 &mant) { | ||
if (shift_length >= 128) { | ||
template <typename T> | ||
LIBC_INLINE cpp::enable_if_t<is_unsigned_integral_or_big_int_v<T>, bool> | ||
shift_mantissa(int shift_length, T &mant) { | ||
if (shift_length >= cpp::numeric_limits<T>::digits) { | ||
mant = 0; | ||
return true; // prod_mant is non-zero. | ||
} | ||
UInt128 mask = (UInt128(1) << shift_length) - 1; | ||
T mask = (T(1) << shift_length) - 1; | ||
bool sticky_bits = (mant & mask) != 0; | ||
mant >>= shift_length; | ||
return sticky_bits; | ||
} | ||
|
||
} // namespace internal | ||
|
||
template <> LIBC_INLINE double fma<double>(double x, double y, double z) { | ||
using FPBits = fputil::FPBits<double>; | ||
template <typename OutType, typename InType> | ||
LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> && | ||
cpp::is_floating_point_v<InType> && | ||
sizeof(OutType) <= sizeof(InType), | ||
OutType> | ||
fma(InType x, InType y, InType z) { | ||
using OutFPBits = fputil::FPBits<OutType>; | ||
using OutStorageType = typename OutFPBits::StorageType; | ||
using InFPBits = fputil::FPBits<InType>; | ||
using InStorageType = typename InFPBits::StorageType; | ||
|
||
constexpr int IN_EXPLICIT_MANT_LEN = InFPBits::FRACTION_LEN + 1; | ||
constexpr size_t PROD_LEN = 2 * (IN_EXPLICIT_MANT_LEN); | ||
overmighty marked this conversation as resolved.
Show resolved
Hide resolved
|
||
constexpr size_t TMP_RESULT_LEN = cpp::bit_ceil(PROD_LEN + 1); | ||
using TmpResultType = UInt<TMP_RESULT_LEN>; | ||
|
||
constexpr size_t EXTRA_FRACTION_LEN = | ||
TMP_RESULT_LEN - 1 - OutFPBits::FRACTION_LEN; | ||
constexpr TmpResultType EXTRA_FRACTION_STICKY_MASK = | ||
(TmpResultType(1) << (EXTRA_FRACTION_LEN - 1)) - 1; | ||
|
||
if (LIBC_UNLIKELY(x == 0 || y == 0 || z == 0)) { | ||
return x * y + z; | ||
return static_cast<OutType>(x * y + z); | ||
} | ||
|
||
int x_exp = 0; | ||
int y_exp = 0; | ||
int z_exp = 0; | ||
|
||
// Normalize denormal inputs. | ||
if (LIBC_UNLIKELY(FPBits(x).is_subnormal())) { | ||
x_exp -= 52; | ||
x *= 0x1.0p+52; | ||
if (LIBC_UNLIKELY(InFPBits(x).is_subnormal())) { | ||
x_exp -= InFPBits::FRACTION_LEN; | ||
x *= InType(InStorageType(1) << InFPBits::FRACTION_LEN); | ||
} | ||
if (LIBC_UNLIKELY(FPBits(y).is_subnormal())) { | ||
y_exp -= 52; | ||
y *= 0x1.0p+52; | ||
if (LIBC_UNLIKELY(InFPBits(y).is_subnormal())) { | ||
y_exp -= InFPBits::FRACTION_LEN; | ||
y *= InType(InStorageType(1) << InFPBits::FRACTION_LEN); | ||
} | ||
if (LIBC_UNLIKELY(FPBits(z).is_subnormal())) { | ||
z_exp -= 52; | ||
z *= 0x1.0p+52; | ||
if (LIBC_UNLIKELY(InFPBits(z).is_subnormal())) { | ||
z_exp -= InFPBits::FRACTION_LEN; | ||
z *= InType(InStorageType(1) << InFPBits::FRACTION_LEN); | ||
} | ||
|
||
FPBits x_bits(x), y_bits(y), z_bits(z); | ||
InFPBits x_bits(x), y_bits(y), z_bits(z); | ||
const Sign z_sign = z_bits.sign(); | ||
Sign prod_sign = (x_bits.sign() == y_bits.sign()) ? Sign::POS : Sign::NEG; | ||
x_exp += x_bits.get_biased_exponent(); | ||
y_exp += y_bits.get_biased_exponent(); | ||
z_exp += z_bits.get_biased_exponent(); | ||
|
||
if (LIBC_UNLIKELY(x_exp == FPBits::MAX_BIASED_EXPONENT || | ||
y_exp == FPBits::MAX_BIASED_EXPONENT || | ||
z_exp == FPBits::MAX_BIASED_EXPONENT)) | ||
return x * y + z; | ||
if (LIBC_UNLIKELY(x_exp == InFPBits::MAX_BIASED_EXPONENT || | ||
y_exp == InFPBits::MAX_BIASED_EXPONENT || | ||
z_exp == InFPBits::MAX_BIASED_EXPONENT)) | ||
return static_cast<OutType>(x * y + z); | ||
|
||
// Extract mantissa and append hidden leading bits. | ||
UInt128 x_mant = x_bits.get_explicit_mantissa(); | ||
UInt128 y_mant = y_bits.get_explicit_mantissa(); | ||
UInt128 z_mant = z_bits.get_explicit_mantissa(); | ||
InStorageType x_mant = x_bits.get_explicit_mantissa(); | ||
InStorageType y_mant = y_bits.get_explicit_mantissa(); | ||
TmpResultType z_mant = z_bits.get_explicit_mantissa(); | ||
|
||
// If the exponent of the product x*y > the exponent of z, then no extra | ||
// precision beside the entire product x*y is needed. On the other hand, when | ||
|
@@ -144,21 +173,24 @@ template <> LIBC_INLINE double fma<double>(double x, double y, double z) { | |
// - prod : 1bb...bb....b | ||
// In that case, in order to store the exact result, we need at least | ||
// (Length of prod) - (MantissaLength of z) = 2*(52 + 1) - 52 = 54. | ||
// TODO: 53? (Explicit mantissa.) ^ | ||
// Overall, before aligning the mantissas and exponents, we can simply left- | ||
// shift the mantissa of z by at least 54, and left-shift the product of x*y | ||
// by (that amount - 52). After that, it is enough to align the least | ||
// TODO: ^ 54? | ||
// significant bit, given that we keep track of the round and sticky bits | ||
// after the least significant bit. | ||
// We pick shifting z_mant by 64 bits so that technically we can simply use | ||
// the original mantissa as high part when constructing 128-bit z_mant. So the | ||
// mantissa of prod will be left-shifted by 64 - 54 = 10 initially. | ||
overmighty marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
UInt128 prod_mant = x_mant * y_mant << 10; | ||
TmpResultType prod_mant = TmpResultType(x_mant) * y_mant; | ||
int prod_lsb_exp = | ||
x_exp + y_exp - (FPBits::EXP_BIAS + 2 * FPBits::FRACTION_LEN + 10); | ||
x_exp + y_exp - (InFPBits::EXP_BIAS + 2 * InFPBits::FRACTION_LEN); | ||
|
||
z_mant <<= 64; | ||
int z_lsb_exp = z_exp - (FPBits::FRACTION_LEN + 64); | ||
constexpr int RESULT_MIN_LEN = PROD_LEN - InFPBits::FRACTION_LEN; | ||
z_mant <<= RESULT_MIN_LEN; | ||
int z_lsb_exp = z_exp - (InFPBits::FRACTION_LEN + RESULT_MIN_LEN); | ||
bool round_bit = false; | ||
bool sticky_bits = false; | ||
bool z_shifted = false; | ||
|
@@ -198,46 +230,40 @@ template <> LIBC_INLINE double fma<double>(double x, double y, double z) { | |
} | ||
} | ||
|
||
uint64_t result = 0; | ||
OutStorageType result = 0; | ||
int r_exp = 0; // Unbiased exponent of the result | ||
|
||
// Normalize the result. | ||
if (prod_mant != 0) { | ||
uint64_t prod_hi = static_cast<uint64_t>(prod_mant >> 64); | ||
int lead_zeros = | ||
prod_hi ? cpp::countl_zero(prod_hi) | ||
: 64 + cpp::countl_zero(static_cast<uint64_t>(prod_mant)); | ||
int lead_zeros = cpp::countl_zero(prod_mant); | ||
// Move the leading 1 to the most significant bit. | ||
prod_mant <<= lead_zeros; | ||
// The lower 64 bits are always sticky bits after moving the leading 1 to | ||
// the most significant bit. | ||
sticky_bits |= (static_cast<uint64_t>(prod_mant) != 0); | ||
result = static_cast<uint64_t>(prod_mant >> 64); | ||
// Change prod_lsb_exp the be the exponent of the least significant bit of | ||
// the result. | ||
prod_lsb_exp += 64 - lead_zeros; | ||
r_exp = prod_lsb_exp + 63; | ||
prod_lsb_exp -= lead_zeros; | ||
r_exp = prod_lsb_exp + (cpp::numeric_limits<TmpResultType>::digits - 1) - | ||
InFPBits::EXP_BIAS + OutFPBits::EXP_BIAS; | ||
|
||
if (r_exp > 0) { | ||
// The result is normal. We will shift the mantissa to the right by | ||
// 63 - 52 = 11 bits (from the locations of the most significant bit). | ||
// Then the rounding bit will correspond the 11th bit, and the lowest | ||
// 10 bits are merged into sticky bits. | ||
overmighty marked this conversation as resolved.
Show resolved
Hide resolved
|
||
round_bit = (result & 0x0400ULL) != 0; | ||
sticky_bits |= (result & 0x03ffULL) != 0; | ||
result >>= 11; | ||
round_bit = | ||
(prod_mant & (TmpResultType(1) << (EXTRA_FRACTION_LEN - 1))) != 0; | ||
sticky_bits |= (prod_mant & EXTRA_FRACTION_STICKY_MASK) != 0; | ||
result = static_cast<OutStorageType>(prod_mant >> EXTRA_FRACTION_LEN); | ||
} else { | ||
if (r_exp < -52) { | ||
if (r_exp < -OutFPBits::FRACTION_LEN) { | ||
// The result is smaller than 1/2 of the smallest denormal number. | ||
sticky_bits = true; // since the result is non-zero. | ||
result = 0; | ||
} else { | ||
// The result is denormal. | ||
uint64_t mask = 1ULL << (11 - r_exp); | ||
round_bit = (result & mask) != 0; | ||
sticky_bits |= (result & (mask - 1)) != 0; | ||
if (r_exp > -52) | ||
result >>= 12 - r_exp; | ||
TmpResultType mask = TmpResultType(1) << (EXTRA_FRACTION_LEN - r_exp); | ||
round_bit = (prod_mant & mask) != 0; | ||
sticky_bits |= (prod_mant & (mask - 1)) != 0; | ||
if (r_exp > -OutFPBits::FRACTION_LEN) | ||
result = static_cast<OutStorageType>( | ||
prod_mant >> (EXTRA_FRACTION_LEN + 1 - r_exp)); | ||
else | ||
result = 0; | ||
} | ||
|
@@ -251,20 +277,21 @@ template <> LIBC_INLINE double fma<double>(double x, double y, double z) { | |
|
||
overmighty marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Finalize the result. | ||
int round_mode = fputil::quick_get_round(); | ||
if (LIBC_UNLIKELY(r_exp >= FPBits::MAX_BIASED_EXPONENT)) { | ||
if (LIBC_UNLIKELY(r_exp >= OutFPBits::MAX_BIASED_EXPONENT)) { | ||
if ((round_mode == FE_TOWARDZERO) || | ||
(round_mode == FE_UPWARD && prod_sign.is_neg()) || | ||
(round_mode == FE_DOWNWARD && prod_sign.is_pos())) { | ||
return FPBits::max_normal(prod_sign).get_val(); | ||
return OutFPBits::max_normal(prod_sign).get_val(); | ||
} | ||
return FPBits::inf(prod_sign).get_val(); | ||
return OutFPBits::inf(prod_sign).get_val(); | ||
} | ||
|
||
// Remove hidden bit and append the exponent field and sign bit. | ||
result = (result & FPBits::FRACTION_MASK) | | ||
(static_cast<uint64_t>(r_exp) << FPBits::FRACTION_LEN); | ||
result = static_cast<OutStorageType>( | ||
(result & OutFPBits::FRACTION_MASK) | | ||
(static_cast<OutStorageType>(r_exp) << OutFPBits::FRACTION_LEN)); | ||
if (prod_sign.is_neg()) { | ||
result |= FPBits::SIGN_MASK; | ||
result |= OutFPBits::SIGN_MASK; | ||
} | ||
|
||
// Rounding. | ||
|
@@ -277,7 +304,7 @@ template <> LIBC_INLINE double fma<double>(double x, double y, double z) { | |
++result; | ||
} | ||
|
||
return cpp::bit_cast<double>(result); | ||
return cpp::bit_cast<OutType>(result); | ||
} | ||
|
||
} // namespace generic | ||
|
Uh oh!
There was an error while loading. Please reload this page.