Skip to content

ENH: allow EADtype to specify _supports_2d #54832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,33 @@ def index_class(self) -> type_t[Index]:

return Index

@property
def _supports_2d(self) -> bool:
"""
Do ExtensionArrays with this dtype support 2D arrays?

Historically ExtensionArrays were limited to 1D. By returning True here,
authors can indicate that their arrays support 2D instances. This can
improve performance in some cases, particularly operations with `axis=1`.

Arrays that support 2D values should:

- implement Array.reshape
- subclass the Dim2CompatTests in tests.extension.base
- _concat_same_type should support `axis` keyword
- _reduce and reductions should support `axis` keyword
"""
return False

@property
def _can_fast_transpose(self) -> bool:
"""
Is transposing an array with this dtype zero-copy?

Only relevant for cases where _supports_2d is True.
"""
return False


class StorageExtensionDtype(ExtensionDtype):
"""ExtensionDtype that may be backed by more than one implementation."""
Expand Down
8 changes: 1 addition & 7 deletions pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,13 +1256,7 @@ def is_1d_only_ea_dtype(dtype: DtypeObj | None) -> bool:
"""
Analogue to is_extension_array_dtype but excluding DatetimeTZDtype.
"""
# Note: if other EA dtypes are ever held in HybridBlock, exclude those
# here too.
# NB: need to check DatetimeTZDtype and not is_datetime64tz_dtype
# to exclude ArrowTimestampUSDtype
return isinstance(dtype, ExtensionDtype) and not isinstance(
dtype, (DatetimeTZDtype, PeriodDtype)
)
return isinstance(dtype, ExtensionDtype) and not dtype._supports_2d


def is_extension_array_dtype(arr_or_dtype) -> bool:
Expand Down
8 changes: 8 additions & 0 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
base = np.dtype("O")
_metadata = ("categories", "ordered")
_cache_dtypes: dict[str_type, PandasExtensionDtype] = {}
_supports_2d = False
_can_fast_transpose = False

def __init__(self, categories=None, ordered: Ordered = False) -> None:
self._finalize(categories, ordered, fastpath=False)
Expand Down Expand Up @@ -727,6 +729,8 @@ class DatetimeTZDtype(PandasExtensionDtype):
_metadata = ("unit", "tz")
_match = re.compile(r"(datetime64|M8)\[(?P<unit>.+), (?P<tz>.+)\]")
_cache_dtypes: dict[str_type, PandasExtensionDtype] = {}
_supports_2d = True
_can_fast_transpose = True

@property
def na_value(self) -> NaTType:
Expand Down Expand Up @@ -970,6 +974,8 @@ class PeriodDtype(PeriodDtypeBase, PandasExtensionDtype):
_cache_dtypes: dict[BaseOffset, int] = {} # type: ignore[assignment]
__hash__ = PeriodDtypeBase.__hash__
_freq: BaseOffset
_supports_2d = True
_can_fast_transpose = True

def __new__(cls, freq):
"""
Expand Down Expand Up @@ -1432,6 +1438,8 @@ class NumpyEADtype(ExtensionDtype):
"""

_metadata = ("_dtype",)
_supports_2d = False
_can_fast_transpose = False

def __init__(self, dtype: npt.DTypeLike | NumpyEADtype | None) -> None:
if isinstance(dtype, NumpyEADtype):
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/extension/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class ExtensionTests(
BaseReduceTests,
BaseReshapingTests,
BaseSetitemTests,
Dim2CompatTests,
):
pass

Expand Down
11 changes: 11 additions & 0 deletions pandas/tests/extension/base/dim2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ class Dim2CompatTests:
# Note: these are ONLY for ExtensionArray subclasses that support 2D arrays.
# i.e. not for pyarrow-backed EAs.

@pytest.fixture(autouse=True)
def skip_if_doesnt_support_2d(self, dtype, request):
if not dtype._supports_2d:
node = request.node
# In cases where we are mixed in to ExtensionTests, we only want to
# skip tests that are defined in Dim2CompatTests
test_func = node._obj
if test_func.__qualname__.startswith("Dim2CompatTests"):
# TODO: is there a less hacky way of checking this?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the tests in Dim2CompatTests just be skipped if data.dtype._supports_2d?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is skipping if not data.dtype.supports_2d. The acrobatics here are for test classes that subclass ExtensionTests, since we want this fixture to only apply to tests defined in Dim2CompatTests, not all tests in ExtensionTests

pytest.skip("Test is only for EAs that support 2D.")

def test_transpose(self, data):
arr2d = data.repeat(2).reshape(-1, 2)
shape = arr2d.shape
Expand Down