Skip to content

Commit 035e3f3

Browse files
authored
Merge pull request #32 from honno/creation-refactor
Refactor assertions in `test_creation.py`
2 parents 797537e + ca2ef81 commit 035e3f3

File tree

8 files changed

+678
-365
lines changed

8 files changed

+678
-365
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
1+
import itertools
12
from functools import reduce
2-
from operator import mul
33
from math import sqrt
4-
import itertools
5-
from typing import Tuple, Optional, List
4+
from operator import mul
5+
from typing import Any, List, NamedTuple, Optional, Tuple
66

77
from hypothesis import assume
8-
from hypothesis.strategies import (lists, integers, sampled_from,
9-
shared, floats, just, composite, one_of,
10-
none, booleans, SearchStrategy)
8+
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
9+
integers, just, lists, none, one_of,
10+
sampled_from, shared)
1111

12-
from .pytest_helpers import nargs
13-
from .array_helpers import ndindex
14-
from .typing import DataType, Shape
15-
from . import dtype_helpers as dh
16-
from ._array_module import (full, float32, float64, bool as bool_dtype,
17-
_UndefinedStub, eye, broadcast_to)
1812
from . import _array_module as xp
13+
from . import dtype_helpers as dh
1914
from . import xps
20-
15+
from ._array_module import _UndefinedStub
16+
from ._array_module import bool as bool_dtype
17+
from ._array_module import broadcast_to, eye, float32, float64, full
18+
from .array_helpers import ndindex
2119
from .function_stubs import elementwise_functions
22-
20+
from .pytest_helpers import nargs
21+
from .typing import DataType, Shape
2322

2423
# Set this to True to not fail tests just because a dtype isn't implemented.
2524
# If no compatible dtype is implemented for a given test, the test will fail
@@ -382,3 +381,24 @@ def test_f(x, kw):
382381
if draw(booleans()):
383382
result[k] = draw(strat)
384383
return result
384+
385+
386+
class KVD(NamedTuple):
387+
keyword: str
388+
value: Any
389+
default: Any
390+
391+
392+
@composite
393+
def specified_kwargs(draw, *keys_values_defaults: KVD):
394+
"""Generates valid kwargs given expected defaults.
395+
396+
When we can't realistically use hh.kwargs() and thus test whether xp infact
397+
defaults correctly, this strategy lets us remove generated arguments if they
398+
are of the default value anyway.
399+
"""
400+
kw = {}
401+
for keyword, value, default in keys_values_defaults:
402+
if value is not default or draw(booleans()):
403+
kw[keyword] = value
404+
return kw

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from hypothesis import given, strategies as st, settings
55

66
from .. import _array_module as xp
7+
from .. import xps
78
from .._array_module import _UndefinedStub
89
from .. import array_helpers as ah
910
from .. import dtype_helpers as dh
@@ -76,6 +77,37 @@ def run(kw):
7677
assert len(c_results) > 0
7778
assert all(isinstance(kw["c"], str) for kw in c_results)
7879

