Skip to content

Commit d6a3379

Browse files
committed
Merge branch 'master' into type-promotion-refactor
2 parents 19cfff7 + 873eeff commit d6a3379

10 files changed

+1037
-549
lines changed

array_api_tests/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from hypothesis.extra.array_api import make_strategies_namespace
2+
3+
from . import _array_module as xp
4+
5+
6+
xps = make_strategies_namespace(xp)

array_api_tests/function_stubs/creation_functions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from ._types import (List, NestedSequence, Optional, SupportsBufferProtocol, SupportsDLPack, Tuple,
1414
Union, array, device, dtype)
15-
from collections.abc import Sequence
1615

1716
def arange(start: Union[int, float], /, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
1817
pass
@@ -26,7 +25,7 @@ def empty(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[dtype] = None,
2625
def empty_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
2726
pass
2827

29-
def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: Optional[int] = 0, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
28+
def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: int = 0, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
3029
pass
3130

3231
def from_dlpack(x: object, /) -> array:
@@ -41,7 +40,7 @@ def full_like(x: array, /, fill_value: Union[int, float], *, dtype: Optional[dty
4140
def linspace(start: Union[int, float], stop: Union[int, float], /, num: int, *, dtype: Optional[dtype] = None, device: Optional[device] = None, endpoint: bool = True) -> array:
4241
pass
4342

44-
def meshgrid(*arrays: Sequence[array], indexing: str = 'xy') -> List[array, ...]:
43+
def meshgrid(*arrays: array, indexing: str = 'xy') -> List[array, ...]:
4544
pass
4645

4746
def ones(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:

array_api_tests/function_stubs/data_type_functions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
from __future__ import annotations
1212

1313
from ._types import List, Tuple, Union, array, dtype, finfo_object, iinfo_object
14-
from collections.abc import Sequence
1514

16-
def broadcast_arrays(*arrays: Sequence[array]) -> List[array]:
15+
def broadcast_arrays(*arrays: array) -> List[array]:
1716
pass
1817

1918
def broadcast_to(x: array, /, shape: Tuple[int, ...]) -> array:
@@ -28,7 +27,7 @@ def finfo(type: Union[dtype, array], /) -> finfo_object:
2827
def iinfo(type: Union[dtype, array], /) -> iinfo_object:
2928
pass
3029

31-
def result_type(*arrays_and_dtypes: Sequence[Union[array, dtype]]) -> dtype:
30+
def result_type(*arrays_and_dtypes: Union[array, dtype]) -> dtype:
3231
pass
3332

3433
__all__ = ['broadcast_arrays', 'broadcast_to', 'can_cast', 'finfo', 'iinfo', 'result_type']

array_api_tests/function_stubs/linalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def solve(x1: array, x2: array, /) -> array:
6868
def svd(x: array, /, *, full_matrices: bool = True) -> Tuple[array, array, array]:
6969
pass
7070

71-
def svdvals(x: array, /) -> Union[array, Tuple[array, ...]]:
71+
def svdvals(x: array, /) -> array:
7272
pass
7373

7474
def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> array:
@@ -77,10 +77,10 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
7777
def trace(x: array, /, *, offset: int = 0) -> array:
7878
pass
7979

80-
def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
80+
def vecdot(x1: array, x2: array, /, *, axis: int = None) -> array:
8181
pass
8282

83-
def vector_norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float, Literal[inf, -inf]]] = 2) -> array:
83+
def vector_norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Union[int, float, Literal[inf, -inf]] = 2) -> array:
8484
pass
8585

8686
__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']

array_api_tests/function_stubs/linear_algebra_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from __future__ import annotations
1212

13-
from ._types import Optional, Tuple, Union, array
13+
from ._types import Tuple, Union, array
1414
from collections.abc import Sequence
1515

1616
def matmul(x1: array, x2: array, /) -> array:
@@ -22,7 +22,7 @@ def matrix_transpose(x: array, /) -> array:
2222
def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> array:
2323
pass
2424

25-
def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
25+
def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
2626
pass
2727

2828
__all__ = ['matmul', 'matrix_transpose', 'tensordot', 'vecdot']

array_api_tests/hypothesis_helpers.py

Lines changed: 91 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,28 @@
11
from functools import reduce
22
from operator import mul
33
from math import sqrt
4+
import itertools
45

56
from hypothesis import assume
67
from hypothesis.strategies import (lists, integers, sampled_from,
78
shared, floats, just, composite, one_of,
89
none, booleans)
9-
from hypothesis.extra.array_api import make_strategies_namespace
1010

1111
from .pytest_helpers import nargs
1212
from .array_helpers import (dtype_ranges, integer_dtype_objects,
1313
floating_dtype_objects, numeric_dtype_objects,
1414
boolean_dtype_objects,
15-
integer_or_boolean_dtype_objects, dtype_objects)
16-
from ._array_module import full, float32, float64, bool as bool_dtype, _UndefinedStub
15+
integer_or_boolean_dtype_objects, dtype_objects,
16+
ndindex)
1717
from .dtype_helpers import promotion_table
18-
from . import _array_module
18+
from ._array_module import (full, float32, float64, bool as bool_dtype,
19+
_UndefinedStub, eye, broadcast_to)
1920
from . import _array_module as xp
21+
from . import xps
2022

