Skip to content

Commit 2e6f5a5

Browse files
authored
Merge pull request #18 from honno/nd-arrays
Implement `test_*_like` tests
2 parents 1303ef5 + 10cb57a commit 2e6f5a5

File tree

6 files changed

+287
-108
lines changed

6 files changed

+287
-108
lines changed

.github/workflows/numpy.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ jobs:
5151
"array_api_tests/test_signatures.py::test_function_positional_args[__index__]",
5252
"array_api_tests/test_signatures.py::test_function_keyword_only_args[prod]",
5353
"array_api_tests/test_signatures.py::test_function_keyword_only_args[sum]",
54-
5554
)
5655
5756
def pytest_collection_modifyitems(config, items):

array_api_tests/hypothesis_helpers.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
from operator import mul
33
from math import sqrt
44

5-
from hypothesis.strategies import (lists, integers, builds, sampled_from,
5+
from hypothesis import assume
6+
from hypothesis.strategies import (lists, integers, sampled_from,
67
shared, floats, just, composite, one_of,
78
none, booleans)
8-
from hypothesis.extra.numpy import mutually_broadcastable_shapes
9-
from hypothesis import assume
9+
from hypothesis.extra.array_api import make_strategies_namespace
1010

1111
from .pytest_helpers import nargs
1212
from .array_helpers import (dtype_ranges, integer_dtype_objects,
@@ -15,10 +15,14 @@
1515
integer_or_boolean_dtype_objects, dtype_objects)
1616
from ._array_module import full, float32, float64, bool as bool_dtype, _UndefinedStub
1717
from . import _array_module
18+
from . import _array_module as xp
1819

1920
from .function_stubs import elementwise_functions
2021

2122

23+
xps = make_strategies_namespace(xp)
24+
25+
2226
# Set this to True to not fail tests just because a dtype isn't implemented.
2327
# If no compatible dtype is implemented for a given test, the test will fail
2428
# with a hypothesis health check error. Note that this functionality will not
@@ -42,8 +46,13 @@
4246
boolean_dtypes = boolean_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub))
4347
dtypes = dtypes.filter(lambda x: not isinstance(x, _UndefinedStub))
4448

45-
shared_dtypes = shared(dtypes)
49+
shared_dtypes = shared(dtypes, key="dtype")
4650

51+
# TODO: Importing things from test_type_promotion should be replaced by
52+
# something that won't cause a circular import. Right now we use @st.composite
53+
# only because it returns a lazy-evaluated strategy - in the future this method
54+
# should remove the composite wrapper, just returning sampled_from(dtype_pairs)
55+
# instead of drawing from it.
4756
@composite
4857
def mutually_promotable_dtypes(draw, dtype_objects=dtype_objects):
4958
from .test_type_promotion import dtype_mapping, promotion_table
@@ -55,17 +64,20 @@ def mutually_promotable_dtypes(draw, dtype_objects=dtype_objects):
5564
# pairs (XXX: Can we redesign the strategies so that they can prefer
5665
# shrinking dtypes over values?)
5766
sorted_table = sorted(promotion_table)
58-
sorted_table = sorted(sorted_table, key=lambda ij: -1 if ij[0] == ij[1] else sorted_table.index(ij))
59-
dtype_pairs = [(dtype_mapping[i], dtype_mapping[j]) for i, j in
60-
sorted_table]
61-
62-
filtered_dtype_pairs = [(i, j) for i, j in dtype_pairs if i in
63-
dtype_objects and j in dtype_objects]
67+
sorted_table = sorted(
68+
sorted_table, key=lambda ij: -1 if ij[0] == ij[1] else sorted_table.index(ij)
69+
)
70+
dtype_pairs = [(dtype_mapping[i], dtype_mapping[j]) for i, j in sorted_table]
6471
if FILTER_UNDEFINED_DTYPES:
65-
filtered_dtype_pairs = [(i, j) for i, j in filtered_dtype_pairs
66-
if not isinstance(i, _UndefinedStub)
67-
and not isinstance(j, _UndefinedStub)]
68-
return draw(sampled_from(filtered_dtype_pairs))
72+
dtype_pairs = [(i, j) for i, j in dtype_pairs
73+
if not isinstance(i, _UndefinedStub)
74+
and not isinstance(j, _UndefinedStub)]
75+
dtype_pairs = [(i, j) for i, j in dtype_pairs if i in dtype_objects and j in dtype_objects]
76+
return draw(sampled_from(dtype_pairs))
77+
78+
shared_mutually_promotable_dtype_pairs = shared(
79+
mutually_promotable_dtypes(), key="mutually_promotable_dtype_pair"
80+
)
6981

7082
# shared() allows us to draw either the function or the function name and they
7183
# will both correspond to the same function.
@@ -96,36 +108,35 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)
96108
return lists(elements, min_size=min_size, max_size=max_size,
97109
unique_by=unique_by, unique=unique).map(tuple)
98110

99-
shapes = tuples(integers(0, 10)).filter(lambda shape: prod(shape) < MAX_ARRAY_SIZE)
100-
101111
# Use this to avoid memory errors with NumPy.
102112
# See https://github.com/numpy/numpy/issues/15753
103-
shapes = tuples(integers(0, 10)).filter(
104-
lambda shape: prod([i for i in shape if i]) < MAX_ARRAY_SIZE)
113+
shapes = xps.array_shapes(min_dims=0, min_side=0).filter(
114+
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
115+
)
105116

