Skip to content

Commit ad8a4ea

Browse files
REF: move _str_extract function in accessor.py to array method (#41663)
1 parent 104769c commit ad8a4ea

File tree

6 files changed

+52
-51
lines changed

6 files changed

+52
-51
lines changed

pandas/core/arrays/categorical.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2453,7 +2453,9 @@ def replace(self, to_replace, value, inplace: bool = False):
24532453

24542454
# ------------------------------------------------------------------------
24552455
# String methods interface
2456-
def _str_map(self, f, na_value=np.nan, dtype=np.dtype("object")):
2456+
def _str_map(
2457+
self, f, na_value=np.nan, dtype=np.dtype("object"), convert: bool = True
2458+
):
24572459
# Optimization to apply the callable `f` to the categories once
24582460
# and rebuild the result by `take`ing from the result with the codes.
24592461
# Returns the same type as the object-dtype implementation though.

pandas/core/arrays/string_.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,9 @@ def _cmp_method(self, other, op):
410410
# String methods interface
411411
_str_na_value = StringDtype.na_value
412412

413-
def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
413+
def _str_map(
414+
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
415+
):
414416
from pandas.arrays import BooleanArray
415417

416418
if dtype is None:

pandas/core/arrays/string_arrow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,9 @@ def value_counts(self, dropna: bool = True) -> Series:
742742

743743
_str_na_value = ArrowStringDtype.na_value
744744

745-
def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
745+
def _str_map(
746+
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
747+
):
746748
# TODO: de-duplicate with StringArray method. This method is moreless copy and
747749
# paste.
748750

pandas/core/strings/accessor.py

Lines changed: 6 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
import numpy as np
1414

1515
import pandas._libs.lib as lib
16-
from pandas._typing import (
17-
ArrayLike,
18-
FrameOrSeriesUnion,
19-
)
16+
from pandas._typing import FrameOrSeriesUnion
2017
from pandas.util._decorators import Appender
2118

2219
from pandas.core.dtypes.common import (
@@ -160,7 +157,6 @@ class StringMethods(NoNewAttributesMixin):
160157
# TODO: Dispatch all the methods
161158
# Currently the following are not dispatched to the array
162159
# * cat
163-
# * extract
164160
# * extractall
165161

166162
def __init__(self, data):
@@ -243,7 +239,7 @@ def _wrap_result(
243239
self,
244240
result,
245241
name=None,
246-
expand=None,
242+
expand: bool | None = None,
247243
fill_value=np.nan,
248244
returns_string=True,
249245
):
@@ -2385,10 +2381,7 @@ def extract(
23852381
2 NaN
23862382
dtype: object
23872383
"""
2388-
from pandas import (
2389-
DataFrame,
2390-
array as pd_array,
2391-
)
2384+
from pandas import DataFrame
23922385

23932386
if not isinstance(expand, bool):
23942387
raise ValueError("expand must be True or False")
@@ -2400,8 +2393,6 @@ def extract(
24002393
if not expand and regex.groups > 1 and isinstance(self._data, ABCIndex):
24012394
raise ValueError("only one regex group is supported with Index")
24022395

2403-
# TODO: dispatch
2404-
24052396
obj = self._data
24062397
result_dtype = _result_dtype(obj)
24072398

@@ -2415,8 +2406,8 @@ def extract(
24152406
result = DataFrame(columns=columns, dtype=result_dtype)
24162407

24172408
else:
2418-
result_list = _str_extract(
2419-
obj.array, pat, flags=flags, expand=returns_df
2409+
result_list = self._data.array._str_extract(
2410+
pat, flags=flags, expand=returns_df
24202411
)
24212412

24222413
result_index: Index | None
@@ -2431,9 +2422,7 @@ def extract(
24312422

24322423
else:
24332424
name = _get_single_group_name(regex)
2434-
result_arr = _str_extract(obj.array, pat, flags=flags, expand=returns_df)
2435-
# not dispatching, so we have to reconstruct here.
2436-
result = pd_array(result_arr, dtype=result_dtype)
2425+
result = self._data.array._str_extract(pat, flags=flags, expand=returns_df)
24372426
return self._wrap_result(result, name=name)
24382427

24392428
@forbid_nonstring_types(["bytes"])
@@ -3121,33 +3110,6 @@ def _get_group_names(regex: re.Pattern) -> list[Hashable]:
31213110
return [names.get(1 + i, i) for i in range(regex.groups)]
31223111

31233112

3124-
def _str_extract(arr: ArrayLike, pat: str, flags=0, expand: bool = True):
3125-
"""
3126-
Find groups in each string in the array using passed regular expression.
3127-
3128-
Returns
3129-
-------
3130-
np.ndarray or list of lists is expand is True
3131-
"""
3132-
regex = re.compile(pat, flags=flags)
3133-
3134-
empty_row = [np.nan] * regex.groups
3135-
3136-
def f(x):
3137-
if not isinstance(x, str):
3138-
return empty_row
3139-
m = regex.search(x)
3140-
if m:
3141-
return [np.nan if item is None else item for item in m.groups()]
3142-
else:
3143-
return empty_row
3144-
3145-
if expand:
3146-
return [f(val) for val in np.asarray(arr)]
3147-
3148-
return np.array([f(val)[0] for val in np.asarray(arr)], dtype=object)
3149-
3150-
31513113
def str_extractall(arr, pat, flags=0):
31523114
regex = re.compile(pat, flags=flags)
31533115
# the regex must contain capture groups.

pandas/core/strings/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,7 @@ def _str_split(self, pat=None, n=-1, expand=False):
230230
@abc.abstractmethod
231231
def _str_rsplit(self, pat=None, n=-1):
232232
pass
233+
234+
@abc.abstractmethod
235+
def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
236+
pass

pandas/core/strings/object_array.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def __len__(self):
3232
# For typing, _str_map relies on the object being sized.
3333
raise NotImplementedError
3434

35-
def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
35+
def _str_map(
36+
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
37+
):
3638
"""
3739
Map a callable over valid element of the array.
3840
@@ -47,6 +49,8 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
4749
for object-dtype and Categorical and ``pd.NA`` for StringArray.
4850
dtype : Dtype, optional
4951
The dtype of the result array.
52+
convert : bool, default True
53+
Whether to call `maybe_convert_objects` on the resulting ndarray
5054
"""
5155
if dtype is None:
5256
dtype = np.dtype("object")
@@ -60,9 +64,9 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
6064

6165
arr = np.asarray(self, dtype=object)
6266
mask = isna(arr)
63-
convert = not np.all(mask)
67+
map_convert = convert and not np.all(mask)
6468
try:
65-
result = lib.map_infer_mask(arr, f, mask.view(np.uint8), convert)
69+
result = lib.map_infer_mask(arr, f, mask.view(np.uint8), map_convert)
6670
except (TypeError, AttributeError) as e:
6771
# Reraise the exception if callable `f` got wrong number of args.
6872
# The user may want to be warned by this, instead of getting NaN
@@ -88,7 +92,7 @@ def g(x):
8892
return result
8993
if na_value is not np.nan:
9094
np.putmask(result, mask, na_value)
91-
if result.dtype == object:
95+
if convert and result.dtype == object:
9296
result = lib.maybe_convert_objects(result)
9397
return result
9498

@@ -410,3 +414,28 @@ def _str_lstrip(self, to_strip=None):
410414

411415
def _str_rstrip(self, to_strip=None):
412416
return self._str_map(lambda x: x.rstrip(to_strip))
417+
418+
def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
419+
regex = re.compile(pat, flags=flags)
420+
na_value = self._str_na_value
421+
422+
if not expand:
423+
424+
def g(x):
425+
m = regex.search(x)
426+
return m.groups()[0] if m else na_value
427+
428+
return self._str_map(g, convert=False)
429+
430+
empty_row = [na_value] * regex.groups
431+
432+
def f(x):
433+
if not isinstance(x, str):
434+
return empty_row
435+
m = regex.search(x)
436+
if m:
437+
return [na_value if item is None else item for item in m.groups()]
438+
else:
439+
return empty_row
440+
441+
return [f(val) for val in np.asarray(self)]

0 commit comments

Comments
 (0)