Skip to content

Commit f97d61c

Browse files
authored
REF: mix NDArrayBacked into Categorical, PandasArray (#41131)
1 parent c2c2ce5 commit f97d61c

File tree

14 files changed

+124
-162
lines changed

14 files changed

+124
-162
lines changed

pandas/_libs/arrays.pyi

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import Sequence
2+
3+
import numpy as np
4+
5+
from pandas._typing import (
6+
DtypeObj,
7+
Shape,
8+
)
9+
10+
class NDArrayBacked:
11+
_dtype: DtypeObj
12+
_ndarray: np.ndarray
13+
14+
def __init__(self, values: np.ndarray, dtype: DtypeObj): ...
15+
16+
@classmethod
17+
def _simple_new(cls, values: np.ndarray, dtype: DtypeObj): ...
18+
19+
def _from_backing_data(self, values: np.ndarray): ...
20+
21+
def __setstate__(self, state): ...
22+
23+
def __len__(self) -> int: ...
24+
25+
@property
26+
def shape(self) -> Shape: ...
27+
28+
@property
29+
def ndim(self) -> int: ...
30+
31+
@property
32+
def size(self) -> int: ...
33+
34+
@property
35+
def nbytes(self) -> int: ...
36+
37+
def copy(self): ...
38+
def delete(self, loc, axis=0): ...
39+
def swapaxes(self, axis1, axis2): ...
40+
def repeat(self, repeats: int | Sequence[int], axis: int | None = ...): ...
41+
def reshape(self, *args, **kwargs): ...
42+
def ravel(self, order="C"): ...
43+
44+
@property
45+
def T(self): ...

pandas/_libs/tslibs/dtypes.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ _period_code_map: dict[str, int]
77

88

99
class PeriodDtypeBase:
10+
_dtype_code: int # PeriodDtypeCode
11+
1012
# actually __cinit__
1113
def __new__(self, code: int): ...
1214

pandas/core/arrays/_mixins.py

Lines changed: 3 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,15 @@
1111
import numpy as np
1212

1313
from pandas._libs import lib
14+
from pandas._libs.arrays import NDArrayBacked
1415
from pandas._typing import (
1516
F,
1617
PositionalIndexer2D,
1718
Shape,
1819
type_t,
1920
)
20-
from pandas.compat.numpy import function as nv
2121
from pandas.errors import AbstractMethodError
22-
from pandas.util._decorators import (
23-
cache_readonly,
24-
doc,
25-
)
22+
from pandas.util._decorators import doc
2623
from pandas.util._validators import (
2724
validate_bool_kwarg,
2825
validate_fillna_kwargs,
@@ -69,24 +66,13 @@ def method(self, *args, **kwargs):
6966
return cast(F, method)
7067

7168

72-
class NDArrayBackedExtensionArray(ExtensionArray):
69+
class NDArrayBackedExtensionArray(NDArrayBacked, ExtensionArray):
7370
"""
7471
ExtensionArray that is backed by a single NumPy ndarray.
7572
"""
7673

7774
_ndarray: np.ndarray
7875

79-
def _from_backing_data(
80-
self: NDArrayBackedExtensionArrayT, arr: np.ndarray
81-
) -> NDArrayBackedExtensionArrayT:
82-
"""
83-
Construct a new ExtensionArray `new_array` with `arr` as its _ndarray.
84-
85-
This should round-trip:
86-
self == self._from_backing_data(self._ndarray)
87-
"""
88-
raise AbstractMethodError(self)
89-
9076
def _box_func(self, x):
9177
"""
9278
Wrap numpy type in our dtype.type if necessary.
@@ -142,46 +128,6 @@ def _validate_fill_value(self, fill_value):
142128

143129
# ------------------------------------------------------------------------
144130

145-
# TODO: make this a cache_readonly; for that to work we need to remove
146-
# the _index_data kludge in libreduction
147-
@property
148-
def shape(self) -> Shape:
149-
return self._ndarray.shape
150-
151-
def __len__(self) -> int:
152-
return self.shape[0]
153-
154-
@cache_readonly
155-
def ndim(self) -> int:
156-
return len(self.shape)
157-
158-
@cache_readonly
159-
def size(self) -> int:
160-
return self._ndarray.size
161-
162-
@cache_readonly
163-
def nbytes(self) -> int:
164-
return self._ndarray.nbytes
165-
166-
def reshape(
167-
self: NDArrayBackedExtensionArrayT, *args, **kwargs
168-
) -> NDArrayBackedExtensionArrayT:
169-
new_data = self._ndarray.reshape(*args, **kwargs)
170-
return self._from_backing_data(new_data)
171-
172-
def ravel(
173-
self: NDArrayBackedExtensionArrayT, *args, **kwargs
174-
) -> NDArrayBackedExtensionArrayT:
175-
new_data = self._ndarray.ravel(*args, **kwargs)
176-
return self._from_backing_data(new_data)
177-
178-
@property
179-
def T(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT:
180-
new_data = self._ndarray.T
181-
return self._from_backing_data(new_data)
182-
183-
# ------------------------------------------------------------------------
184-
185131
def equals(self, other) -> bool:
186132
if type(self) is not type(other):
187133
return False
@@ -208,24 +154,6 @@ def argmax(self, axis: int = 0, skipna: bool = True): # type:ignore[override]
208154
raise NotImplementedError
209155
return nargminmax(self, "argmax", axis=axis)
210156

211-
def copy(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT:
212-
new_data = self._ndarray.copy()
213-
return self._from_backing_data(new_data)
214-
215-
def repeat(
216-
self: NDArrayBackedExtensionArrayT, repeats, axis=None
217-
) -> NDArrayBackedExtensionArrayT:
218-
"""
219-
Repeat elements of an array.
220-
221-
See Also
222-
--------
223-
numpy.ndarray.repeat
224-
"""
225-
nv.validate_repeat((), {"axis": axis})
226-
new_data = self._ndarray.repeat(repeats, axis=axis)
227-
return self._from_backing_data(new_data)
228-
229157
def unique(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT:
230158
new_data = unique(self._ndarray)
231159
return self._from_backing_data(new_data)
@@ -418,18 +346,6 @@ def where(
418346
res_values = np.where(mask, self._ndarray, value)
419347
return self._from_backing_data(res_values)
420348

421-
def delete(
422-
self: NDArrayBackedExtensionArrayT, loc, axis: int = 0
423-
) -> NDArrayBackedExtensionArrayT:
424-
res_values = np.delete(self._ndarray, loc, axis=axis)
425-
return self._from_backing_data(res_values)
426-
427-
def swapaxes(
428-
self: NDArrayBackedExtensionArrayT, axis1, axis2
429-
) -> NDArrayBackedExtensionArrayT:
430-
res_values = self._ndarray.swapaxes(axis1, axis2)
431-
return self._from_backing_data(res_values)
432-
433349
# ------------------------------------------------------------------------
434350
# Additional array methods
435351
# These are not part of the EA API, but we implement them because

pandas/core/arrays/categorical.py

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
algos as libalgos,
2828
hashtable as htable,
2929
)
30+
from pandas._libs.arrays import NDArrayBacked
3031
from pandas._libs.lib import no_default
3132
from pandas._typing import (
3233
ArrayLike,
@@ -349,12 +350,13 @@ class Categorical(NDArrayBackedExtensionArray, PandasObject, ObjectStringArrayMi
349350
# For comparisons, so that numpy uses our implementation if the compare
350351
# ops, which raise
351352
__array_priority__ = 1000
352-
_dtype = CategoricalDtype(ordered=False)
353353
# tolist is not actually deprecated, just suppressed in the __dir__
354354
_hidden_attrs = PandasObject._hidden_attrs | frozenset(["tolist"])
355355
_typ = "categorical"
356356
_can_hold_na = True
357357

358+
_dtype: CategoricalDtype
359+
358360
def __init__(
359361
self,
360362
values,
@@ -373,8 +375,9 @@ def __init__(
373375
# infer categories in a factorization step further below
374376

375377
if fastpath:
376-
self._ndarray = coerce_indexer_dtype(values, dtype.categories)
377-
self._dtype = self._dtype.update_dtype(dtype)
378+
codes = coerce_indexer_dtype(values, dtype.categories)
379+
dtype = CategoricalDtype(ordered=False).update_dtype(dtype)
380+
super().__init__(codes, dtype)
378381
return
379382

380383
if not is_list_like(values):
@@ -463,8 +466,11 @@ def __init__(
463466
full_codes[~null_mask] = codes
464467
codes = full_codes
465468

466-
self._dtype = self._dtype.update_dtype(dtype)
467-
self._ndarray = coerce_indexer_dtype(codes, dtype.categories)
469+
dtype = CategoricalDtype(ordered=False).update_dtype(dtype)
470+
arr = coerce_indexer_dtype(codes, dtype.categories)
471+
# error: Argument 1 to "__init__" of "NDArrayBacked" has incompatible
472+
# type "Union[ExtensionArray, ndarray]"; expected "ndarray"
473+
super().__init__(arr, dtype) # type: ignore[arg-type]
468474

469475
@property
470476
def dtype(self) -> CategoricalDtype:
@@ -513,9 +519,7 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike:
513519
raise ValueError("Cannot convert float NaN to integer")
514520

515521
elif len(self.codes) == 0 or len(self.categories) == 0:
516-
# error: Incompatible types in assignment (expression has type "ndarray",
517-
# variable has type "Categorical")
518-
result = np.array( # type: ignore[assignment]
522+
result = np.array(
519523
self,
520524
dtype=dtype,
521525
copy=copy,
@@ -533,11 +537,7 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike:
533537
msg = f"Cannot cast {self.categories.dtype} dtype to {dtype}"
534538
raise ValueError(msg)
535539

536-
# error: Incompatible types in assignment (expression has type "ndarray",
537-
# variable has type "Categorical")
538-
result = take_nd( # type: ignore[assignment]
539-
new_cats, ensure_platform_int(self._codes)
540-
)
540+
result = take_nd(new_cats, ensure_platform_int(self._codes))
541541

542542
return result
543543

@@ -745,7 +745,7 @@ def categories(self, categories):
745745
"new categories need to have the same number of "
746746
"items as the old categories!"
747747
)
748-
self._dtype = new_dtype
748+
super().__init__(self._ndarray, new_dtype)
749749

750750
@property
751751
def ordered(self) -> Ordered:
@@ -809,7 +809,7 @@ def _set_categories(self, categories, fastpath=False):
809809
"items than the old categories!"
810810
)
811811

812-
self._dtype = new_dtype
812+
super().__init__(self._ndarray, new_dtype)
813813

814814
def _set_dtype(self, dtype: CategoricalDtype) -> Categorical:
815815
"""
@@ -842,7 +842,7 @@ def set_ordered(self, value, inplace=False):
842842
inplace = validate_bool_kwarg(inplace, "inplace")
843843
new_dtype = CategoricalDtype(self.categories, ordered=value)
844844
cat = self if inplace else self.copy()
845-
cat._dtype = new_dtype
845+
NDArrayBacked.__init__(cat, cat._ndarray, new_dtype)
846846
if not inplace:
847847
return cat
848848

@@ -961,12 +961,12 @@ def set_categories(
961961
):
962962
# remove all _codes which are larger and set to -1/NaN
963963
cat._codes[cat._codes >= len(new_dtype.categories)] = -1
964+
codes = cat._codes
964965
else:
965966
codes = recode_for_categories(
966967
cat.codes, cat.categories, new_dtype.categories
967968
)
968-
cat._ndarray = codes
969-
cat._dtype = new_dtype
969+
NDArrayBacked.__init__(cat, codes, new_dtype)
970970

971971
if not inplace:
972972
return cat
@@ -1182,8 +1182,8 @@ def add_categories(self, new_categories, inplace=no_default):
11821182
new_dtype = CategoricalDtype(new_categories, self.ordered)
11831183

11841184
cat = self if inplace else self.copy()
1185-
cat._dtype = new_dtype
1186-
cat._ndarray = coerce_indexer_dtype(cat._ndarray, new_dtype.categories)
1185+
codes = coerce_indexer_dtype(cat._ndarray, new_dtype.categories)
1186+
NDArrayBacked.__init__(cat, codes, new_dtype)
11871187
if not inplace:
11881188
return cat
11891189

@@ -1303,9 +1303,8 @@ def remove_unused_categories(self, inplace=no_default):
13031303
new_dtype = CategoricalDtype._from_fastpath(
13041304
new_categories, ordered=self.ordered
13051305
)
1306-
cat._dtype = new_dtype
1307-
cat._ndarray = coerce_indexer_dtype(inv, new_dtype.categories)
1308-
1306+
new_codes = coerce_indexer_dtype(inv, new_dtype.categories)
1307+
NDArrayBacked.__init__(cat, new_codes, new_dtype)
13091308
if not inplace:
13101309
return cat
13111310

@@ -1484,7 +1483,7 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
14841483
def __setstate__(self, state):
14851484
"""Necessary for making this object picklable"""
14861485
if not isinstance(state, dict):
1487-
raise Exception("invalid pickle state")
1486+
return super().__setstate__(state)
14881487

14891488
if "_dtype" not in state:
14901489
state["_dtype"] = CategoricalDtype(state["_categories"], state["_ordered"])
@@ -1493,8 +1492,7 @@ def __setstate__(self, state):
14931492
# backward compat, changed what is property vs attribute
14941493
state["_ndarray"] = state.pop("_codes")
14951494

1496-
for k, v in state.items():
1497-
setattr(self, k, v)
1495+
super().__setstate__(state)
14981496

14991497
@property
15001498
def nbytes(self) -> int:
@@ -1863,16 +1861,7 @@ def _codes(self) -> np.ndarray:
18631861

18641862
@_codes.setter
18651863
def _codes(self, value: np.ndarray):
1866-
self._ndarray = value
1867-
1868-
def _from_backing_data(self, arr: np.ndarray) -> Categorical:
1869-
assert isinstance(arr, np.ndarray)
1870-
assert arr.dtype == self._ndarray.dtype
1871-
1872-
res = object.__new__(type(self))
1873-
res._ndarray = arr
1874-
res._dtype = self.dtype
1875-
return res
1864+
NDArrayBacked.__init__(self, value, self.dtype)
18761865

18771866
def _box_func(self, i: int):
18781867
if i == -1:

0 commit comments

Comments
 (0)