14
14
Sequence ,
15
15
cast ,
16
16
final ,
17
+ overload ,
17
18
)
18
19
import warnings
19
20
101
102
Categorical ,
102
103
DataFrame ,
103
104
Index ,
105
+ MultiIndex ,
104
106
Series ,
105
107
)
106
108
from pandas .core .arrays import (
@@ -1792,7 +1794,7 @@ def safe_sort(
1792
1794
na_sentinel : int = - 1 ,
1793
1795
assume_unique : bool = False ,
1794
1796
verify : bool = True ,
1795
- ) -> np .ndarray | tuple [np .ndarray , np .ndarray ]:
1797
+ ) -> np .ndarray | MultiIndex | tuple [np .ndarray | MultiIndex , np .ndarray ]:
1796
1798
"""
1797
1799
Sort ``values`` and reorder corresponding ``codes``.
1798
1800
@@ -1821,7 +1823,7 @@ def safe_sort(
1821
1823
1822
1824
Returns
1823
1825
-------
1824
- ordered : ndarray
1826
+ ordered : ndarray or MultiIndex
1825
1827
Sorted ``values``
1826
1828
new_codes : ndarray
1827
1829
Reordered ``codes``; returned when ``codes`` is not None.
@@ -1839,6 +1841,7 @@ def safe_sort(
1839
1841
raise TypeError (
1840
1842
"Only list-like objects are allowed to be passed to safe_sort as values"
1841
1843
)
1844
+ original_values = values
1842
1845
1843
1846
if not isinstance (values , (np .ndarray , ABCExtensionArray )):
1844
1847
# don't convert to string types
@@ -1850,6 +1853,7 @@ def safe_sort(
1850
1853
values = np .asarray (values , dtype = dtype ) # type: ignore[arg-type]
1851
1854
1852
1855
sorter = None
1856
+ ordered : np .ndarray | MultiIndex
1853
1857
1854
1858
if (
1855
1859
not is_extension_array_dtype (values )
@@ -1865,7 +1869,7 @@ def safe_sort(
1865
1869
# which would work, but which fails for special case of 1d arrays
1866
1870
# with tuples.
1867
1871
if values .size and isinstance (values [0 ], tuple ):
1868
- ordered = _sort_tuples (values )
1872
+ ordered = _sort_tuples (values , original_values )
1869
1873
else :
1870
1874
ordered = _sort_mixed (values )
1871
1875
@@ -1927,19 +1931,33 @@ def _sort_mixed(values) -> np.ndarray:
1927
1931
)
1928
1932
1929
1933
1930
- def _sort_tuples (values : np .ndarray ) -> np .ndarray :
1934
+ @overload
1935
+ def _sort_tuples (values : np .ndarray , original_values : np .ndarray ) -> np .ndarray :
1936
+ ...
1937
+
1938
+
1939
+ @overload
1940
+ def _sort_tuples (values : np .ndarray , original_values : MultiIndex ) -> MultiIndex :
1941
+ ...
1942
+
1943
+
1944
+ def _sort_tuples (
1945
+ values : np .ndarray , original_values : np .ndarray | MultiIndex
1946
+ ) -> np .ndarray | MultiIndex :
1931
1947
"""
1932
1948
Convert array of tuples (1d) to array or array (2d).
1933
1949
We need to keep the columns separately as they contain different types and
1934
1950
nans (can't use `np.sort` as it may fail when str and nan are mixed in a
1935
1951
column as types cannot be compared).
1952
+ We have to apply the indexer to the original values to keep the dtypes in
1953
+ case of MultiIndexes
1936
1954
"""
1937
1955
from pandas .core .internals .construction import to_arrays
1938
1956
from pandas .core .sorting import lexsort_indexer
1939
1957
1940
1958
arrays , _ = to_arrays (values , None )
1941
1959
indexer = lexsort_indexer (arrays , orders = True )
1942
- return values [indexer ]
1960
+ return original_values [indexer ]
1943
1961
1944
1962
1945
1963
def union_with_duplicates (lvals : ArrayLike , rvals : ArrayLike ) -> ArrayLike :
0 commit comments