Skip to content

Commit bd989f9

Browse files
committed
fixup! [libc][math][c23] Add MPFR exhaustive test for fmodf16
1 parent 5d66a52 commit bd989f9

File tree

1 file changed

+65
-121
lines changed

1 file changed

+65
-121
lines changed

libc/test/src/math/exhaustive/exhaustive_test.h

Lines changed: 65 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -69,22 +69,22 @@ struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
6969
}
7070
};
7171

72-
template <typename T> using BinaryOp = T(T, T);
72+
template <typename OutType, typename InType = OutType>
73+
using BinaryOp = OutType(InType, InType);
7374

74-
template <typename T, mpfr::Operation Op, BinaryOp<T> Func>
75+
template <typename OutType, typename InType, mpfr::Operation Op,
76+
BinaryOp<OutType, InType> Func>
7577
struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
76-
using FloatType = T;
78+
using FloatType = InType;
7779
using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
7880
using StorageType = typename FPBits::StorageType;
7981

80-
static constexpr BinaryOp<FloatType> *FUNC = Func;
81-
8282
// Check in a range, return the number of failures.
8383
uint64_t check(StorageType x_start, StorageType x_stop, StorageType y_start,
8484
StorageType y_stop, mpfr::RoundingMode rounding) {
8585
mpfr::ForceRoundingMode r(rounding);
8686
if (!r.success)
87-
return (x_stop > x_start || y_stop > y_start);
87+
return x_stop > x_start || y_stop > y_start;
8888
StorageType xbits = x_start;
8989
uint64_t failed = 0;
9090
do {
@@ -93,12 +93,12 @@ struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
9393
do {
9494
FloatType y = FPBits(ybits).get_val();
9595
mpfr::BinaryInput<FloatType> input{x, y};
96-
bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, input, FUNC(x, y),
96+
bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, input, Func(x, y),
9797
0.5, rounding);
9898
failed += (!correct);
9999
// Uncomment to print out failed values.
100100
// if (!correct) {
101-
// TEST_MPFR_MATCH(Op::Operation, x, Op::func(x, y), 0.5, rounding);
101+
// EXPECT_MPFR_MATCH_ROUNDING(Op, input, Func(x, y), 0.5, rounding);
102102
// }
103103
} while (ybits++ < y_stop);
104104
} while (xbits++ < x_stop);
@@ -108,20 +108,45 @@ struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
108108

