Skip to content

Commit f5d1a77

Browse files
committed
Test prod()
1 parent 98d04a4 commit f5d1a77

File tree

1 file changed

+87
-21
lines changed

1 file changed

+87
-21
lines changed

array_api_tests/test_statistical_functions.py

Lines changed: 87 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22

3-
from hypothesis import given
3+
from hypothesis import assume, given
44
from hypothesis import strategies as st
55

66
from . import _array_module as xp
@@ -9,8 +9,22 @@
99
from . import hypothesis_helpers as hh
1010
from . import pytest_helpers as ph
1111
from . import xps
12+
from .typing import Scalar, ScalarType
1213

13-
RTOL = 0.05
14+
15+
def assert_equals(
16+
func_name: str, type_: ScalarType, out: Scalar, expected: Scalar, /, **kw
17+
):
18+
f_func = f"{func_name}({ph.fmt_kw(kw)})"
19+
if type_ is bool or type_ is int:
20+
msg = f"{out=}, should be {expected} [{f_func}]"
21+
assert out == expected, msg
22+
elif math.isnan(expected):
23+
msg = f"{out=}, should be {expected} [{f_func}]"
24+
assert math.isnan(out), msg
25+
else:
26+
msg = f"{out=}, should be roughly {expected} [{f_func}]"
27+
assert math.isclose(out, expected, rel_tol=0.05), msg
1428

1529

1630
@given(
@@ -34,7 +48,7 @@ def test_min(x, data):
3448
f_func = f"min({ph.fmt_kw(kw)})"
3549

3650
# TODO: support axis
37-
if kw.get("axis") is None:
51+
if kw.get("axis", None) is None:
3852
keepdims = kw.get("keepdims", False)
3953
if keepdims:
4054
idx = tuple(1 for _ in x.shape)
@@ -53,11 +67,7 @@ def test_min(x, data):
5367
elements.append(s)
5468
min_ = scalar_type(_out)
5569
expected = min(elements)
56-
msg = f"out={min_}, should be {expected} [{f_func}]"
57-
if math.isnan(min_):
58-
assert math.isnan(expected), msg
59-
else:
60-
assert min_ == expected, msg
70+
assert_equals("min", dh.get_scalar_type(out.dtype), min_, expected)
6171

6272

6373
@given(
@@ -81,7 +91,7 @@ def test_max(x, data):
8191
f_func = f"max({ph.fmt_kw(kw)})"
8292

8393
# TODO: support axis
84-
if kw.get("axis") is None:
94+
if kw.get("axis", None) is None:
8595
keepdims = kw.get("keepdims", False)
8696
if keepdims:
8797
idx = tuple(1 for _ in x.shape)
@@ -100,11 +110,7 @@ def test_max(x, data):
100110
elements.append(s)
101111
max_ = scalar_type(_out)
102112
expected = max(elements)
103-
msg = f"out={max_}, should be {expected} [{f_func}]"
104-
if math.isnan(max_):
105-
assert math.isnan(expected), msg
106-
else:
107-
assert max_ == expected, msg
113+
assert_equals("mean", dh.get_scalar_type(out.dtype), max_, expected)
108114

109115

110116
@given(
@@ -128,7 +134,7 @@ def test_mean(x, data):
128134
f_func = f"mean({ph.fmt_kw(kw)})"
129135

130136
# TODO: support axis
131-
if kw.get("axis") is None:
137+
if kw.get("axis", None) is None:
132138
keepdims = kw.get("keepdims", False)
133139
if keepdims:
134140
idx = tuple(1 for _ in x.shape)
@@ -146,15 +152,75 @@ def test_mean(x, data):
146152
elements.append(s)
147153
mean = float(_out)
148154
expected = sum(elements) / len(elements)
149-
msg = f"out={mean}, should be roughly {expected} [{f_func}]"
150-
assert math.isclose(mean, expected, rel_tol=RTOL), msg
155+
assert_equals("mean", float, mean, expected)
151156

152157

153158
# TODO: generate kwargs
154-
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)))
155-
def test_prod(x):
156-
xp.prod(x)
157-
# TODO
159+
@given(
160+
x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)),
161+
data=st.data(),
162+
)
163+
def test_prod(x, data):
164+
axis_strats = [st.none()]
165+
if x.shape != ():
166+
axis_strats.append(
167+
st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim)
168+
)
169+
kw = data.draw(
170+
hh.kwargs(
171+
axis=st.one_of(axis_strats),
172+
dtype=st.none() | st.just(x.dtype), # TODO: all valid dtypes
173+
keepdims=st.booleans(),
174+
),
175+
label="kw",
176+
)
177+
178+
out = xp.prod(x, **kw)
179+
180+
dtype = kw.get("dtype", None)
181+
if dtype is None:
182+
if dh.is_int_dtype(x.dtype):
183+
m, M = dh.dtype_ranges[x.dtype]
184+
d_m, d_M = dh.dtype_ranges[dh.default_int]
185+
if m < d_m or M > d_M:
186+
_dtype = x.dtype
187+
else:
188+
_dtype = dh.default_int
189+
else:
190+
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
191+
_dtype = x.dtype
192+
else:
193+
_dtype = dh.default_float
194+
else:
195+
_dtype = dtype
196+
ph.assert_dtype("prod", x.dtype, out.dtype, _dtype)
197+
198+
f_func = f"prod({ph.fmt_kw(kw)})"
199+
200+
# TODO: support axis
201+
if kw.get("axis", None) is None:
202+
keepdims = kw.get("keepdims", False)
203+
if keepdims:
204+
idx = tuple(1 for _ in x.shape)
205+
msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]"
206+
assert out.shape == idx, msg
207+
else:
208+
ph.assert_shape("prod", out.shape, (), **kw)
209+
210+
# TODO: figure out NaN behaviour
211+
if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)):
212+
_out = xp.reshape(out, ()) if keepdims else out
213+
scalar_type = dh.get_scalar_type(out.dtype)
214+
elements = []
215+
for idx in ah.ndindex(x.shape):
216+
s = scalar_type(x[idx])
217+
elements.append(s)
218+
prod = scalar_type(_out)
219+
expected = math.prod(elements)
220+
if dh.is_int_dtype(out.dtype):
221+
m, M = dh.dtype_ranges[out.dtype]
222+
assume(m <= expected <= M)
223+
assert_equals("prod", dh.get_scalar_type(out.dtype), prod, expected)
158224

159225

160226
# TODO: generate kwargs

0 commit comments

Comments
 (0)