Skip to content

Commit c85bbc6

Browse files
authored
BUG: Logical and comparison ops with ArrowDtype & masked (#52633)
1 parent f351f74 commit c85bbc6

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

doc/source/whatsnew/v2.0.1.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Bug fixes
3434
- Bug in :meth:`DataFrame.max` and related casting different :class:`Timestamp` resolutions always to nanoseconds (:issue:`52524`)
3535
- Bug in :meth:`Series.describe` not returning :class:`ArrowDtype` with ``pyarrow.float64`` type with numeric data (:issue:`52427`)
3636
- Bug in :meth:`Series.dt.tz_localize` incorrectly localizing timestamps with :class:`ArrowDtype` (:issue:`52677`)
37+
- Bug in logical and comparison operations between :class:`ArrowDtype` and numpy masked types (e.g. ``"boolean"``) (:issue:`52625`)
3738
- Fixed bug in :func:`merge` when merging with ``ArrowDtype`` one one and a NumPy dtype on the other side (:issue:`52406`)
3839
- Fixed segfault in :meth:`Series.to_numpy` with ``null[pyarrow]`` dtype (:issue:`52443`)
3940

pandas/core/arrays/arrow/array.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
ExtensionArray,
5656
ExtensionArraySupportsAnyAll,
5757
)
58+
from pandas.core.arrays.masked import BaseMaskedArray
5859
from pandas.core.arrays.string_ import StringDtype
5960
import pandas.core.common as com
6061
from pandas.core.indexers import (
@@ -450,6 +451,9 @@ def _cmp_method(self, other, op):
450451
result = pc_func(self._pa_array, other._pa_array)
451452
elif isinstance(other, (np.ndarray, list)):
452453
result = pc_func(self._pa_array, other)
454+
elif isinstance(other, BaseMaskedArray):
455+
# GH 52625
456+
result = pc_func(self._pa_array, other.__arrow_array__())
453457
elif is_scalar(other):
454458
try:
455459
result = pc_func(self._pa_array, pa.scalar(other))
@@ -497,6 +501,9 @@ def _evaluate_op_method(self, other, op, arrow_funcs):
497501
result = pc_func(self._pa_array, other._pa_array)
498502
elif isinstance(other, (np.ndarray, list)):
499503
result = pc_func(self._pa_array, pa.array(other, from_pandas=True))
504+
elif isinstance(other, BaseMaskedArray):
505+
# GH 52625
506+
result = pc_func(self._pa_array, other.__arrow_array__())
500507
elif is_scalar(other):
501508
if isna(other) and op.__name__ in ARROW_LOGICAL_FUNCS:
502509
# pyarrow kleene ops require null to be typed

pandas/tests/extension/test_arrow.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
BytesIO,
2222
StringIO,
2323
)
24+
import operator
2425
import pickle
2526
import re
2627

@@ -1216,7 +1217,7 @@ def test_add_series_with_extension_array(self, data, request):
12161217

12171218

12181219
class TestBaseComparisonOps(base.BaseComparisonOpsTests):
1219-
def test_compare_array(self, data, comparison_op, na_value, request):
1220+
def test_compare_array(self, data, comparison_op, na_value):
12201221
ser = pd.Series(data)
12211222
# pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray
12221223
# since ser.iloc[0] is a python scalar
@@ -1255,6 +1256,20 @@ def test_invalid_other_comp(self, data, comparison_op):
12551256
):
12561257
comparison_op(data, object())
12571258

1259+
@pytest.mark.parametrize("masked_dtype", ["boolean", "Int64", "Float64"])
1260+
def test_comp_masked_numpy(self, masked_dtype, comparison_op):
1261+
# GH 52625
1262+
data = [1, 0, None]
1263+
ser_masked = pd.Series(data, dtype=masked_dtype)
1264+
ser_pa = pd.Series(data, dtype=f"{masked_dtype.lower()}[pyarrow]")
1265+
result = comparison_op(ser_pa, ser_masked)
1266+
if comparison_op in [operator.lt, operator.gt, operator.ne]:
1267+
exp = [False, False, None]
1268+
else:
1269+
exp = [True, True, None]
1270+
expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
1271+
tm.assert_series_equal(result, expected)
1272+
12581273

12591274
class TestLogicalOps:
12601275
"""Various Series and DataFrame logical ops methods."""
@@ -1399,6 +1414,23 @@ def test_kleene_xor_scalar(self, other, expected):
13991414
a, pd.Series([True, False, None], dtype="boolean[pyarrow]")
14001415
)
14011416

1417+
@pytest.mark.parametrize(
1418+
"op, exp",
1419+
[
1420+
["__and__", True],
1421+
["__or__", True],
1422+
["__xor__", False],
1423+
],
1424+
)
1425+
def test_logical_masked_numpy(self, op, exp):
1426+
# GH 52625
1427+
data = [True, False, None]
1428+
ser_masked = pd.Series(data, dtype="boolean")
1429+
ser_pa = pd.Series(data, dtype="boolean[pyarrow]")
1430+
result = getattr(ser_pa, op)(ser_masked)
1431+
expected = pd.Series([exp, False, None], dtype=ArrowDtype(pa.bool_()))
1432+
tm.assert_series_equal(result, expected)
1433+
14021434

14031435
def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
14041436
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):

0 commit comments

Comments
 (0)