Skip to content

Commit d0d8e10

Browse files
[ArrowStringArray] CLN: assorted cleanup (#41306)
1 parent 2331098 commit d0d8e10

File tree

6 files changed

+60
-114
lines changed

6 files changed

+60
-114
lines changed

pandas/core/arrays/string_arrow.py

Lines changed: 24 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
Sequence,
99
cast,
1010
)
11-
import warnings
1211

1312
import numpy as np
1413

@@ -767,20 +766,13 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
767766
# -> We don't know the result type. E.g. `.get` can return anything.
768767
return lib.map_infer_mask(arr, f, mask.view("uint8"))
769768

770-
def _str_contains(self, pat, case=True, flags=0, na=np.nan, regex=True):
769+
def _str_contains(self, pat, case=True, flags=0, na=np.nan, regex: bool = True):
771770
if flags:
772771
return super()._str_contains(pat, case, flags, na, regex)
773772

774773
if regex:
775774
# match_substring_regex added in pyarrow 4.0.0
776775
if hasattr(pc, "match_substring_regex") and case:
777-
if re.compile(pat).groups:
778-
warnings.warn(
779-
"This pattern has match groups. To actually get the "
780-
"groups, use str.extract.",
781-
UserWarning,
782-
stacklevel=3,
783-
)
784776
result = pc.match_substring_regex(self._data, pat)
785777
else:
786778
return super()._str_contains(pat, case, flags, na, regex)
@@ -817,67 +809,44 @@ def _str_endswith(self, pat, na=None):
817809
return super()._str_endswith(pat, na)
818810

819811
def _str_isalnum(self):
820-
if hasattr(pc, "utf8_is_alnum"):
821-
result = pc.utf8_is_alnum(self._data)
822-
return BooleanDtype().__from_arrow__(result)
823-
else:
824-
return super()._str_isalnum()
812+
result = pc.utf8_is_alnum(self._data)
813+
return BooleanDtype().__from_arrow__(result)
825814

826815
def _str_isalpha(self):
827-
if hasattr(pc, "utf8_is_alpha"):
828-
result = pc.utf8_is_alpha(self._data)
829-
return BooleanDtype().__from_arrow__(result)
830-
else:
831-
return super()._str_isalpha()
816+
result = pc.utf8_is_alpha(self._data)
817+
return BooleanDtype().__from_arrow__(result)
832818

833819
def _str_isdecimal(self):
834-
if hasattr(pc, "utf8_is_decimal"):
835-
result = pc.utf8_is_decimal(self._data)
836-
return BooleanDtype().__from_arrow__(result)
837-
else:
838-
return super()._str_isdecimal()
820+
result = pc.utf8_is_decimal(self._data)
821+
return BooleanDtype().__from_arrow__(result)
839822

840823
def _str_isdigit(self):
841-
if hasattr(pc, "utf8_is_digit"):
842-
result = pc.utf8_is_digit(self._data)
843-
return BooleanDtype().__from_arrow__(result)
844-
else:
845-
return super()._str_isdigit()
824+
result = pc.utf8_is_digit(self._data)
825+
return BooleanDtype().__from_arrow__(result)
846826

847827
def _str_islower(self):
848-
if hasattr(pc, "utf8_is_lower"):
849-
result = pc.utf8_is_lower(self._data)
850-
return BooleanDtype().__from_arrow__(result)
851-
else:
852-
return super()._str_islower()
828+
result = pc.utf8_is_lower(self._data)
829+
return BooleanDtype().__from_arrow__(result)
853830

854831
def _str_isnumeric(self):
855-
if hasattr(pc, "utf8_is_numeric"):
856-
result = pc.utf8_is_numeric(self._data)
857-
return BooleanDtype().__from_arrow__(result)
858-
else:
859-
return super()._str_isnumeric()
832+
result = pc.utf8_is_numeric(self._data)
833+
return BooleanDtype().__from_arrow__(result)
860834

861835
def _str_isspace(self):
836+
# utf8_is_space added in pyarrow 2.0.0
862837
if hasattr(pc, "utf8_is_space"):
863838
result = pc.utf8_is_space(self._data)
864839
return BooleanDtype().__from_arrow__(result)
865840
else:
866841
return super()._str_isspace()
867842

