Skip to content

Commit e0c9cf1

Browse files
authored
implement dpnp.nanmedian (#2191)
In this PR, `dpnp.nanmedian` is implemented.
1 parent e53fa72 commit e0c9cf1

File tree

8 files changed

+440
-86
lines changed

8 files changed

+440
-86
lines changed

dpnp/dpnp_iface_nanfunctions.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import warnings
4141

4242
import dpnp
43+
from dpnp.dpnp_utils.dpnp_utils_statistics import dpnp_median
4344

4445
__all__ = [
4546
"nanargmax",
@@ -48,6 +49,7 @@
4849
"nancumsum",
4950
"nanmax",
5051
"nanmean",
52+
"nanmedian",
5153
"nanmin",
5254
"nanprod",
5355
"nanstd",
@@ -568,6 +570,107 @@ def nanmean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
568570
return avg
569571

570572

573+
def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False):
574+
"""
575+
Compute the median along the specified axis, while ignoring NaNs.
576+
577+
For full documentation refer to :obj:`numpy.nanmedian`.
578+
579+
Parameters
580+
----------
581+
a : {dpnp.ndarray, usm_ndarray}
582+
Input array.
583+
axis : {None, int, tuple or list of ints}, optional
584+
Axis or axes along which the medians are computed. The default,
585+
``axis=None``, will compute the median along a flattened version of
586+
the array. If a sequence of axes, the array is first flattened along
587+
the given axes, then the median is computed along the resulting
588+
flattened axis.
589+
Default: ``None``.
590+
out : {None, dpnp.ndarray, usm_ndarray}, optional
591+
Alternative output array in which to place the result. It must have
592+
the same shape as the expected output but the type (of the calculated
593+
values) will be cast if necessary.
594+
Default: ``None``.
595+
overwrite_input : bool, optional
596+
If ``True``, then allow use of memory of input array `a` for
597+
calculations. The input array will be modified by the call to
598+
:obj:`dpnp.nanmedian`. This will save memory when you do not need to
599+
preserve the contents of the input array. Treat the input as undefined,
600+
but it will probably be fully or partially sorted.
601+
Default: ``False``.
602+
keepdims : bool, optional
603+
If ``True``, the reduced axes (dimensions) are included in the result
604+
as singleton dimensions, so that the returned array remains
605+
compatible with the input array according to Array Broadcasting
606+
rules. Otherwise, if ``False``, the reduced axes are not included in
607+
the returned array.
608+
Default: ``False``.
609+
610+
Returns
611+
-------
612+
out : dpnp.ndarray
613+
A new array holding the result. If `a` has a floating-point data type,
614+
the returned array will have the same data type as `a`. If `a` has a
615+
boolean or integral data type, the returned array will have the
616+
default floating point data type for the device where input array `a`
617+
is allocated.
618+
619+
See Also
620+
--------
621+
:obj:`dpnp.mean` : Compute the arithmetic mean along the specified axis.
622+
:obj:`dpnp.median` : Compute the median along the specified axis.
623+
:obj:`dpnp.percentile` : Compute the q-th percentile of the data
624+
along the specified axis.
625+
626+
Notes
627+
-----
628+
Given a vector ``V`` of length ``N``, the median of ``V`` is the
629+
middle value of a sorted copy of ``V``, ``V_sorted`` - i.e.,
630+
``V_sorted[(N-1)/2]``, when ``N`` is odd, and the average of the
631+
two middle values of ``V_sorted`` when ``N`` is even.
632+
633+
Examples
634+
--------
635+
>>> import dpnp as np
636+
>>> a = np.array([[10.0, 7, 4], [3, 2, 1]])
637+
>>> a[0, 1] = np.nan
638+
>>> a
639+
array([[10., nan, 4.],
640+
[ 3., 2., 1.]])
641+
>>> np.median(a)
642+
array(nan)
643+
>>> np.nanmedian(a)
644+
array(3.)
645+
646+
>>> np.nanmedian(a, axis=0)
647+
array([6.5, 2., 2.5])
648+
>>> np.nanmedian(a, axis=1)
649+
array([7., 2.])
650+
651+
>>> b = a.copy()
652+
>>> np.nanmedian(b, axis=1, overwrite_input=True)
653+
array([7., 2.])
654+
>>> assert not np.all(a==b)
655+
>>> b = a.copy()
656+
>>> np.nanmedian(b, axis=None, overwrite_input=True)
657+
array(3.)
658+
>>> assert not np.all(a==b)
659+
660+
"""
661+
662+
dpnp.check_supported_arrays_type(a)
663+
ignore_nan = False
664+
if dpnp.issubdtype(a.dtype, dpnp.inexact):
665+
mask = dpnp.isnan(a)
666+
if dpnp.any(mask):
667+
ignore_nan = True
668+
669+
return dpnp_median(
670+
a, axis, out, overwrite_input, keepdims, ignore_nan=ignore_nan
671+
)
672+
673+
571674
def nanmin(a, axis=None, out=None, keepdims=False, initial=None, where=True):
572675
"""
573676
Return the minimum of an array or minimum along an axis, ignoring any NaNs.

dpnp/dpnp_iface_statistics.py

Lines changed: 11 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,15 @@
3939

4040
import dpctl.tensor as dpt
4141
import numpy
42-
from dpctl.tensor._numpy_helper import (
43-
normalize_axis_index,
44-
normalize_axis_tuple,
45-
)
42+
from dpctl.tensor._numpy_helper import normalize_axis_index
4643

4744
import dpnp
4845

4946
# pylint: disable=no-name-in-module
5047
from .dpnp_algo import dpnp_correlate
51-
from .dpnp_array import dpnp_array
5248
from .dpnp_utils import call_origin, get_usm_allocations
5349
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
54-
from .dpnp_utils.dpnp_utils_statistics import dpnp_cov
50+
from .dpnp_utils.dpnp_utils_statistics import dpnp_cov, dpnp_median
5551

5652
__all__ = [
5753
"amax",
@@ -113,22 +109,6 @@ def _count_reduce_items(arr, axis, where=True):
113109
return items
114110

115111

116-
def _flatten_array_along_axes(arr, axes_to_flatten):
117-
"""Flatten an array along a specific set of axes."""
118-
119-
axes_to_keep = (
120-
axis for axis in range(arr.ndim) if axis not in axes_to_flatten
121-
)
122-
123-
# Move the axes_to_flatten to the front
124-
arr_moved = dpnp.moveaxis(arr, axes_to_flatten, range(len(axes_to_flatten)))
125-
126-
new_shape = (-1,) + tuple(arr.shape[axis] for axis in axes_to_keep)
127-
flattened_arr = arr_moved.reshape(new_shape)
128-
129-
return flattened_arr
130-
131-
132112
def _get_comparison_res_dt(a, _dtype, _out):
133113
"""Get a data type used by dpctl for result array in comparison function."""
134114

@@ -765,7 +745,7 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
765745
preserve the contents of the input array. Treat the input as undefined,
766746
but it will probably be fully or partially sorted.
767747
Default: ``False``.
768-
keepdims : {None, bool}, optional
748+
keepdims : bool, optional
769749
If ``True``, the reduced axes (dimensions) are included in the result
770750
as singleton dimensions, so that the returned array remains
771751
compatible with the input array according to Array Broadcasting
@@ -775,7 +755,7 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
775755
776756
Returns
777757
-------
778-
dpnp.median : dpnp.ndarray
758+
out : dpnp.ndarray
779759
A new array holding the result. If `a` has a floating-point data type,
780760
the returned array will have the same data type as `a`. If `a` has a
781761
boolean or integral data type, the returned array will have the
@@ -808,20 +788,20 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
808788
>>> np.median(a, axis=0)
809789
array([6.5, 4.5, 2.5])
810790
>>> np.median(a, axis=1)
811-
array([7., 2.])
791+
array([7., 2.])
812792
>>> np.median(a, axis=(0, 1))
813793
array(3.5)
814794
815795
>>> m = np.median(a, axis=0)
816796
>>> out = np.zeros_like(m)
817797
>>> np.median(a, axis=0, out=m)
818-
array([6.5, 4.5, 2.5])
798+
array([6.5, 4.5, 2.5])
819799
>>> m
820-
array([6.5, 4.5, 2.5])
800+
array([6.5, 4.5, 2.5])
821801
822802
>>> b = a.copy()
823803
>>> np.median(b, axis=1, overwrite_input=True)
824-
array([7., 2.])
804+
array([7., 2.])
825805
>>> assert not np.all(a==b)
826806
>>> b = a.copy()
827807
>>> np.median(b, axis=None, overwrite_input=True)
@@ -831,62 +811,9 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
831811
"""
832812

833813
dpnp.check_supported_arrays_type(a)
834-
a_ndim = a.ndim
835-
a_shape = a.shape
836-
_axis = range(a_ndim) if axis is None else axis
837-
_axis = normalize_axis_tuple(_axis, a_ndim)
838-
839-
if isinstance(axis, (tuple, list)):
840-
if len(axis) == 1:
841-
axis = axis[0]
842-
else:
843-
# Need to flatten if `axis` is a sequence of axes since `dpnp.sort`
844-
# only accepts integer `axis`
845-
# Note that the output of _flatten_array_along_axes is not
846-
# necessarily a view of the input since `reshape` is used there.
847-
# If this is the case, using overwrite_input is meaningless
848-
a = _flatten_array_along_axes(a, _axis)
849-
axis = 0
850-
851-
if overwrite_input:
852-
if axis is None:
853-
a_sorted = dpnp.ravel(a)
854-
a_sorted.sort()
855-
else:
856-
if isinstance(a, dpt.usm_ndarray):
857-
# dpnp.ndarray.sort only works with dpnp_array
858-
a = dpnp_array._create_from_usm_ndarray(a)
859-
a.sort(axis=axis)
860-
a_sorted = a
861-
else:
862-
a_sorted = dpnp.sort(a, axis=axis)
863-
864-
if axis is None:
865-
axis = 0
866-
indexer = [slice(None)] * a_sorted.ndim
867-
index, remainder = divmod(a_sorted.shape[axis], 2)
868-
if remainder == 1:
869-
# index with slice to allow mean (below) to work
870-
indexer[axis] = slice(index, index + 1)
871-
else:
872-
indexer[axis] = slice(index - 1, index + 1)
873-
874-
# Use `mean` in odd and even case to coerce data type and use `out` array
875-
res = dpnp.mean(a_sorted[tuple(indexer)], axis=axis, out=out)
876-
nan_mask = dpnp.isnan(a_sorted).any(axis=axis)
877-
if nan_mask.any():
878-
res[nan_mask] = dpnp.nan
879-
880-
if keepdims:
881-
# We can't use dpnp.mean(..., keepdims) and dpnp.any(..., keepdims)
882-
# above because of the reshape hack might have been used in
883-
# `_flatten_array_along_axes` to handle cases when axis is a tuple.
884-
res_shape = list(a_shape)
885-
for i in _axis:
886-
res_shape[i] = 1
887-
res = res.reshape(tuple(res_shape))
888-
889-
return res
814+
return dpnp_median(
815+
a, axis, out, overwrite_input, keepdims, ignore_nan=False
816+
)
890817

891818

892819
def min(a, axis=None, out=None, keepdims=False, initial=None, where=True):

0 commit comments

Comments
 (0)