Skip to content

Commit 47424e8

Browse files
committed
Values testing for test_add and test_subtract
1 parent f11a6d0 commit 47424e8

File tree

1 file changed

+64
-8
lines changed

1 file changed

+64
-8
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,37 @@ def test_add(ctx, data):
310310

311311
assert_binary_param_dtype(ctx, left, right, res)
312312
assert_binary_param_shape(ctx, left, right, res)
313-
if not ctx.right_is_scalar:
314-
# add is commutative
315-
expected = ctx.func(right, left)
316-
ah.assert_exactly_equal(res, expected)
313+
m, M = dh.dtype_ranges[res.dtype]
314+
scalar_type = dh.get_scalar_type(res.dtype)
315+
if ctx.right_is_scalar:
316+
for idx in sh.ndindex(res.shape):
317+
scalar_l = scalar_type(left[idx])
318+
expected = scalar_l + right
319+
if not math.isfinite(expected) or expected <= m or expected >= M:
320+
continue
321+
scalar_o = scalar_type(res[idx])
322+
f_l = sh.fmt_idx(ctx.left_sym, idx)
323+
f_o = sh.fmt_idx(ctx.res_name, idx)
324+
assert isclose(scalar_o, expected), (
325+
f"{f_o}={scalar_o}, but should be roughly ({f_l} + {right})={expected} "
326+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}"
327+
)
328+
else:
329+
ph.assert_array(ctx.func_name, res, ctx.func(right, left)) # cumulative
330+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
331+
scalar_l = scalar_type(left[l_idx])
332+
scalar_r = scalar_type(right[r_idx])
333+
expected = scalar_l + scalar_r
334+
if not math.isfinite(expected) or expected <= m or expected >= M:
335+
continue
336+
scalar_o = scalar_type(res[o_idx])
337+
f_l = sh.fmt_idx(ctx.left_sym, l_idx)
338+
f_r = sh.fmt_idx(ctx.right_sym, r_idx)
339+
f_o = sh.fmt_idx(ctx.res_name, o_idx)
340+
assert isclose(scalar_o, expected), (
341+
f"{f_o}={scalar_o}, but should be roughly ({f_l} + {f_r})={expected} "
342+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
343+
)
317344

318345

319346
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
@@ -1487,9 +1514,9 @@ def test_sign(x):
14871514
expr = f"({f_x} / |{f_x}|)={expected}"
14881515
scalar_o = scalar_type(out[idx])
14891516
f_o = sh.fmt_idx("out", idx)
1490-
assert scalar_o == expected, (
1491-
f"{f_o}={scalar_o}, but should be {expr} [sign()]\n{f_x}={scalar_x}"
1492-
)
1517+
assert (
1518+
scalar_o == expected
1519+
), f"{f_o}={scalar_o}, but should be {expr} [sign()]\n{f_x}={scalar_x}"
14931520

14941521

14951522
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
@@ -1535,7 +1562,36 @@ def test_subtract(ctx, data):
15351562

15361563
assert_binary_param_dtype(ctx, left, right, res)
15371564
assert_binary_param_shape(ctx, left, right, res)
1538-
# TODO
1565+
m, M = dh.dtype_ranges[res.dtype]
1566+
scalar_type = dh.get_scalar_type(res.dtype)
1567+
if ctx.right_is_scalar:
1568+
for idx in sh.ndindex(res.shape):
1569+
scalar_l = scalar_type(left[idx])
1570+
expected = scalar_l - right
1571+
if not math.isfinite(expected) or expected <= m or expected >= M:
1572+
continue
1573+
scalar_o = scalar_type(res[idx])
1574+
f_l = sh.fmt_idx(ctx.left_sym, idx)
1575+
f_o = sh.fmt_idx(ctx.res_name, idx)
1576+
assert isclose(scalar_o, expected), (
1577+
f"{f_o}={scalar_o}, but should be roughly ({f_l} - {right})={expected} "
1578+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}"
1579+
)
1580+
else:
1581+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
1582+
scalar_l = scalar_type(left[l_idx])
1583+
scalar_r = scalar_type(right[r_idx])
1584+
expected = scalar_l - scalar_r
1585+
if not math.isfinite(expected) or expected <= m or expected >= M:
1586+
continue
1587+
scalar_o = scalar_type(res[o_idx])
1588+
f_l = sh.fmt_idx(ctx.left_sym, l_idx)
1589+
f_r = sh.fmt_idx(ctx.right_sym, r_idx)
1590+
f_o = sh.fmt_idx(ctx.res_name, o_idx)
1591+
assert isclose(scalar_o, expected), (
1592+
f"{f_o}={scalar_o}, but should be roughly ({f_l} - {f_r})={expected} "
1593+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
1594+
)
15391595

15401596

15411597
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))

0 commit comments

Comments
 (0)