2123
from .function_stubs import elementwise_functions
2224

2325

24-
xps = make_strategies_namespace(xp)
25-
26-
2726
# Set this to True to not fail tests just because a dtype isn't implemented.
2827
# If no compatible dtype is implemented for a given test, the test will fail
2928
# with a hypothesis health check error. Note that this functionality will not
@@ -48,6 +47,7 @@
4847
dtypes = dtypes.filter(lambda x: not isinstance(x, _UndefinedStub))
4948

5049
shared_dtypes = shared(dtypes, key="dtype")
50+
shared_floating_dtypes = shared(floating_dtypes, key="dtype")
5151

5252

5353
sorted_table = sorted(promotion_table)
@@ -72,10 +72,6 @@ def mutually_promotable_dtypes(dtype_objects=dtype_objects):
7272
[(i, j) for i, j in sorted_table if i in dtype_objects and j in dtype_objects]
7373
)
7474

75-
shared_mutually_promotable_dtype_pairs = shared(
76-
mutually_promotable_dtypes(), key="mutually_promotable_dtype_pair"
77-
)
78-
7975
# shared() allows us to draw either the function or the function name and they
8076
# will both correspond to the same function.
8177

@@ -86,10 +82,10 @@ def mutually_promotable_dtypes(dtype_objects=dtype_objects):
8682
lambda func_name: nargs(func_name) > 1)
8783

8884
elementwise_function_objects = elementwise_functions_names.map(
89-
lambda i: getattr(_array_module, i))
85+
lambda i: getattr(xp, i))
9086
array_functions = elementwise_function_objects
9187
multiarg_array_functions = multiarg_array_functions_names.map(
92-
lambda i: getattr(_array_module, i))
88+
lambda i: getattr(xp, i))
9389

9490
# Limit the total size of an array shape
9591
MAX_ARRAY_SIZE = 10000
@@ -111,25 +107,72 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)
111107
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
112108
)
113109

110+
# Matrix shapes assume stacks of matrices
111+
matrix_shapes = xps.array_shapes(min_dims=2, min_side=1).filter(
112+
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
113+
)
114+
115+
square_matrix_shapes = matrix_shapes.filter(lambda shape: shape[-1] == shape[-2])
116+
114117
two_mutually_broadcastable_shapes = xps.mutually_broadcastable_shapes(num_shapes=2)\
115118
.map(lambda S: S.input_shapes)\
116119
.filter(lambda S: all(prod(i for i in shape if i) < MAX_ARRAY_SIZE for shape in S))
117120

