Skip to content

Commit 9815b8d

Browse files
committed
TST: add a test that wrapping preserves a view/copy semantics for unary functions
If a bare library returns a copy, so does the wrapped library; if the bare library returns a view, so does the wrapped library.
1 parent 8e3ab3e commit 9815b8d

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

tests/test_copies_or_views.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
A collection of tests to make sure that wrapped namespaces agree with the bare ones
3+
on whether to return a view or a copy of inputs.
4+
"""
5+
import pytest
6+
from ._helpers import import_
7+
8+
9+
LIB_NAMES = ['numpy', 'cupy', 'torch', 'dask.array', 'array_api_strict']
10+
11+
FUNC_INPUTS = [
12+
# func_name, arr_input, dtype, scalar_value
13+
('abs', [1, 2], 'int8', 3),
14+
('abs', [1, 2], 'float32', 3.),
15+
('ceil', [1, 2], 'int8', 3),
16+
('clip', [1, 2], 'int8', 3),
17+
('conj', [1, 2], 'int8', 3),
18+
('floor', [1, 2], 'int8', 3),
19+
('imag', [1j, 2j], 'complex64', 3),
20+
('positive', [1, 2], 'int8', 3),
21+
('real', [1., 2.], 'float32', 3.),
22+
('round', [1, 2], 'int8', 3),
23+
('sign', [0, 0], 'float32', 3),
24+
('trunc', [1, 2], 'int8', 3),
25+
('trunc', [1, 2], 'float32', 3),
26+
]
27+
28+
29+
def ensure_unary(func, arr):
30+
"""Make a trivial unary function from func."""
31+
if func.__name__ == 'clip':
32+
return lambda x: func(x, arr[0], arr[1])
33+
return func
34+
35+
36+
def is_view(func, a, value):
37+
"""Apply `func`, mutate the output; does the input change?"""
38+
b = func(a)
39+
b[0] = value
40+
return a[0] == value
41+
42+
43+
@pytest.mark.parametrize('xp_name', LIB_NAMES)
44+
@pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS])
45+
def test_view_or_copy(inputs, xp_name):
46+
bare_xp = import_(xp_name, wrapper=False)
47+
wrapped_xp = import_(xp_name, wrapper=True)
48+
49+
func_name, arr_input, dtype_str, value = inputs
50+
dtype = getattr(bare_xp, dtype_str)
51+
52+
bare_func = getattr(bare_xp, func_name)
53+
bare_func = ensure_unary(bare_func, arr_input)
54+
55+
wrapped_func = getattr(wrapped_xp, func_name)
56+
wrapped_func = ensure_unary(wrapped_func, arr_input)
57+
58+
# bare namespace: mutate the output, does the input change?
59+
a = bare_xp.asarray(arr_input, dtype=dtype)
60+
is_view_bare = is_view(bare_func, a, value)
61+
62+
# wrapped namespace: mutate the output, does the input change?
63+
a1 = wrapped_xp.asarray(arr_input, dtype=dtype)
64+
is_view_wrapped = is_view(wrapped_func, a1, value)
65+
66+
assert is_view_bare == is_view_wrapped

0 commit comments

Comments
 (0)