Skip to content

Commit a131266

Browse files
authored
ENH: Use MaskedEngine for numeric pyarrow dtypes (#51316)
1 parent 95a087d commit a131266

File tree

5 files changed

+56
-18
lines changed

5 files changed

+56
-18
lines changed

doc/source/whatsnew/v2.0.0.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,7 @@ Performance improvements
11171117
- Performance improvement in :meth:`DataFrame.loc` and :meth:`Series.loc` for tuple-based indexing of a :class:`MultiIndex` (:issue:`48384`)
11181118
- Performance improvement for :meth:`Series.replace` with categorical dtype (:issue:`49404`)
11191119
- Performance improvement for :meth:`MultiIndex.unique` (:issue:`48335`)
1120-
- Performance improvement for indexing operations with nullable dtypes (:issue:`49420`)
1120+
- Performance improvement for indexing operations with nullable and arrow dtypes (:issue:`49420`, :issue:`51316`)
11211121
- Performance improvement for :func:`concat` with extension array backed indexes (:issue:`49128`, :issue:`49178`)
11221122
- Performance improvement for :func:`api.types.infer_dtype` (:issue:`51054`)
11231123
- Reduce memory usage of :meth:`DataFrame.to_pickle`/:meth:`Series.to_pickle` when using BZ2 or LZMA (:issue:`49068`)

pandas/_libs/index.pyx

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,12 +1141,26 @@ cdef class ExtensionEngine(SharedEngine):
11411141

11421142
cdef class MaskedIndexEngine(IndexEngine):
11431143
def __init__(self, object values):
1144-
super().__init__(values._data)
1145-
self.mask = values._mask
1144+
super().__init__(self._get_data(values))
1145+
self.mask = self._get_mask(values)
1146+
1147+
def _get_data(self, object values) -> np.ndarray:
1148+
if hasattr(values, "_mask"):
1149+
return values._data
1150+
# We are an ArrowExtensionArray
1151+
# Set 1 as na_value to avoid ending up with NA and an object array
1152+
# TODO: Remove when arrow engine is implemented
1153+
return values.to_numpy(na_value=1, dtype=values.dtype.numpy_dtype)
1154+
1155+
def _get_mask(self, object values) -> np.ndarray:
1156+
if hasattr(values, "_mask"):
1157+
return values._mask
1158+
# We are an ArrowExtensionArray
1159+
return values.isna()
11461160

11471161
def get_indexer(self, object values) -> np.ndarray:
11481162
self._ensure_mapping_populated()
1149-
return self.mapping.lookup(values._data, values._mask)
1163+
return self.mapping.lookup(self._get_data(values), self._get_mask(values))
11501164

11511165
def get_indexer_non_unique(self, object targets):
11521166
"""
@@ -1171,8 +1185,8 @@ cdef class MaskedIndexEngine(IndexEngine):
11711185
Py_ssize_t count = 0, count_missing = 0
11721186
Py_ssize_t i, j, n, n_t, n_alloc, start, end, na_idx
11731187

1174-
target_vals = targets._data
1175-
target_mask = targets._mask
1188+
target_vals = self._get_data(targets)
1189+
target_mask = self._get_mask(targets)
11761190

11771191
values = self.values
11781192
assert not values.dtype == object # go through object path instead

pandas/core/arrays/arrow/array.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
TypeVar,
1212
cast,
1313
)
14+
import warnings
1415

1516
import numpy as np
1617

@@ -890,7 +891,10 @@ def to_numpy(
890891
mask = ~self.isna()
891892
result[mask] = np.asarray(self[mask]._data)
892893
else:
893-
result = np.asarray(self._data, dtype=dtype)
894+
with warnings.catch_warnings():
895+
# int dtype with NA raises Warning
896+
warnings.filterwarnings("ignore", category=RuntimeWarning)
897+
result = np.asarray(self._data, dtype=dtype)
894898
if copy or self._hasna:
895899
result = result.copy()
896900
if self._hasna:

pandas/core/indexes/base.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,19 @@
221221
"Int16": libindex.MaskedInt16Engine,
222222
"Int8": libindex.MaskedInt8Engine,
223223
"boolean": libindex.MaskedBoolEngine,
224+
"double[pyarrow]": libindex.MaskedFloat64Engine,
225+
"float64[pyarrow]": libindex.MaskedFloat64Engine,
226+
"float32[pyarrow]": libindex.MaskedFloat32Engine,
227+
"float[pyarrow]": libindex.MaskedFloat32Engine,
228+
"uint64[pyarrow]": libindex.MaskedUInt64Engine,
229+
"uint32[pyarrow]": libindex.MaskedUInt32Engine,
230+
"uint16[pyarrow]": libindex.MaskedUInt16Engine,
231+
"uint8[pyarrow]": libindex.MaskedUInt8Engine,
232+
"int64[pyarrow]": libindex.MaskedInt64Engine,
233+
"int32[pyarrow]": libindex.MaskedInt32Engine,
234+
"int16[pyarrow]": libindex.MaskedInt16Engine,
235+
"int8[pyarrow]": libindex.MaskedInt8Engine,
236+
"bool[pyarrow]": libindex.MaskedBoolEngine,
224237
}
225238

226239

@@ -796,7 +809,7 @@ def _engine(
796809
# For base class (object dtype) we get ObjectEngine
797810
target_values = self._get_engine_target()
798811
if isinstance(target_values, ExtensionArray):
799-
if isinstance(target_values, BaseMaskedArray):
812+
if isinstance(target_values, (BaseMaskedArray, ArrowExtensionArray)):
800813
return _masked_engines[target_values.dtype.name](target_values)
801814
elif self._engine_type is libindex.ObjectEngine:
802815
return libindex.ExtensionEngine(target_values)
@@ -4932,6 +4945,10 @@ def _get_engine_target(self) -> ArrayLike:
49324945
type(self) is Index
49334946
and isinstance(self._values, ExtensionArray)
49344947
and not isinstance(self._values, BaseMaskedArray)
4948+
and not (
4949+
isinstance(self._values, ArrowExtensionArray)
4950+
and is_numeric_dtype(self.dtype)
4951+
)
49354952
):
49364953
# TODO(ExtensionIndex): remove special-case, just use self._values
49374954
return self._values.astype(object)

pandas/tests/indexes/numeric/test_indexing.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -317,26 +317,26 @@ def test_get_indexer_uint64(self, index_large):
317317
tm.assert_numpy_array_equal(indexer, expected)
318318

319319
@pytest.mark.parametrize("val, val2", [(4, 5), (4, 4), (4, NA), (NA, NA)])
320-
def test_get_loc_masked(self, val, val2, any_numeric_ea_dtype):
320+
def test_get_loc_masked(self, val, val2, any_numeric_ea_and_arrow_dtype):
321321
# GH#39133
322-
idx = Index([1, 2, 3, val, val2], dtype=any_numeric_ea_dtype)
322+
idx = Index([1, 2, 3, val, val2], dtype=any_numeric_ea_and_arrow_dtype)
323323
result = idx.get_loc(2)
324324
assert result == 1
325325

326326
with pytest.raises(KeyError, match="9"):
327327
idx.get_loc(9)
328328

329-
def test_get_loc_masked_na(self, any_numeric_ea_dtype):
329+
def test_get_loc_masked_na(self, any_numeric_ea_and_arrow_dtype):
330330
# GH#39133
331-
idx = Index([1, 2, NA], dtype=any_numeric_ea_dtype)
331+
idx = Index([1, 2, NA], dtype=any_numeric_ea_and_arrow_dtype)
332332
result = idx.get_loc(NA)
333333
assert result == 2
334334

335-
idx = Index([1, 2, NA, NA], dtype=any_numeric_ea_dtype)
335+
idx = Index([1, 2, NA, NA], dtype=any_numeric_ea_and_arrow_dtype)
336336
result = idx.get_loc(NA)
337337
tm.assert_numpy_array_equal(result, np.array([False, False, True, True]))
338338

339-
idx = Index([1, 2, 3], dtype=any_numeric_ea_dtype)
339+
idx = Index([1, 2, 3], dtype=any_numeric_ea_and_arrow_dtype)
340340
with pytest.raises(KeyError, match="NA"):
341341
idx.get_loc(NA)
342342

@@ -371,16 +371,19 @@ def test_get_loc_masked_na_and_nan(self):
371371
idx.get_loc(NA)
372372

373373
@pytest.mark.parametrize("val", [4, 2])
374-
def test_get_indexer_masked_na(self, any_numeric_ea_dtype, val):
374+
def test_get_indexer_masked_na(self, any_numeric_ea_and_arrow_dtype, val):
375375
# GH#39133
376-
idx = Index([1, 2, NA, 3, val], dtype=any_numeric_ea_dtype)
376+
idx = Index([1, 2, NA, 3, val], dtype=any_numeric_ea_and_arrow_dtype)
377377
result = idx.get_indexer_for([1, NA, 5])
378378
expected = np.array([0, 2, -1])
379379
tm.assert_numpy_array_equal(result, expected, check_dtype=False)
380380

381-
def test_get_indexer_masked_na_boolean(self):
381+
@pytest.mark.parametrize("dtype", ["boolean", "bool[pyarrow]"])
382+
def test_get_indexer_masked_na_boolean(self, dtype):
382383
# GH#39133
383-
idx = Index([True, False, NA], dtype="boolean")
384+
if dtype == "bool[pyarrow]":
385+
pytest.importorskip("pyarrow")
386+
idx = Index([True, False, NA], dtype=dtype)
384387
result = idx.get_loc(False)
385388
assert result == 1
386389
result = idx.get_loc(NA)

0 commit comments

Comments
 (0)