-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
Implement Arrow String Array that is compatible with NumPy semantics #54533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
b24afc9
b306c6f
2dbcfb0
d9e61e5
3188c25
df231f0
cd19bfb
c73c6b0
6b26309
da6d67c
4be0ee8
6cf2639
0c260fb
8df070a
1e732c9
5b0d24c
3a913e1
a95d2ee
ec56cef
31a59c4
ed77967
d05c51d
aad3d2e
5383eeb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
from functools import partial | ||
import re | ||
from typing import ( | ||
TYPE_CHECKING, | ||
|
@@ -27,6 +28,7 @@ | |
) | ||
from pandas.core.dtypes.missing import isna | ||
|
||
from pandas.core.arrays._arrow_string_mixins import ArrowStringArrayMixin | ||
from pandas.core.arrays.arrow import ArrowExtensionArray | ||
from pandas.core.arrays.boolean import BooleanDtype | ||
from pandas.core.arrays.integer import Int64Dtype | ||
|
@@ -113,10 +115,11 @@ class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringAr | |
# error: Incompatible types in assignment (expression has type "StringDtype", | ||
# base class "ArrowExtensionArray" defined the type as "ArrowDtype") | ||
_dtype: StringDtype # type: ignore[assignment] | ||
_storage = "pyarrow" | ||
|
||
def __init__(self, values) -> None: | ||
super().__init__(values) | ||
self._dtype = StringDtype(storage="pyarrow") | ||
self._dtype = StringDtype(storage=self._storage) | ||
|
||
if not pa.types.is_string(self._pa_array.type) and not ( | ||
pa.types.is_dictionary(self._pa_array.type) | ||
|
@@ -144,7 +147,10 @@ def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False) | |
|
||
if dtype and not (isinstance(dtype, str) and dtype == "string"): | ||
dtype = pandas_dtype(dtype) | ||
assert isinstance(dtype, StringDtype) and dtype.storage == "pyarrow" | ||
assert isinstance(dtype, StringDtype) and dtype.storage in ( | ||
"pyarrow", | ||
"pyarrow_numpy", | ||
) | ||
|
||
if isinstance(scalars, BaseMaskedArray): | ||
# avoid costly conversion to object dtype in ensure_string_array and | ||
|
@@ -178,6 +184,10 @@ def insert(self, loc: int, item) -> ArrowStringArray: | |
raise TypeError("Scalar must be NA or str") | ||
return super().insert(loc, item) | ||
|
||
@staticmethod | ||
jorisvandenbossche marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def _result_converter(values, **kwargs): | ||
jorisvandenbossche marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return BooleanDtype().__from_arrow__(values) | ||
|
||
def _maybe_convert_setitem_value(self, value): | ||
"""Maybe convert value to be pyarrow compatible.""" | ||
if is_scalar(value): | ||
|
@@ -313,7 +323,7 @@ def _str_contains( | |
result = pc.match_substring_regex(self._pa_array, pat, ignore_case=not case) | ||
else: | ||
result = pc.match_substring(self._pa_array, pat, ignore_case=not case) | ||
result = BooleanDtype().__from_arrow__(result) | ||
result = self._result_converter(result, na=na) | ||
if not isna(na): | ||
result[isna(result)] = bool(na) | ||
return result | ||
|
@@ -322,7 +332,7 @@ def _str_startswith(self, pat: str, na=None): | |
result = pc.starts_with(self._pa_array, pattern=pat) | ||
if not isna(na): | ||
result = result.fill_null(na) | ||
result = BooleanDtype().__from_arrow__(result) | ||
result = self._result_converter(result) | ||
if not isna(na): | ||
result[isna(result)] = bool(na) | ||
return result | ||
|
@@ -331,7 +341,7 @@ def _str_endswith(self, pat: str, na=None): | |
result = pc.ends_with(self._pa_array, pattern=pat) | ||
if not isna(na): | ||
result = result.fill_null(na) | ||
result = BooleanDtype().__from_arrow__(result) | ||
result = self._result_converter(result) | ||
if not isna(na): | ||
result[isna(result)] = bool(na) | ||
return result | ||
|
@@ -369,39 +379,39 @@ def _str_fullmatch( | |
|
||
def _str_isalnum(self): | ||
result = pc.utf8_is_alnum(self._pa_array) | ||
return BooleanDtype().__from_arrow__(result) | ||
return self._result_converter(result) | ||
|
||
def _str_isalpha(self): | ||
result = pc.utf8_is_alpha(self._pa_array) | ||
return BooleanDtype().__from_arrow__(result) | ||
return self._result_converter(result) | ||
|
||
def _str_isdecimal(self): | ||
result = pc.utf8_is_decimal(self._pa_array) | ||
return BooleanDtype().__from_arrow__(result) | ||
return self._result_converter(result) | ||
|
||
def _str_isdigit(self): | ||
result = pc.utf8_is_digit(self._pa_array) | ||
return BooleanDtype().__from_arrow__(result) | ||
return self._result_converter(result) | ||
|
||
def _str_islower(self): | ||
result = pc.utf8_is_lower(self._pa_array) | ||
return BooleanDtype().__from_arrow__(result) | ||
return self._result_converter(result) | ||
|
||
def _str_isnumeric(self): | ||
result = pc.utf8_is_numeric(self._pa_array) | ||
return BooleanDtype().__from_arrow__(result) | ||
return self._result_converter(result) | ||
|
||
def _str_isspace(self): | ||
result = pc.utf8_is_space(self._pa_array) | ||
return BooleanDtype().__from_arrow__(result) | ||
return self._result_converter(result) | ||
|
||
def _str_istitle(self): | ||
result = pc.utf8_is_title(self._pa_array) | ||
return BooleanDtype().__from_arrow__(result) | ||
return self._result_converter(result) | ||
|
||
def _str_isupper(self): | ||
result = pc.utf8_is_upper(self._pa_array) | ||
return BooleanDtype().__from_arrow__(result) | ||
return self._result_converter(result) | ||
|
||
def _str_len(self): | ||
result = pc.utf8_length(self._pa_array) | ||
|
@@ -433,3 +443,114 @@ def _str_rstrip(self, to_strip=None): | |
else: | ||
result = pc.utf8_rtrim(self._pa_array, characters=to_strip) | ||
return type(self)(result) | ||
|
||
|
||
class ArrowStringArrayNumpySemantics(ArrowStringArray): | ||
_storage = "pyarrow_numpy" | ||
|
||
@staticmethod | ||
def _result_converter(values, na=None): | ||
if not isna(na): | ||
values = values.fill_null(bool(na)) | ||
return ArrowExtensionArray(values).to_numpy(na_value=np.nan) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why only For example (using this branch):
If we want that the second example has a proper numpy bool dtype that can be used for boolean indexing, we probably should convert the NaN into False? (either with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure about that. This would break behaviour compared to object dtype, which is something I tried to avoid (and will break a huge number of tests). I think it makes more sense to keep as is for consistency sake There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, OK, so also the current object-dtype string methods propagate the NaN, wasn't aware of that:
Personally, I think that's something we should change, but that's for another issue then. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah happy to discuss that, but maybe better as a breaking change for 3.0? Don't know though. Let's open an issue about that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Opened #54805 for the question whether string predicate methods like startswith should propagate NaN or not |
||
|
||
def __getattribute__(self, item): | ||
# ArrowStringArray and we both inherit from ArrowExtensionArray, which | ||
# creates inheritance problems (Diamnond inheritance) | ||
jorisvandenbossche marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if item in ArrowStringArrayMixin.__dict__ and item != "_pa_array": | ||
return partial(getattr(ArrowStringArrayMixin, item), self) | ||
return super().__getattribute__(item) | ||
|
||
def _str_map( | ||
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True | ||
): | ||
if dtype is None: | ||
dtype = self.dtype | ||
if na_value is None: | ||
na_value = self.dtype.na_value | ||
|
||
mask = isna(self) | ||
arr = np.asarray(self) | ||
|
||
if is_integer_dtype(dtype) or is_bool_dtype(dtype): | ||
if is_integer_dtype(dtype): | ||
na_value = np.nan | ||
else: | ||
na_value = False | ||
try: | ||
result = lib.map_infer_mask( | ||
arr, | ||
f, | ||
mask.view("uint8"), | ||
convert=False, | ||
na_value=na_value, | ||
dtype=np.dtype(dtype), # type: ignore[arg-type] | ||
) | ||
return result | ||
|
||
except ValueError: | ||
result = lib.map_infer_mask( | ||
arr, | ||
f, | ||
mask.view("uint8"), | ||
convert=False, | ||
na_value=na_value, | ||
) | ||
if convert and result.dtype == object: | ||
result = lib.maybe_convert_objects(result) | ||
return result | ||
|
||
elif is_string_dtype(dtype) and not is_object_dtype(dtype): | ||
# i.e. StringDtype | ||
result = lib.map_infer_mask( | ||
arr, f, mask.view("uint8"), convert=False, na_value=na_value | ||
) | ||
result = pa.array(result, mask=mask, type=pa.string(), from_pandas=True) | ||
return type(self)(result) | ||
else: | ||
# This is when the result type is object. We reach this when | ||
# -> We know the result type is truly object (e.g. .encode returns bytes | ||
# or .findall returns a list). | ||
# -> We don't know the result type. E.g. `.get` can return anything. | ||
return lib.map_infer_mask(arr, f, mask.view("uint8")) | ||
|
||
def _convert_int_dtype(self, result): | ||
if result.dtype == np.int32: | ||
result = result.astype(np.int64) | ||
return result | ||
|
||
def _str_count(self, pat: str, flags: int = 0): | ||
if flags: | ||
return super()._str_count(pat, flags) | ||
result = pc.count_substring_regex(self._pa_array, pat).to_numpy() | ||
return self._convert_int_dtype(result) | ||
|
||
def _str_len(self): | ||
result = pc.utf8_length(self._pa_array).to_numpy() | ||
return self._convert_int_dtype(result) | ||
|
||
def _str_find(self, sub: str, start: int = 0, end: int | None = None): | ||
if start != 0 and end is not None: | ||
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) | ||
result = pc.find_substring(slices, sub) | ||
not_found = pc.equal(result, -1) | ||
offset_result = pc.add(result, end - start) | ||
result = pc.if_else(not_found, result, offset_result) | ||
elif start == 0 and end is None: | ||
slices = self._pa_array | ||
result = pc.find_substring(slices, sub) | ||
else: | ||
return super()._str_find(sub, start, end) | ||
return self._convert_int_dtype(result.to_numpy()) | ||
|
||
def _cmp_method(self, other, op): | ||
result = super()._cmp_method(other, op) | ||
return result.to_numpy(np.bool_, na_value=False) | ||
|
||
def value_counts(self, dropna: bool = True): | ||
from pandas import Series | ||
|
||
result = super().value_counts(dropna) | ||
return Series( | ||
result._values.to_numpy(), index=result.index, name=result.name, copy=False | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -500,7 +500,7 @@ def use_inf_as_na_cb(key) -> None: | |
"string_storage", | ||
"python", | ||
string_storage_doc, | ||
validator=is_one_of_factory(["python", "pyarrow"]), | ||
validator=is_one_of_factory(["python", "pyarrow", "pyarrow_numpy"]), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want the user to allow setting this for the new option as well? Because we also have the new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd say yes, you might want to astype columns after operations for example There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably something to further discuss in a follow up issue, but I would expect that if you opt-in for the future string dtype with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah agree, let's do this as a follow up There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Opened #54793 for this interaction between |
||
) | ||
|
||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.