Skip to content

implement dpnp.apply_over_axes #2174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 85 additions & 2 deletions dpnp/dpnp_iface_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@


import numpy
from dpctl.tensor._numpy_helper import normalize_axis_index
from dpctl.tensor._numpy_helper import (
normalize_axis_index,
normalize_axis_tuple,
)

import dpnp

__all__ = ["apply_along_axis"]
__all__ = ["apply_along_axis", "apply_over_axes"]


def apply_along_axis(func1d, axis, arr, *args, **kwargs):
Expand Down Expand Up @@ -185,3 +188,83 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
buff = dpnp.moveaxis(buff, -1, axis)

return buff


def apply_over_axes(func, a, axes):
"""
Apply a function repeatedly over multiple axes.

`func` is called as ``res = func(a, axis)``, where `axis` is the first
element of `axes`. The result `res` of the function call must have
either the same dimensions as `a` or one less dimension. If `res`
has one less dimension than `a`, a dimension is inserted before
`axis`. The call to `func` is then repeated for each axis in `axes`,
with `res` as the first argument.

For full documentation refer to :obj:`numpy.apply_over_axes`.

Parameters
----------
func : function
This function must take two arguments, ``func(a, axis)``.
a : {dpnp.ndarray, usm_ndarray}
Input array.
axes : {int, sequence of ints}
Axes over which `func` is applied.

Returns
-------
out : dpnp.ndarray
The output array. The number of dimensions is the same as `a`,
but the shape can be different. This depends on whether `func`
changes the shape of its output with respect to its input.

See Also
--------
:obj:`dpnp.apply_along_axis` : Apply a function to 1-D slices of an array
along the given axis.

Examples
--------
>>> import dpnp as np
>>> a = np.arange(24).reshape(2, 3, 4)
>>> a
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])

Sum over axes 0 and 2. The result has same number of dimensions
as the original array:

>>> np.apply_over_axes(np.sum, a, [0, 2])
array([[[ 60],
[ 92],
[124]]])

Tuple axis arguments to ufuncs are equivalent:

>>> np.sum(a, axis=(0, 2), keepdims=True)
array([[[ 60],
[ 92],
[124]]])

"""

dpnp.check_supported_arrays_type(a)
if isinstance(axes, int):
axes = (axes,)
axes = normalize_axis_tuple(axes, a.ndim)

for axis in axes:
res = func(a, axis)
if res.ndim != a.ndim:
res = dpnp.expand_dims(res, axis)
if res.ndim != a.ndim:
raise ValueError(
"function is not returning an array of the correct shape"
)
a = res
return res
21 changes: 20 additions & 1 deletion dpnp/tests/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy
import pytest
from numpy.testing import assert_array_equal
from numpy.testing import assert_array_equal, assert_raises

import dpnp

Expand Down Expand Up @@ -46,3 +46,22 @@ def test_args(self, dtype):
# positional args: axis, dtype, out, keepdims
result = dpnp.apply_along_axis(dpnp.mean, 0, ia, 0, dtype, None, True)
assert_array_equal(result, expected)


class TestApplyOverAxes:
@pytest.mark.parametrize("func", ["sum", "cumsum"])
@pytest.mark.parametrize("axes", [1, [0, 2], (-1, -2)])
def test_basic(self, func, axes):
a = numpy.arange(24).reshape(2, 3, 4)
ia = dpnp.array(a)

expected = numpy.apply_over_axes(getattr(numpy, func), a, axes)
result = dpnp.apply_over_axes(getattr(dpnp, func), ia, axes)
assert_array_equal(result, expected)

def test_custom_func(self):
def custom_func(x, axis):
return dpnp.expand_dims(dpnp.expand_dims(x, axis), axis)

ia = dpnp.arange(24).reshape(2, 3, 4)
assert_raises(ValueError, dpnp.apply_over_axes, custom_func, ia, 1)
12 changes: 12 additions & 0 deletions dpnp/tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2215,6 +2215,18 @@ def test_apply_along_axis(device):
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)


@pytest.mark.parametrize(
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
def test_apply_over_axes(device):
x = dpnp.arange(18, device=device).reshape(2, 3, 3)
result = dpnp.apply_over_axes(dpnp.sum, x, [0, 1])

assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)


@pytest.mark.parametrize(
"device_x",
valid_devices,
Expand Down
8 changes: 8 additions & 0 deletions dpnp/tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,14 @@ def test_apply_along_axis(usm_type):
assert x.usm_type == y.usm_type


@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
def test_apply_over_axes(usm_type):
x = dp.arange(18, usm_type=usm_type).reshape(2, 3, 3)
y = dp.apply_over_axes(dp.sum, x, [0, 1])

assert x.usm_type == y.usm_type


@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
def test_broadcast_to(usm_type):
x = dp.ones(7, usm_type=usm_type)
Expand Down
Loading