|
8 | 8 |
|
9 | 9 | #include "src/__support/CPP/type_traits.h"
|
10 | 10 | #include "src/__support/FPUtil/FPBits.h"
|
| 11 | +#include "src/__support/macros/properties/types.h" |
11 | 12 | #include "test/UnitTest/FPMatcher.h"
|
12 | 13 | #include "test/UnitTest/Test.h"
|
13 | 14 | #include "utils/MPFRWrapper/MPFRUtils.h"
|
@@ -68,6 +69,43 @@ struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
|
68 | 69 | }
|
69 | 70 | };
|
70 | 71 |
|
| 72 | +template <typename T> using BinaryOp = T(T, T); |
| 73 | + |
| 74 | +template <typename T, mpfr::Operation Op, BinaryOp<T> Func> |
| 75 | +struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test { |
| 76 | + using FloatType = T; |
| 77 | + using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>; |
| 78 | + using StorageType = typename FPBits::StorageType; |
| 79 | + |
| 80 | + static constexpr BinaryOp<FloatType> *FUNC = Func; |
| 81 | + |
| 82 | + // Check in a range, return the number of failures. |
| 83 | + uint64_t check(StorageType x_start, StorageType x_stop, StorageType y_start, |
| 84 | + StorageType y_stop, mpfr::RoundingMode rounding) { |
| 85 | + mpfr::ForceRoundingMode r(rounding); |
| 86 | + if (!r.success) |
| 87 | + return (x_stop > x_start || y_stop > y_start); |
| 88 | + StorageType xbits = x_start; |
| 89 | + uint64_t failed = 0; |
| 90 | + do { |
| 91 | + FloatType x = FPBits(xbits).get_val(); |
| 92 | + StorageType ybits = y_start; |
| 93 | + do { |
| 94 | + FloatType y = FPBits(ybits).get_val(); |
| 95 | + mpfr::BinaryInput<FloatType> input{x, y}; |
| 96 | + bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, input, FUNC(x, y), |
| 97 | + 0.5, rounding); |
| 98 | + failed += (!correct); |
| 99 | + // Uncomment to print out failed values. |
| 100 | + // if (!correct) { |
| 101 | + // TEST_MPFR_MATCH(Op::Operation, x, Op::func(x, y), 0.5, rounding); |
| 102 | + // } |
| 103 | + } while (ybits++ < y_stop); |
| 104 | + } while (xbits++ < x_stop); |
| 105 | + return failed; |
| 106 | + } |
| 107 | +}; |
| 108 | + |
71 | 109 | // Checker class needs inherit from LIBC_NAMESPACE::testing::Test and provide
|
72 | 110 | // StorageType and check method.
|
73 | 111 | template <typename Checker>
|
@@ -167,6 +205,118 @@ struct LlvmLibcExhaustiveMathTest
|
167 | 205 | };
|
168 | 206 | };
|
169 | 207 |
|
| 208 | +template <typename Checker> |
| 209 | +struct LlvmLibcBinaryInputExhaustiveMathTest |
| 210 | + : public virtual LIBC_NAMESPACE::testing::Test, |
| 211 | + public Checker { |
| 212 | + using FloatType = typename Checker::FloatType; |
| 213 | + using FPBits = typename Checker::FPBits; |
| 214 | + using StorageType = typename Checker::StorageType; |
| 215 | + |
| 216 | + static constexpr StorageType Increment = (1 << 2); |
| 217 | + |
| 218 | + // Break [start, stop) into `nthreads` subintervals and apply *check to each |
| 219 | + // subinterval in parallel. |
| 220 | + void test_full_range(StorageType x_start, StorageType x_stop, |
| 221 | + StorageType y_start, StorageType y_stop, |
| 222 | + mpfr::RoundingMode rounding) { |
| 223 | + int n_threads = std::thread::hardware_concurrency(); |
| 224 | + std::vector<std::thread> thread_list; |
| 225 | + std::mutex mx_cur_val; |
| 226 | + int current_percent = -1; |
| 227 | + StorageType current_value = x_start; |
| 228 | + std::atomic<uint64_t> failed(0); |
| 229 | + |
| 230 | + for (int i = 0; i < n_threads; ++i) { |
| 231 | + thread_list.emplace_back([&, this]() { |
| 232 | + while (true) { |
| 233 | + StorageType range_begin, range_end; |
| 234 | + int new_percent = -1; |
| 235 | + { |
| 236 | + std::lock_guard<std::mutex> lock(mx_cur_val); |
| 237 | + if (current_value == x_stop) |
| 238 | + return; |
| 239 | + |
| 240 | + range_begin = current_value; |
| 241 | + if (x_stop >= Increment && x_stop - Increment >= current_value) { |
| 242 | + range_end = current_value + Increment; |
| 243 | + } else { |
| 244 | + range_end = x_stop; |
| 245 | + } |
| 246 | + current_value = range_end; |
| 247 | + int pc = 100.0 * (range_end - x_start) / (x_stop - x_start); |
| 248 | + if (current_percent != pc) { |
| 249 | + new_percent = pc; |
| 250 | + current_percent = pc; |
| 251 | + } |
| 252 | + } |
| 253 | + if (new_percent >= 0) { |
| 254 | + std::stringstream msg; |
| 255 | + msg << new_percent << "% is in process \r"; |
| 256 | + std::cout << msg.str() << std::flush; |
| 257 | + } |
| 258 | + |
| 259 | + uint64_t failed_in_range = |
| 260 | + Checker::check(range_begin, range_end, y_start, y_stop, rounding); |
| 261 | + if (failed_in_range > 0) { |
| 262 | + using T = LIBC_NAMESPACE::cpp::conditional_t< |
| 263 | + LIBC_NAMESPACE::cpp::is_same_v<FloatType, float16>, float, |
| 264 | + FloatType>; |
| 265 | + std::stringstream msg; |
| 266 | + msg << "Test failed for " << std::dec << failed_in_range |
| 267 | + << " inputs in range: " << range_begin << " to " << range_end |
| 268 | + << " [0x" << std::hex << range_begin << ", 0x" << range_end |
| 269 | + << "), [" << std::hexfloat |
| 270 | + << static_cast<T>(FPBits(range_begin).get_val()) << ", " |
| 271 | + << static_cast<T>(FPBits(range_end).get_val()) << ")\n"; |
| 272 | + std::cerr << msg.str() << std::flush; |
| 273 | + |
| 274 | + failed.fetch_add(failed_in_range); |
| 275 | + } |
| 276 | + } |
| 277 | + }); |
| 278 | + } |
| 279 | + |
| 280 | + for (auto &thread : thread_list) { |
| 281 | + if (thread.joinable()) { |
| 282 | + thread.join(); |
| 283 | + } |
| 284 | + } |
| 285 | + |
| 286 | + std::cout << std::endl; |
| 287 | + std::cout << "Test " << ((failed > 0) ? "FAILED" : "PASSED") << std::endl; |
| 288 | + ASSERT_EQ(failed.load(), uint64_t(0)); |
| 289 | + } |
| 290 | + |
| 291 | + void test_full_range_all_roundings(StorageType x_start, StorageType x_stop, |
| 292 | + StorageType y_start, StorageType y_stop) { |
| 293 | + test_full_range(x_start, x_stop, y_start, y_stop, |
| 294 | + mpfr::RoundingMode::Nearest); |
| 295 | + |
| 296 | + std::cout << "-- Testing for FE_UPWARD in x range [0x" << std::hex |
| 297 | + << x_start << ", 0x" << x_stop << ") y range [0x" << std::hex |
| 298 | + << y_start << ", 0x" << y_stop << ") --" << std::dec << std::endl; |
| 299 | + test_full_range(x_start, x_stop, y_start, y_stop, |
| 300 | + mpfr::RoundingMode::Upward); |
| 301 | + |
| 302 | + std::cout << "-- Testing for FE_DOWNWARD in x range [0x" << std::hex |
| 303 | + << x_start << ", 0x" << x_stop << ") y range [0x" << std::hex |
| 304 | + << y_start << ", 0x" << y_stop << ") --" << std::dec << std::endl; |
| 305 | + test_full_range(x_start, x_stop, y_start, y_stop, |
| 306 | + mpfr::RoundingMode::Downward); |
| 307 | + |
| 308 | + std::cout << "-- Testing for FE_TOWARDZERO in x range [0x" << std::hex |
| 309 | + << x_start << ", 0x" << x_stop << ") y range [0x" << std::hex |
| 310 | + << y_start << ", 0x" << y_stop << ") --" << std::dec << std::endl; |
| 311 | + test_full_range(x_start, x_stop, y_start, y_stop, |
| 312 | + mpfr::RoundingMode::TowardZero); |
| 313 | + }; |
| 314 | +}; |
| 315 | + |
170 | 316 | template <typename FloatType, mpfr::Operation Op, UnaryOp<FloatType> Func>
|
171 | 317 | using LlvmLibcUnaryOpExhaustiveMathTest =
|
172 | 318 | LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, Op, Func>>;
|
| 319 | + |
| 320 | +template <typename FloatType, mpfr::Operation Op, BinaryOp<FloatType> Func> |
| 321 | +using LlvmLibcBinaryOpExhaustiveMathTest = |
| 322 | + LlvmLibcBinaryInputExhaustiveMathTest<BinaryOpChecker<FloatType, Op, Func>>; |
0 commit comments