Skip to content

Commit 67a2d6b

Browse files
committed
mutually_promotable_dtypes() can generate 2-or-more dtypes
1 parent 2e342bd commit 67a2d6b

File tree

5 files changed

+44
-37
lines changed

5 files changed

+44
-37
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
from operator import mul
33
from math import sqrt
44
import itertools
5-
from typing import Tuple
5+
from typing import Tuple, Optional
66

77
from hypothesis import assume
88
from hypothesis.strategies import (lists, integers, sampled_from,
99
shared, floats, just, composite, one_of,
10-
none, booleans)
11-
from hypothesis.strategies._internal.strategies import SearchStrategy
10+
none, booleans, SearchStrategy)
1211

1312
from .pytest_helpers import nargs
1413
from .array_helpers import ndindex
@@ -77,10 +76,34 @@ def _dtypes_sorter(dtype_pair):
7776
]
7877

7978

80-
def mutually_promotable_dtypes(dtype_objs=dh.all_dtypes):
81-
return sampled_from(
82-
[(i, j) for i, j in promotable_dtypes if i in dtype_objs and j in dtype_objs]
83-
)
79+
def mutually_promotable_dtypes(
80+
max_size: Optional[int] = 2,
81+
*,
82+
dtypes=dh.all_dtypes,
83+
) -> SearchStrategy[Tuple]:
84+
if max_size == 2:
85+
return sampled_from(
86+
[(i, j) for i, j in promotable_dtypes if i in dtypes and j in dtypes]
87+
)
88+
if isinstance(max_size, int) and max_size < 2:
89+
raise ValueError(f'{max_size=} should be >=2')
90+
strats = []
91+
category_samples = {
92+
category: [d for d in dtypes if d in category] for category in _dtype_categories
93+
}
94+
for samples in category_samples.values():
95+
if len(samples) > 0:
96+
strat = lists(sampled_from(samples), min_size=2, max_size=max_size)
97+
strats.append(strat)
98+
if len(category_samples[dh.uint_dtypes]) > 0 and len(category_samples[dh.int_dtypes]) > 0:
99+
mixed_samples = category_samples[dh.uint_dtypes] + category_samples[dh.int_dtypes]
100+
strat = lists(sampled_from(mixed_samples), min_size=2, max_size=max_size)
101+
if xp.uint64 in mixed_samples:
102+
strat = strat.filter(
103+
lambda l: not (xp.uint64 in l and any(d in dh.int_dtypes for d in l))
104+
)
105+
return one_of(strats).map(tuple)
106+
84107

85108
# shared() allows us to draw either the function or the function name and they
86109
# will both correspond to the same function.
@@ -324,9 +347,9 @@ def multiaxis_indices(draw, shapes):
324347

325348

326349
def two_mutual_arrays(
327-
dtype_objs=dh.all_dtypes, two_shapes=two_mutually_broadcastable_shapes
350+
dtypes=dh.all_dtypes, two_shapes=two_mutually_broadcastable_shapes
328351
):
329-
mutual_dtypes = shared(mutually_promotable_dtypes(dtype_objs))
352+
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes))
330353
mutual_shapes = shared(two_shapes)
331354
arrays1 = xps.arrays(
332355
dtype=mutual_dtypes.map(lambda pair: pair[0]),

array_api_tests/meta_tests/test_hypothesis_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes)
1515
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
1616

