Skip to content

[libc][math][c23] Add f16fmaf C23 math function #95483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 14, 2024

Conversation

overmighty
Copy link
Member

Part of #93566.

@overmighty
Copy link
Member Author

cc @lntue

@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2024

@llvm/pr-subscribers-libc

Author: OverMighty (overmighty)

Changes

Part of #93566.


Patch is 54.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95483.diff

27 Files Affected:

  • (modified) libc/config/linux/aarch64/entrypoints.txt (+1)
  • (modified) libc/config/linux/x86_64/entrypoints.txt (+1)
  • (modified) libc/docs/math/index.rst (+2)
  • (modified) libc/spec/stdc.td (+2)
  • (modified) libc/src/__support/FPUtil/FMA.h (+10-22)
  • (modified) libc/src/__support/FPUtil/generic/FMA.h (+86-59)
  • (modified) libc/src/__support/FPUtil/multiply_add.h (+2-2)
  • (modified) libc/src/__support/big_int.h (+16)
  • (modified) libc/src/math/CMakeLists.txt (+2)
  • (added) libc/src/math/f16fmaf.h (+20)
  • (modified) libc/src/math/generic/CMakeLists.txt (+12)
  • (modified) libc/src/math/generic/expm1f.cpp (+1-1)
  • (added) libc/src/math/generic/f16fmaf.cpp (+19)
  • (modified) libc/src/math/generic/fma.cpp (+1-1)
  • (modified) libc/src/math/generic/fmaf.cpp (+1-1)
  • (modified) libc/src/math/generic/range_reduction_fma.h (+13-12)
  • (modified) libc/test/src/math/CMakeLists.txt (+19-2)
  • (modified) libc/test/src/math/FmaTest.h (+88-59)
  • (added) libc/test/src/math/f16fmaf_test.cpp (+25)
  • (modified) libc/test/src/math/smoke/CMakeLists.txt (+16-2)
  • (modified) libc/test/src/math/smoke/FmaTest.h (+65-31)
  • (added) libc/test/src/math/smoke/f16fmaf_test.cpp (+13)
  • (modified) libc/test/src/math/smoke/fma_test.cpp (+1-5)
  • (modified) libc/test/src/math/smoke/fmaf_test.cpp (+1-5)
  • (modified) libc/test/src/sched/CMakeLists.txt (+17-17)
  • (modified) libc/utils/MPFRWrapper/MPFRUtils.cpp (+46-32)
  • (modified) libc/utils/MPFRWrapper/MPFRUtils.h (+33-15)
diff --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt
index 2b2d0985a8992..dab747ca0ac74 100644
--- a/libc/config/linux/aarch64/entrypoints.txt
+++ b/libc/config/linux/aarch64/entrypoints.txt
@@ -503,6 +503,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
     libc.src.math.canonicalizef16
     libc.src.math.ceilf16
     libc.src.math.copysignf16
+    libc.src.math.f16fmaf
     libc.src.math.f16sqrtf
     libc.src.math.fabsf16
     libc.src.math.fdimf16
diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt
index 2d36ca296c3a4..45914fe9f7ad2 100644
--- a/libc/config/linux/x86_64/entrypoints.txt
+++ b/libc/config/linux/x86_64/entrypoints.txt
@@ -535,6 +535,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
     libc.src.math.canonicalizef16
     libc.src.math.ceilf16
     libc.src.math.copysignf16
+    libc.src.math.f16fmaf
     libc.src.math.f16sqrtf
     libc.src.math.fabsf16
     libc.src.math.fdimf16
diff --git a/libc/docs/math/index.rst b/libc/docs/math/index.rst
index 790786147c164..293edd1c15100 100644
--- a/libc/docs/math/index.rst
+++ b/libc/docs/math/index.rst
@@ -124,6 +124,8 @@ Basic Operations
 +------------------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | dsub             | N/A              | N/A             |                        | N/A                  |                        | 7.12.14.2              | F.10.11                    |
 +------------------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
