Skip to content

More promotion test cases #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7ae4b73
Rudimentary where type promotion test
honno Oct 18, 2021
cbc2f26
Rudimentary `test_result_type`, library-agnostic `dh.result_type()`
honno Oct 19, 2021
8ce99e3
Rudimentary `test_meshgrid`
honno Oct 19, 2021
705bfd1
Make `hh.shapes` a wrapper function, rudimentary `test_concat`
honno Oct 19, 2021
59364aa
Rudimentary `test_stack`
honno Oct 19, 2021
3324f44
Rudimentary `test_matmul`
honno Oct 19, 2021
a9d5d20
Rudimentary tensor/vec dot tests
honno Oct 19, 2021
508abcc
Alias `Param` type hint as `Tuple`
honno Oct 20, 2021
ee9ab19
Include func/op and param dtypes in type promotion error messages
honno Oct 20, 2021
65db28a
Replace array filtering with `xps.from_dtype()` kwargs
honno Oct 20, 2021
b82358e
Scrap generating kwargs for promotion tests
honno Oct 20, 2021
deebaef
Better minimisation behaviour for `multi_promotable_dtypes()`
honno Oct 20, 2021
c041228
Use setdefault instead of manualy keys check in `hh.shapes()`
honno Oct 21, 2021
bbce580
Remove faulty matmul type promotion test
honno Oct 21, 2021
0c71410
`mutually_promotable_dtypes()` can generate 2-or-more dtypes
honno Oct 21, 2021
f438298
Factor out `assert_dtype` and `fmt_types`, add `typing.py`
honno Oct 21, 2021
ed76b32
Construct test case name in `ph.assert_dtype()`
honno Oct 21, 2021
6206a55
Use `ph.assert_dtype` in `test_matmul` (proof of concept)
honno Oct 21, 2021
d6665c0
Type hint some hypothesis helpers
honno Oct 22, 2021
dc211d4
Default `expected` and `out_name` in `ph.assert_dtype()`
honno Oct 22, 2021
27ded9b
Accept single dtype in `dh.resut_type()` and thus `ph.assert_dtype()`
honno Oct 22, 2021
8acad7b
Remove `sanity_check()` in elementwise
honno Oct 22, 2021
2d91340
Comment that non-elementwise promotion tests are temporary
honno Oct 25, 2021
866dcd8
Support incomplete case names for `xfails.txt`
honno Oct 25, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/numpy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ jobs:
array_api_tests/test_signatures.py::test_function_keyword_only_args[__dlpack__]
# floor_divide has an issue related to https://github.com/data-apis/array-api/issues/264
array_api_tests/test_elementwise_functions.py::test_floor_divide
# mesgrid doesn't return all arrays as the promoted dtype
array_api_tests/test_type_promotion.py::test_meshgrid
# https://github.com/numpy/numpy/pull/20066#issuecomment-947056094
array_api_tests/test_type_promotion.py::test_where
# shape mismatches are not handled
array_api_tests/test_type_promotion.py::test_tensordot

EOF

Expand Down
42 changes: 37 additions & 5 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from warnings import warn
from typing import NamedTuple
from functools import lru_cache
from typing import NamedTuple, Tuple, Union

from . import _array_module as xp
from ._array_module import _UndefinedStub
from .typing import DataType, ScalarType


__all__ = [
Expand All @@ -28,19 +30,20 @@
'binary_op_to_symbol',
'unary_op_to_symbol',
'inplace_op_to_symbol',
'fmt_types',
]


_int_names = ('int8', 'int16', 'int32', 'int64')
_uint_names = ('uint8', 'uint16', 'uint32', 'uint64')
_int_names = ('int8', 'int16', 'int32', 'int64')
_float_names = ('float32', 'float64')
_dtype_names = ('bool',) + _int_names + _uint_names + _float_names
_dtype_names = ('bool',) + _uint_names + _int_names + _float_names


