Skip to content

Commit 5933c60

Browse files
authored
Backport PR #55362 on branch 2.1.x (BUG: rank raising for arrow string dtypes) (#55406)
1 parent 2a32088 commit 5933c60

File tree

4 files changed

+68
-6
lines changed

4 files changed

+68
-6
lines changed

doc/source/whatsnew/v2.1.2.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Bug fixes
2727
- Fixed bug in :meth:`DataFrame.interpolate` raising incorrect error message (:issue:`55347`)
2828
- Fixed bug in :meth:`Index.insert` raising when inserting ``None`` into :class:`Index` with ``dtype="string[pyarrow_numpy]"`` (:issue:`55365`)
2929
- Fixed bug in :meth:`Series.all` and :meth:`Series.any` not treating missing values correctly for ``dtype="string[pyarrow_numpy]"`` (:issue:`55367`)
30+
- Fixed bug in :meth:`Series.rank` for ``string[pyarrow_numpy]`` dtype (:issue:`55362`)
3031
- Silence ``Period[B]`` warnings introduced by :issue:`53446` during normal plotting activity (:issue:`55138`)
3132
-
3233

pandas/core/arrays/arrow/array.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,7 +1712,7 @@ def __setitem__(self, key, value) -> None:
17121712
data = pa.chunked_array([data])
17131713
self._pa_array = data
17141714

1715-
def _rank(
1715+
def _rank_calc(
17161716
self,
17171717
*,
17181718
axis: AxisInt = 0,
@@ -1721,9 +1721,6 @@ def _rank(
17211721
ascending: bool = True,
17221722
pct: bool = False,
17231723
):
1724-
"""
1725-
See Series.rank.__doc__.
1726-
"""
17271724
if pa_version_under9p0 or axis != 0:
17281725
ranked = super()._rank(
17291726
axis=axis,
@@ -1738,7 +1735,7 @@ def _rank(
17381735
else:
17391736
pa_type = pa.uint64()
17401737
result = pa.array(ranked, type=pa_type, from_pandas=True)
1741-
return type(self)(result)
1738+
return result
17421739

17431740
data = self._pa_array.combine_chunks()
17441741
sort_keys = "ascending" if ascending else "descending"
@@ -1777,7 +1774,29 @@ def _rank(
17771774
divisor = pc.count(result)
17781775
result = pc.divide(result, divisor)
17791776

1780-
return type(self)(result)
1777+
return result
1778+
1779+
def _rank(
1780+
self,
1781+
*,
1782+
axis: AxisInt = 0,
1783+
method: str = "average",
1784+
na_option: str = "keep",
1785+
ascending: bool = True,
1786+
pct: bool = False,
1787+
):
1788+
"""
1789+
See Series.rank.__doc__.
1790+
"""
1791+
return type(self)(
1792+
self._rank_calc(
1793+
axis=axis,
1794+
method=method,
1795+
na_option=na_option,
1796+
ascending=ascending,
1797+
pct=pct,
1798+
)
1799+
)
17811800

17821801
def _quantile(self, qs: npt.NDArray[np.float64], interpolation: str) -> Self:
17831802
"""

pandas/core/arrays/string_arrow.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
if TYPE_CHECKING:
5050
from pandas._typing import (
51+
AxisInt,
5152
Dtype,
5253
Scalar,
5354
npt,
@@ -444,6 +445,31 @@ def _str_rstrip(self, to_strip=None):
444445
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
445446
return type(self)(result)
446447

448+
def _convert_int_dtype(self, result):
449+
return Int64Dtype().__from_arrow__(result)
450+
451+
def _rank(
452+
self,
453+
*,
454+
axis: AxisInt = 0,
455+
method: str = "average",
456+
na_option: str = "keep",
457+
ascending: bool = True,
458+
pct: bool = False,
459+
):
460+
"""
461+
See Series.rank.__doc__.
462+
"""
463+
return self._convert_int_dtype(
464+
self._rank_calc(
465+
axis=axis,
466+
method=method,
467+
na_option=na_option,
468+
ascending=ascending,
469+
pct=pct,
470+
)
471+
)
472+
447473

448474
class ArrowStringArrayNumpySemantics(ArrowStringArray):
449475
_storage = "pyarrow_numpy"
@@ -527,6 +553,10 @@ def _str_map(
527553
return lib.map_infer_mask(arr, f, mask.view("uint8"))
528554

529555
def _convert_int_dtype(self, result):
556+
if isinstance(result, pa.Array):
557+
result = result.to_numpy(zero_copy_only=False)
558+
elif not isinstance(result, np.ndarray):
559+
result = result.to_numpy()
530560
if result.dtype == np.int32:
531561
result = result.astype(np.int64)
532562
return result

pandas/tests/frame/methods/test_rank.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,3 +488,15 @@ def test_rank_mixed_axis_zero(self, data, expected):
488488
df.rank()
489489
result = df.rank(numeric_only=True)
490490
tm.assert_frame_equal(result, expected)
491+
492+
@pytest.mark.parametrize(
493+
"dtype, exp_dtype",
494+
[("string[pyarrow]", "Int64"), ("string[pyarrow_numpy]", "float64")],
495+
)
496+
def test_rank_string_dtype(self, dtype, exp_dtype):
497+
# GH#55362
498+
pytest.importorskip("pyarrow")
499+
obj = Series(["foo", "foo", None, "foo"], dtype=dtype)
500+
result = obj.rank(method="first")
501+
expected = Series([1, 2, None, 3], dtype=exp_dtype)
502+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)