Skip to content

Operator tests #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions array_api_tests/algos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
__all__ = ["broadcast_shapes"]


from .typing import Shape


# We use a custom exception to differentiate from potential bugs
class BroadcastError(ValueError):
pass


def _broadcast_shapes(shape1: Shape, shape2: Shape) -> Shape:
"""Broadcasts `shape1` and `shape2`"""
N1 = len(shape1)
N2 = len(shape2)
N = max(N1, N2)
shape = [None for _ in range(N)]
i = N - 1
while i >= 0:
n1 = N1 - N + i
if N1 - N + i >= 0:
d1 = shape1[n1]
else:
d1 = 1
n2 = N2 - N + i
if N2 - N + i >= 0:
d2 = shape2[n2]
else:
d2 = 1

if d1 == 1:
shape[i] = d2
elif d2 == 1:
shape[i] = d1
elif d1 == d2:
shape[i] = d1
else:
raise BroadcastError

i = i - 1

return tuple(shape)


def broadcast_shapes(*shapes: Shape):
if len(shapes) == 0:
raise ValueError("shapes=[] must be non-empty")
elif len(shapes) == 1:
return shapes[0]
result = _broadcast_shapes(shapes[0], shapes[1])
for i in range(2, len(shapes)):
result = _broadcast_shapes(result, shapes[i])
return result
17 changes: 14 additions & 3 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
'dtype_to_scalars',
'is_int_dtype',
'is_float_dtype',
'get_scalar_type',
'dtype_ranges',
'default_int',
'default_float',
Expand All @@ -30,6 +31,7 @@
'binary_op_to_symbol',
'unary_op_to_symbol',
'inplace_op_to_symbol',
'op_to_func',
'fmt_types',
]

Expand Down Expand Up @@ -74,6 +76,15 @@ def is_float_dtype(dtype):
return dtype in float_dtypes


def get_scalar_type(dtype: DataType) -> ScalarType:
if is_int_dtype(dtype):
return int
elif is_float_dtype(dtype):
return float
else:
return bool


class MinMax(NamedTuple):
min: int
max: int
Expand Down Expand Up @@ -332,7 +343,7 @@ def result_type(*dtypes: DataType):
}