int_dtypes = tuple(getattr(xp, name) for name in _int_names)
uint_dtypes = tuple(getattr(xp, name) for name in _uint_names)
int_dtypes = tuple(getattr(xp, name) for name in _int_names)
float_dtypes = tuple(getattr(xp, name) for name in _float_names)
all_int_dtypes = int_dtypes + uint_dtypes
all_int_dtypes = uint_dtypes + int_dtypes
numeric_dtypes = all_int_dtypes + float_dtypes
all_dtypes = (xp.bool,) + numeric_dtypes
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
Expand Down Expand Up @@ -148,6 +151,17 @@ class MinMax(NamedTuple):
}


def result_type(*dtypes: DataType):
if len(dtypes) == 0:
raise ValueError()
elif len(dtypes) == 1:
return dtypes[0]
result = promotion_table[dtypes[0], dtypes[1]]
for i in range(2, len(dtypes)):
result = promotion_table[result, dtypes[i]]
return result


dtype_nbits = {
**{d: 8 for d in [xp.int8, xp.uint8]},
**{d: 16 for d in [xp.int16, xp.uint16]},
Expand All @@ -163,6 +177,7 @@ class MinMax(NamedTuple):


func_in_dtypes = {
# elementwise
'abs': numeric_dtypes,
'acos': float_dtypes,
'acosh': float_dtypes,
Expand Down Expand Up @@ -219,10 +234,13 @@ class MinMax(NamedTuple):
'tan': float_dtypes,
'tanh': float_dtypes,
'trunc': numeric_dtypes,
# searching
'where': all_dtypes,
}


func_returns_bool = {
# elementwise
'abs': False,
'acos': False,
'acosh': False,
Expand Down Expand Up @@ -279,6 +297,8 @@ class MinMax(NamedTuple):
'tan': False,
'tanh': False,
'trunc': False,
# searching
'where': False,
}


Expand Down Expand Up @@ -352,3 +372,15 @@ class MinMax(NamedTuple):
inplace_op_to_symbol[iop] = f'{symbol}='
func_in_dtypes[iop] = func_in_dtypes[op]
func_returns_bool[iop] = func_returns_bool[op]


@lru_cache
def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str:
f_types = []
for type_ in types:
try:
f_types.append(dtype_to_name[type_])
except KeyError:
# i.e. dtype is bool, int, or float
f_types.append(type_.__name__)
return ', '.join(f_types)
71 changes: 51 additions & 20 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
from operator import mul
from math import sqrt
import itertools
from typing import Tuple
from typing import Tuple, Optional, List

from hypothesis import assume
from hypothesis.strategies import (lists, integers, sampled_from,
shared, floats, just, composite, one_of,
none, booleans)
from hypothesis.strategies._internal.strategies import SearchStrategy
none, booleans, SearchStrategy)

from .pytest_helpers import nargs
from .array_helpers import ndindex
from .typing import DataType, Shape
from . import dtype_helpers as dh
from ._array_module import (full, float32, float64, bool as bool_dtype,
_UndefinedStub, eye, broadcast_to)
Expand Down Expand Up @@ -50,7 +50,7 @@
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes]
_sorted_dtypes = [d for category in _dtype_categories for d in category]

def _dtypes_sorter(dtype_pair):
def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
dtype1, dtype2 = dtype_pair
if dtype1 == dtype2:
return _sorted_dtypes.index(dtype1)
Expand All @@ -67,7 +67,7 @@ def _dtypes_sorter(dtype_pair):
key += 1
return key

promotable_dtypes = sorted(dh.promotion_table.keys(), key=_dtypes_sorter)
promotable_dtypes: List[Tuple[DataType, DataType]] = sorted(dh.promotion_table.keys(), key=_dtypes_sorter)

if FILTER_UNDEFINED_DTYPES:
promotable_dtypes = [
Expand All @@ -77,10 +77,34 @@ def _dtypes_sorter(dtype_pair):
]


