1
1
import math
2
2
3
- from hypothesis import given
3
+ from hypothesis import assume , given
4
4
from hypothesis import strategies as st
5
5
6
6
from . import _array_module as xp
9
9
from . import hypothesis_helpers as hh
10
10
from . import pytest_helpers as ph
11
11
from . import xps
12
+ from .typing import Scalar , ScalarType
12
13
13
- RTOL = 0.05
14
+
15
+ def assert_equals (
16
+ func_name : str , type_ : ScalarType , out : Scalar , expected : Scalar , / , ** kw
17
+ ):
18
+ f_func = f"{ func_name } ({ ph .fmt_kw (kw )} )"
19
+ if type_ is bool or type_ is int :
20
+ msg = f"{ out = } , should be { expected } [{ f_func } ]"
21
+ assert out == expected , msg
22
+ elif math .isnan (expected ):
23
+ msg = f"{ out = } , should be { expected } [{ f_func } ]"
24
+ assert math .isnan (out ), msg
25
+ else :
26
+ msg = f"{ out = } , should be roughly { expected } [{ f_func } ]"
27
+ assert math .isclose (out , expected , rel_tol = 0.05 ), msg
14
28
15
29
16
30
@given (
@@ -34,7 +48,7 @@ def test_min(x, data):
34
48
f_func = f"min({ ph .fmt_kw (kw )} )"
35
49
36
50
# TODO: support axis
37
- if kw .get ("axis" ) is None :
51
+ if kw .get ("axis" , None ) is None :
38
52
keepdims = kw .get ("keepdims" , False )
39
53
if keepdims :
40
54
idx = tuple (1 for _ in x .shape )
@@ -53,11 +67,7 @@ def test_min(x, data):
53
67
elements .append (s )
54
68
min_ = scalar_type (_out )
55
69
expected = min (elements )
56
- msg = f"out={ min_ } , should be { expected } [{ f_func } ]"
57
- if math .isnan (min_ ):
58
- assert math .isnan (expected ), msg
59
- else :
60
- assert min_ == expected , msg
70
+ assert_equals ("min" , dh .get_scalar_type (out .dtype ), min_ , expected )
61
71
62
72
63
73
@given (
@@ -81,7 +91,7 @@ def test_max(x, data):
81
91
f_func = f"max({ ph .fmt_kw (kw )} )"
82
92
83
93
# TODO: support axis
84
- if kw .get ("axis" ) is None :
94
+ if kw .get ("axis" , None ) is None :
85
95
keepdims = kw .get ("keepdims" , False )
86
96
if keepdims :
87
97
idx = tuple (1 for _ in x .shape )
@@ -100,11 +110,7 @@ def test_max(x, data):
100
110
elements .append (s )
101
111
max_ = scalar_type (_out )
102
112
expected = max (elements )
103
- msg = f"out={ max_ } , should be { expected } [{ f_func } ]"
104
- if math .isnan (max_ ):
105
- assert math .isnan (expected ), msg
106
- else :
107
- assert max_ == expected , msg
113
+ assert_equals ("mean" , dh .get_scalar_type (out .dtype ), max_ , expected )
108
114
109
115
110
116
@given (
@@ -128,7 +134,7 @@ def test_mean(x, data):
128
134
f_func = f"mean({ ph .fmt_kw (kw )} )"
129
135
130
136
# TODO: support axis
131
- if kw .get ("axis" ) is None :
137
+ if kw .get ("axis" , None ) is None :
132
138
keepdims = kw .get ("keepdims" , False )
133
139
if keepdims :
134
140
idx = tuple (1 for _ in x .shape )
@@ -146,15 +152,75 @@ def test_mean(x, data):
146
152
elements .append (s )
147
153
mean = float (_out )
148
154
expected = sum (elements ) / len (elements )
149
- msg = f"out={ mean } , should be roughly { expected } [{ f_func } ]"
150
- assert math .isclose (mean , expected , rel_tol = RTOL ), msg
155
+ assert_equals ("mean" , float , mean , expected )
151
156
152
157
153
158
# TODO: generate kwargs
154
- @given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )))
155
- def test_prod (x ):
156
- xp .prod (x )
157
- # TODO
159
+ @given (
160
+ x = xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )),
161
+ data = st .data (),
162
+ )
163
+ def test_prod (x , data ):
164
+ axis_strats = [st .none ()]
165
+ if x .shape != ():
166
+ axis_strats .append (
167
+ st .integers (- x .ndim , x .ndim - 1 ) | xps .valid_tuple_axes (x .ndim )
168
+ )
169
+ kw = data .draw (
170
+ hh .kwargs (
171
+ axis = st .one_of (axis_strats ),
172
+ dtype = st .none () | st .just (x .dtype ), # TODO: all valid dtypes
173
+ keepdims = st .booleans (),
174
+ ),
175
+ label = "kw" ,
176
+ )
177
+
178
+ out = xp .prod (x , ** kw )
179
+
180
+ dtype = kw .get ("dtype" , None )
181
+ if dtype is None :
182
+ if dh .is_int_dtype (x .dtype ):
183
+ m , M = dh .dtype_ranges [x .dtype ]
184
+ d_m , d_M = dh .dtype_ranges [dh .default_int ]
185
+ if m < d_m or M > d_M :
186
+ _dtype = x .dtype
187
+ else :
188
+ _dtype = dh .default_int
189
+ else :
190
+ if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_float ]:
191
+ _dtype = x .dtype
192
+ else :
193
+ _dtype = dh .default_float
194
+ else :
195
+ _dtype = dtype
196
+ ph .assert_dtype ("prod" , x .dtype , out .dtype , _dtype )
197
+
198
+ f_func = f"prod({ ph .fmt_kw (kw )} )"
199
+
200
+ # TODO: support axis
201
+ if kw .get ("axis" , None ) is None :
202
+ keepdims = kw .get ("keepdims" , False )
203
+ if keepdims :
204
+ idx = tuple (1 for _ in x .shape )
205
+ msg = f"{ out .shape = } , should be reduced dimension { idx } [{ f_func } ]"
206
+ assert out .shape == idx , msg
207
+ else :
208
+ ph .assert_shape ("prod" , out .shape , (), ** kw )
209
+
210
+ # TODO: figure out NaN behaviour
211
+ if dh .is_int_dtype (x .dtype ) or not xp .any (xp .isnan (x )):
212
+ _out = xp .reshape (out , ()) if keepdims else out
213
+ scalar_type = dh .get_scalar_type (out .dtype )
214
+ elements = []
215
+ for idx in ah .ndindex (x .shape ):
216
+ s = scalar_type (x [idx ])
217
+ elements .append (s )
218
+ prod = scalar_type (_out )
219
+ expected = math .prod (elements )
220
+ if dh .is_int_dtype (out .dtype ):
221
+ m , M = dh .dtype_ranges [out .dtype ]
222
+ assume (m <= expected <= M )
223
+ assert_equals ("prod" , dh .get_scalar_type (out .dtype ), prod , expected )
158
224
159
225
160
226
# TODO: generate kwargs
0 commit comments