Skip to content

Commit 23c0339

Browse files
phoflmeeseeksmachine
authored andcommitted
Backport PR pandas-dev#48412: BUG: safe_sort losing MultiIndex dtypes
1 parent 5a9db39 commit 23c0339

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

pandas/core/algorithms.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Sequence,
1515
cast,
1616
final,
17+
overload,
1718
)
1819
import warnings
1920

@@ -101,6 +102,7 @@
101102
Categorical,
102103
DataFrame,
103104
Index,
105+
MultiIndex,
104106
Series,
105107
)
106108
from pandas.core.arrays import (
@@ -1792,7 +1794,7 @@ def safe_sort(
17921794
na_sentinel: int = -1,
17931795
assume_unique: bool = False,
17941796
verify: bool = True,
1795-
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
1797+
) -> np.ndarray | MultiIndex | tuple[np.ndarray | MultiIndex, np.ndarray]:
17961798
"""
17971799
Sort ``values`` and reorder corresponding ``codes``.
17981800
@@ -1821,7 +1823,7 @@ def safe_sort(
18211823
18221824
Returns
18231825
-------
1824-
ordered : ndarray
1826+
ordered : ndarray or MultiIndex
18251827
Sorted ``values``
18261828
new_codes : ndarray
18271829
Reordered ``codes``; returned when ``codes`` is not None.
@@ -1839,6 +1841,7 @@ def safe_sort(
18391841
raise TypeError(
18401842
"Only list-like objects are allowed to be passed to safe_sort as values"
18411843
)
1844+
original_values = values
18421845

18431846
if not isinstance(values, (np.ndarray, ABCExtensionArray)):
18441847
# don't convert to string types
@@ -1850,6 +1853,7 @@ def safe_sort(
18501853
values = np.asarray(values, dtype=dtype) # type: ignore[arg-type]
18511854

18521855
sorter = None
1856+
ordered: np.ndarray | MultiIndex
18531857

18541858
if (
18551859
not is_extension_array_dtype(values)
@@ -1865,7 +1869,7 @@ def safe_sort(
18651869
# which would work, but which fails for special case of 1d arrays
18661870
# with tuples.
18671871
if values.size and isinstance(values[0], tuple):
1868-
ordered = _sort_tuples(values)
1872+
ordered = _sort_tuples(values, original_values)
18691873
else:
18701874
ordered = _sort_mixed(values)
18711875

@@ -1927,19 +1931,33 @@ def _sort_mixed(values) -> np.ndarray:
19271931
)
19281932

19291933

1930-
def _sort_tuples(values: np.ndarray) -> np.ndarray:
1934+
@overload
1935+
def _sort_tuples(values: np.ndarray, original_values: np.ndarray) -> np.ndarray:
1936+
...
1937+
1938+
1939+
@overload
1940+
def _sort_tuples(values: np.ndarray, original_values: MultiIndex) -> MultiIndex:
1941+
...
1942+
1943+
1944+
def _sort_tuples(
1945+
values: np.ndarray, original_values: np.ndarray | MultiIndex
1946+
) -> np.ndarray | MultiIndex:
19311947
"""
19321948
Convert array of tuples (1d) to array or array (2d).
19331949
We need to keep the columns separately as they contain different types and
19341950
nans (can't use `np.sort` as it may fail when str and nan are mixed in a
19351951
column as types cannot be compared).
1952+
We have to apply the indexer to the original values to keep the dtypes in
1953+
case of MultiIndexes
19361954
"""
19371955
from pandas.core.internals.construction import to_arrays
19381956
from pandas.core.sorting import lexsort_indexer
19391957

19401958
arrays, _ = to_arrays(values, None)
19411959
indexer = lexsort_indexer(arrays, orders=True)
1942-
return values[indexer]
1960+
return original_values[indexer]
19431961

19441962

19451963
def union_with_duplicates(lvals: ArrayLike, rvals: ArrayLike) -> ArrayLike:

pandas/tests/test_sorting.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212

1313
from pandas import (
14+
NA,
1415
DataFrame,
1516
MultiIndex,
1617
Series,
@@ -510,3 +511,15 @@ def test_mixed_str_nan():
510511
result = safe_sort(values)
511512
expected = np.array([np.nan, "a", "b", "b"], dtype=object)
512513
tm.assert_numpy_array_equal(result, expected)
514+
515+
516+
def test_safe_sort_multiindex():
517+
# GH#48412
518+
arr1 = Series([2, 1, NA, NA], dtype="Int64")
519+
arr2 = [2, 1, 3, 3]
520+
midx = MultiIndex.from_arrays([arr1, arr2])
521+
result = safe_sort(midx)
522+
expected = MultiIndex.from_arrays(
523+
[Series([1, 2, NA, NA], dtype="Int64"), [1, 2, 3, 3]]
524+
)
525+
tm.assert_index_equal(result, expected)

0 commit comments

Comments
 (0)