+| f16fma           | |check|          |                 |                        | N/A                  |                        | 7.12.14.5              | F.10.11                    |
++------------------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | fabs             | |check|          | |check|         | |check|                | |check|              | |check|                | 7.12.7.3               | F.10.4.3                   |
 +------------------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | fadd             | N/A              |                 |                        | N/A                  |                        | 7.12.14.1              | F.10.11                    |
diff --git a/libc/spec/stdc.td b/libc/spec/stdc.td
index 7c4135032a0b2..b089b596b0958 100644
--- a/libc/spec/stdc.td
+++ b/libc/spec/stdc.td
@@ -715,6 +715,8 @@ def StdC : StandardSpec<"stdc"> {
 
           GuardedFunctionSpec<"totalordermagf16", RetValSpec<IntType>, [ArgSpec<Float16Ptr>, ArgSpec<Float16Ptr>], "LIBC_TYPES_HAS_FLOAT16">,
 
+          GuardedFunctionSpec<"f16fmaf", RetValSpec<Float16Type>, [ArgSpec<FloatType>, ArgSpec<FloatType>, ArgSpec<FloatType>], "LIBC_TYPES_HAS_FLOAT16">,
+
           GuardedFunctionSpec<"f16sqrtf", RetValSpec<Float16Type>, [ArgSpec<FloatType>], "LIBC_TYPES_HAS_FLOAT16">,
       ]
   >;
diff --git a/libc/src/__support/FPUtil/FMA.h b/libc/src/__support/FPUtil/FMA.h
index c277da49538bf..cf01a317d7359 100644
--- a/libc/src/__support/FPUtil/FMA.h
+++ b/libc/src/__support/FPUtil/FMA.h
@@ -10,41 +10,29 @@
 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_FMA_H
 
 #include "src/__support/CPP/type_traits.h"
+#include "src/__support/FPUtil/generic/FMA.h"
 #include "src/__support/macros/properties/architectures.h"
 #include "src/__support/macros/properties/cpu_features.h" // LIBC_TARGET_CPU_HAS_FMA
 
