Skip to content

Commit a88c3b8

Browse files
authored
Merge pull request #30 from honno/promotion-test-cases
More promotion test cases
2 parents 7c53ded + 866dcd8 commit a88c3b8

14 files changed

+345
-164
lines changed

.github/workflows/numpy.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ jobs:
4141
array_api_tests/test_signatures.py::test_function_keyword_only_args[__dlpack__]
4242
# floor_divide has an issue related to https://github.com/data-apis/array-api/issues/264
4343
array_api_tests/test_elementwise_functions.py::test_floor_divide
44+
# mesgrid doesn't return all arrays as the promoted dtype
45+
array_api_tests/test_type_promotion.py::test_meshgrid
46+
# https://github.com/numpy/numpy/pull/20066#issuecomment-947056094
47+
array_api_tests/test_type_promotion.py::test_where
48+
# shape mismatches are not handled
49+
array_api_tests/test_type_promotion.py::test_tensordot
4450
4551
EOF
4652

array_api_tests/dtype_helpers.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from warnings import warn
2-
from typing import NamedTuple
2+
from functools import lru_cache
3+
from typing import NamedTuple, Tuple, Union
34

45
from . import _array_module as xp
56
from ._array_module import _UndefinedStub
7+
from .typing import DataType, ScalarType
68

79

810
__all__ = [
@@ -28,19 +30,20 @@
2830
'binary_op_to_symbol',
2931
'unary_op_to_symbol',
3032
'inplace_op_to_symbol',
33+
'fmt_types',
3134
]
3235

3336

34-
_int_names = ('int8', 'int16', 'int32', 'int64')
3537
_uint_names = ('uint8', 'uint16', 'uint32', 'uint64')
38+
_int_names = ('int8', 'int16', 'int32', 'int64')
3639
_float_names = ('float32', 'float64')
37-
_dtype_names = ('bool',) + _int_names + _uint_names + _float_names
40+
_dtype_names = ('bool',) + _uint_names + _int_names + _float_names
3841

3942

40-
int_dtypes = tuple(getattr(xp, name) for name in _int_names)
4143
uint_dtypes = tuple(getattr(xp, name) for name in _uint_names)
44+
int_dtypes = tuple(getattr(xp, name) for name in _int_names)
4245
float_dtypes = tuple(getattr(xp, name) for name in _float_names)
43-
all_int_dtypes = int_dtypes + uint_dtypes
46+
all_int_dtypes = uint_dtypes + int_dtypes
4447
numeric_dtypes = all_int_dtypes + float_dtypes
4548
all_dtypes = (xp.bool,) + numeric_dtypes
4649
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
@@ -148,6 +151,17 @@ class MinMax(NamedTuple):
148151
}
149152

150153

154+
def result_type(*dtypes: DataType):
155+
if len(dtypes) == 0:
156+
raise ValueError()
157+
elif len(dtypes) == 1:
158+
return dtypes[0]
159+
result = promotion_table[dtypes[0], dtypes[1]]
160+
for i in range(2, len(dtypes)):
161+
result = promotion_table[result, dtypes[i]]
162+
return result
163+
164+
151165
dtype_nbits = {
152166
**{d: 8 for d in [xp.int8, xp.uint8]},
153167
**{d: 16 for d in [xp.int16, xp.uint16]},
@@ -163,6 +177,7 @@ class MinMax(NamedTuple):
163177

164178

165179
func_in_dtypes = {
180+
# elementwise
166181
'abs': numeric_dtypes,
167182
'acos': float_dtypes,
168183
'acosh': float_dtypes,
@@ -219,10 +234,13 @@ class MinMax(NamedTuple):
219234
'tan': float_dtypes,
220235
'tanh': float_dtypes,
221236
'trunc': numeric_dtypes,
237+
# searching
238+
'where': all_dtypes,
222239
}
223240

224241

225242
func_returns_bool = {
243+
# elementwise
226244
'abs': False,
227245
'acos': False,
228246
'acosh': False,
@@ -279,6 +297,8 @@ class MinMax(NamedTuple):
279297
'tan': False,
280298
'tanh': False,
281299
'trunc': False,
300+
# searching
301+
'where': False,
282302
}
283303

284304

@@ -352,3 +372,15 @@ class MinMax(NamedTuple):
352372
inplace_op_to_symbol[iop] = f'{symbol}='
353373
func_in_dtypes[iop] = func_in_dtypes[op]
354374
func_returns_bool[iop] = func_returns_bool[op]
375+
376+
377+
@lru_cache
378+
def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str:
379+
f_types = []
380+
for type_ in types:
381+
try:
382+
f_types.append(dtype_to_name[type_])
383+
except KeyError:
384+
# i.e. dtype is bool, int, or float
385+
f_types.append(type_.__name__)
386+
return ', '.join(f_types)

array_api_tests/hypothesis_helpers.py

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
from operator import mul
33
from math import sqrt
44
import itertools
5-
from typing import Tuple
5+
from typing import Tuple, Optional, List
66

77
from hypothesis import assume
88
from hypothesis.strategies import (lists, integers, sampled_from,
99
shared, floats, just, composite, one_of,
10-
none, booleans)
11-
from hypothesis.strategies._internal.strategies import SearchStrategy
10+
none, booleans, SearchStrategy)
1211