109109
// Checker class needs inherit from LIBC_NAMESPACE::testing::Test and provide
110110
// StorageType and check method.
111-
template <typename Checker>
111+
template <typename Checker, size_t Increment = 1 << 20>
112112
struct LlvmLibcExhaustiveMathTest
113113
: public virtual LIBC_NAMESPACE::testing::Test,
114114
public Checker {
115115
using FloatType = typename Checker::FloatType;
116116
using FPBits = typename Checker::FPBits;
117117
using StorageType = typename Checker::StorageType;
118118

119-
static constexpr StorageType INCREMENT = (1 << 20);
119+
static constexpr StorageType INCREMENT = Increment;
120+
121+
void explain_failed_range(std::stringstream &msg, StorageType x_begin,
122+
StorageType x_end) {
123+
#ifdef LIBC_TYPES_HAS_FLOAT16
124+
using T = LIBC_NAMESPACE::cpp::conditional_t<
125+
LIBC_NAMESPACE::cpp::is_same_v<FloatType, float16>, float, FloatType>;
126+
#else
127+
using T = FloatType;
128+
#endif
129+
130+
msg << x_begin << " to " << x_end << " [0x" << std::hex << x_begin << ", 0x"
131+
<< x_end << "), [" << std::hexfloat
132+
<< static_cast<T>(FPBits(x_begin).get_val()) << ", "
133+
<< static_cast<T>(FPBits(x_end).get_val()) << ")";
134+
}
135+
136+
void explain_failed_range(std::stringstream &msg, StorageType x_begin,
137+
StorageType x_end, StorageType y_begin,
138+
StorageType y_end) {
139+
msg << "x ";
140+
explain_failed_range(msg, x_begin, x_end);
141+
msg << ", y ";
142+
explain_failed_range(msg, y_begin, y_end);
143+
}
120144

121145
// Break [start, stop) into `nthreads` subintervals and apply *check to each
122146
// subinterval in parallel.
123-
void test_full_range(StorageType start, StorageType stop,
124-
mpfr::RoundingMode rounding) {
147+
template <typename... T>
148+
void test_full_range(mpfr::RoundingMode rounding, StorageType start,
149+
StorageType stop, T... extra_range_bounds) {
125150
int n_threads = std::thread::hardware_concurrency();
126151
std::vector<std::thread> thread_list;
127152
std::mutex mx_cur_val;
@@ -158,15 +183,14 @@ struct LlvmLibcExhaustiveMathTest
158183
std::cout << msg.str() << std::flush;
159184
}
160185

161-
uint64_t failed_in_range =
162-
Checker::check(range_begin, range_end, rounding);
186+
uint64_t failed_in_range = Checker::check(
187+
range_begin, range_end, extra_range_bounds..., rounding);
163188
if (failed_in_range > 0) {
164189
std::stringstream msg;
165190
msg << "Test failed for " << std::dec << failed_in_range
166-
<< " inputs in range: " << range_begin << " to " << range_end
167-
<< " [0x" << std::hex << range_begin << ", 0x" << range_end
168-
<< "), [" << std::hexfloat << FPBits(range_begin).get_val()
169-
<< ", " << FPBits(range_end).get_val() << ")\n";
191+
<< " inputs in range: ";
192+
explain_failed_range(msg, start, stop, extra_range_bounds...);
193+
msg << "\n";
170194
std::cerr << msg.str() << std::flush;
171195

172196
failed.fetch_add(failed_in_range);
@@ -189,127 +213,46 @@ struct LlvmLibcExhaustiveMathTest
189213
void test_full_range_all_roundings(StorageType start, StorageType stop) {
190214
std::cout << "-- Testing for FE_TONEAREST in range [0x" << std::hex << start
191215
<< ", 0x" << stop << ") --" << std::dec << std::endl;
192-
test_full_range(start, stop, mpfr::RoundingMode::Nearest);
216+
test_full_range(mpfr::RoundingMode::Nearest, start, stop);
193217

194218
std::cout << "-- Testing for FE_UPWARD in range [0x" << std::hex << start
195219
<< ", 0x" << stop << ") --" << std::dec << std::endl;
196-
test_full_range(start, stop, mpfr::RoundingMode::Upward);
220+
test_full_range(mpfr::RoundingMode::Upward, start, stop);
197221

198222
std::cout << "-- Testing for FE_DOWNWARD in range [0x" << std::hex << start
199223
<< ", 0x" << stop << ") --" << std::dec << std::endl;
200-
test_full_range(start, stop, mpfr::RoundingMode::Downward);
224+
test_full_range(mpfr::RoundingMode::Downward, start, stop);
201225

202226
std::cout << "-- Testing for FE_TOWARDZERO in range [0x" << std::hex
203227
<< start << ", 0x" << stop << ") --" << std::dec << std::endl;
204-
test_full_range(start, stop, mpfr::RoundingMode::TowardZero);
228+
test_full_range(mpfr::RoundingMode::TowardZero, start, stop);
205229
};
206-
};
207-
208-
template <typename Checker>
209-
struct LlvmLibcBinaryInputExhaustiveMathTest
210-
: public virtual LIBC_NAMESPACE::testing::Test,
211-
public Checker {
212-
using FloatType = typename Checker::FloatType;
213-
using FPBits = typename Checker::FPBits;
214-
using StorageType = typename Checker::StorageType;
215-
216-
static constexpr StorageType Increment = (1 << 2);
217-
218-
// Break [start, stop) into `nthreads` subintervals and apply *check to each
219-
// subinterval in parallel.
220-
void test_full_range(StorageType x_start, StorageType x_stop,
221-
StorageType y_start, StorageType y_stop,
222-
mpfr::RoundingMode rounding) {
223-
int n_threads = std::thread::hardware_concurrency();
224-
std::vector<std::thread> thread_list;
225-
std::mutex mx_cur_val;
226-
int current_percent = -1;
227-
StorageType current_value = x_start;
228-
std::atomic<uint64_t> failed(0);
229-
230-
for (int i = 0; i < n_threads; ++i) {
231-
thread_list.emplace_back([&, this]() {
232-
while (true) {
233-
StorageType range_begin, range_end;
234-
int new_percent = -1;
235-
{
236-
std::lock_guard<std::mutex> lock(mx_cur_val);
237-
if (current_value == x_stop)
238-
return;
239-
240-
range_begin = current_value;
241-
if (x_stop >= Increment && x_stop - Increment >= current_value) {
242-
range_end = current_value + Increment;
243-
} else {
244-
range_end = x_stop;
245-
}
246-
current_value = range_end;
247-
int pc = 100.0 * (range_end - x_start) / (x_stop - x_start);
248-
if (current_percent != pc) {
249-
new_percent = pc;
250-
current_percent = pc;
251-
}
252-
}
253-
if (new_percent >= 0) {
254-
std::stringstream msg;
255-
msg << new_percent << "% is in process \r";
256-
std::cout << msg.str() << std::flush;
257-
}
258-
259-
uint64_t failed_in_range =
260-
Checker::check(range_begin, range_end, y_start, y_stop, rounding);
261-
if (failed_in_range > 0) {
262-
using T = LIBC_NAMESPACE::cpp::conditional_t<
263-
LIBC_NAMESPACE::cpp::is_same_v<FloatType, float16>, float,
264-
FloatType>;
265-
std::stringstream msg;
266-
msg << "Test failed for " << std::dec << failed_in_range
267-
<< " inputs in range: " << range_begin << " to " << range_end
268-
<< " [0x" << std::hex << range_begin << ", 0x" << range_end
269-
<< "), [" << std::hexfloat
270-
<< static_cast<T>(FPBits(range_begin).get_val()) << ", "
271-
<< static_cast<T>(FPBits(range_end).get_val()) << ")\n";
272-
std::cerr << msg.str() << std::flush;
273-
274-
failed.fetch_add(failed_in_range);
275-
}
276-
}
277-
});
278-
}
279-
280-
for (auto &thread : thread_list) {
281-
if (thread.joinable()) {
282-
thread.join();
283-
}
284-
}
285-
286-
std::cout << std::endl;
287-
std::cout << "Test " << ((failed > 0) ? "FAILED" : "PASSED") << std::endl;
288-
ASSERT_EQ(failed.load(), uint64_t(0));
289-
}
290230

291231
void test_full_range_all_roundings(StorageType x_start, StorageType x_stop,
292232
StorageType y_start, StorageType y_stop) {
293-
test_full_range(x_start, x_stop, y_start, y_stop,
294-
mpfr::RoundingMode::Nearest);
233+
std::cout << "-- Testing for FE_TONEAREST in x range [0x" << std::hex
234+
<< x_start << ", 0x" << x_stop << "), y range [0x" << y_start
235+
<< ", 0x" << y_stop << ") --" << std::dec << std::endl;
236+
test_full_range(mpfr::RoundingMode::Nearest, x_start, x_stop, y_start,
237+
y_stop);
295238

296239
std::cout << "-- Testing for FE_UPWARD in x range [0x" << std::hex
297-
<< x_start << ", 0x" << x_stop << ") y range [0x" << std::hex
298-
<< y_start << ", 0x" << y_stop << ") --" << std::dec << std::endl;
299-
test_full_range(x_start, x_stop, y_start, y_stop,
300-
mpfr::RoundingMode::Upward);
240+
<< x_start << ", 0x" << x_stop << "), y range [0x" << y_start
241+
<< ", 0x" << y_stop << ") --" << std::dec << std::endl;
242+
test_full_range(mpfr::RoundingMode::Upward, x_start, x_stop, y_start,
243+
y_stop);
301244

302245
std::cout << "-- Testing for FE_DOWNWARD in x range [0x" << std::hex
303-
<< x_start << ", 0x" << x_stop << ") y range [0x" << std::hex
304-
<< y_start << ", 0x" << y_stop << ") --" << std::dec << std::endl;
305-
test_full_range(x_start, x_stop, y_start, y_stop,
306-
mpfr::RoundingMode::Downward);
246+
<< x_start << ", 0x" << x_stop << "), y range [0x" << y_start
247+
<< ", 0x" << y_stop << ") --" << std::dec << std::endl;
248+
test_full_range(mpfr::RoundingMode::Downward, x_start, x_stop, y_start,
249+
y_stop);
307250

308251
std::cout << "-- Testing for FE_TOWARDZERO in x range [0x" << std::hex
309-
<< x_start << ", 0x" << x_stop << ") y range [0x" << std::hex
310-
<< y_start << ", 0x" << y_stop << ") --" << std::dec << std::endl;
311-
test_full_range(x_start, x_stop, y_start, y_stop,
312-
mpfr::RoundingMode::TowardZero);
252+
<< x_start << ", 0x" << x_stop << "), y range [0x" << y_start
253+
<< ", 0x" << y_stop << ") --" << std::dec << std::endl;
254+
test_full_range(mpfr::RoundingMode::TowardZero, x_start, x_stop, y_start,
255+
y_stop);
313256
};
314257
};
315258

@@ -324,4 +267,5 @@ using LlvmLibcUnaryNarrowingOpExhaustiveMathTest =
324267

325268
template <typename FloatType, mpfr::Operation Op, BinaryOp<FloatType> Func>
326269
using LlvmLibcBinaryOpExhaustiveMathTest =
327-
LlvmLibcBinaryInputExhaustiveMathTest<BinaryOpChecker<FloatType, Op, Func>>;
270+
LlvmLibcExhaustiveMathTest<BinaryOpChecker<FloatType, FloatType, Op, Func>,
271+
1 << 2>;

0 commit comments

Comments
 (0)