Skip to content

Support structural subtyping with ExtensionArray #57634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ enhancement2

Other enhancements
^^^^^^^^^^^^^^^^^^
- :class:`ExtensionArray` now supports the structural subtyping (:issue:`57633`)
- :func:`DataFrame.to_excel` now raises an ``UserWarning`` when the character count in a cell exceeds Excel's limitation of 32767 characters (:issue:`56954`)
- :func:`read_stata` now returns ``datetime64`` resolutions better matching those natively stored in the stata format (:issue:`55642`)
- :meth:`Styler.set_tooltips` provides alternative method to storing tooltips by using title attribute of td elements. (:issue:`56981`)
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
Callable,
ClassVar,
Literal,
Protocol,
cast,
overload,
runtime_checkable,
)
import warnings

Expand Down Expand Up @@ -108,7 +110,8 @@
_extension_array_shared_docs: dict[str, str] = {}


class ExtensionArray:
@runtime_checkable
class ExtensionArray(Protocol):
"""
Abstract base class for custom 1-D array types.

Expand Down
3 changes: 2 additions & 1 deletion pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ def _validate(self) -> None:
# Ravel if ndims > 2 b/c no cythonized version available
lib.convert_nans_to_NA(self._ndarray.ravel("K"))
else:
lib.convert_nans_to_NA(self._ndarray)
if self._ndarray.flags["WRITEABLE"]: # TODO: not this
lib.convert_nans_to_NA(self._ndarray)

@classmethod
def _from_sequence(
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from typing import (
TYPE_CHECKING,
Any,
Protocol,
TypeVar,
cast,
overload,
runtime_checkable,
)

import numpy as np
Expand Down Expand Up @@ -40,7 +42,8 @@
ExtensionDtypeT = TypeVar("ExtensionDtypeT", bound="ExtensionDtype")


class ExtensionDtype:
@runtime_checkable
class ExtensionDtype(Protocol):
"""
A custom data type, to be paired with an ExtensionArray.

Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/extension/array_with_attr/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,7 @@ def _concat_same_type(cls, to_concat):
data = np.concatenate([x.data for x in to_concat])
attr = to_concat[0].attr if len(to_concat) else None
return cls(data, attr)

@property
def nbytes(self):
pass
9 changes: 9 additions & 0 deletions pandas/tests/extension/date/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import numpy as np

from pandas.compat import PY312

from pandas.core.dtypes.dtypes import register_extension_dtype

from pandas.api.extensions import (
Expand Down Expand Up @@ -125,6 +127,13 @@ def __init__(
def dtype(self) -> ExtensionDtype:
return DateDtype()

@property
def T(self):
# Python<3.12 protocol isinstance checks used hasattr(obj, T) which raises
if not PY312:
return None
return super().T

def astype(self, dtype, copy=True):
dtype = pandas_dtype(dtype)

Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/extension/list/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def _concat_same_type(cls, to_concat):
data = np.concatenate([x.data for x in to_concat])
return cls(data)

@property
def nbytes(self):
pass


def make_data():
# TODO: Use a regular dict. See _NDFrameIndexer._setitem_with_indexer
Expand Down
7 changes: 7 additions & 0 deletions pandas/tests/extension/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def astype(self, dtype, copy=True):

return np.array(self, dtype=dtype, copy=copy)

def isna(self):
pass

@property
def nbytes(self):
pass


class TestExtensionArrayDtype:
@pytest.mark.parametrize(
Expand Down
228 changes: 228 additions & 0 deletions pandas/tests/extension/test_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
from pandas.core.arrays import ExtensionArray


class ProtocolImplementingClass:
_typ = ...
__pandas_priority__ = ...
__hash__ = ...

@classmethod
def _from_sequence(cls, scalars, *, dtype, copy):
...

@classmethod
def _from_scalars(cls, scalars, *, dtype, copy):
...

@classmethod
def _from_sequence_of_strings(cls, strings, *, dtype, copy):
...

@classmethod
def _from_factorized(cls, values, original):
...

def __getitem__(self, item):
...

def __setitem__(self, key, value):
...

def __len__(self):
...

def __iter__(self):
...

def __contains__(self, item):
...

def __eq__(self, other):
...

def __ne__(self, other):
...

def to_numpy(self, dtype, copy, na_value):
...

@property
def dtype(self):
...

@property
def shape(self):
...

@property
def size(self):
...

@property
def ndim(self):
...

@property
def nbytes(self):
...

def astype(self, dtype, copy):
...

def isna(self):
...

@property
def _hasna(self):
...

def _values_for_argsort(self):
...

def argsort(self, *, ascending, kind, na_position, **kwargs):
...

def argmin(self, skipna):
...

def argmax(self, skipna):
...

def interpolate(
self, *, method, axis, index, limit, limit_direction, limit_area, copy, **kwargs
):
...

def _pad_or_backfill(self, *, method, limit, limit_area, copy):
...

def fillna(self, value, method, limit, copy):
...

def dropna(self):
...

def duplicated(self, keep):
...

def shift(self, periods, fill_value):
...

def unique(self):
...

def searchsorted(self, value, side, sorter):
...

def equals(self, other):
...

def isin(self, values):
...

def _values_for_factorize(self):
...

def factorize(self, use_na_sentintel):
...

def repeat(self, repeats, axis):
...

def take(self, indices, *, allow_fill, fill_value):
...

def copy(self):
...

def view(self, dtype):
...

def __repr__(self):
...

def _get_repr_footer(self):
...

def _repr_2d(self):
...

def _formatter(self, boxed):
...

def transpose(self, *axes):
...

@property
def T(self):
...

def ravel(self, order):
...

@classmethod
def _concat_same_type(cls, to_concat):
...

@property
def _can_hold_na(self):
...

def _accumulate(self, name, *, skipna, **kwargs):
...

def _reduce(self, name, *, skipna, keepdims, **kwargs):
...

def _values_for_json(self):
...

def _hash_pandas_object(self, *, encoding, hash_key, categorize):
...

def _explode(self):
...

def tolist(self):
...

def delete(self, loc):
...

def insert(self, loc, item):
...

def _putmask(self, mask, value):
...

def _where(self, mask, value):
...

def _fill_mask_inplace(self, method, limit, mask):
...

def _rank(self, *, axis, method, na_option, ascending, pct):
...

@classmethod
def _empty(cls, shape, dtype):
...

def _quantile(self, qs, interpolation):
...

def _mode(self, dropna):
...

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...

def map(self, mapper, na_action):
...

def _groupby_op(self, *, how, has_dropped_na, min_count, ngroups, ids, **kwargs):
...


def test_structural_subtyping_instance_check():
assert isinstance(ProtocolImplementingClass(), ExtensionArray)
7 changes: 7 additions & 0 deletions pandas/tests/frame/methods/test_select_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def __getitem__(self, item):
def copy(self):
return self

def isna(self):
pass

@property
def nbytes(self):
pass


class TestSelectDtypes:
def test_select_dtypes_include_using_list_like(self):
Expand Down