1312
from .pytest_helpers import nargs
1413
from .array_helpers import ndindex
14+
from .typing import DataType, Shape
1515
from . import dtype_helpers as dh
1616
from ._array_module import (full, float32, float64, bool as bool_dtype,
1717
_UndefinedStub, eye, broadcast_to)
@@ -50,7 +50,7 @@
5050
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes]
5151
_sorted_dtypes = [d for category in _dtype_categories for d in category]
5252

53-
def _dtypes_sorter(dtype_pair):
53+
def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
5454
dtype1, dtype2 = dtype_pair
5555
if dtype1 == dtype2:
5656
return _sorted_dtypes.index(dtype1)
@@ -67,7 +67,7 @@ def _dtypes_sorter(dtype_pair):
6767
key += 1
6868
return key
6969

70-
promotable_dtypes = sorted(dh.promotion_table.keys(), key=_dtypes_sorter)
70+
promotable_dtypes: List[Tuple[DataType, DataType]] = sorted(dh.promotion_table.keys(), key=_dtypes_sorter)
7171

7272
if FILTER_UNDEFINED_DTYPES:
7373
promotable_dtypes = [
@@ -77,10 +77,34 @@ def _dtypes_sorter(dtype_pair):
7777
]
7878

7979

80-
def mutually_promotable_dtypes(dtype_objs=dh.all_dtypes):
81-
return sampled_from(
82-
[(i, j) for i, j in promotable_dtypes if i in dtype_objs and j in dtype_objs]
83-
)
80+
def mutually_promotable_dtypes(
81+
max_size: Optional[int] = 2,
82+
*,
83+
dtypes: Tuple[DataType, ...] = dh.all_dtypes,
84+
) -> SearchStrategy[Tuple[DataType, ...]]:
85+
if max_size == 2:
86+
return sampled_from(
87+
[(i, j) for i, j in promotable_dtypes if i in dtypes and j in dtypes]
88+
)
89+
if isinstance(max_size, int) and max_size < 2:
90+
raise ValueError(f'{max_size=} should be >=2')
91+
strats = []
92+
category_samples = {
93+
category: [d for d in dtypes if d in category] for category in _dtype_categories
94+
}
95+
for samples in category_samples.values():
96+
if len(samples) > 0:
97+
strat = lists(sampled_from(samples), min_size=2, max_size=max_size)
98+
strats.append(strat)
99+
if len(category_samples[dh.uint_dtypes]) > 0 and len(category_samples[dh.int_dtypes]) > 0:
100+
mixed_samples = category_samples[dh.uint_dtypes] + category_samples[dh.int_dtypes]
101+
strat = lists(sampled_from(mixed_samples), min_size=2, max_size=max_size)
102+
if xp.uint64 in mixed_samples:
103+
strat = strat.filter(
104+
lambda l: not (xp.uint64 in l and any(d in dh.int_dtypes for d in l))
105+
)
106+
return one_of(strats).map(tuple)
107+
84108

85109
# shared() allows us to draw either the function or the function name and they
86110
# will both correspond to the same function.
@@ -113,15 +137,19 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)
113137

114138
# Use this to avoid memory errors with NumPy.
115139
# See https://github.com/numpy/numpy/issues/15753
116-
shapes = xps.array_shapes(min_dims=0, min_side=0).filter(
117-
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
118-
)
140+
def shapes(**kw):
141+
kw.setdefault('min_dims', 0)
142+
kw.setdefault('min_side', 0)
143+
return xps.array_shapes(**kw).filter(
144+
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
145+
)
146+
119147

120148
one_d_shapes = xps.array_shapes(min_dims=1, max_dims=1, min_side=0, max_side=SQRT_MAX_ARRAY_SIZE)
121149

122150
# Matrix shapes assume stacks of matrices
123151
@composite
124-
def matrix_shapes(draw, stack_shapes=shapes):
152+
def matrix_shapes(draw, stack_shapes=shapes()):
125153
stack_shape = draw(stack_shapes)
126154
mat_shape = draw(xps.array_shapes(max_dims=2, min_dims=2))
127155
shape = stack_shape + mat_shape
@@ -135,9 +163,11 @@ def matrix_shapes(draw, stack_shapes=shapes):
135163
elements=dict(allow_nan=False,
136164
allow_infinity=False))
137165

138-
def mutually_broadcastable_shapes(num_shapes: int) -> SearchStrategy[Tuple[Tuple]]:
166+
def mutually_broadcastable_shapes(
167+
num_shapes: int, **kw
168+
) -> SearchStrategy[Tuple[Shape, ...]]:
139169
return (
140-
xps.mutually_broadcastable_shapes(num_shapes)
170+
xps.mutually_broadcastable_shapes(num_shapes, **kw)
141171
.map(lambda BS: BS.input_shapes)
142172
.filter(lambda shapes: all(
143173
prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes
@@ -164,13 +194,13 @@ def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
164194
# using something like
165195
# https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/datasets/_samples_generator.py#L1351.
166196
n = draw(integers(0))
167-
shape = draw(shapes) + (n, n)
197+
shape = draw(shapes()) + (n, n)
168198
assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE)
169199
dtype = draw(dtypes)
170200
return broadcast_to(eye(n, dtype=dtype), shape)
171201

172202
@composite
173-
def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes):
203+
def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes()):
174204
# For now, just generate stacks of diagonal matrices.
175205
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
176206
stack_shape = draw(stack_shapes)
@@ -318,9 +348,10 @@ def multiaxis_indices(draw, shapes):
318348