106-
two_mutually_broadcastable_shapes = mutually_broadcastable_shapes(num_shapes=2)\
117+
two_mutually_broadcastable_shapes = xps.mutually_broadcastable_shapes(num_shapes=2)\
107118
.map(lambda S: S.input_shapes)\
108-
.filter(lambda S: all(prod([i for i in shape if i]) < MAX_ARRAY_SIZE for shape in S))
119+
.filter(lambda S: all(prod(i for i in shape if i) < MAX_ARRAY_SIZE for shape in S))
109120

110121
@composite
111-
def two_broadcastable_shapes(draw, shapes=shapes):
122+
def two_broadcastable_shapes(draw):
112123
"""
113124
This will produce two shapes (shape1, shape2) such that shape2 can be
114125
broadcast to shape1.
115-
116126
"""
117127
from .test_broadcasting import broadcast_shapes
118-
119-
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
120-
if broadcast_shapes(shape1, shape2) != shape1:
121-
assume(False)
128+
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
129+
assume(broadcast_shapes(shape1, shape2) == shape1)
122130
return (shape1, shape2)
123131

124132
sizes = integers(0, MAX_ARRAY_SIZE)
125133
sqrt_sizes = integers(0, SQRT_MAX_ARRAY_SIZE)
126134

127135
# TODO: Generate general arrays here, rather than just scalars.
128-
numeric_arrays = builds(full, just((1,)), floats())
136+
numeric_arrays = xps.arrays(
137+
dtype=shared(xps.floating_dtypes(), key='dtypes'),
138+
shape=shared(xps.array_shapes(), key='shapes'),
139+
)
129140

130141
@composite
131142
def scalars(draw, dtypes, finite=False):
@@ -230,3 +241,22 @@ def multiaxis_indices(draw, shapes):
230241
extra = draw(lists(one_of(integer_indices(sizes), slices(sizes)), min_size=0, max_size=3))
231242
res += extra
232243
return tuple(res)
244+
245+
246+
shared_arrays1 = xps.arrays(
247+
dtype=shared_mutually_promotable_dtype_pairs.map(lambda pair: pair[0]),
248+
shape=shared(two_mutually_broadcastable_shapes, key="shape_pair").map(lambda pair: pair[0]),
249+
)
250+
shared_arrays2 = xps.arrays(
251+
dtype=shared_mutually_promotable_dtype_pairs.map(lambda pair: pair[1]),
252+
shape=shared(two_mutually_broadcastable_shapes, key="shape_pair").map(lambda pair: pair[1]),
253+
)
254+
255+
256+
@composite
257+
def kwargs(draw, **kw):
258+
result = {}
259+
for k, strat in kw.items():
260+
if draw(booleans()):
261+
result[k] = draw(strat)
262+
return result
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from math import prod
2+
3+
import pytest
4+
from hypothesis import given, strategies as st
5+
6+
from .. import _array_module as xp
7+
from .._array_module import _UndefinedStub
8+
from .. import array_helpers as ah
9+
from .. import hypothesis_helpers as hh
10+
11+
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in ah.dtype_objects)
12+
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
13+
14+
15+
@given(hh.mutually_promotable_dtypes([xp.float32, xp.float64]))
16+
def test_mutually_promotable_dtypes(pairs):
17+
assert pairs in (
18+
(xp.float32, xp.float32),
19+
(xp.float32, xp.float64),
20+
(xp.float64, xp.float32),
21+
(xp.float64, xp.float64),
22+
)
23+
24+
25+
def valid_shape(shape) -> bool:
26+
return (
27+
all(isinstance(side, int) for side in shape)
28+
and all(side >= 0 for side in shape)
29+
and prod(shape) < hh.MAX_ARRAY_SIZE
30+
)
31+
32+
33+
@given(hh.shapes)
34+
def test_shapes(shape):
35+
assert valid_shape(shape)
36+
37+
38+
@given(hh.two_mutually_broadcastable_shapes)
39+
def test_two_mutually_broadcastable_shapes(pair):
40+
for shape in pair:
41+
assert valid_shape(shape)
42+
43+
44+
@given(hh.two_broadcastable_shapes())
45+
def test_two_broadcastable_shapes(pair):
46+
for shape in pair:
47+
assert valid_shape(shape)
48+
49+
from ..test_broadcasting import broadcast_shapes
50+
51+
assert broadcast_shapes(pair[0], pair[1]) == pair[0]
52+
53+
54+
def test_kwargs():
55+
results = []
56+
57+
@given(hh.kwargs(n=st.integers(0, 10), c=st.from_regex("[a-f]")))
58+
def run(kw):
59+
results.append(kw)
60+
61+
run()
62+
assert all(isinstance(kw, dict) for kw in results)
63+
for size in [0, 1, 2]:
64+
assert any(len(kw) == size for kw in results)
65+
66+
n_results = [kw for kw in results if "n" in kw]
67+
assert len(n_results) > 0
68+
assert all(isinstance(kw["n"], int) for kw in n_results)
69+
70+
c_results = [kw for kw in results if "c" in kw]
71+
assert len(c_results) > 0
72+
assert all(isinstance(kw["c"], str) for kw in c_results)

0 commit comments

Comments
 (0)