Skip to content

Leverage on dpctl.tensor implementation in dpnp.put_along_axis #2134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 29 additions & 57 deletions dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,57 +84,6 @@
]


def _build_along_axis_index(a, ind, axis):
"""
Build a fancy index used by a family of `_along_axis` functions.

The fancy index consists of orthogonal arranges, with the
requested index inserted at the right location.

The resulting index is going to be used inside `dpnp.put_along_axis`
and `dpnp.take_along_axis` implementations.

"""

if not dpnp.issubdtype(ind.dtype, dpnp.integer):
raise IndexError("`indices` must be an integer array")

# normalize array shape and input axis
if axis is None:
a_shape = (a.size,)
axis = 0
else:
a_shape = a.shape
axis = normalize_axis_index(axis, a.ndim)

if len(a_shape) != ind.ndim:
raise ValueError(
"`indices` and `a` must have the same number of dimensions"
)

# compute dimensions to iterate over
dest_dims = list(range(axis)) + [None] + list(range(axis + 1, ind.ndim))
shape_ones = (1,) * ind.ndim

# build the index
fancy_index = []
for dim, n in zip(dest_dims, a_shape):
if dim is None:
fancy_index.append(ind)
else:
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim + 1 :]
fancy_index.append(
dpnp.arange(
n,
dtype=ind.dtype,
usm_type=ind.usm_type,
sycl_queue=ind.sycl_queue,
).reshape(ind_shape)
)

return tuple(fancy_index)


def _ravel_multi_index_checks(multi_index, dims, order):
dpnp.check_supported_arrays_type(*multi_index)
ndim = len(dims)
Expand Down Expand Up @@ -1371,7 +1320,7 @@ def put(a, ind, v, /, *, axis=None, mode="wrap"):
in_usm_a[:] = dpt.reshape(usm_a, in_usm_a.shape, copy=False)


def put_along_axis(a, ind, values, axis):
def put_along_axis(a, ind, values, axis, mode="wrap"):
"""
Put values into the destination array by matching 1d index and data slices.

Expand All @@ -1395,9 +1344,18 @@ def put_along_axis(a, ind, values, axis):
values : {scalar, array_like}, (Ni..., J, Nk...)
Values to insert at those indices. Its shape and dimension are
broadcast to match that of `ind`.
axis : int
axis : {None, int}
The axis to take 1d slices along. If axis is ``None``, the destination
array is treated as if a flattened 1d view had been created of it.
mode : {"wrap", "clip"}, optional
Specifies how out-of-bounds indices will be handled. Possible values
are:

- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
negative indices.
- ``"clip"``: clips indices to (``0 <= i < n``).

Default: ``"wrap"``.

See Also
--------
Expand Down Expand Up @@ -1426,12 +1384,26 @@ def put_along_axis(a, ind, values, axis):

"""

dpnp.check_supported_arrays_type(a, ind)

if axis is None:
a = a.ravel()
dpnp.check_supported_arrays_type(ind)
if ind.ndim != 1:
raise ValueError(
"when axis=None, `ind` must have a single dimension."
)

a = dpnp.ravel(a)
axis = 0

usm_a = dpnp.get_usm_ndarray(a)
usm_ind = dpnp.get_usm_ndarray(ind)
if dpnp.is_supported_array_type(values):
usm_vals = dpnp.get_usm_ndarray(values)
else:
usm_vals = dpt.asarray(
values, usm_type=a.usm_type, sycl_queue=a.sycl_queue
)

a[_build_along_axis_index(a, ind, axis)] = values
dpt.put_along_axis(usm_a, usm_ind, usm_vals, axis=axis, mode=mode)


def putmask(x1, mask, values):
Expand Down
51 changes: 35 additions & 16 deletions tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,38 +594,57 @@ def test_replace_max(self, arr_dt, axis):
],
)
def test_values(self, arr_dt, idx_dt, ndim, values):
np_a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim)
np_ai = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape(
a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim)
ind = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape(
(1,) * (ndim - 1) + (4,)
)

dp_a = dpnp.array(np_a, dtype=arr_dt)
dp_ai = dpnp.array(np_ai, dtype=idx_dt)
ia, iind = dpnp.array(a), dpnp.array(ind)

for axis in range(ndim):
numpy.put_along_axis(np_a, np_ai, values, axis)
dpnp.put_along_axis(dp_a, dp_ai, values, axis)
assert_array_equal(np_a, dp_a)
numpy.put_along_axis(a, ind, values, axis)
dpnp.put_along_axis(ia, iind, values, axis)
assert_array_equal(ia, a)

@pytest.mark.parametrize("xp", [numpy, dpnp])
@pytest.mark.parametrize("dt", [bool, numpy.float32])
def test_invalid_indices_dtype(self, xp, dt):
a = xp.ones((10, 10))
ind = xp.ones(10, dtype=dt)
ind = xp.ones_like(a, dtype=dt)
assert_raises(IndexError, xp.put_along_axis, a, ind, 7, axis=1)

@pytest.mark.parametrize("arr_dt", get_all_dtypes())
@pytest.mark.parametrize("idx_dt", get_integer_dtypes())
def test_broadcast(self, arr_dt, idx_dt):
np_a = numpy.ones((3, 4, 1), dtype=arr_dt)
np_ai = numpy.arange(10, dtype=idx_dt).reshape((1, 2, 5)) % 4
a = numpy.ones((3, 4, 1), dtype=arr_dt)
ind = numpy.arange(10, dtype=idx_dt).reshape((1, 2, 5)) % 4
ia, iind = dpnp.array(a), dpnp.array(ind)

numpy.put_along_axis(a, ind, 20, axis=1)
dpnp.put_along_axis(ia, iind, 20, axis=1)
assert_array_equal(ia, a)

def test_mode_wrap(self):
a = numpy.array([-2, -1, 0, 1, 2])
ind = numpy.array([-2, 2, -5, 4])
ia, iind = dpnp.array(a), dpnp.array(ind)

dpnp.put_along_axis(ia, iind, 3, axis=0, mode="wrap")
numpy.put_along_axis(a, ind, 3, axis=0)
assert_array_equal(ia, a)

def test_mode_clip(self):
a = dpnp.array([-2, -1, 0, 1, 2])
ind = dpnp.array([-2, 2, -5, 4])

dp_a = dpnp.array(np_a, dtype=arr_dt)
dp_ai = dpnp.array(np_ai, dtype=idx_dt)
# numpy does not support keyword `mode`
dpnp.put_along_axis(a, ind, 4, axis=0, mode="clip")
assert (a == dpnp.array([4, -1, 4, 1, 4])).all()

numpy.put_along_axis(np_a, np_ai, 20, axis=1)
dpnp.put_along_axis(dp_a, dp_ai, 20, axis=1)
assert_array_equal(np_a, dp_a)
@pytest.mark.parametrize("xp", [numpy, dpnp])
def test_indices_ndim_axis_none(self, xp):
a = xp.ones((10, 10))
ind = xp.ones((10, 2), dtype=xp.intp)
assert_raises(ValueError, xp.put_along_axis, a, ind, -1, axis=None)


class TestTake:
Expand Down
Loading