80+
81+
def test_specified_kwargs():
82+
results = []
83+
84+
@given(n=st.integers(0, 10), d=st.none() | xps.scalar_dtypes(), data=st.data())
85+
@settings(max_examples=100)
86+
def run(n, d, data):
87+
kw = data.draw(
88+
hh.specified_kwargs(
89+
hh.KVD("n", n, 0),
90+
hh.KVD("d", d, None),
91+
),
92+
label="kw",
93+
)
94+
results.append(kw)
95+
run()
96+
97+
assert all(isinstance(kw, dict) for kw in results)
98+
99+
assert any(len(kw) == 0 for kw in results)
100+
101+
assert any("n" not in kw.keys() for kw in results)
102+
assert any("n" in kw.keys() and kw["n"] == 0 for kw in results)
103+
assert any("n" in kw.keys() and kw["n"] != 0 for kw in results)
104+
105+
assert any("d" not in kw.keys() for kw in results)
106+
assert any("d" in kw.keys() and kw["d"] is None for kw in results)
107+
assert any("d" in kw.keys() and kw["d"] is xp.float64 for kw in results)
108+
109+
110+
79111
@given(m=hh.symmetric_matrices(hh.shared_floating_dtypes,
80112
finite=st.shared(st.booleans(), key='finite')),
81113
dtype=hh.shared_floating_dtypes,

array_api_tests/meta/test_utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,26 @@
1+
import pytest
2+
13
from ..test_signatures import extension_module
4+
from ..test_creation_functions import frange
25

36

47
def test_extension_module_is_extension():
5-
assert extension_module('linalg')
8+
assert extension_module("linalg")
69

710

811
def test_extension_func_is_not_extension():
9-
assert not extension_module('linalg.cross')
12+
assert not extension_module("linalg.cross")
13+
14+
15+
@pytest.mark.parametrize(
16+
"r, size, elements",
17+
[
18+
(frange(0, 1, 1), 1, [0]),
19+
(frange(1, 0, -1), 1, [1]),
20+
(frange(0, 1, -1), 0, []),
21+
(frange(0, 1, 2), 1, [0]),
22+
],
23+
)
24+
def test_frange(r, size, elements):
25+
assert len(r) == size
26+
assert list(r) == elements

array_api_tests/pytest_helpers.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,27 @@
1+
import math
12
from inspect import getfullargspec
2-
from typing import Optional, Tuple
3+
from typing import Any, Dict, Optional, Tuple, Union
34

5+
from . import array_helpers as ah
46
from . import dtype_helpers as dh
57
from . import function_stubs
6-
from .typing import DataType
8+
from .typing import Array, DataType, Scalar, Shape
79

10+
__all__ = [
11+
"raises",
12+
"doesnt_raise",
13+
"nargs",
14+
"fmt_kw",
15+
"assert_dtype",
16+
"assert_kw_dtype",
17+
"assert_default_float",
18+
"assert_default_int",
19+
"assert_shape",
20+
"assert_fill",
21+
]
822

9-
def raises(exceptions, function, message=''):
23+
24+
def raises(exceptions, function, message=""):
1025
"""
1126
Like pytest.raises() except it allows custom error messages
1227
"""
@@ -16,11 +31,14 @@ def raises(exceptions, function, message=''):
1631
return
1732
except Exception as e:
1833
if message:
19-
raise AssertionError(f"Unexpected exception {e!r} (expected {exceptions}): {message}")
34+
raise AssertionError(
35+
f"Unexpected exception {e!r} (expected {exceptions}): {message}"
36+
)
2037
raise AssertionError(f"Unexpected exception {e!r} (expected {exceptions})")
2138
raise AssertionError(message)
2239

23-
def doesnt_raise(function, message=''):
40+
41+
def doesnt_raise(function, message=""):
2442
"""
2543
The inverse of raises().
2644
@@ -36,10 +54,15 @@ def doesnt_raise(function, message=''):
3654
raise AssertionError(f"Unexpected exception {e!r}: {message}")
3755
raise AssertionError(f"Unexpected exception {e!r}")
3856

57+
3958
def nargs(func_name):
4059
return len(getfullargspec(getattr(function_stubs, func_name)).args)
4160

4261

62+
def fmt_kw(kw: Dict[str, Any]) -> str:
63+
return ", ".join(f"{k}={v}" for k, v in kw.items())
64+
65+
4366
def assert_dtype(
4467
func_name: str,
4568
in_dtypes: Tuple[DataType, ...],
@@ -60,3 +83,54 @@ def assert_dtype(
6083
assert out_dtype == expected, msg
6184

6285

86+
def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
87+
f_kw_dtype = dh.dtype_to_name[kw_dtype]
88+
f_out_dtype = dh.dtype_to_name[out_dtype]
89+
msg = (
90+
f"out.dtype={f_out_dtype}, but should be {f_kw_dtype} "
91+
f"[{func_name}(dtype={f_kw_dtype})]"
92+
)
93+
assert out_dtype == kw_dtype, msg
94+
95+
96+
def assert_default_float(func_name: str, dtype: DataType):
97+
f_dtype = dh.dtype_to_name[dtype]
98+
f_default = dh.dtype_to_name[dh.default_float]
99+
msg = (
100+
f"out.dtype={f_dtype}, should be default "
101+
f"floating-point dtype {f_default} [{func_name}()]"
102+
)
103+
assert dtype == dh.default_float, msg
104+
105+
106+
def assert_default_int(func_name: str, dtype: DataType):
107+
f_dtype = dh.dtype_to_name[dtype]
108+
f_default = dh.dtype_to_name[dh.default_int]
109+
msg = (
110+
f"out.dtype={f_dtype}, should be default "
111+
f"integer dtype {f_default} [{func_name}()]"
112+
)
113+
assert dtype == dh.default_int, msg
114+
115+
116+
def assert_shape(
117+
func_name: str, out_shape: Union[int, Shape], expected: Union[int, Shape], /, **kw
118+
):
119+
if isinstance(out_shape, int):
120+
out_shape = (out_shape,)
121+
if isinstance(expected, int):
122+
expected = (expected,)
123+
msg = (
124+
f"out.shape={out_shape}, but should be {expected} [{func_name}({fmt_kw(kw)})]"
125+
)
126+
assert out_shape == expected, msg
127+
128+
129+
def assert_fill(
130+
func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw
131+
):
132+
msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}"
133+
if math.isnan(fill_value):
134+
assert ah.all(ah.isnan(out)), msg
135+
else:
136+
assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), msg

0 commit comments

Comments
 (0)