37
37
38
38
"""
39
39
40
+ # pylint: disable=protected-access
41
+
40
42
import operator
41
43
42
44
import dpctl .tensor as dpt
45
+ import dpctl .tensor ._tensor_impl as ti
46
+ import dpctl .utils as dpu
43
47
import numpy
48
+ from dpctl .tensor ._copy_utils import _nonzero_impl
49
+ from dpctl .tensor ._indexing_functions import _get_indexing_mode
44
50
from dpctl .tensor ._numpy_helper import normalize_axis_index
45
51
46
52
import dpnp
55
61
56
62
__all__ = [
57
63
"choose" ,
64
+ "compress" ,
58
65
"diag_indices" ,
59
66
"diag_indices_from" ,
60
67
"diagonal" ,
@@ -155,6 +162,157 @@ def choose(x1, choices, out=None, mode="raise"):
155
162
return call_origin (numpy .choose , x1 , choices , out , mode )
156
163
157
164
165
+ def _take_index (x , inds , axis , q , usm_type , out = None , mode = 0 ):
166
+ # arg validation assumed done by caller
167
+ x_sh = x .shape
168
+ axis_end = axis + 1
169
+ if 0 in x_sh [axis :axis_end ] and inds .size != 0 :
170
+ raise IndexError ("cannot take non-empty indices from an empty axis" )
171
+ res_sh = x_sh [:axis ] + inds .shape + x_sh [axis_end :]
172
+
173
+ if out is not None :
174
+ out = dpnp .get_usm_ndarray (out )
175
+
176
+ if not out .flags .writable :
177
+ raise ValueError ("provided `out` array is read-only" )
178
+
179
+ if out .shape != res_sh :
180
+ raise ValueError (
181
+ "The shape of input and output arrays are inconsistent. "
182
+ f"Expected output shape is { res_sh } , got { out .shape } "
183
+ )
184
+
185
+ if x .dtype != out .dtype :
186
+ raise TypeError (
187
+ f"Output array of type { x .dtype } is needed, " f"got { out .dtype } "
188
+ )
189
+
190
+ if dpu .get_execution_queue ((q , out .sycl_queue )) is None :
191
+ raise dpu .ExecutionPlacementError (
192
+ "Input and output allocation queues are not compatible"
193
+ )
194
+
195
+ if ti ._array_overlap (x , out ):
196
+ # Allocate a temporary buffer to avoid memory overlapping.
197
+ out = dpt .empty_like (out )
198
+ else :
199
+ out = dpt .empty (res_sh , dtype = x .dtype , usm_type = usm_type , sycl_queue = q )
200
+
201
+ _manager = dpu .SequentialOrderManager [q ]
202
+ dep_evs = _manager .submitted_events
203
+
204
+ h_ev , take_ev = ti ._take (
205
+ src = x ,
206
+ ind = (inds ,),
207
+ dst = out ,
208
+ axis_start = axis ,
209
+ mode = mode ,
210
+ sycl_queue = q ,
211
+ depends = dep_evs ,
212
+ )
213
+ _manager .add_event_pair (h_ev , take_ev )
214
+
215
+ return out
216
+
217
+
218
+ def compress (condition , a , axis = None , out = None ):
219
+ """
220
+ Return selected slices of an array along given axis.
221
+
222
+ A slice of `a` is returned for each index along `axis` where `condition`
223
+ is ``True``.
224
+
225
+ For full documentation refer to :obj:`numpy.choose`.
226
+
227
+ Parameters
228
+ ----------
229
+ condition : {array_like, dpnp.ndarray, usm_ndarray}
230
+ Array that selects which entries to extract. If the length of
231
+ `condition` is less than the size of `a` along `axis`, then
232
+ the output is truncated to the length of `condition`.
233
+ a : {dpnp.ndarray, usm_ndarray}
234
+ Array to extract from.
235
+ axis : {None, int}, optional
236
+ Axis along which to extract slices. If ``None``, works over the
237
+ flattened array.
238
+ Default: ``None``.
239
+ out : {None, dpnp.ndarray, usm_ndarray}, optional
240
+ If provided, the result will be placed in this array. It should
241
+ be of the appropriate shape and dtype.
242
+ Default: ``None``.
243
+
244
+ Returns
245
+ -------
246
+ out : dpnp.ndarray
247
+ A copy of the slices of `a` where `condition` is ``True``.
248
+
249
+ See also
250
+ --------
251
+ :obj:`dpnp.take` : Take elements from an array along an axis.
252
+ :obj:`dpnp.choose` : Construct an array from an index array and a set of
253
+ arrays to choose from.
254
+ :obj:`dpnp.diag` : Extract a diagonal or construct a diagonal array.
255
+ :obj:`dpnp.diagonal` : Return specified diagonals.
256
+ :obj:`dpnp.select` : Return an array drawn from elements in `choicelist`,
257
+ depending on conditions.
258
+ :obj:`dpnp.ndarray.compress` : Equivalent method.
259
+ :obj:`dpnp.extract` : Equivalent function when working on 1-D arrays.
260
+
261
+ Examples
262
+ --------
263
+ >>> import numpy as np
264
+ >>> a = np.array([[1, 2], [3, 4], [5, 6]])
265
+ >>> a
266
+ array([[1, 2],
267
+ [3, 4],
268
+ [5, 6]])
269
+ >>> np.compress([0, 1], a, axis=0)
270
+ array([[3, 4]])
271
+ >>> np.compress([False, True, True], a, axis=0)
272
+ array([[3, 4],
273
+ [5, 6]])
274
+ >>> np.compress([False, True], a, axis=1)
275
+ array([[2],
276
+ [4],
277
+ [6]])
278
+
279
+ Working on the flattened array does not return slices along an axis but
280
+ selects elements.
281
+
282
+ >>> np.compress([False, True], a)
283
+ array([2])
284
+ """
285
+
286
+ dpnp .check_supported_arrays_type (a )
287
+ if axis is None :
288
+ if a .ndim != 1 :
289
+ a = dpnp .ravel (a )
290
+ axis = 0
291
+ axis = normalize_axis_index (operator .index (axis ), a .ndim )
292
+
293
+ a_ary = dpnp .get_usm_ndarray (a )
294
+ cond_ary = dpnp .as_usm_ndarray (
295
+ condition ,
296
+ dtype = dpnp .bool ,
297
+ usm_type = a_ary .usm_type ,
298
+ sycl_queue = a_ary .sycl_queue ,
299
+ )
300
+
301
+ if not cond_ary .ndim == 1 :
302
+ raise ValueError (
303
+ "`condition` must be a 1-D array or un-nested sequence"
304
+ )
305
+
306
+ res_usm_type , exec_q = get_usm_allocations ([a_ary , cond_ary ])
307
+
308
+ # _nonzero_impl synchronizes and returns a tuple of usm_ndarray indices
309
+ inds = _nonzero_impl (cond_ary )
310
+
311
+ res = _take_index (a_ary , inds [0 ], axis , exec_q , res_usm_type , out = out )
312
+
313
+ return dpnp .get_result_array (res , out = out )
314
+
315
+
158
316
def diag_indices (n , ndim = 2 , device = None , usm_type = "device" , sycl_queue = None ):
159
317
"""
160
318
Return the indices to access the main diagonal of an array.
@@ -1806,8 +1964,8 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
1806
1964
1807
1965
"""
1808
1966
1809
- if mode not in ( "wrap" , "clip" ):
1810
- raise ValueError ( f"` mode` must be 'wrap' or 'clip', but got ` { mode } `." )
1967
+ # sets mode to 0 for "wrap" and 1 for "clip", raises otherwise
1968
+ mode = _get_indexing_mode ( mode )
1811
1969
1812
1970
usm_a = dpnp .get_usm_ndarray (a )
1813
1971
if not dpnp .is_supported_array_type (indices ):
@@ -1817,34 +1975,28 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
1817
1975
else :
1818
1976
usm_ind = dpnp .get_usm_ndarray (indices )
1819
1977
1978
+ res_usm_type , exec_q = get_usm_allocations ([usm_a , usm_ind ])
1979
+
1820
1980
a_ndim = a .ndim
1821
1981
if axis is None :
1822
- res_shape = usm_ind .shape
1823
-
1824
1982
if a_ndim > 1 :
1825
- # dpt.take requires flattened input array
1983
+ # flatten input array
1826
1984
usm_a = dpt .reshape (usm_a , - 1 )
1985
+ axis = 0
1827
1986
elif a_ndim == 0 :
1828
1987
axis = normalize_axis_index (operator .index (axis ), 1 )
1829
- res_shape = usm_ind .shape
1830
1988
else :
1831
1989
axis = normalize_axis_index (operator .index (axis ), a_ndim )
1832
- a_sh = a .shape
1833
- res_shape = a_sh [:axis ] + usm_ind .shape + a_sh [axis + 1 :]
1834
-
1835
- if usm_ind .ndim != 1 :
1836
- # dpt.take supports only 1-D array of indices
1837
- usm_ind = dpt .reshape (usm_ind , - 1 )
1838
1990
1839
1991
if not dpnp .issubdtype (usm_ind .dtype , dpnp .integer ):
1840
1992
# dpt.take supports only integer dtype for array of indices
1841
1993
usm_ind = dpt .astype (usm_ind , dpnp .intp , copy = False , casting = "safe" )
1842
1994
1843
- usm_res = dpt .take (usm_a , usm_ind , axis = axis , mode = mode )
1995
+ usm_res = _take_index (
1996
+ usm_a , usm_ind , axis , exec_q , res_usm_type , out = out , mode = mode
1997
+ )
1844
1998
1845
- # need to reshape the result if shape of indices array was changed
1846
- result = dpnp .reshape (usm_res , res_shape )
1847
- return dpnp .get_result_array (result , out )
1999
+ return dpnp .get_result_array (usm_res , out = out )
1848
2000
1849
2001
1850
2002
def take_along_axis (a , indices , axis , mode = "wrap" ):
0 commit comments