Skip to content

Commit 4f4a446

Browse files
ndgrigorianvtavana
authored andcommitted
Implement dpnp.compress and dpnp_array.compress method (#2177)
* Implement dpnp.compress * Add `dpnp_array.compress` method * Break up `compress` to satisfy pylint Also disable checks for protected access, as `compress` uses dpctl.tensor private functions * Unmute third-party tests for `compress` * Use `get_usm_allocations` in `compress` * Fix bug where `out` in `compress` is dpnp_array Also removes an unnecessary check per PR review * Apply comments per PR review by @antonwolfy Also fix a typo when `condition` is not an array * Remove branching when `condition` is an array Also tweaks to docstring * Add tests for `compress` * Re-use `_take_index` for `dpnp.take` Should slightly improve efficiency by escaping an additional copy where `out` is not `None` and flattening of indices * Change error for incorrect out array dtype to `TypeError` * Move compress tests into a TestCompress class * Use NumPy in compress tests * Add `no_none=True` to `test_compress_condition_all_dtypes` * Add USM type and SYCL queue tests for `compress` * More tests for compress added * Docstring change per PR review * Integrate test for compute follows data in compress into test_2in_1out * Add test for `dpnp_array.compress` and add a test for strided inputs to `compress` * Refactor `test_compress` in test_usm_type.py into `test_2in_1out`
1 parent 354ea03 commit 4f4a446

File tree

6 files changed

+278
-24
lines changed

6 files changed

+278
-24
lines changed

dpnp/dpnp_array.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,14 @@ def clip(self, min=None, max=None, out=None, **kwargs):
786786

787787
return dpnp.clip(self, min, max, out=out, **kwargs)
788788

789-
# 'compress',
789+
def compress(self, condition, axis=None, out=None):
790+
"""
791+
Select slices of an array along a given axis.
792+
793+
Refer to :obj:`dpnp.compress` for full documentation.
794+
"""
795+
796+
return dpnp.compress(condition, self, axis=axis, out=out)
790797

791798
def conj(self):
792799
"""

dpnp/dpnp_iface_indexing.py

Lines changed: 168 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,16 @@
3737
3838
"""
3939

40+
# pylint: disable=protected-access
41+
4042
import operator
4143

4244
import dpctl.tensor as dpt
45+
import dpctl.tensor._tensor_impl as ti
46+
import dpctl.utils as dpu
4347
import numpy
48+
from dpctl.tensor._copy_utils import _nonzero_impl
49+
from dpctl.tensor._indexing_functions import _get_indexing_mode
4450
from dpctl.tensor._numpy_helper import normalize_axis_index
4551

4652
import dpnp
@@ -55,6 +61,7 @@
5561

5662
__all__ = [
5763
"choose",
64+
"compress",
5865
"diag_indices",
5966
"diag_indices_from",
6067
"diagonal",
@@ -155,6 +162,157 @@ def choose(x1, choices, out=None, mode="raise"):
155162
return call_origin(numpy.choose, x1, choices, out, mode)
156163

157164

165+
def _take_index(x, inds, axis, q, usm_type, out=None, mode=0):
166+
# arg validation assumed done by caller
167+
x_sh = x.shape
168+
axis_end = axis + 1
169+
if 0 in x_sh[axis:axis_end] and inds.size != 0:
170+
raise IndexError("cannot take non-empty indices from an empty axis")
171+
res_sh = x_sh[:axis] + inds.shape + x_sh[axis_end:]
172+
173+
if out is not None:
174+
out = dpnp.get_usm_ndarray(out)
175+
176+
if not out.flags.writable:
177+
raise ValueError("provided `out` array is read-only")
178+
179+
if out.shape != res_sh:
180+
raise ValueError(
181+
"The shape of input and output arrays are inconsistent. "
182+
f"Expected output shape is {res_sh}, got {out.shape}"
183+
)
184+
185+
if x.dtype != out.dtype:
186+
raise TypeError(
187+
f"Output array of type {x.dtype} is needed, " f"got {out.dtype}"
188+
)
189+
190+
if dpu.get_execution_queue((q, out.sycl_queue)) is None:
191+
raise dpu.ExecutionPlacementError(
192+
"Input and output allocation queues are not compatible"
193+
)
194+
195+
if ti._array_overlap(x, out):
196+
# Allocate a temporary buffer to avoid memory overlapping.
197+
out = dpt.empty_like(out)
198+
else:
199+
out = dpt.empty(res_sh, dtype=x.dtype, usm_type=usm_type, sycl_queue=q)
200+
201+
_manager = dpu.SequentialOrderManager[q]
202+
dep_evs = _manager.submitted_events
203+
204+
h_ev, take_ev = ti._take(
205+
src=x,
206+
ind=(inds,),
207+
dst=out,
208+
axis_start=axis,
209+
mode=mode,
210+
sycl_queue=q,
211+
depends=dep_evs,
212+
)
213+
_manager.add_event_pair(h_ev, take_ev)
214+
215+
return out
216+
217+
218+
def compress(condition, a, axis=None, out=None):
219+
"""
220+
Return selected slices of an array along given axis.
221+
222+
A slice of `a` is returned for each index along `axis` where `condition`
223+
is ``True``.
224+
225+
For full documentation refer to :obj:`numpy.choose`.
226+
227+
Parameters
228+
----------
229+
condition : {array_like, dpnp.ndarray, usm_ndarray}
230+
Array that selects which entries to extract. If the length of
231+
`condition` is less than the size of `a` along `axis`, then
232+
the output is truncated to the length of `condition`.
233+
a : {dpnp.ndarray, usm_ndarray}
234+
Array to extract from.
235+
axis : {None, int}, optional
236+
Axis along which to extract slices. If ``None``, works over the
237+
flattened array.
238+
Default: ``None``.
239+
out : {None, dpnp.ndarray, usm_ndarray}, optional
240+
If provided, the result will be placed in this array. It should
241+
be of the appropriate shape and dtype.
242+
Default: ``None``.
243+
244+
Returns
245+
-------
246+
out : dpnp.ndarray
247+
A copy of the slices of `a` where `condition` is ``True``.
248+
249+
See also
250+
--------
251+
:obj:`dpnp.take` : Take elements from an array along an axis.
252+
:obj:`dpnp.choose` : Construct an array from an index array and a set of
253+
arrays to choose from.
254+
:obj:`dpnp.diag` : Extract a diagonal or construct a diagonal array.
255+
:obj:`dpnp.diagonal` : Return specified diagonals.
256+
:obj:`dpnp.select` : Return an array drawn from elements in `choicelist`,
257+
depending on conditions.
258+
:obj:`dpnp.ndarray.compress` : Equivalent method.
259+
:obj:`dpnp.extract` : Equivalent function when working on 1-D arrays.
260+
261+
Examples
262+
--------
263+
>>> import numpy as np
264+
>>> a = np.array([[1, 2], [3, 4], [5, 6]])
265+
>>> a
266+
array([[1, 2],
267+
[3, 4],
268+
[5, 6]])
269+
>>> np.compress([0, 1], a, axis=0)
270+
array([[3, 4]])
271+
>>> np.compress([False, True, True], a, axis=0)
272+
array([[3, 4],
273+
[5, 6]])
274+
>>> np.compress([False, True], a, axis=1)
275+
array([[2],
276+
[4],
277+
[6]])
278+
279+
Working on the flattened array does not return slices along an axis but
280+
selects elements.
281+
282+
>>> np.compress([False, True], a)
283+
array([2])
284+
"""
285+
286+
dpnp.check_supported_arrays_type(a)
287+
if axis is None:
288+
if a.ndim != 1:
289+
a = dpnp.ravel(a)
290+
axis = 0
291+
axis = normalize_axis_index(operator.index(axis), a.ndim)
292+
293+
a_ary = dpnp.get_usm_ndarray(a)
294+
cond_ary = dpnp.as_usm_ndarray(
295+
condition,
296+
dtype=dpnp.bool,
297+
usm_type=a_ary.usm_type,
298+
sycl_queue=a_ary.sycl_queue,
299+
)
300+
301+
if not cond_ary.ndim == 1:
302+
raise ValueError(
303+
"`condition` must be a 1-D array or un-nested sequence"
304+
)
305+
306+
res_usm_type, exec_q = get_usm_allocations([a_ary, cond_ary])
307+
308+
# _nonzero_impl synchronizes and returns a tuple of usm_ndarray indices
309+
inds = _nonzero_impl(cond_ary)
310+
311+
res = _take_index(a_ary, inds[0], axis, exec_q, res_usm_type, out=out)
312+
313+
return dpnp.get_result_array(res, out=out)
314+
315+
158316
def diag_indices(n, ndim=2, device=None, usm_type="device", sycl_queue=None):
159317
"""
160318
Return the indices to access the main diagonal of an array.
@@ -1806,8 +1964,8 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
18061964
18071965
"""
18081966

1809-
if mode not in ("wrap", "clip"):
1810-
raise ValueError(f"`mode` must be 'wrap' or 'clip', but got `{mode}`.")
1967+
# sets mode to 0 for "wrap" and 1 for "clip", raises otherwise
1968+
mode = _get_indexing_mode(mode)
18111969

18121970
usm_a = dpnp.get_usm_ndarray(a)
18131971
if not dpnp.is_supported_array_type(indices):
@@ -1817,34 +1975,28 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
18171975
else:
18181976
usm_ind = dpnp.get_usm_ndarray(indices)
18191977

1978+
res_usm_type, exec_q = get_usm_allocations([usm_a, usm_ind])
1979+
18201980
a_ndim = a.ndim
18211981
if axis is None:
1822-
res_shape = usm_ind.shape
1823-
18241982
if a_ndim > 1:
1825-
# dpt.take requires flattened input array
1983+
# flatten input array
18261984
usm_a = dpt.reshape(usm_a, -1)
1985+
axis = 0
18271986
elif a_ndim == 0:
18281987
axis = normalize_axis_index(operator.index(axis), 1)
1829-
res_shape = usm_ind.shape
18301988
else:
18311989
axis = normalize_axis_index(operator.index(axis), a_ndim)
1832-
a_sh = a.shape
1833-
res_shape = a_sh[:axis] + usm_ind.shape + a_sh[axis + 1 :]
1834-
1835-
if usm_ind.ndim != 1:
1836-
# dpt.take supports only 1-D array of indices
1837-
usm_ind = dpt.reshape(usm_ind, -1)
18381990

18391991
if not dpnp.issubdtype(usm_ind.dtype, dpnp.integer):
18401992
# dpt.take supports only integer dtype for array of indices
18411993
usm_ind = dpt.astype(usm_ind, dpnp.intp, copy=False, casting="safe")
18421994

1843-
usm_res = dpt.take(usm_a, usm_ind, axis=axis, mode=mode)
1995+
usm_res = _take_index(
1996+
usm_a, usm_ind, axis, exec_q, res_usm_type, out=out, mode=mode
1997+
)
18441998

1845-
# need to reshape the result if shape of indices array was changed
1846-
result = dpnp.reshape(usm_res, res_shape)
1847-
return dpnp.get_result_array(result, out)
1999+
return dpnp.get_result_array(usm_res, out=out)
18482000

18492001

18502002
def take_along_axis(a, indices, axis, mode="wrap"):

dpnp/tests/test_indexing.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import functools
22

3+
import dpctl
34
import dpctl.tensor as dpt
45
import numpy
56
import pytest
67
from dpctl.tensor._numpy_helper import AxisError
8+
from dpctl.utils import ExecutionPlacementError
79
from numpy.testing import (
810
assert_,
911
assert_array_equal,
@@ -1333,3 +1335,101 @@ def test_error(self):
13331335
dpnp.select([x0], [x1], default=x1)
13341336
with pytest.raises(TypeError):
13351337
dpnp.select([x1], [x1])
1338+
1339+
1340+
class TestCompress:
1341+
def test_compress_basic(self):
1342+
conditions = [True, False, True]
1343+
a_np = numpy.arange(16).reshape(4, 4)
1344+
a = dpnp.arange(16).reshape(4, 4)
1345+
cond_np = numpy.array(conditions)
1346+
cond = dpnp.array(conditions)
1347+
expected = numpy.compress(cond_np, a_np, axis=0)
1348+
result = dpnp.compress(cond, a, axis=0)
1349+
assert_array_equal(expected, result)
1350+
1351+
def test_compress_method_basic(self):
1352+
conditions = [True, True, False, True]
1353+
a_np = numpy.arange(3 * 4).reshape(3, 4)
1354+
a = dpnp.arange(3 * 4).reshape(3, 4)
1355+
cond_np = numpy.array(conditions)
1356+
cond = dpnp.array(conditions)
1357+
expected = a_np.compress(cond_np, axis=1)
1358+
result = a.compress(cond, axis=1)
1359+
assert_array_equal(expected, result)
1360+
1361+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
1362+
def test_compress_condition_all_dtypes(self, dtype):
1363+
a_np = numpy.arange(10, dtype="i4")
1364+
a = dpnp.arange(10, dtype="i4")
1365+
cond_np = numpy.tile(numpy.asarray([0, 1], dtype=dtype), 5)
1366+
cond = dpnp.tile(dpnp.asarray([0, 1], dtype=dtype), 5)
1367+
expected = numpy.compress(cond_np, a_np)
1368+
result = dpnp.compress(cond, a)
1369+
assert_array_equal(expected, result)
1370+
1371+
def test_compress_invalid_out_errors(self):
1372+
q1 = dpctl.SyclQueue()
1373+
q2 = dpctl.SyclQueue()
1374+
a = dpnp.ones(10, dtype="i4", sycl_queue=q1)
1375+
condition = dpnp.asarray([True], sycl_queue=q1)
1376+
out_bad_shape = dpnp.empty_like(a)
1377+
with pytest.raises(ValueError):
1378+
dpnp.compress(condition, a, out=out_bad_shape)
1379+
out_bad_queue = dpnp.empty(1, dtype="i4", sycl_queue=q2)
1380+
with pytest.raises(ExecutionPlacementError):
1381+
dpnp.compress(condition, a, out=out_bad_queue)
1382+
out_bad_dt = dpnp.empty(1, dtype="i8", sycl_queue=q1)
1383+
with pytest.raises(TypeError):
1384+
dpnp.compress(condition, a, out=out_bad_dt)
1385+
out_read_only = dpnp.empty(1, dtype="i4", sycl_queue=q1)
1386+
out_read_only.flags.writable = False
1387+
with pytest.raises(ValueError):
1388+
dpnp.compress(condition, a, out=out_read_only)
1389+
1390+
def test_compress_empty_axis(self):
1391+
a = dpnp.ones((10, 0, 5), dtype="i4")
1392+
condition = [True, False, True]
1393+
r = dpnp.compress(condition, a, axis=0)
1394+
assert r.shape == (2, 0, 5)
1395+
# empty take from empty axis is permitted
1396+
assert dpnp.compress([False], a, axis=1).shape == (10, 0, 5)
1397+
# non-empty take from empty axis raises IndexError
1398+
with pytest.raises(IndexError):
1399+
dpnp.compress(condition, a, axis=1)
1400+
1401+
def test_compress_in_overlaps_out(self):
1402+
conditions = [False, True, True]
1403+
a_np = numpy.arange(6)
1404+
a = dpnp.arange(6)
1405+
cond_np = numpy.array(conditions)
1406+
cond = dpnp.array(conditions)
1407+
out = a[2:4]
1408+
expected = numpy.compress(cond_np, a_np, axis=None)
1409+
result = dpnp.compress(cond, a, axis=None, out=out)
1410+
assert_array_equal(expected, result)
1411+
assert result is out
1412+
assert (a[2:4] == out).all()
1413+
1414+
def test_compress_condition_not_1d(self):
1415+
a = dpnp.arange(4)
1416+
cond = dpnp.ones((1, 4), dtype="?")
1417+
with pytest.raises(ValueError):
1418+
dpnp.compress(cond, a, axis=None)
1419+
1420+
def test_compress_strided(self):
1421+
a = dpnp.arange(20)
1422+
a_np = dpnp.asnumpy(a)
1423+
cond = dpnp.tile(dpnp.array([True, False, False, True]), 5)
1424+
cond_np = dpnp.asnumpy(cond)
1425+
result = dpnp.compress(cond, a)
1426+
expected = numpy.compress(cond_np, a_np)
1427+
assert_array_equal(result, expected)
1428+
# use axis keyword
1429+
a = dpnp.arange(50).reshape(10, 5)
1430+
a_np = dpnp.asnumpy(a)
1431+
cond = dpnp.array(dpnp.array([True, False, False, True, False]))
1432+
cond_np = dpnp.asnumpy(cond)
1433+
result = dpnp.compress(cond, a)
1434+
expected = numpy.compress(cond_np, a_np)
1435+
assert_array_equal(result, expected)

dpnp/tests/test_sycl_queue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,7 @@ def test_reduce_hypot(device):
718718
),
719719
pytest.param("append", [1, 2, 3], [4, 5, 6]),
720720
pytest.param("arctan2", [-1, +1, +1, -1], [-1, -1, +1, +1]),
721+
pytest.param("compress", [0, 1, 1, 0], [0, 1, 2, 3]),
721722
pytest.param("copysign", [0.0, 1.0, 2.0], [-1.0, 0.0, 1.0]),
722723
pytest.param(
723724
"corrcoef",

dpnp/tests/test_usm_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ def test_1in_1out(func, data, usm_type):
686686
),
687687
pytest.param("append", [1, 2, 3], [4, 5, 6]),
688688
pytest.param("arctan2", [-1, +1, +1, -1], [-1, -1, +1, +1]),
689+
pytest.param("compress", [False, True, True], [0, 1, 2, 3, 4]),
689690
pytest.param("copysign", [0.0, 1.0, 2.0], [-1.0, 0.0, 1.0]),
690691
pytest.param("cross", [1.0, 2.0, 3.0], [4.0, 5.0, 6.0]),
691692
pytest.param("digitize", [0.2, 6.4, 3.0], [0.0, 1.0, 2.5, 4.0]),

0 commit comments

Comments
 (0)