17-
@given(hh.mutually_promotable_dtypes([xp.float32, xp.float64]))
17+
@given(hh.mutually_promotable_dtypes(dtypes=dh.float_dtypes))
1818
def test_mutually_promotable_dtypes(pairs):
1919
assert pairs in (
2020
(xp.float32, xp.float32),

array_api_tests/test_elementwise_functions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@
2929
integer_or_boolean_scalars = hh.array_scalars(hh.integer_or_boolean_dtypes)
3030
boolean_scalars = hh.array_scalars(hh.boolean_dtypes)
3131

32-
two_integer_dtypes = hh.mutually_promotable_dtypes(dh.all_int_dtypes)
33-
two_floating_dtypes = hh.mutually_promotable_dtypes(dh.float_dtypes)
34-
two_numeric_dtypes = hh.mutually_promotable_dtypes(dh.numeric_dtypes)
35-
two_integer_or_boolean_dtypes = hh.mutually_promotable_dtypes(dh.bool_and_all_int_dtypes)
36-
two_boolean_dtypes = hh.mutually_promotable_dtypes((xp.bool,))
32+
two_integer_dtypes = hh.mutually_promotable_dtypes(dtypes=dh.all_int_dtypes)
33+
two_floating_dtypes = hh.mutually_promotable_dtypes(dtypes=dh.float_dtypes)
34+
two_numeric_dtypes = hh.mutually_promotable_dtypes(dtypes=dh.numeric_dtypes)
35+
two_integer_or_boolean_dtypes = hh.mutually_promotable_dtypes(dtypes=dh.bool_and_all_int_dtypes)
36+
two_boolean_dtypes = hh.mutually_promotable_dtypes(dtypes=(xp.bool,))
3737
two_any_dtypes = hh.mutually_promotable_dtypes()
3838

3939
@st.composite

array_api_tests/test_linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def cross_args(draw, dtype_objects=dh.numeric_dtypes):
110110
axis = kw.get('axis', -1)
111111
shape[axis] = 3
112112

113-
mutual_dtypes = shared(mutually_promotable_dtypes(dtype_objects))
113+
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtype_objects))
114114
arrays1 = xps.arrays(
115115
dtype=mutual_dtypes.map(lambda pair: pair[0]),
116116
shape=shape,
@@ -342,7 +342,7 @@ def test_matrix_transpose(x):
342342
_test_stacks(_array_module.matrix_transpose, x, res=res, true_val=true_val)
343343

344344
@given(
345-
*two_mutual_arrays(dtype_objs=dh.numeric_dtypes,
345+
*two_mutual_arrays(dtypes=dh.numeric_dtypes,
346346
two_shapes=tuples(one_d_shapes, one_d_shapes))
347347
)
348348
def test_outer(x1, x2):

array_api_tests/test_type_promotion.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,23 +42,7 @@ def assert_dtype(test_case: str, result_name: str, dtype: DT, expected: DT):
4242
assert dtype == expected, msg
4343

4444

45-
def multi_promotable_dtypes(
46-
allow_bool: bool = True,
47-
) -> st.SearchStrategy[Tuple[DT, ...]]:
48-
strats = [
49-
st.lists(st.sampled_from(dh.uint_dtypes), min_size=2),
50-
st.lists(st.sampled_from(dh.int_dtypes), min_size=2),
51-
st.lists(st.sampled_from(dh.float_dtypes), min_size=2),
52-
st.lists(st.sampled_from(dh.all_int_dtypes), min_size=2).filter(
53-
lambda l: not (xp.uint64 in l and any(d in dh.int_dtypes for d in l))
54-
),
55-
]
56-
if allow_bool:
57-
strats.insert(0, st.lists(st.just(xp.bool), min_size=2))
58-
return st.one_of(strats).map(tuple)
59-
60-
61-
@given(multi_promotable_dtypes())
45+
@given(hh.mutually_promotable_dtypes(None))
6246
def test_result_type(dtypes):
6347
out = xp.result_type(*dtypes)
6448
assert_dtype(
@@ -67,7 +51,7 @@ def test_result_type(dtypes):
6751

6852

6953
@given(
70-
dtypes=multi_promotable_dtypes(allow_bool=False),
54+
dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes),
7155
data=st.data(),
7256
)
7357
def test_meshgrid(dtypes, data):
@@ -85,7 +69,7 @@ def test_meshgrid(dtypes, data):
8569

8670
@given(
8771
shape=hh.shapes(min_dims=1),
88-
dtypes=multi_promotable_dtypes(allow_bool=False),
72+
dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes),
8973
data=st.data(),
9074
)
9175
def test_concat(shape, dtypes, data):
@@ -101,7 +85,7 @@ def test_concat(shape, dtypes, data):
10185

10286
@given(
10387
shape=hh.shapes(),
104-
dtypes=multi_promotable_dtypes(),
88+
dtypes=hh.mutually_promotable_dtypes(None),
10589
data=st.data(),
10690
)
10791
def test_stack(shape, dtypes, data):

0 commit comments

Comments
 (0)