Skip to content

Commit b617d6c

Browse files
authored
add support for kwarg ndmin for dpnp.array (#2135)
* add support for ndmin for dpnp.array * update dpnp.ascontiguousarray and dpnp.asfortranarray * update to fix issue for 0-d arrays * address comments
1 parent 6ba840a commit b617d6c

File tree

6 files changed

+322
-368
lines changed

6 files changed

+322
-368
lines changed

dpnp/dpnp_iface_arraycreation.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,11 @@ def array(
319319
order : {"C", "F", "A", "K"}, optional
320320
Memory layout of the newly output array.
321321
Default: ``"K"``.
322+
ndmin : int, optional
323+
Specifies the minimum number of dimensions that the resulting array
324+
should have. Ones will be prepended to the shape as needed to meet
325+
this requirement.
326+
Default: ``0``.
322327
device : {None, string, SyclDevice, SyclQueue}, optional
323328
An array API concept of device where the output array is created.
324329
The `device` can be ``None`` (the default), an OneAPI filter selector
@@ -345,7 +350,6 @@ def array(
345350
Limitations
346351
-----------
347352
Parameter `subok` is supported only with default value ``False``.
348-
Parameter `ndmin` is supported only with default value ``0``.
349353
Parameter `like` is supported only with default value ``None``.
350354
Otherwise, the function raises ``NotImplementedError`` exception.
351355
@@ -373,6 +377,11 @@ def array(
373377
>>> x
374378
array([1, 2, 3])
375379
380+
Upcasting:
381+
382+
>>> np.array([1, 2, 3.0])
383+
array([ 1., 2., 3.])
384+
376385
More than one dimension:
377386
378387
>>> x2 = np.array([[1, 2], [3, 4]])
@@ -382,6 +391,16 @@ def array(
382391
array([[1, 2],
383392
[3, 4]])
384393
394+
Minimum dimensions 2:
395+
396+
>>> np.array([1, 2, 3], ndmin=2)
397+
array([[1, 2, 3]])
398+
399+
Type provided:
400+
401+
>>> np.array([1, 2, 3], dtype=complex)
402+
array([ 1.+0.j, 2.+0.j, 3.+0.j])
403+
385404
Creating an array on a different device or with a specified usm_type
386405
387406
>>> x = np.array([1, 2, 3]) # default case
@@ -399,13 +418,10 @@ def array(
399418
"""
400419

401420
dpnp.check_limitations(subok=subok, like=like)
402-
if ndmin != 0:
403-
raise NotImplementedError(
404-
"Keyword argument `ndmin` is supported only with "
405-
f"default value ``0``, but got {ndmin}"
406-
)
421+
if not isinstance(ndmin, (int, dpnp.integer)):
422+
raise TypeError(f"`ndmin` should be an integer, got {type(ndmin)}")
407423

408-
return dpnp_container.asarray(
424+
result = dpnp_container.asarray(
409425
a,
410426
dtype=dtype,
411427
copy=copy,
@@ -415,6 +431,14 @@ def array(
415431
sycl_queue=sycl_queue,
416432
)
417433

434+
res_ndim = result.ndim
435+
if res_ndim >= ndmin:
436+
return result
437+
438+
num_axes = ndmin - res_ndim
439+
new_shape = (1,) * num_axes + result.shape
440+
return result.reshape(new_shape)
441+
418442

419443
def asanyarray(
420444
a,
@@ -635,7 +659,7 @@ def ascontiguousarray(
635659
a, dtype=None, *, like=None, device=None, usm_type=None, sycl_queue=None
636660
):
637661
"""
638-
Return a contiguous array in memory (C order).
662+
Return a contiguous array ``(ndim >= 1)`` in memory (C order).
639663
640664
For full documentation refer to :obj:`numpy.ascontiguousarray`.
641665
@@ -731,14 +755,12 @@ def ascontiguousarray(
731755

732756
dpnp.check_limitations(like=like)
733757

734-
# at least 1-d array has to be returned
735-
if dpnp.isscalar(a) or hasattr(a, "ndim") and a.ndim == 0:
736-
a = [a]
737-
738-
return asarray(
758+
return dpnp.array(
739759
a,
740760
dtype=dtype,
761+
copy=None,
741762
order="C",
763+
ndmin=1,
742764
device=device,
743765
usm_type=usm_type,
744766
sycl_queue=sycl_queue,
@@ -849,14 +871,12 @@ def asfortranarray(
849871

850872
dpnp.check_limitations(like=like)
851873

852-
# at least 1-d array has to be returned
853-
if dpnp.isscalar(a) or hasattr(a, "ndim") and a.ndim == 0:
854-
a = [a]
855-
856-
return asarray(
874+
return dpnp.array(
857875
a,
858876
dtype=dtype,
877+
copy=None,
859878
order="F",
879+
ndmin=1,
860880
device=device,
861881
usm_type=usm_type,
862882
sycl_queue=sycl_queue,

0 commit comments

Comments
 (0)