-#if defined(LIBC_TARGET_CPU_HAS_FMA)
-
 namespace LIBC_NAMESPACE {
 namespace fputil {
 
-template <typename T>
-LIBC_INLINE cpp::enable_if_t<cpp::is_same_v<T, float>, T> fma(T x, T y, T z) {
-  return __builtin_fmaf(x, y, z);
+template <typename OutType, typename InType>
+LIBC_INLINE OutType fma(InType x, InType y, InType z) {
+  return generic::fma<OutType>(x, y, z);
 }
 
-template <typename T>
-LIBC_INLINE cpp::enable_if_t<cpp::is_same_v<T, double>, T> fma(T x, T y, T z) {
-  return __builtin_fma(x, y, z);
+#ifdef LIBC_TARGET_CPU_HAS_FMA
+template <> LIBC_INLINE float fma(float x, float y, float z) {
+  return __builtin_fmaf(x, y, z);
 }
 
-} // namespace fputil
-} // namespace LIBC_NAMESPACE
-
-#else
-// FMA instructions are not available
-#include "generic/FMA.h"
-
-namespace LIBC_NAMESPACE {
-namespace fputil {
-
-template <typename T> LIBC_INLINE T fma(T x, T y, T z) {
-  return generic::fma(x, y, z);
+template <> LIBC_INLINE double fma(double x, double y, double z) {
+  return __builtin_fma(x, y, z);
 }
+#endif // LIBC_TARGET_CPU_HAS_FMA
 
 } // namespace fputil
 } // namespace LIBC_NAMESPACE
 
-#endif
-
 #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_FMA_H
diff --git a/libc/src/__support/FPUtil/generic/FMA.h b/libc/src/__support/FPUtil/generic/FMA.h
index f403aa7333b39..3bbb30476e5d8 100644
--- a/libc/src/__support/FPUtil/generic/FMA.h
+++ b/libc/src/__support/FPUtil/generic/FMA.h
@@ -10,20 +10,28 @@
 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H
 
 #include "src/__support/CPP/bit.h"
+#include "src/__support/CPP/limits.h"
 #include "src/__support/CPP/type_traits.h"
-#include "src/__support/FPUtil/FEnvImpl.h"
 #include "src/__support/FPUtil/FPBits.h"
 #include "src/__support/FPUtil/rounding_mode.h"
+#include "src/__support/big_int.h"
 #include "src/__support/macros/attributes.h"   // LIBC_INLINE
 #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
-#include "src/__support/uint128.h"
+
+#include "hdr/fenv_macros.h"
 
 namespace LIBC_NAMESPACE {
 namespace fputil {
 namespace generic {
 
-template <typename T> LIBC_INLINE T fma(T x, T y, T z);
+template <typename OutType, typename InType>
+LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
+                                 cpp::is_floating_point_v<InType> &&
+                                 sizeof(OutType) <= sizeof(InType),
+                             OutType>
+fma(InType x, InType y, InType z);
 
+#ifndef LIBC_TARGET_CPU_HAS_FMA
 // TODO(lntue): Implement fmaf that is correctly rounded to all rounding modes.
 // The implementation below only is only correct for the default rounding mode,
 // round-to-nearest tie-to-even.
@@ -74,17 +82,20 @@ template <> LIBC_INLINE float fma<float>(float x, float y, float z) {
 
   return static_cast<float>(bit_sum.get_val());
 }
+#endif // LIBC_TARGET_CPU_HAS_FMA
 
 namespace internal {
 
 // Extract the sticky bits and shift the `mantissa` to the right by
 // `shift_length`.
-LIBC_INLINE bool shift_mantissa(int shift_length, UInt128 &mant) {
-  if (shift_length >= 128) {
+template <typename T>
+LIBC_INLINE cpp::enable_if_t<is_unsigned_integral_or_big_int_v<T>, bool>
+shift_mantissa(int shift_length, T &mant) {
+  if (shift_length >= cpp::numeric_limits<T>::digits) {
     mant = 0;
     return true; // prod_mant is non-zero.
   }
-  UInt128 mask = (UInt128(1) << shift_length) - 1;
+  T mask = (T(1) << shift_length) - 1;
   bool sticky_bits = (mant & mask) != 0;
   mant >>= shift_length;
   return sticky_bits;
@@ -92,11 +103,29 @@ LIBC_INLINE bool shift_mantissa(int shift_length, UInt128 &mant) {
 
 } // namespace internal
 
-template <> LIBC_INLINE double fma<double>(double x, double y, double z) {
-  using FPBits = fputil::FPBits<double>;
+template <typename OutType, typename InType>
+LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
+                                 cpp::is_floating_point_v<InType> &&
+                                 sizeof(OutType) <= sizeof(InType),
+                             OutType>
+fma(InType x, InType y, InType z) {
+  using OutFPBits = fputil::FPBits<OutType>;
+  using OutStorageType = typename OutFPBits::StorageType;
+  using InFPBits = fputil::FPBits<InType>;
+  using InStorageType = typename InFPBits::StorageType;
+
+  constexpr int IN_EXPLICIT_MANT_LEN = InFPBits::FRACTION_LEN + 1;
+  constexpr size_t PROD_LEN = 2 * (IN_EXPLICIT_MANT_LEN);
+  constexpr size_t TMP_RESULT_LEN = cpp::bit_ceil(PROD_LEN + 1);
+  using TmpResultType = UInt<TMP_RESULT_LEN>;
+
+  constexpr size_t EXTRA_FRACTION_LEN =
+      TMP_RESULT_LEN - 1 - OutFPBits::FRACTION_LEN;
+  constexpr TmpResultType EXTRA_FRACTION_STICKY_MASK =
+      (TmpResultType(1) << (EXTRA_FRACTION_LEN - 1)) - 1;
 
   if (LIBC_UNLIKELY(x == 0 || y == 0 || z == 0)) {
-    return x * y + z;
+    return static_cast<OutType>(x * y + z);
   }
 
   int x_exp = 0;
@@ -104,35 +133,35 @@ template <> LIBC_INLINE double fma<double>(double x, double y, double z) {
   int z_exp = 0;
 
   // Normalize denormal inputs.
-  if (LIBC_UNLIKELY(FPBits(x).is_subnormal())) {
-    x_exp -= 52;
-    x *= 0x1.0p+52;
+  if (LIBC_UNLIKELY(InFPBits(x).is_subnormal())) {
+    x_exp -= InFPBits::FRACTION_LEN;
+    x *= InType(InStorageType(1) << InFPBits::FRACTION_LEN);
   }
-  if (LIBC_UNLIKELY(FPBits(y).is_subnormal())) {
-    y_exp -= 52;
-    y *= 0x1.0p+52;
+  if (LIBC_UNLIKELY(InFPBits(y).is_subnormal())) {
+    y_exp -= InFPBits::FRACTION_LEN;
+    y *= InType(InStorageType(1) << InFPBits::FRACTION_LEN);
   }
-  if (LIBC_UNLIKELY(FPBits(z).is_subnormal())) {
-    z_exp -= 52;
-    z *= 0x1.0p+52;
+  if (LIBC_UNLIKELY(InFPBits(z).is_subnormal())) {
+    z_exp -= InFPBits::FRACTION_LEN;
+    z *= InType(InStorageType(1) << InFPBits::FRACTION_LEN);
   }
 
-  FPBits x_bits(x), y_bits(y), z_bits(z);
+  InFPBits x_bits(x), y_bits(y), z_bits(z);
   const Sign z_sign = z_bits.sign();
   Sign prod_sign = (x_bits.sign() == y_bits.sign()) ? Sign::POS : Sign::NEG;
   x_exp += x_bits.get_biased_exponent();
   y_exp += y_bits.get_biased_exponent();
   z_exp += z_bits.get_biased_exponent();
 
-  if (LIBC_UNLIKELY(x_exp == FPBits::MAX_BIASED_EXPONENT ||
-                    y_exp == FPBits::MAX_BIASED_EXPONENT ||
-                    z_exp == FPBits::MAX_BIASED_EXPONENT))
-    return x * y + z;
+  if (LIBC_UNLIKELY(x_exp == InFPBits::MAX_BIASED_EXPONENT ||
+                    y_exp == InFPBits::MAX_BIASED_EXPONENT ||
+                    z_exp == InFPBits::MAX_BIASED_EXPONENT))
+    return static_cast<OutType>(x * y + z);
 
   // Extract mantissa and append hidden leading bits.
-  UInt128 x_mant = x_bits.get_explicit_mantissa();
-  UInt128 y_mant = y_bits.get_explicit_mantissa();
-  UInt128 z_mant = z_bits.get_explicit_mantissa();
+  InStorageType x_mant = x_bits.get_explicit_mantissa();
+  InStorageType y_mant = y_bits.get_explicit_mantissa();
+  TmpResultType z_mant = z_bits.get_explicit_mantissa();
 
   // If the exponent of the product x*y > the exponent of z, then no extra
   // precision beside the entire product x*y is needed.  On the other hand, when
@@ -144,21 +173,24 @@ template <> LIBC_INLINE double fma<double>(double x, double y, double z) {
   // - prod :     1bb...bb....b
   // In that case, in order to store the exact result, we need at least
   //   (Length of prod) - (MantissaLength of z) = 2*(52 + 1) - 52 = 54.
+  // TODO:                            53? (Explicit mantissa.) ^
   // Overall, before aligning the mantissas and exponents, we can simply left-
   // shift the mantissa of z by at least 54, and left-shift the product of x*y
   // by (that amount - 52).  After that, it is enough to align the least
+  // TODO:             ^ 54?
   // significant bit, given that we keep track of the round and sticky bits
   // after the least significant bit.
   // We pick shifting z_mant by 64 bits so that technically we can simply use
   // the original mantissa as high part when constructing 128-bit z_mant. So the
   // mantissa of prod will be left-shifted by 64 - 54 = 10 initially.
 
-  UInt128 prod_mant = x_mant * y_mant << 10;
+  TmpResultType prod_mant = TmpResultType(x_mant) * y_mant;
   int prod_lsb_exp =
-      x_exp + y_exp - (FPBits::EXP_BIAS + 2 * FPBits::FRACTION_LEN + 10);
+      x_exp + y_exp - (InFPBits::EXP_BIAS + 2 * InFPBits::FRACTION_LEN);
 
-  z_mant <<= 64;
-  int z_lsb_exp = z_exp - (FPBits::FRACTION_LEN + 64);
+  constexpr int RESULT_MIN_LEN = PROD_LEN - InFPBits::FRACTION_LEN;
+  z_mant <<= RESULT_MIN_LEN;
+  int z_lsb_exp = z_exp - (InFPBits::FRACTION_LEN + RESULT_MIN_LEN);
   bool round_bit = false;
   bool sticky_bits = false;
   bool z_shifted = false;
@@ -198,46 +230,40 @@ template <> LIBC_INLINE double fma<double>(double x, double y, double z) {
     }
   }
 
-  uint64_t result = 0;
+  OutStorageType result = 0;
   int r_exp = 0; // Unbiased exponent of the result
 
   // Normalize the result.
   if (prod_mant != 0) {
-    uint64_t prod_hi = static_cast<uint64_t>(prod_mant >> 64);
-    int lead_zeros =
-        prod_hi ? cpp::countl_zero(prod_hi)
-                : 64 + cpp::countl_zero(static_cast<uint64_t>(prod_mant));
+    int lead_zeros = cpp::countl_zero(prod_mant);
     // Move the leading 1 to the most significant bit.
     prod_mant <<= lead_zeros;
-    // The lower 64 bits are always sticky bits after moving the leading 1 to
-    // the most significant bit.
-    sticky_bits |= (static_cast<uint64_t>(prod_mant) != 0);
-    result = static_cast<uint64_t>(prod_mant >> 64);
-    // Change prod_lsb_exp the be the exponent of the least significant bit of
-    // the result.
-    prod_lsb_exp += 64 - lead_zeros;
-    r_exp = prod_lsb_exp + 63;
+    prod_lsb_exp -= lead_zeros;
+    r_exp = prod_lsb_exp + (cpp::numeric_limits<TmpResultType>::digits - 1) -
+            InFPBits::EXP_BIAS + OutFPBits::EXP_BIAS;
 
     if (r_exp > 0) {
       // The result is normal.  We will shift the mantissa to the right by
       // 63 - 52 = 11 bits (from the locations of the most significant bit).
       // Then the rounding bit will correspond the 11th bit, and the lowest
       // 10 bits are merged into sticky bits.
-      round_bit = (result & 0x0400ULL) != 0;
-      sticky_bits |= (result & 0x03ffULL) != 0;
-      result >>= 11;
+      round_bit =
+          (prod_mant & (TmpResultType(1) << (EXTRA_FRACTION_LEN - 1))) != 0;
+      sticky_bits |= (prod_mant & EXTRA_FRACTION_STICKY_MASK) != 0;
+      result = static_cast<OutStorageType>(prod_mant >> EXTRA_FRACTION_LEN);
     } else {
-      if (r_exp < -52) {
+      if (r_exp < -OutFPBits::FRACTION_LEN) {
         // The result is smaller than 1/2 of the smallest denormal number.
         sticky_bits = true; // since the result is non-zero.
         result = 0;
       } else {
         // The result is denormal.
-        uint64_t mask = 1ULL << (11 - r_exp);
-        round_bit = (result & mask) != 0;
-        sticky_bits |= (result & (mask - 1)) != 0;
-        if (r_exp > -52)
-          result >>= 12 - r_exp;
+        TmpResultType mask = TmpResultType(1) << (EXTRA_FRACTION_LEN - r_exp);
+        round_bit = (prod_mant & mask) != 0;
+        sticky_bits |= (prod_mant & (mask - 1)) != 0;
+        if (r_exp > -OutFPBits::FRACTION_LEN)
+          result = static_cast<OutStorageType>(
+              prod_mant >> (EXTRA_FRACTION_LEN + 1 - r_exp));
         else
           result = 0;
       }
@@ -251,20 +277,21 @@ template <> LIBC_INLINE double fma<double>(double x, double y, double z) {
 
   // Finalize the result.
   int round_mode = fputil::quick_get_round();
-  if (LIBC_UNLIKELY(r_exp >= FPBits::MAX_BIASED_EXPONENT)) {
+  if (LIBC_UNLIKELY(r_exp >= OutFPBits::MAX_BIASED_EXPONENT)) {
     if ((round_mode == FE_TOWARDZERO) ||
         (round_mode == FE_UPWARD && prod_sign.is_neg()) ||
         (round_mode == FE_DOWNWARD && prod_sign.is_pos())) {
-      return FPBits::max_normal(prod_sign).get_val();
+      return OutFPBits::max_normal(prod_sign).get_val();
     }
-    return FPBits::inf(prod_sign).get_val();
+    return OutFPBits::inf(prod_sign).get_val();
   }
 
   // Remove hidden bit and append the exponent field and sign bit.
-  result = (result & FPBits::FRACTION_MASK) |
-           (static_cast<uint64_t>(r_exp) << FPBits::FRACTION_LEN);
+  result = static_cast<OutStorageType>(
+      (result & OutFPBits::FRACTION_MASK) |
+      (static_cast<OutStorageType>(r_exp) << OutFPBits::FRACTION_LEN));
   if (prod_sign.is_neg()) {
-    result |= FPBits::SIGN_MASK;
+    result |= OutFPBits::SIGN_MASK;
   }
 
   // Rounding.
@@ -277,7 +304,7 @@ template <> LIBC_INLINE double fma<double>(double x, double y, double z) {
       ++result;
   }
 
-  return cpp::bit_cast<double>(result);
+  return cpp::bit_cast<OutType>(result);
 }
 
 } // namespace generic
diff --git a/libc/src/__support/FPUtil/multiply_add.h b/libc/src/__support/FPUtil/multiply_add.h
index 82932da5417c8..622914e4265c9 100644
--- a/libc/src/__support/FPUtil/multiply_add.h
+++ b/libc/src/__support/FPUtil/multiply_add.h
@@ -45,11 +45,11 @@ namespace LIBC_NAMESPACE {
 namespace fputil {
 
 LIBC_INLINE float multiply_add(float x, float y, float z) {
-  return fma(x, y, z);
+  return fma<float>(x, y, z);
 }
 
 LIBC_INLINE double multiply_add(double x, double y, double z) {
-  return fma(x, y, z);
+  return fma<double>(x, y, z);
 }
 
 } // namespace fputil
diff --git a/libc/src/__support/big_int.h b/libc/src/__support/big_int.h
index 40ad6eeed7ac2..27cb3b783470c 100644
--- a/libc/src/__support/big_int.h
+++ b/libc/src/__support/big_int.h
@@ -983,6 +983,12 @@ using UInt = BigInt<Bits, false, internal::WordTypeSelectorT<Bits>>;
 template <size_t Bits>
 using Int = BigInt<Bits, true, internal::WordTypeSelectorT<Bits>>;
 
+// Provides limits of BigInt.
+template <size_t Bits, bool Signed, typename T>
+struct cpp::numeric_limits<BigInt<Bits, Signed, T>> {
+  LIBC_INLINE_VAR static constexpr int digits = Bits - Signed;
+};
+
 // Provides limits of U/Int<128>.
 template <> class cpp::numeric_limits<UInt<128>> {
 public:
@@ -1073,6 +1079,16 @@ template <typename T>
 using make_integral_or_big_int_signed_t =
     typename make_integral_or_big_int_signed<T>::type;
 
+// is_unsigned_integral_or_big_int
+template <typename T>
+struct is_unsigned_integral_or_big_int
+    : cpp::bool_constant<
+          cpp::is_same_v<T, make_integral_or_big_int_unsigned_t<T>>> {};
+
+template <typename T>
+LIBC_INLINE_VAR constexpr bool is_unsigned_integral_or_big_int_v =
+    is_unsigned_integral_or_big_int<T>::value;
+
 namespace cpp {
 
 // Specialization of cpp::bit_cast ('bit.h') from T to BigInt.
diff --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt
index df8e6c0b253da..4472367d6c073 100644
--- a/libc/src/math/CMakeLists.txt
+++ b/libc/src/math/CMakeLists.txt
@@ -99,6 +99,8 @@ add_math_entrypoint_object(exp10f)
 add_math_entrypoint_object(expm1)
 add_math_entrypoint_object(expm1f)
 
+add_math_entrypoint_object(f16fmaf)
+
 add_math_entrypoint_object(f16sqrtf)
 
 add_math_entrypoint_object(fabs)
diff --git a/libc/src/math/f16fmaf.h b/libc/src/math/f16fmaf.h
new file mode 100644
index 0000000000000..d92cb43c292eb
--- /dev/null
+++ b/libc/src/math/f16fmaf.h
@@ -0,0 +1,20 @@
+//===-- Implementation header for f16fmaf -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_F16FMAF_H
+#define LLVM_LIBC_SRC_MATH_F16FMAF_H
+
+#include "src/__support/macros/properties/types.h"
+
+namespace LIBC_NAMESPACE {
+
+float16 f16fmaf(float x, float y, float z);
+
+} // namespace LIBC_NAMESPACE
+
+#endif // LLVM_LIBC_SRC_MATH_F16FMAF_H
diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt
index f1f7d6c367be2..706bfc4b08670 100644
--- a/libc/src/math/generic/CMakeLists.txt
+++ b/libc/src/math/generic/CMakeLists.txt
@@ -3602,6 +3602,18 @@ add_entrypoint_object(
     -O3
 )
 
+add_entrypoint_object(
+  f16fmaf
+  SRCS
+    f16fmaf.cpp
+  HDRS
+    ../f16fmaf.h
+  DEPENDS
+    libc.src.__support.FPUtil.fma
+  COMPILE_OPTIONS
+    -O3
+)
+
 add_entrypoint_object(
   f16sqrtf
   SRCS
diff --git a/libc/src/math/generic/expm1f.cpp b/libc/src/math/generic/expm1f.cpp
index 037e60021b296..6b9f07476a650 100644
--- a/libc/src/math/generic/expm1f.cpp
+++ b/libc/src/math/generic/expm1f.cpp
@@ -104,7 +104,7 @@ LLVM_LIBC_FUNCTION(float, expm1f, (float x)) {
         // intermediate results as it is more efficient than using an emulated
         // version of...
[truncated]

Comment on lines 34 to 37
#ifndef LIBC_TARGET_CPU_HAS_FMA
// TODO(lntue): Implement fmaf that is correctly rounded to all rounding modes.
// The implementation below only is only correct for the default rounding mode,
// round-to-nearest tie-to-even.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use the type-generic fma from this PR to implement fmaf correctly rounded for all rounding modes, or is a specialized implementation that is both faster and correct for all rounding modes planned?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you double check if this implementation is correct for all rounding modes? I think the only thing that was blocking me to claim this one as correctly rounded to other rounding mode is that Dekker's 2Sum might not be exact in other rounding modes. But a recent paper by Paul Zimmermann shows that in other rounding modes, the errors from the 2Sum algorithm is still very close to double-double precision ULP. And in this case, it should mean exacts, since the precisions of the summands are 46 and 23 respectively.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It passes the fmaf MPFR unit test (libc.test.src.math.fmaf_test.__unit__).

Also, I just realized that the #ifndef LIBC_TARGET_CPU_HAS_FMA guard is useless, as this is fputil::generic::fma, not fputil::fma.

Comment on lines +48 to +49
OutConstants out;
InConstants in;
Copy link
Member Author

@overmighty overmighty Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't change the const constants defined by DECLARE_SPECIAL_CONSTANTS to static constexpr, or fmod* tests would fail (it might have been fmodf16 in particular, I only remember seeing FModTest.h in the console output). I instantiate OutConstants and InConstants as a workaround.

Comment on lines +290 to +293
(op == Operation::Sqrt && cpp::is_floating_point_v<InputType> &&
cpp::is_floating_point_v<OutputType> &&
sizeof(OutputType) <= sizeof(InputType)) ||
(op == Operation::Fma && internal::IsTernaryInput<InputType>::VALUE &&
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should just change this to check for operation ranges, like the code below does.

@lntue lntue merged commit f3aceee into llvm:main Jun 14, 2024
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants