39
39
40
40
import dpctl .tensor as dpt
41
41
import numpy
42
- from dpctl .tensor ._numpy_helper import (
43
- normalize_axis_index ,
44
- normalize_axis_tuple ,
45
- )
42
+ from dpctl .tensor ._numpy_helper import normalize_axis_index
46
43
47
44
import dpnp
48
45
49
46
# pylint: disable=no-name-in-module
50
47
from .dpnp_algo import dpnp_correlate
51
- from .dpnp_array import dpnp_array
52
48
from .dpnp_utils import call_origin , get_usm_allocations
53
49
from .dpnp_utils .dpnp_utils_reduction import dpnp_wrap_reduction_call
54
- from .dpnp_utils .dpnp_utils_statistics import dpnp_cov
50
+ from .dpnp_utils .dpnp_utils_statistics import dpnp_cov , dpnp_median
55
51
56
52
__all__ = [
57
53
"amax" ,
@@ -113,22 +109,6 @@ def _count_reduce_items(arr, axis, where=True):
113
109
return items
114
110
115
111
116
- def _flatten_array_along_axes (arr , axes_to_flatten ):
117
- """Flatten an array along a specific set of axes."""
118
-
119
- axes_to_keep = (
120
- axis for axis in range (arr .ndim ) if axis not in axes_to_flatten
121
- )
122
-
123
- # Move the axes_to_flatten to the front
124
- arr_moved = dpnp .moveaxis (arr , axes_to_flatten , range (len (axes_to_flatten )))
125
-
126
- new_shape = (- 1 ,) + tuple (arr .shape [axis ] for axis in axes_to_keep )
127
- flattened_arr = arr_moved .reshape (new_shape )
128
-
129
- return flattened_arr
130
-
131
-
132
112
def _get_comparison_res_dt (a , _dtype , _out ):
133
113
"""Get a data type used by dpctl for result array in comparison function."""
134
114
@@ -765,7 +745,7 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
765
745
preserve the contents of the input array. Treat the input as undefined,
766
746
but it will probably be fully or partially sorted.
767
747
Default: ``False``.
768
- keepdims : {None, bool} , optional
748
+ keepdims : bool, optional
769
749
If ``True``, the reduced axes (dimensions) are included in the result
770
750
as singleton dimensions, so that the returned array remains
771
751
compatible with the input array according to Array Broadcasting
@@ -775,7 +755,7 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
775
755
776
756
Returns
777
757
-------
778
- dpnp.median : dpnp.ndarray
758
+ out : dpnp.ndarray
779
759
A new array holding the result. If `a` has a floating-point data type,
780
760
the returned array will have the same data type as `a`. If `a` has a
781
761
boolean or integral data type, the returned array will have the
@@ -808,20 +788,20 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
808
788
>>> np.median(a, axis=0)
809
789
array([6.5, 4.5, 2.5])
810
790
>>> np.median(a, axis=1)
811
- array([7., 2.])
791
+ array([7., 2.])
812
792
>>> np.median(a, axis=(0, 1))
813
793
array(3.5)
814
794
815
795
>>> m = np.median(a, axis=0)
816
796
>>> out = np.zeros_like(m)
817
797
>>> np.median(a, axis=0, out=m)
818
- array([6.5, 4.5, 2.5])
798
+ array([6.5, 4.5, 2.5])
819
799
>>> m
820
- array([6.5, 4.5, 2.5])
800
+ array([6.5, 4.5, 2.5])
821
801
822
802
>>> b = a.copy()
823
803
>>> np.median(b, axis=1, overwrite_input=True)
824
- array([7., 2.])
804
+ array([7., 2.])
825
805
>>> assert not np.all(a==b)
826
806
>>> b = a.copy()
827
807
>>> np.median(b, axis=None, overwrite_input=True)
@@ -831,62 +811,9 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
831
811
"""
832
812
833
813
dpnp .check_supported_arrays_type (a )
834
- a_ndim = a .ndim
835
- a_shape = a .shape
836
- _axis = range (a_ndim ) if axis is None else axis
837
- _axis = normalize_axis_tuple (_axis , a_ndim )
838
-
839
- if isinstance (axis , (tuple , list )):
840
- if len (axis ) == 1 :
841
- axis = axis [0 ]
842
- else :
843
- # Need to flatten if `axis` is a sequence of axes since `dpnp.sort`
844
- # only accepts integer `axis`
845
- # Note that the output of _flatten_array_along_axes is not
846
- # necessarily a view of the input since `reshape` is used there.
847
- # If this is the case, using overwrite_input is meaningless
848
- a = _flatten_array_along_axes (a , _axis )
849
- axis = 0
850
-
851
- if overwrite_input :
852
- if axis is None :
853
- a_sorted = dpnp .ravel (a )
854
- a_sorted .sort ()
855
- else :
856
- if isinstance (a , dpt .usm_ndarray ):
857
- # dpnp.ndarray.sort only works with dpnp_array
858
- a = dpnp_array ._create_from_usm_ndarray (a )
859
- a .sort (axis = axis )
860
- a_sorted = a
861
- else :
862
- a_sorted = dpnp .sort (a , axis = axis )
863
-
864
- if axis is None :
865
- axis = 0
866
- indexer = [slice (None )] * a_sorted .ndim
867
- index , remainder = divmod (a_sorted .shape [axis ], 2 )
868
- if remainder == 1 :
869
- # index with slice to allow mean (below) to work
870
- indexer [axis ] = slice (index , index + 1 )
871
- else :
872
- indexer [axis ] = slice (index - 1 , index + 1 )
873
-
874
- # Use `mean` in odd and even case to coerce data type and use `out` array
875
- res = dpnp .mean (a_sorted [tuple (indexer )], axis = axis , out = out )
876
- nan_mask = dpnp .isnan (a_sorted ).any (axis = axis )
877
- if nan_mask .any ():
878
- res [nan_mask ] = dpnp .nan
879
-
880
- if keepdims :
881
- # We can't use dpnp.mean(..., keepdims) and dpnp.any(..., keepdims)
882
- # above because of the reshape hack might have been used in
883
- # `_flatten_array_along_axes` to handle cases when axis is a tuple.
884
- res_shape = list (a_shape )
885
- for i in _axis :
886
- res_shape [i ] = 1
887
- res = res .reshape (tuple (res_shape ))
888
-
889
- return res
814
+ return dpnp_median (
815
+ a , axis , out , overwrite_input , keepdims , ignore_nan = False
816
+ )
890
817
891
818
892
819
def min (a , axis = None , out = None , keepdims = False , initial = None , where = True ):
0 commit comments