_op_to_func = {
op_to_func = {
'__abs__': 'abs',
'__add__': 'add',
'__and__': 'bitwise_and',
Expand All @@ -341,14 +352,14 @@ def result_type(*dtypes: DataType):
'__ge__': 'greater_equal',
'__gt__': 'greater',
'__le__': 'less_equal',
'__lshift__': 'bitwise_left_shift',
'__lt__': 'less',
# '__matmul__': 'matmul', # TODO: support matmul
'__mod__': 'remainder',
'__mul__': 'multiply',
'__ne__': 'not_equal',
'__or__': 'bitwise_or',
'__pow__': 'pow',
'__lshift__': 'bitwise_left_shift',
'__rshift__': 'bitwise_right_shift',
'__sub__': 'subtract',
'__truediv__': 'divide',
Expand All @@ -359,7 +370,7 @@ def result_type(*dtypes: DataType):
}


for op, elwise_func in _op_to_func.items():
for op, elwise_func in op_to_func.items():
func_in_dtypes[op] = func_in_dtypes[elwise_func]
func_returns_bool[op] = func_returns_bool[elwise_func]

Expand Down
4 changes: 2 additions & 2 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from .array_helpers import ndindex
from .function_stubs import elementwise_functions
from .pytest_helpers import nargs
from .typing import DataType, Shape, Array
from .typing import Array, DataType, Shape
from .algos import broadcast_shapes

# Set this to True to not fail tests just because a dtype isn't implemented.
# If no compatible dtype is implemented for a given test, the test will fail
Expand Down Expand Up @@ -218,7 +219,6 @@ def two_broadcastable_shapes(draw):
This will produce two shapes (shape1, shape2) such that shape2 can be
broadcast to shape1.
"""
from .test_broadcasting import broadcast_shapes
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
assume(broadcast_shapes(shape1, shape2) == shape1)
return (shape1, shape2)
Expand Down
35 changes: 35 additions & 0 deletions array_api_tests/meta/test_broadcasting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
https://github.com/data-apis/array-api/blob/master/spec/API_specification/broadcasting.md
"""

import pytest

from ..algos import BroadcastError, _broadcast_shapes


@pytest.mark.parametrize(
"shape1, shape2, expected",
[
[(8, 1, 6, 1), (7, 1, 5), (8, 7, 6, 5)],
[(5, 4), (1,), (5, 4)],
[(5, 4), (4,), (5, 4)],
[(15, 3, 5), (15, 1, 5), (15, 3, 5)],
[(15, 3, 5), (3, 5), (15, 3, 5)],
[(15, 3, 5), (3, 1), (15, 3, 5)],
],
)
def test_broadcast_shapes(shape1, shape2, expected):
assert _broadcast_shapes(shape1, shape2) == expected


@pytest.mark.parametrize(
"shape1, shape2",
[
[(3,), (4,)], # dimension does not match
[(2, 1), (8, 4, 3)], # second dimension does not match
[(15, 3, 5), (15, 3)], # singleton dimensions can only be prepended
],
)
def test_broadcast_shapes_fails_on_bad_shapes(shape1, shape2):
with pytest.raises(BroadcastError):
_broadcast_shapes(shape1, shape2)
9 changes: 5 additions & 4 deletions array_api_tests/meta/test_hypothesis_helpers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from math import prod

import pytest
from hypothesis import given, strategies as st, settings
from hypothesis import given, settings
from hypothesis import strategies as st

from .. import _array_module as xp
from .. import xps
from .._array_module import _UndefinedStub
from .. import array_helpers as ah
from .. import dtype_helpers as dh
from .. import hypothesis_helpers as hh
from ..test_broadcasting import broadcast_shapes
from .. import xps
from .._array_module import _UndefinedStub
from ..algos import broadcast_shapes

UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes)
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
Expand Down
37 changes: 33 additions & 4 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from array_api_tests.algos import broadcast_shapes
import math
from inspect import getfullargspec
from typing import Any, Dict, Optional, Tuple, Union
Expand All @@ -17,6 +18,7 @@
"assert_default_float",
"assert_default_int",
"assert_shape",
"assert_result_shape",
"assert_fill",
]

Expand Down Expand Up @@ -69,15 +71,15 @@ def assert_dtype(
out_dtype: DataType,
expected: Optional[DataType] = None,
*,
out_name: str = "out.dtype",
repr_name: str = "out.dtype",
):
f_in_dtypes = dh.fmt_types(in_dtypes)
f_out_dtype = dh.dtype_to_name[out_dtype]
if expected is None:
expected = dh.result_type(*in_dtypes)
f_expected = dh.dtype_to_name[expected]
msg = (
f"{out_name}={f_out_dtype}, but should be {f_expected} "
f"{repr_name}={f_out_dtype}, but should be {f_expected} "
f"[{func_name}({f_in_dtypes})]"
)
assert out_dtype == expected, msg
Expand Down Expand Up @@ -114,14 +116,41 @@ def assert_default_int(func_name: str, dtype: DataType):


def assert_shape(
func_name: str, out_shape: Union[int, Shape], expected: Union[int, Shape], /, **kw
func_name: str,
out_shape: Union[int, Shape],
expected: Union[int, Shape],
/,
repr_name="out.shape",
**kw,
):
if isinstance(out_shape, int):
out_shape = (out_shape,)
if isinstance(expected, int):
expected = (expected,)
msg = (
f"out.shape={out_shape}, but should be {expected} [{func_name}({fmt_kw(kw)})]"
f"{repr_name}={out_shape}, but should be {expected} [{func_name}({fmt_kw(kw)})]"
)
assert out_shape == expected, msg


def assert_result_shape(
func_name: str,
in_shapes: Tuple[Shape],
out_shape: Shape,
/,
expected: Optional[Shape] = None,
*,
repr_name="out.shape",
**kw,
):
if expected is None:
expected = broadcast_shapes(*in_shapes)
f_in_shapes = " . ".join(str(s) for s in in_shapes)
f_sig = f" {f_in_shapes} "
if kw:
f_sig += f", {fmt_kw(kw)}"
msg = (
f"{repr_name}={out_shape}, but should be {expected} [{func_name}({f_sig})]"
)
assert out_shape == expected, msg

Expand Down
Loading