Skip to content

Commit b8399a1

Browse files
committed
Internally flatmap promotable_dtypes
1 parent 4807b86 commit b8399a1

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from math import sqrt
44

55
from hypothesis.strategies import (lists, integers, builds, sampled_from,
6-
shared, floats, just, composite, one_of,
7-
none, booleans)
6+
shared, tuples as hypotheses_tuples,
7+
floats, just, composite, one_of, none,
8+
booleans, SearchStrategy)
89
from hypothesis.extra.numpy import mutually_broadcastable_shapes
910
from hypothesis.extra.array_api import make_strategies_namespace
1011
from hypothesis import assume
@@ -70,6 +71,8 @@ def make_dtype_pairs():
7071
return dtype_pairs
7172

7273
def promotable_dtypes(dtype):
74+
if isinstance(dtype, SearchStrategy):
75+
return dtype.flatmap(promotable_dtypes)
7376
dtype_pairs = make_dtype_pairs()
7477
dtypes = [j for i, j in dtype_pairs if i == dtype]
7578
return sampled_from(dtypes)

array_api_tests/test_creation_functions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_empty(shape, dtype):
8484
dtype=shared_dtypes,
8585
shape=xps.array_shapes(),
8686
),
87-
dtype=one_of(none(), shared_dtypes.flatmap(promotable_dtypes)),
87+
dtype=one_of(none(), promotable_dtypes(shared_dtypes)),
8888
)
8989
def test_empty_like(a, dtype):
9090
kwargs = {} if dtype is None else {'dtype': dtype}
@@ -155,8 +155,8 @@ def test_full(shape, fill_value, dtype):
155155
dtype=shared_dtypes,
156156
shape=xps.array_shapes(),
157157
),
158-
fill_value=shared_dtypes.flatmap(promotable_dtypes).flatmap(xps.from_dtype),
159-
dtype=one_of(none(), shared_dtypes.flatmap(promotable_dtypes)),
158+
fill_value=promotable_dtypes(shared_dtypes).flatmap(xps.from_dtype),
159+
dtype=one_of(none(), promotable_dtypes(shared_dtypes)),
160160
)
161161
def test_full_like(a, fill_value, dtype):
162162
kwargs = {} if dtype is None else {'dtype': dtype}
@@ -247,7 +247,7 @@ def test_ones(shape, dtype):
247247
dtype=shared_dtypes,
248248
shape=xps.array_shapes(),
249249
),
250-
dtype=one_of(none(), shared_dtypes.flatmap(promotable_dtypes)),
250+
dtype=one_of(none(), promotable_dtypes(shared_dtypes)),
251251
)
252252
def test_ones_like(a, dtype):
253253
kwargs = {} if dtype is None else {'dtype': dtype}
@@ -298,7 +298,7 @@ def test_zeros(shape, dtype):
298298
dtype=shared_dtypes,
299299
shape=xps.array_shapes(),
300300
),
301-
dtype=one_of(none(), shared_dtypes.flatmap(promotable_dtypes)),
301+
dtype=one_of(none(), promotable_dtypes(shared_dtypes)),
302302
)
303303
def test_zeros_like(a, dtype):
304304
kwargs = {} if dtype is None else {'dtype': dtype}

0 commit comments

Comments
 (0)