121+
# Note: This should become hermitian_matrices when complex dtypes are added
122+
@composite
123+
def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True):
124+
shape = draw(square_matrix_shapes)
125+
dtype = draw(dtypes)
126+
elements = {'allow_nan': False, 'allow_infinity': False} if finite else None
127+
a = draw(xps.arrays(dtype=dtype, shape=shape, elements=elements))
128+
upper = xp.triu(a)
129+
lower = xp.triu(a, k=1).mT
130+
return upper + lower
131+
132+
@composite
133+
def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
134+
# For now just generate stacks of identity matrices
135+
# TODO: Generate arbitrary positive definite matrices, for instance, by
136+
# using something like
137+
# https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/datasets/_samples_generator.py#L1351.
138+
n = draw(integers(0))
139+
shape = draw(shapes) + (n, n)
140+
assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE)
141+
dtype = draw(dtypes)
142+
return broadcast_to(eye(n, dtype=dtype), shape)
143+
144+
@composite
145+
def invertible_matrices(draw, dtypes=xps.floating_dtypes()):
146+
# For now, just generate stacks of diagonal matrices.
147+
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
148+
stack_shape = draw(shapes)
149+
shape = stack_shape + (n, n)
150+
d = draw(xps.arrays(dtypes, shape=n*prod(stack_shape),
151+
elements=dict(allow_nan=False, allow_infinity=False)))
152+
# Functions that require invertible matrices may do anything when it is
153+
# singular, including raising an exception, so we make sure the diagonals
154+
# are sufficiently nonzero to avoid any numerical issues.
155+
assume(xp.all(xp.abs(d) > 0.5))
156+
157+
a = xp.zeros(shape)
158+
for j, (idx, i) in enumerate(itertools.product(ndindex(stack_shape), range(n))):
159+
a[idx + (i, i)] = d[j]
160+
return a
161+
118162
@composite
119163
def two_broadcastable_shapes(draw):
120164
"""
121165
This will produce two shapes (shape1, shape2) such that shape2 can be
122166
broadcast to shape1.
123167
"""
124168
from .test_broadcasting import broadcast_shapes
125-
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
169+
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
126170
assume(broadcast_shapes(shape1, shape2) == shape1)
127171
return (shape1, shape2)
128172

129173
sizes = integers(0, MAX_ARRAY_SIZE)
130174
sqrt_sizes = integers(0, SQRT_MAX_ARRAY_SIZE)
131175

132-
# TODO: Generate general arrays here, rather than just scalars.
133176
numeric_arrays = xps.arrays(
134177
dtype=shared(xps.floating_dtypes(), key='dtypes'),
135178
shape=shared(xps.array_shapes(), key='shapes'),
@@ -233,25 +276,46 @@ def multiaxis_indices(draw, shapes):
233276

234277
# Avoid using 'in', which might do == on an array.
235278
res_has_ellipsis = any(i is ... for i in res)
236-
if n_entries == len(shape) and not res_has_ellipsis:
237-
# note("Adding extra")
238-
extra = draw(lists(one_of(integer_indices(sizes), slices(sizes)), min_size=0, max_size=3))
239-
res += extra
279+
if not res_has_ellipsis:
280+
if n_entries < len(shape):
281+
# The spec requires either an ellipsis or exactly as many indices
282+
# as dimensions.
283+
assume(False)
284+
elif n_entries == len(shape):
285+
# note("Adding extra")
286+
extra = draw(lists(one_of(integer_indices(sizes), slices(sizes)), min_size=0, max_size=3))
287+
res += extra
240288
return tuple(res)
241289

242290

243-
shared_arrays1 = xps.arrays(
244-
dtype=shared_mutually_promotable_dtype_pairs.map(lambda pair: pair[0]),
245-
shape=shared(two_mutually_broadcastable_shapes, key="shape_pair").map(lambda pair: pair[0]),
246-
)
247-
shared_arrays2 = xps.arrays(
248-
dtype=shared_mutually_promotable_dtype_pairs.map(lambda pair: pair[1]),
249-
shape=shared(two_mutually_broadcastable_shapes, key="shape_pair").map(lambda pair: pair[1]),
250-
)
291+
def two_mutual_arrays(dtype_objects=dtype_objects):
292+
mutual_dtypes = shared(mutually_promotable_dtypes(dtype_objects))
293+
mutual_shapes = shared(two_mutually_broadcastable_shapes)
294+
arrays1 = xps.arrays(
295+
dtype=mutual_dtypes.map(lambda pair: pair[0]),
296+
shape=mutual_shapes.map(lambda pair: pair[0]),
297+
)
298+
arrays2 = xps.arrays(
299+
dtype=mutual_dtypes.map(lambda pair: pair[1]),
300+
shape=mutual_shapes.map(lambda pair: pair[1]),
301+
)
302+
return arrays1, arrays2
251303

252304

253305
@composite
254306
def kwargs(draw, **kw):
307+
"""
308+
Strategy for keyword arguments
309+
310+
For a signature like f(x, /, dtype=None, val=1) use
311+
312+
@given(x=arrays(), kw=kwargs(a=none() | dtypes, val=integers()))
313+
def test_f(x, kw):
314+
res = f(x, **kw)
315+
316+
kw may omit the keyword argument, meaning the default for f will be used.
317+
318+
"""
255319
result = {}
256320
for k, strat in kw.items():
257321
if draw(booleans()):

array_api_tests/meta_tests/test_hypothesis_helpers.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
from math import prod
22

33
import pytest
4-
from hypothesis import given, strategies as st
4+
from hypothesis import given, strategies as st, settings
55

66
from .. import _array_module as xp
77
from .._array_module import _UndefinedStub
88
from .. import array_helpers as ah
99
from .. import hypothesis_helpers as hh
10+
from ..test_broadcasting import broadcast_shapes
11+
from ..test_elementwise_functions import sanity_check
1012

1113
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in ah.dtype_objects)
1214
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
1315

