Skip to content

Commit 1461f9a

Browse files
committed
Better minimisation behaviour for multi_promotable_dtypes()
1 parent 8985b33 commit 1461f9a

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@
2727
]
2828

2929

30-
_int_names = ('int8', 'int16', 'int32', 'int64')
3130
_uint_names = ('uint8', 'uint16', 'uint32', 'uint64')
31+
_int_names = ('int8', 'int16', 'int32', 'int64')
3232
_float_names = ('float32', 'float64')
33-
_dtype_names = ('bool',) + _int_names + _uint_names + _float_names
33+
_dtype_names = ('bool',) + _uint_names + _int_names + _float_names
3434

3535

36-
int_dtypes = tuple(getattr(xp, name) for name in _int_names)
3736
uint_dtypes = tuple(getattr(xp, name) for name in _uint_names)
37+
int_dtypes = tuple(getattr(xp, name) for name in _int_names)
3838
float_dtypes = tuple(getattr(xp, name) for name in _float_names)
39-
all_int_dtypes = int_dtypes + uint_dtypes
39+
all_int_dtypes = uint_dtypes + int_dtypes
4040
numeric_dtypes = all_int_dtypes + float_dtypes
4141
all_dtypes = (xp.bool,) + numeric_dtypes
4242
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes

array_api_tests/test_type_promotion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,15 @@ def multi_promotable_dtypes(
4646
allow_bool: bool = True,
4747
) -> st.SearchStrategy[Tuple[DT, ...]]:
4848
strats = [
49+
st.lists(st.sampled_from(dh.uint_dtypes), min_size=2),
50+
st.lists(st.sampled_from(dh.int_dtypes), min_size=2),
51+
st.lists(st.sampled_from(dh.float_dtypes), min_size=2),
4952
st.lists(st.sampled_from(dh.all_int_dtypes), min_size=2).filter(
5053
lambda l: not (xp.uint64 in l and any(d in dh.int_dtypes for d in l))
5154
),
52-
st.lists(st.sampled_from(dh.float_dtypes), min_size=2),
5355
]
5456
if allow_bool:
55-
strats.append(st.lists(st.just(xp.bool), min_size=2))
57+
strats.insert(0, st.lists(st.just(xp.bool), min_size=2))
5658
return st.one_of(strats).map(tuple)
5759

5860

0 commit comments

Comments
 (0)