def mutually_promotable_dtypes(dtype_objs=dh.all_dtypes):
return sampled_from(
[(i, j) for i, j in promotable_dtypes if i in dtype_objs and j in dtype_objs]
)
def mutually_promotable_dtypes(
max_size: Optional[int] = 2,
*,
dtypes: Tuple[DataType, ...] = dh.all_dtypes,
) -> SearchStrategy[Tuple[DataType, ...]]:
if max_size == 2:
return sampled_from(
[(i, j) for i, j in promotable_dtypes if i in dtypes and j in dtypes]
)
if isinstance(max_size, int) and max_size < 2:
raise ValueError(f'{max_size=} should be >=2')
strats = []
category_samples = {
category: [d for d in dtypes if d in category] for category in _dtype_categories
}
for samples in category_samples.values():
if len(samples) > 0:
strat = lists(sampled_from(samples), min_size=2, max_size=max_size)
strats.append(strat)
if len(category_samples[dh.uint_dtypes]) > 0 and len(category_samples[dh.int_dtypes]) > 0:
mixed_samples = category_samples[dh.uint_dtypes] + category_samples[dh.int_dtypes]
strat = lists(sampled_from(mixed_samples), min_size=2, max_size=max_size)
if xp.uint64 in mixed_samples:
strat = strat.filter(
lambda l: not (xp.uint64 in l and any(d in dh.int_dtypes for d in l))
)
return one_of(strats).map(tuple)


# shared() allows us to draw either the function or the function name and they
# will both correspond to the same function.
Expand Down Expand Up @@ -113,15 +137,19 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)

# Use this to avoid memory errors with NumPy.
# See https://github.com/numpy/numpy/issues/15753
shapes = xps.array_shapes(min_dims=0, min_side=0).filter(
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
)
def shapes(**kw):
kw.setdefault('min_dims', 0)
kw.setdefault('min_side', 0)
return xps.array_shapes(**kw).filter(
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
)


one_d_shapes = xps.array_shapes(min_dims=1, max_dims=1, min_side=0, max_side=SQRT_MAX_ARRAY_SIZE)

# Matrix shapes assume stacks of matrices
@composite
def matrix_shapes(draw, stack_shapes=shapes):
def matrix_shapes(draw, stack_shapes=shapes()):
stack_shape = draw(stack_shapes)
mat_shape = draw(xps.array_shapes(max_dims=2, min_dims=2))
shape = stack_shape + mat_shape
Expand All @@ -135,9 +163,11 @@ def matrix_shapes(draw, stack_shapes=shapes):
elements=dict(allow_nan=False,
allow_infinity=False))

def mutually_broadcastable_shapes(num_shapes: int) -> SearchStrategy[Tuple[Tuple]]:
def mutually_broadcastable_shapes(
num_shapes: int, **kw
) -> SearchStrategy[Tuple[Shape, ...]]:
return (
xps.mutually_broadcastable_shapes(num_shapes)
xps.mutually_broadcastable_shapes(num_shapes, **kw)
.map(lambda BS: BS.input_shapes)
.filter(lambda shapes: all(
prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes
Expand All @@ -164,13 +194,13 @@ def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
# using something like
# https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/datasets/_samples_generator.py#L1351.
n = draw(integers(0))
shape = draw(shapes) + (n, n)
shape = draw(shapes()) + (n, n)
assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE)
dtype = draw(dtypes)
return broadcast_to(eye(n, dtype=dtype), shape)

@composite
def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes):
def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes()):
# For now, just generate stacks of diagonal matrices.
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
stack_shape = draw(stack_shapes)
Expand Down Expand Up @@ -318,9 +348,10 @@ def multiaxis_indices(draw, shapes):