319349

320350
def two_mutual_arrays(
321-
dtype_objs=dh.all_dtypes, two_shapes=two_mutually_broadcastable_shapes
322-
):
323-
mutual_dtypes = shared(mutually_promotable_dtypes(dtype_objs))
351+
dtypes: Tuple[DataType, ...] = dh.all_dtypes,
352+
two_shapes: SearchStrategy[Tuple[Shape, Shape]] = two_mutually_broadcastable_shapes,
353+
) -> SearchStrategy:
354+
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes))
324355
mutual_shapes = shared(two_shapes)
325356
arrays1 = xps.arrays(
326357
dtype=mutual_dtypes.map(lambda pair: pair[0]),

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99
from .. import dtype_helpers as dh
1010
from .. import hypothesis_helpers as hh
1111
from ..test_broadcasting import broadcast_shapes
12-
from ..test_elementwise_functions import sanity_check
1312

1413
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes)
1514
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
1615

17-
@given(hh.mutually_promotable_dtypes([xp.float32, xp.float64]))
16+
@given(hh.mutually_promotable_dtypes(dtypes=dh.float_dtypes))
1817
def test_mutually_promotable_dtypes(pairs):
1918
assert pairs in (
2019
(xp.float32, xp.float32),
@@ -32,7 +31,7 @@ def valid_shape(shape) -> bool:
3231
)
3332

3433

35-
@given(hh.shapes)
34+
@given(hh.shapes())
3635
def test_shapes(shape):
3736
assert valid_shape(shape)
3837

@@ -52,7 +51,7 @@ def test_two_broadcastable_shapes(pair):
5251

5352
@given(*hh.two_mutual_arrays())
5453
def test_two_mutual_arrays(x1, x2):
55-
sanity_check(x1, x2)
54+
assert (x1.dtype, x2.dtype) in dh.promotion_table.keys()
5655
assert broadcast_shapes(x1.shape, x2.shape) in (x1.shape, x2.shape)
5756

5857

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from pytest import raises
2+
3+
from .. import pytest_helpers as ph
4+
from .. import _array_module as xp
5+
6+
7+
def test_assert_dtype():
8+
ph.assert_dtype("promoted_func", (xp.uint8, xp.int8), xp.int16)
9+
with raises(AssertionError):
10+
ph.assert_dtype("bad_func", (xp.uint8, xp.int8), xp.float32)
11+
ph.assert_dtype("bool_func", (xp.uint8, xp.int8), xp.bool, xp.bool)
12+
ph.assert_dtype("single_promoted_func", (xp.uint8,), xp.uint8)
13+
ph.assert_dtype("single_bool_func", (xp.uint8,), xp.bool, xp.bool)

array_api_tests/pytest_helpers.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
from inspect import getfullargspec
2+
from typing import Optional, Tuple
3+
4+
from . import dtype_helpers as dh
25
from . import function_stubs
6+
from .typing import DataType
7+
38

49
def raises(exceptions, function, message=''):
510
"""
@@ -33,3 +38,25 @@ def doesnt_raise(function, message=''):
3338

3439
def nargs(func_name):
3540
return len(getfullargspec(getattr(function_stubs, func_name)).args)
41+
42+
43+
def assert_dtype(
44+
func_name: str,
45+
in_dtypes: Tuple[DataType, ...],
46+
out_dtype: DataType,
47+
expected: Optional[DataType] = None,
48+
*,
49+
out_name: str = "out.dtype",
50+
):
51+
f_in_dtypes = dh.fmt_types(in_dtypes)
52+
f_out_dtype = dh.dtype_to_name[out_dtype]
53+
if expected is None:
54+
expected = dh.result_type(*in_dtypes)
55+
f_expected = dh.dtype_to_name[expected]
56+
msg = (
57+
f"{out_name}={f_out_dtype}, but should be {f_expected} "
58+
f"[{func_name}({f_in_dtypes})]"
59+
)
60+
assert out_dtype == expected, msg
61+
62+

array_api_tests/test_broadcasting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_broadcast_shapes_explicit_spec():
110110
@pytest.mark.parametrize('func_name', [i for i in
111111
elementwise_functions.__all__ if
112112
nargs(i) > 1])
113-
@given(shape1=shapes, shape2=shapes, data=data())
113+
@given(shape1=shapes(), shape2=shapes(), data=data())
114114
def test_broadcasting_hypothesis(func_name, shape1, shape2, data):
115115
# Internal consistency checks
116116
assert nargs(func_name) == 2

0 commit comments

Comments
 (0)