Skip to content

Commit 30a87ca

Browse files
authored
fix Series.split with expand=True (#199)
* fix Series.split with expand=True * align asterisk in split params
1 parent 3d24a9a commit 30a87ca

File tree

5 files changed

+34
-6
lines changed

5 files changed

+34
-6
lines changed

pandas-stubs/core/indexes/base.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ from typing import (
1212
import numpy as np
1313
from pandas import (
1414
DataFrame,
15+
MultiIndex,
1516
Series,
1617
)
1718
from pandas.core.arrays import ExtensionArray
@@ -58,7 +59,7 @@ class Index(IndexOpsMixin, PandasObject):
5859
tupleize_cols: bool = ...,
5960
): ...
6061
@property
61-
def str(self) -> StringMethods[Index]: ...
62+
def str(self) -> StringMethods[Index, MultiIndex]: ...
6263
@property
6364
def asi8(self) -> np_ndarray_int64: ...
6465
def is_(self, other) -> bool: ...

pandas-stubs/core/series.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
854854
) -> Series[S1]: ...
855855
def to_period(self, freq: _str | None = ..., copy: _bool = ...) -> DataFrame: ...
856856
@property
857-
def str(self) -> StringMethods[Series]: ...
857+
def str(self) -> StringMethods[Series, DataFrame]: ...
858858
@property
859859
def dt(self) -> CombinedDatetimelikeProperties: ...
860860
@property

pandas-stubs/core/strings.pyi

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,25 @@ from typing import (
55
Generic,
66
Literal,
77
Sequence,
8+
TypeVar,
89
overload,
910
)
1011

1112
import numpy as np
1213
import pandas as pd
13-
from pandas import Series
14+
from pandas import (
15+
DataFrame,
16+
MultiIndex,
17+
Series,
18+
)
1419
from pandas.core.base import NoNewAttributesMixin
1520

1621
from pandas._typing import T
1722

18-
class StringMethods(NoNewAttributesMixin, Generic[T]):
23+
# The _TS type is what is used for the result of str.split with expand=True
24+
_TS = TypeVar("_TS", DataFrame, MultiIndex)
25+
26+
class StringMethods(NoNewAttributesMixin, Generic[T, _TS]):
1927
def __init__(self, data: T) -> None: ...
2028
def __getitem__(self, key: slice | int) -> T: ...
2129
def __iter__(self) -> T: ...
@@ -44,11 +52,21 @@ class StringMethods(NoNewAttributesMixin, Generic[T]):
4452
na_rep: str | None = ...,
4553
join: Literal["left", "right", "outer", "inner"] = ...,
4654
) -> T: ...
55+
@overload
56+
def split(
57+
self, pat: str = ..., n: int = ..., *, expand: Literal[True], regex: bool = ...
58+
) -> _TS: ...
59+
@overload
4760
def split(
48-
self, pat: str = ..., n: int = ..., expand: bool = ..., *, regex: bool = ...
61+
self, pat: str = ..., n: int = ..., *, expand: bool = ..., regex: bool = ...
4962
) -> T: ...
63+
@overload
64+
def rsplit(
65+
self, pat: str = ..., n: int = ..., *, expand: Literal[True], regex: bool = ...
66+
) -> T: ...
67+
@overload
5068
def rsplit(
51-
self, pat: str = ..., n: int = ..., expand: bool = ..., *, regex: bool = ...
69+
self, pat: str = ..., n: int = ..., *, expand: bool = ..., regex: bool = ...
5270
) -> T: ...
5371
@overload
5472
def partition(self, sep: str = ...) -> pd.DataFrame: ...

tests/test_indexes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,10 @@ def test_difference_none() -> None:
6868
# https://github.com/pandas-dev/pandas-stubs/issues/17
6969
ind = pd.Index([1, 2, 3])
7070
check(assert_type(ind.difference([1, None]), "pd.Index"), pd.Index, int)
71+
72+
73+
def test_str_split() -> None:
74+
# GH 194
75+
ind = pd.Index(["a-b", "c-d"])
76+
check(assert_type(ind.str.split("-"), pd.Index), pd.Index)
77+
check(assert_type(ind.str.split("-", expand=True), pd.MultiIndex), pd.MultiIndex)

tests/test_series.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,8 @@ def test_string_accessors():
977977
check(assert_type(s.str.slice(0, 4, 2), pd.Series), pd.Series)
978978
check(assert_type(s.str.slice_replace(0, 2, "XX"), pd.Series), pd.Series)
979979
check(assert_type(s.str.split("a"), pd.Series), pd.Series)
980+
# GH 194
981+
check(assert_type(s.str.split("a", expand=True), pd.DataFrame), pd.DataFrame)
980982
check(assert_type(s.str.startswith("a"), "pd.Series[bool]"), pd.Series, bool)
981983
check(assert_type(s.str.strip(), pd.Series), pd.Series)
982984
check(assert_type(s.str.swapcase(), pd.Series), pd.Series)

0 commit comments

Comments
 (0)