14-
1516
@given(hh.mutually_promotable_dtypes([xp.float32, xp.float64]))
1617
def test_mutually_promotable_dtypes(pairs):
1718
assert pairs in (
@@ -45,20 +46,24 @@ def test_two_mutually_broadcastable_shapes(pair):
4546
def test_two_broadcastable_shapes(pair):
4647
for shape in pair:
4748
assert valid_shape(shape)
49+
assert broadcast_shapes(pair[0], pair[1]) == pair[0]
4850

49-
from ..test_broadcasting import broadcast_shapes
5051

51-
assert broadcast_shapes(pair[0], pair[1]) == pair[0]
52+
@given(*hh.two_mutual_arrays())
53+
def test_two_mutual_arrays(x1, x2):
54+
sanity_check(x1, x2)
55+
assert broadcast_shapes(x1.shape, x2.shape) in (x1.shape, x2.shape)
5256

5357

5458
def test_kwargs():
5559
results = []
5660

5761
@given(hh.kwargs(n=st.integers(0, 10), c=st.from_regex("[a-f]")))
62+
@settings(max_examples=100)
5863
def run(kw):
5964
results.append(kw)
60-
6165
run()
66+
6267
assert all(isinstance(kw, dict) for kw in results)
6368
for size in [0, 1, 2]:
6469
assert any(len(kw) == size for kw in results)
@@ -70,3 +75,21 @@ def run(kw):
7075
c_results = [kw for kw in results if "c" in kw]
7176
assert len(c_results) > 0
7277
assert all(isinstance(kw["c"], str) for kw in c_results)
78+
79+
@given(m=hh.symmetric_matrices(hh.shared_floating_dtypes,
80+
finite=st.shared(st.booleans(), key='finite')),
81+
dtype=hh.shared_floating_dtypes,
82+
finite=st.shared(st.booleans(), key='finite'))
83+
def test_symmetric_matrices(m, dtype, finite):
84+
assert m.dtype == dtype
85+
# TODO: This part of this test should be part of the .mT test
86+
ah.assert_exactly_equal(m, m.mT)
87+
88+
if finite:
89+
ah.assert_finite(m)
90+
91+
@given(m=hh.positive_definite_matrices(hh.shared_floating_dtypes),
92+
dtype=hh.shared_floating_dtypes)
93+
def test_positive_definite_matrices(m, dtype):
94+
assert m.dtype == dtype
95+
# TODO: Test that it actually is positive definite

array_api_tests/test_creation_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
assert_exactly_equal, isintegral, is_float_dtype)
99
from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE,
1010
shapes, sizes, sqrt_sizes, shared_dtypes,
11-
scalars, xps, kwargs)
11+
scalars, kwargs)
12+
from . import xps
1213

1314
from hypothesis import assume, given
1415
from hypothesis.strategies import integers, floats, one_of, none, booleans, just, shared, composite

0 commit comments

Comments
 (0)