868843
def _str_istitle(self):
869-
if hasattr(pc, "utf8_is_title"):
870-
result = pc.utf8_is_title(self._data)
871-
return BooleanDtype().__from_arrow__(result)
872-
else:
873-
return super()._str_istitle()
844+
result = pc.utf8_is_title(self._data)
845+
return BooleanDtype().__from_arrow__(result)
874846

875847
def _str_isupper(self):
876-
if hasattr(pc, "utf8_is_upper"):
877-
result = pc.utf8_is_upper(self._data)
878-
return BooleanDtype().__from_arrow__(result)
879-
else:
880-
return super()._str_isupper()
848+
result = pc.utf8_is_upper(self._data)
849+
return BooleanDtype().__from_arrow__(result)
881850

882851
def _str_len(self):
883852
# utf8_length added in pyarrow 4.0.0
@@ -895,27 +864,33 @@ def _str_upper(self):
895864

896865
def _str_strip(self, to_strip=None):
897866
if to_strip is None:
867+
# utf8_trim_whitespace added in pyarrow 4.0.0
898868
if hasattr(pc, "utf8_trim_whitespace"):
899869
return type(self)(pc.utf8_trim_whitespace(self._data))
900870
else:
871+
# utf8_trim added in pyarrow 4.0.0
901872
if hasattr(pc, "utf8_trim"):
902873
return type(self)(pc.utf8_trim(self._data, characters=to_strip))
903874
return super()._str_strip(to_strip)
904875

905876
def _str_lstrip(self, to_strip=None):
906877
if to_strip is None:
878+
# utf8_ltrim_whitespace added in pyarrow 4.0.0
907879
if hasattr(pc, "utf8_ltrim_whitespace"):
908880
return type(self)(pc.utf8_ltrim_whitespace(self._data))
909881
else:
882+
# utf8_ltrim added in pyarrow 4.0.0
910883
if hasattr(pc, "utf8_ltrim"):
911884
return type(self)(pc.utf8_ltrim(self._data, characters=to_strip))
912885
return super()._str_lstrip(to_strip)
913886

914887
def _str_rstrip(self, to_strip=None):
915888
if to_strip is None:
889+
# utf8_rtrim_whitespace added in pyarrow 4.0.0
916890
if hasattr(pc, "utf8_rtrim_whitespace"):
917891
return type(self)(pc.utf8_rtrim_whitespace(self._data))
918892
else:
893+
# utf8_rtrim added in pyarrow 4.0.0
919894
if hasattr(pc, "utf8_rtrim"):
920895
return type(self)(pc.utf8_rtrim(self._data, characters=to_strip))
921896
return super()._str_rstrip(to_strip)

pandas/core/strings/accessor.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,6 @@ def _validate(data):
196196
-------
197197
dtype : inferred dtype of data
198198
"""
199-
from pandas import StringDtype
200-
201199
if isinstance(data, ABCMultiIndex):
202200
raise AttributeError(
203201
"Can only use .str accessor with Index, not MultiIndex"
@@ -209,10 +207,6 @@ def _validate(data):
209207
values = getattr(data, "values", data) # Series / Index
210208
values = getattr(values, "categories", values) # categorical / normal
211209

212-
# explicitly allow StringDtype
213-
if isinstance(values.dtype, StringDtype):
214-
return "string"
215-
216210
inferred_dtype = lib.infer_dtype(values, skipna=True)
217211

218212
if inferred_dtype not in allowed_types:
@@ -1133,6 +1127,14 @@ def contains(self, pat, case=True, flags=0, na=None, regex=True):
11331127
4 False
11341128
dtype: bool
11351129
"""
1130+
if regex and re.compile(pat).groups:
1131+
warnings.warn(
1132+
"This pattern has match groups. To actually get the "
1133+
"groups, use str.extract.",
1134+
UserWarning,
1135+
stacklevel=3,
1136+
)
1137+
11361138
result = self._data.array._str_contains(pat, case, flags, na, regex)
11371139
return self._wrap_result(result, fill_value=na, returns_string=False)
11381140

pandas/core/strings/object_array.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
Union,
88
)
99
import unicodedata
10-
import warnings
1110

1211
import numpy as np
1312

