Skip to content

Commit 2077986

Browse files
committed
Rudimentary values testing refactor, updates to logical elwise tests
1 parent 47424e8 commit 2077986

File tree

1 file changed

+78
-30
lines changed

1 file changed

+78
-30
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from . import pytest_helpers as ph
2626
from . import shape_helpers as sh
2727
from . import xps
28-
from .typing import Array, DataType, Param, Scalar, Shape
28+
from .typing import Array, DataType, Param, Scalar, ScalarType, Shape
2929

3030
pytestmark = pytest.mark.ci
3131

@@ -38,12 +38,68 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
3838
return xps.boolean_dtypes() | all_integer_dtypes()
3939

4040

41-
def isclose(n1: Union[int, float], n2: Union[int, float]):
41+
def isclose(n1: Union[int, float], n2: Union[int, float]) -> bool:
4242
if not (math.isfinite(n1) and math.isfinite(n2)):
4343
raise ValueError(f"{n1=} and {n1=}, but input must be finite")
4444
return math.isclose(n1, n2, rel_tol=0.25, abs_tol=1)
4545

4646

47+
def unary_assert_against_refimpl(
48+
func_name: str,
49+
in_stype: ScalarType,
50+
in_: Array,
51+
res: Array,
52+
refimpl: Callable[[Scalar], Scalar],
53+
expr_template: str,
54+
res_stype: Optional[ScalarType] = None,
55+
):
56+
if in_.shape != res.shape:
57+
raise ValueError(f"{res.shape=}, but should be {in_.shape=}")
58+
if res_stype is None:
59+
res_stype = in_stype
60+
for idx in sh.ndindex(in_.shape):
61+
scalar_i = in_stype(in_[idx])
62+
expected = refimpl(scalar_i)
63+
scalar_o = res_stype(res[idx])
64+
f_i = sh.fmt_idx("x", idx)
65+
f_o = sh.fmt_idx("out", idx)
66+
expr = expr_template.format(scalar_i, expected)
67+
assert scalar_o == expected, (
68+
f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n"
69+
f"{f_i}={scalar_i}"
70+
)
71+
72+
73+
def binary_assert_against_refimpl(
74+
func_name: str,
75+
in_stype: ScalarType,
76+
left: Array,
77+
right: Array,
78+
res: Array,
79+
refimpl: Callable[[Scalar, Scalar], Scalar],
80+
expr_template: str,
81+
res_stype: Optional[ScalarType] = None,
82+
left_sym: str = "x1",
83+
right_sym: str = "x2",
84+
res_sym: str = "out",
85+
):
86+
if res_stype is None:
87+
res_stype = in_stype
88+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
89+
scalar_l = in_stype(left[l_idx])
90+
scalar_r = in_stype(right[r_idx])
91+
expected = refimpl(scalar_l, scalar_r)
92+
scalar_o = res_stype(res[o_idx])
93+
f_l = sh.fmt_idx(left_sym, l_idx)
94+
f_r = sh.fmt_idx(right_sym, r_idx)
95+
f_o = sh.fmt_idx(res_sym, o_idx)
96+
expr = expr_template.format(scalar_l, scalar_r, expected)
97+
assert scalar_o == expected, (
98+
f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n"
99+
f"{f_l}={scalar_l}, {f_r}={scalar_r}"
100+
)
101+
102+
47103
# When appropiate, this module tests operators alongside their respective
48104
# elementwise methods. We do this by parametrizing a generalised test method
49105
# with every relevant method and operator.
@@ -1249,53 +1305,45 @@ def test_logical_and(x1, x2):
12491305
out = ah.logical_and(x1, x2)
12501306
ph.assert_dtype("logical_and", (x1.dtype, x2.dtype), out.dtype)
12511307
ph.assert_result_shape("logical_and", (x1.shape, x2.shape), out.shape)
1252-
for l_idx, r_idx, o_idx in sh.iter_indices(x1.shape, x2.shape, out.shape):
1253-
scalar_l = bool(x1[l_idx])
1254-
scalar_r = bool(x2[r_idx])
1255-
expected = scalar_l and scalar_r
1256-
scalar_o = bool(out[o_idx])
1257-
f_l = sh.fmt_idx("x1", l_idx)
1258-
f_r = sh.fmt_idx("x2", r_idx)
1259-
f_o = sh.fmt_idx("out", o_idx)
1260-
assert scalar_o == expected, (
1261-
f"{f_o}={scalar_o}, but should be ({f_l} and {f_r})={expected} "
1262-
f"[logical_and()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
1263-
)
1308+
binary_assert_against_refimpl(
1309+
"logical_and",
1310+
bool,
1311+
x1,
1312+
x2,
1313+
out,
1314+
lambda l, r: l and r,
1315+
"({} and {})={}",
1316+
)
12641317

12651318

12661319
@given(xps.arrays(dtype=xp.bool, shape=hh.shapes()))
12671320
def test_logical_not(x):
12681321
out = ah.logical_not(x)
12691322
ph.assert_dtype("logical_not", x.dtype, out.dtype)
12701323
ph.assert_shape("logical_not", out.shape, x.shape)
1271-
for idx in sh.ndindex(x.shape):
1272-
assert out[idx] == (not bool(x[idx]))
1324+
unary_assert_against_refimpl(
1325+
"logical_not", bool, x, out, lambda i: not i, "(not {})={}"
1326+
)
12731327

12741328

12751329
@given(*hh.two_mutual_arrays([xp.bool]))
12761330
def test_logical_or(x1, x2):
12771331
out = ah.logical_or(x1, x2)
12781332
ph.assert_dtype("logical_or", (x1.dtype, x2.dtype), out.dtype)
1279-
# See the comments in test_equal
1280-
shape = sh.broadcast_shapes(x1.shape, x2.shape)
1281-
ph.assert_shape("logical_or", out.shape, shape)
1282-
_x1 = xp.broadcast_to(x1, shape)
1283-
_x2 = xp.broadcast_to(x2, shape)
1284-
for idx in sh.ndindex(shape):
1285-
assert out[idx] == (bool(_x1[idx]) or bool(_x2[idx]))
1333+
ph.assert_result_shape("logical_or", (x1.shape, x2.shape), out.shape)
1334+
binary_assert_against_refimpl(
1335+
"logical_or", bool, x1, x2, out, lambda l, r: l or r, "({} or {})={}"
1336+
)
12861337

12871338

12881339
@given(*hh.two_mutual_arrays([xp.bool]))
12891340
def test_logical_xor(x1, x2):
12901341
out = xp.logical_xor(x1, x2)
12911342
ph.assert_dtype("logical_xor", (x1.dtype, x2.dtype), out.dtype)
1292-
# See the comments in test_equal
1293-
shape = sh.broadcast_shapes(x1.shape, x2.shape)
1294-
ph.assert_shape("logical_xor", out.shape, shape)
1295-
_x1 = xp.broadcast_to(x1, shape)
1296-
_x2 = xp.broadcast_to(x2, shape)
1297-
for idx in sh.ndindex(shape):
1298-
assert out[idx] == (bool(_x1[idx]) ^ bool(_x2[idx]))
1343+
ph.assert_result_shape("logical_xor", (x1.shape, x2.shape), out.shape)
1344+
binary_assert_against_refimpl(
1345+
"logical_xor", bool, x1, x2, out, lambda l, r: l ^ r, "({} ^ {})={}"
1346+
)
12991347

13001348

13011349
@pytest.mark.parametrize("ctx", make_binary_params("multiply", xps.numeric_dtypes()))

0 commit comments

Comments
 (0)