|
| 1 | +""" |
| 2 | +A collection of tests to make sure that wrapped namespaces agree with the bare ones |
| 3 | +on whether to return a view or a copy of inputs. |
| 4 | +""" |
| 5 | +import pytest |
| 6 | +from ._helpers import import_ |
| 7 | + |
| 8 | + |
| 9 | +LIB_NAMES = ['numpy', 'cupy', 'torch', 'dask.array', 'array_api_strict'] |
| 10 | + |
| 11 | +FUNC_INPUTS = [ |
| 12 | + # func_name, arr_input, dtype, scalar_value |
| 13 | + ('abs', [1, 2], 'int8', 3), |
| 14 | + ('abs', [1, 2], 'float32', 3.), |
| 15 | + ('ceil', [1, 2], 'int8', 3), |
| 16 | + ('clip', [1, 2], 'int8', 3), |
| 17 | + ('conj', [1, 2], 'int8', 3), |
| 18 | + ('floor', [1, 2], 'int8', 3), |
| 19 | + ('imag', [1j, 2j], 'complex64', 3), |
| 20 | + ('positive', [1, 2], 'int8', 3), |
| 21 | + ('real', [1., 2.], 'float32', 3.), |
| 22 | + ('round', [1, 2], 'int8', 3), |
| 23 | + ('sign', [0, 0], 'float32', 3), |
| 24 | + ('trunc', [1, 2], 'int8', 3), |
| 25 | + ('trunc', [1, 2], 'float32', 3), |
| 26 | +] |
| 27 | + |
| 28 | + |
| 29 | +def ensure_unary(func, arr): |
| 30 | + """Make a trivial unary function from func.""" |
| 31 | + if func.__name__ == 'clip': |
| 32 | + return lambda x: func(x, arr[0], arr[1]) |
| 33 | + return func |
| 34 | + |
| 35 | + |
| 36 | +def is_view(func, a, value): |
| 37 | + """Apply `func`, mutate the output; does the input change?""" |
| 38 | + b = func(a) |
| 39 | + b[0] = value |
| 40 | + return a[0] == value |
| 41 | + |
| 42 | + |
| 43 | +@pytest.mark.parametrize('xp_name', LIB_NAMES) |
| 44 | +@pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS]) |
| 45 | +def test_view_or_copy(inputs, xp_name): |
| 46 | + bare_xp = import_(xp_name, wrapper=False) |
| 47 | + wrapped_xp = import_(xp_name, wrapper=True) |
| 48 | + |
| 49 | + func_name, arr_input, dtype_str, value = inputs |
| 50 | + dtype = getattr(bare_xp, dtype_str) |
| 51 | + |
| 52 | + bare_func = getattr(bare_xp, func_name) |
| 53 | + bare_func = ensure_unary(bare_func, arr_input) |
| 54 | + |
| 55 | + wrapped_func = getattr(wrapped_xp, func_name) |
| 56 | + wrapped_func = ensure_unary(wrapped_func, arr_input) |
| 57 | + |
| 58 | + # bare namespace: mutate the output, does the input change? |
| 59 | + a = bare_xp.asarray(arr_input, dtype=dtype) |
| 60 | + is_view_bare = is_view(bare_func, a, value) |
| 61 | + |
| 62 | + # wrapped namespace: mutate the output, does the input change? |
| 63 | + a1 = wrapped_xp.asarray(arr_input, dtype=dtype) |
| 64 | + is_view_wrapped = is_view(wrapped_func, a1, value) |
| 65 | + |
| 66 | + assert is_view_bare == is_view_wrapped |
0 commit comments