@@ -115,22 +114,14 @@ def _str_pad(self, width, side="left", fillchar=" "):
115114
raise ValueError("Invalid side")
116115
return self._str_map(f)
117116

118-
def _str_contains(self, pat, case=True, flags=0, na=np.nan, regex=True):
117+
def _str_contains(self, pat, case=True, flags=0, na=np.nan, regex: bool = True):
119118
if regex:
120119
if not case:
121120
flags |= re.IGNORECASE
122121

123-
regex = re.compile(pat, flags=flags)
122+
pat = re.compile(pat, flags=flags)
124123

125-
if regex.groups > 0:
126-
warnings.warn(
127-
"This pattern has match groups. To actually get the "
128-
"groups, use str.extract.",
129-
UserWarning,
130-
stacklevel=3,
131-
)
132-
133-
f = lambda x: regex.search(x) is not None
124+
f = lambda x: pat.search(x) is not None
134125
else:
135126
if case:
136127
f = lambda x: pat in x

pandas/tests/strings/conftest.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
import pytest
33

4+
import pandas.util._test_decorators as td
5+
46
from pandas import Series
57
from pandas.core import strings as strings
68

@@ -173,3 +175,24 @@ def any_allowed_skipna_inferred_dtype(request):
173175

174176
# correctness of inference tested in tests/dtypes/test_inference.py
175177
return inferred_dtype, values
178+
179+
180+
@pytest.fixture(
181+
params=[
182+
"object",
183+
"string",
184+
pytest.param(
185+
"arrow_string", marks=td.skip_if_no("pyarrow", min_version="1.0.0")
186+
),
187+
]
188+
)
189+
def any_string_dtype(request):
190+
"""
191+
Parametrized fixture for string dtypes.
192+
* 'object'
193+
* 'string'
194+
* 'arrow_string'
195+
"""
196+
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401
197+
198+
return request.param

pandas/tests/strings/test_find_replace.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import numpy as np
55
import pytest
66

7-
import pandas.util._test_decorators as td
8-
97
import pandas as pd
108
from pandas import (
119
Index,
@@ -14,27 +12,6 @@
1412
)
1513

1614

17-
@pytest.fixture(
18-
params=[
19-
"object",
20-
"string",
21-
pytest.param(
22-
"arrow_string", marks=td.skip_if_no("pyarrow", min_version="1.0.0")
23-
),
24-
]
25-
)
26-
def any_string_dtype(request):
27-
"""
28-
Parametrized fixture for string dtypes.
29-
* 'object'
30-
* 'string'
31-
* 'arrow_string'
32-
"""
33-
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401
34-
35-
return request.param
36-
37-
3815
def test_contains(any_string_dtype):
3916
values = np.array(
4017
["foo", np.nan, "fooommm__foo", "mmm_", "foommm[_]+bar"], dtype=np.object_
@@ -770,6 +747,7 @@ def test_flags_kwarg(any_string_dtype):
770747
result = data.str.count(pat, flags=re.IGNORECASE)
771748
assert result[0] == 1
772749

773-
with tm.assert_produces_warning(UserWarning):
750+
msg = "This pattern has match groups"
751+
with tm.assert_produces_warning(UserWarning, match=msg):
774752
result = data.str.contains(pat, flags=re.IGNORECASE)
775753
assert result[0]

pandas/tests/strings/test_strings.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import numpy as np
77
import pytest
88

9-
import pandas.util._test_decorators as td
10-
119
from pandas import (
1210
DataFrame,
1311
Index,
@@ -18,27 +16,6 @@
1816
import pandas._testing as tm
1917

2018

21-
@pytest.fixture(
22-
params=[
23-
"object",
24-
"string",
25-
pytest.param(
26-
"arrow_string", marks=td.skip_if_no("pyarrow", min_version="1.0.0")
27-
),
28-
]
29-
)
30-
def any_string_dtype(request):
31-
"""
32-
Parametrized fixture for string dtypes.
33-
* 'object'
34-
* 'string'
35-
* 'arrow_string'
36-
"""
37-
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401
38-
39-
return request.param
40-
41-
4219
def assert_series_or_index_equal(left, right):
4320
if isinstance(left, Series):
4421
tm.assert_series_equal(left, right)

0 commit comments

Comments
 (0)