def two_mutual_arrays(
dtype_objs=dh.all_dtypes, two_shapes=two_mutually_broadcastable_shapes
):
mutual_dtypes = shared(mutually_promotable_dtypes(dtype_objs))
dtypes: Tuple[DataType, ...] = dh.all_dtypes,
two_shapes: SearchStrategy[Tuple[Shape, Shape]] = two_mutually_broadcastable_shapes,
) -> SearchStrategy:
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes))
mutual_shapes = shared(two_shapes)
arrays1 = xps.arrays(
dtype=mutual_dtypes.map(lambda pair: pair[0]),
Expand Down
7 changes: 3 additions & 4 deletions array_api_tests/meta/test_hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from .. import dtype_helpers as dh
from .. import hypothesis_helpers as hh
from ..test_broadcasting import broadcast_shapes
from ..test_elementwise_functions import sanity_check

UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes)
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]

@given(hh.mutually_promotable_dtypes([xp.float32, xp.float64]))
@given(hh.mutually_promotable_dtypes(dtypes=dh.float_dtypes))
def test_mutually_promotable_dtypes(pairs):
assert pairs in (
(xp.float32, xp.float32),
Expand All @@ -32,7 +31,7 @@ def valid_shape(shape) -> bool:
)


@given(hh.shapes)
@given(hh.shapes())
def test_shapes(shape):
assert valid_shape(shape)

Expand All @@ -52,7 +51,7 @@ def test_two_broadcastable_shapes(pair):

@given(*hh.two_mutual_arrays())
def test_two_mutual_arrays(x1, x2):
sanity_check(x1, x2)
assert (x1.dtype, x2.dtype) in dh.promotion_table.keys()
assert broadcast_shapes(x1.shape, x2.shape) in (x1.shape, x2.shape)


Expand Down
13 changes: 13 additions & 0 deletions array_api_tests/meta/test_pytest_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pytest import raises

from .. import pytest_helpers as ph
from .. import _array_module as xp


def test_assert_dtype():
ph.assert_dtype("promoted_func", (xp.uint8, xp.int8), xp.int16)
with raises(AssertionError):
ph.assert_dtype("bad_func", (xp.uint8, xp.int8), xp.float32)
ph.assert_dtype("bool_func", (xp.uint8, xp.int8), xp.bool, xp.bool)
ph.assert_dtype("single_promoted_func", (xp.uint8,), xp.uint8)
ph.assert_dtype("single_bool_func", (xp.uint8,), xp.bool, xp.bool)
27 changes: 27 additions & 0 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from inspect import getfullargspec
from typing import Optional, Tuple

from . import dtype_helpers as dh
from . import function_stubs
from .typing import DataType


def raises(exceptions, function, message=''):
"""
Expand Down Expand Up @@ -33,3 +38,25 @@ def doesnt_raise(function, message=''):

def nargs(func_name):
return len(getfullargspec(getattr(function_stubs, func_name)).args)


def assert_dtype(
func_name: str,
in_dtypes: Tuple[DataType, ...],
out_dtype: DataType,
expected: Optional[DataType] = None,
*,
out_name: str = "out.dtype",
):
f_in_dtypes = dh.fmt_types(in_dtypes)
f_out_dtype = dh.dtype_to_name[out_dtype]
if expected is None:
expected = dh.result_type(*in_dtypes)
f_expected = dh.dtype_to_name[expected]
msg = (
f"{out_name}={f_out_dtype}, but should be {f_expected} "
f"[{func_name}({f_in_dtypes})]"
)
assert out_dtype == expected, msg


2 changes: 1 addition & 1 deletion array_api_tests/test_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_broadcast_shapes_explicit_spec():
@pytest.mark.parametrize('func_name', [i for i in
elementwise_functions.__all__ if
nargs(i) > 1])
@given(shape1=shapes, shape2=shapes, data=data())
@given(shape1=shapes(), shape2=shapes(), data=data())
def test_broadcasting_hypothesis(func_name, shape1, shape2, data):
# Internal consistency checks
assert nargs(func_name) == 2
Expand Down
Loading