Skip to content

[ArrowStringArray] CLN: assorted cleanup #41306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 24 additions & 49 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Sequence,
cast,
)
import warnings

import numpy as np

Expand Down Expand Up @@ -766,20 +765,13 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))

def _str_contains(self, pat, case=True, flags=0, na=np.nan, regex=True):
def _str_contains(self, pat, case=True, flags=0, na=np.nan, regex: bool = True):
if flags:
return super()._str_contains(pat, case, flags, na, regex)

if regex:
# match_substring_regex added in pyarrow 4.0.0
if hasattr(pc, "match_substring_regex") and case:
if re.compile(pat).groups:
warnings.warn(
"This pattern has match groups. To actually get the "
"groups, use str.extract.",
UserWarning,
stacklevel=3,
)
result = pc.match_substring_regex(self._data, pat)
else:
return super()._str_contains(pat, case, flags, na, regex)
Expand Down Expand Up @@ -816,67 +808,44 @@ def _str_endswith(self, pat, na=None):
return super()._str_endswith(pat, na)

def _str_isalnum(self):
if hasattr(pc, "utf8_is_alnum"):
result = pc.utf8_is_alnum(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_isalnum()
result = pc.utf8_is_alnum(self._data)
return BooleanDtype().__from_arrow__(result)

def _str_isalpha(self):
if hasattr(pc, "utf8_is_alpha"):
result = pc.utf8_is_alpha(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_isalpha()
result = pc.utf8_is_alpha(self._data)
return BooleanDtype().__from_arrow__(result)

def _str_isdecimal(self):
if hasattr(pc, "utf8_is_decimal"):
result = pc.utf8_is_decimal(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_isdecimal()
result = pc.utf8_is_decimal(self._data)
return BooleanDtype().__from_arrow__(result)

def _str_isdigit(self):
if hasattr(pc, "utf8_is_digit"):
result = pc.utf8_is_digit(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_isdigit()
result = pc.utf8_is_digit(self._data)
return BooleanDtype().__from_arrow__(result)

def _str_islower(self):
if hasattr(pc, "utf8_is_lower"):
result = pc.utf8_is_lower(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_islower()
result = pc.utf8_is_lower(self._data)
return BooleanDtype().__from_arrow__(result)

def _str_isnumeric(self):
if hasattr(pc, "utf8_is_numeric"):
result = pc.utf8_is_numeric(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_isnumeric()
result = pc.utf8_is_numeric(self._data)
return BooleanDtype().__from_arrow__(result)

def _str_isspace(self):
# utf8_is_space added in pyarrow 2.0.0
if hasattr(pc, "utf8_is_space"):
result = pc.utf8_is_space(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_isspace()

def _str_istitle(self):
if hasattr(pc, "utf8_is_title"):
result = pc.utf8_is_title(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_istitle()
result = pc.utf8_is_title(self._data)
return BooleanDtype().__from_arrow__(result)

def _str_isupper(self):
if hasattr(pc, "utf8_is_upper"):
result = pc.utf8_is_upper(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_isupper()
result = pc.utf8_is_upper(self._data)
return BooleanDtype().__from_arrow__(result)

def _str_lower(self):
return type(self)(pc.utf8_lower(self._data))
Expand All @@ -886,27 +855,33 @@ def _str_upper(self):

def _str_strip(self, to_strip=None):
if to_strip is None:
# utf8_trim_whitespace added in pyarrow 4.0.0
if hasattr(pc, "utf8_trim_whitespace"):
return type(self)(pc.utf8_trim_whitespace(self._data))
else:
# utf8_trim added in pyarrow 4.0.0
if hasattr(pc, "utf8_trim"):
return type(self)(pc.utf8_trim(self._data, characters=to_strip))
return super()._str_strip(to_strip)

def _str_lstrip(self, to_strip=None):
if to_strip is None:
# utf8_ltrim_whitespace added in pyarrow 4.0.0
if hasattr(pc, "utf8_ltrim_whitespace"):
return type(self)(pc.utf8_ltrim_whitespace(self._data))
else:
# utf8_ltrim added in pyarrow 4.0.0
if hasattr(pc, "utf8_ltrim"):
return type(self)(pc.utf8_ltrim(self._data, characters=to_strip))
return super()._str_lstrip(to_strip)

def _str_rstrip(self, to_strip=None):
if to_strip is None:
# utf8_rtrim_whitespace added in pyarrow 4.0.0
if hasattr(pc, "utf8_rtrim_whitespace"):
return type(self)(pc.utf8_rtrim_whitespace(self._data))
else:
# utf8_rtrim added in pyarrow 4.0.0
if hasattr(pc, "utf8_rtrim"):
return type(self)(pc.utf8_rtrim(self._data, characters=to_strip))
return super()._str_rstrip(to_strip)
14 changes: 8 additions & 6 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,6 @@ def _validate(data):
-------
dtype : inferred dtype of data
"""
from pandas import StringDtype

if isinstance(data, ABCMultiIndex):
raise AttributeError(
"Can only use .str accessor with Index, not MultiIndex"
Expand All @@ -208,10 +206,6 @@ def _validate(data):
values = getattr(data, "values", data) # Series / Index
values = getattr(values, "categories", values) # categorical / normal

# explicitly allow StringDtype
if isinstance(values.dtype, StringDtype):
return "string"

inferred_dtype = lib.infer_dtype(values, skipna=True)

if inferred_dtype not in allowed_types:
Expand Down Expand Up @@ -1132,6 +1126,14 @@ def contains(self, pat, case=True, flags=0, na=None, regex=True):
4 False
dtype: bool
"""
if regex and re.compile(pat).groups:
warnings.warn(
"This pattern has match groups. To actually get the "
"groups, use str.extract.",
UserWarning,
stacklevel=3,
)

result = self._data.array._str_contains(pat, case, flags, na, regex)
return self._wrap_result(result, fill_value=na, returns_string=False)

Expand Down
15 changes: 3 additions & 12 deletions pandas/core/strings/object_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Union,
)
import unicodedata
import warnings

import numpy as np

Expand Down Expand Up @@ -115,22 +114,14 @@ def _str_pad(self, width, side="left", fillchar=" "):
raise ValueError("Invalid side")
return self._str_map(f)

def _str_contains(self, pat, case=True, flags=0, na=np.nan, regex=True):
def _str_contains(self, pat, case=True, flags=0, na=np.nan, regex: bool = True):
if regex:
if not case:
flags |= re.IGNORECASE

regex = re.compile(pat, flags=flags)
pat = re.compile(pat, flags=flags)

if regex.groups > 0:
warnings.warn(
"This pattern has match groups. To actually get the "
"groups, use str.extract.",
UserWarning,
stacklevel=3,
)

f = lambda x: regex.search(x) is not None
f = lambda x: pat.search(x) is not None
else:
if case:
f = lambda x: pat in x
Expand Down
23 changes: 23 additions & 0 deletions pandas/tests/strings/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

from pandas import Series
from pandas.core import strings as strings

Expand Down Expand Up @@ -173,3 +175,24 @@ def any_allowed_skipna_inferred_dtype(request):

# correctness of inference tested in tests/dtypes/test_inference.py
return inferred_dtype, values


@pytest.fixture(
params=[
"object",
"string",
pytest.param(
"arrow_string", marks=td.skip_if_no("pyarrow", min_version="1.0.0")
),
]
)
def any_string_dtype(request):
"""
Parametrized fixture for string dtypes.
* 'object'
* 'string'
* 'arrow_string'
"""
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401

return request.param
26 changes: 2 additions & 24 deletions pandas/tests/strings/test_find_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

import pandas as pd
from pandas import (
Index,
Expand All @@ -14,27 +12,6 @@
)


@pytest.fixture(
params=[
"object",
"string",
pytest.param(
"arrow_string", marks=td.skip_if_no("pyarrow", min_version="1.0.0")
),
]
)
def any_string_dtype(request):
"""
Parametrized fixture for string dtypes.
* 'object'
* 'string'
* 'arrow_string'
"""
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401

return request.param


def test_contains(any_string_dtype):
values = np.array(
["foo", np.nan, "fooommm__foo", "mmm_", "foommm[_]+bar"], dtype=np.object_
Expand Down Expand Up @@ -751,6 +728,7 @@ def test_flags_kwarg(any_string_dtype):
result = data.str.count(pat, flags=re.IGNORECASE)
assert result[0] == 1

with tm.assert_produces_warning(UserWarning):
msg = "This pattern has match groups"
with tm.assert_produces_warning(UserWarning, match=msg):
result = data.str.contains(pat, flags=re.IGNORECASE)
assert result[0]
23 changes: 0 additions & 23 deletions pandas/tests/strings/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

from pandas import (
DataFrame,
Index,
Expand All @@ -19,27 +17,6 @@
import pandas._testing as tm


@pytest.fixture(
params=[
"object",
"string",
pytest.param(
"arrow_string", marks=td.skip_if_no("pyarrow", min_version="1.0.0")
),
]
)
def any_string_dtype(request):
"""
Parametrized fixture for string dtypes.
* 'object'
* 'string'
* 'arrow_string'
"""
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401

return request.param


def assert_series_or_index_equal(left, right):
if isinstance(left, Series):
tm.assert_series_equal(left, right)
Expand Down