Skip to content

API: CategoricalDtype.__eq__ with categories=None stricter #38516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ See :ref:`install.dependencies` and :ref:`install.optional_dependencies` for mor

Other API changes
^^^^^^^^^^^^^^^^^

- Partially initialized :class:`CategoricalDtype` (i.e. those with ``categories=None`` objects will no longer compare as equal to fully initialized dtype objects.
-
-

Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
is_scalar,
is_timedelta64_dtype,
needs_i8_conversion,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import CategoricalDtype
from pandas.core.dtypes.generic import ABCIndex, ABCSeries
Expand Down Expand Up @@ -409,6 +410,7 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike:
If copy is set to False and dtype is categorical, the original
object is returned.
"""
dtype = pandas_dtype(dtype)
if self.dtype is dtype:
result = self.copy() if copy else self

Expand Down
13 changes: 13 additions & 0 deletions pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,19 @@ def is_dtype_equal(source, target) -> bool:
>>> is_dtype_equal(DatetimeTZDtype(tz="UTC"), "datetime64")
False
"""
if isinstance(target, str):
if not isinstance(source, str):
# GH#38516 ensure we get the same behavior from
# is_dtype_equal(CDT, "category") and CDT == "category"
try:
src = get_dtype(source)
if isinstance(src, ExtensionDtype):
return src == target
except (TypeError, AttributeError):
return False
elif isinstance(source, str):
return is_dtype_equal(target, source)

try:
source = get_dtype(source)
target = get_dtype(target)
Expand Down
10 changes: 4 additions & 6 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,12 +354,10 @@ def __eq__(self, other: Any) -> bool:
elif not (hasattr(other, "ordered") and hasattr(other, "categories")):
return False
elif self.categories is None or other.categories is None:
# We're forced into a suboptimal corner thanks to math and
# backwards compatibility. We require that `CDT(...) == 'category'`
# for all CDTs **including** `CDT(None, ...)`. Therefore, *all*
# CDT(., .) = CDT(None, False) and *all*
# CDT(., .) = CDT(None, True).
return True
# For non-fully-initialized dtypes, these are only equal to
# - the string "category" (handled above)
# - other CategoricalDtype with categories=None
return self.categories is other.categories
elif self.ordered or other.ordered:
# At least one has ordered=True; equal if both have ordered=True
# and the same values for categories in the same order.
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arrays/categorical/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_astype(self, ordered):
expected = np.array(cat)
tm.assert_numpy_array_equal(result, expected)

msg = r"Cannot cast object dtype to <class 'float'>"
msg = r"Cannot cast object dtype to float64"
with pytest.raises(ValueError, match=msg):
cat.astype(float)

Expand Down
1 change: 0 additions & 1 deletion pandas/tests/dtypes/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,6 @@ def test_is_complex_dtype():
(pd.CategoricalIndex(["a", "b"]).dtype, CategoricalDtype(["a", "b"])),
(pd.CategoricalIndex(["a", "b"]), CategoricalDtype(["a", "b"])),
(CategoricalDtype(), CategoricalDtype()),
(CategoricalDtype(["a", "b"]), CategoricalDtype()),
(pd.DatetimeIndex([1, 2]), np.dtype("=M8[ns]")),
(pd.DatetimeIndex([1, 2]).dtype, np.dtype("=M8[ns]")),
("<M8[ns]", np.dtype("<M8[ns]")),
Expand Down
32 changes: 30 additions & 2 deletions pandas/tests/dtypes/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,20 @@ def test_hash_vs_equality(self, dtype):
assert hash(dtype) == hash(dtype2)

def test_equality(self, dtype):
assert dtype == "category"
assert is_dtype_equal(dtype, "category")
assert "category" == dtype
assert is_dtype_equal("category", dtype)

assert dtype == CategoricalDtype()
assert is_dtype_equal(dtype, CategoricalDtype())
assert CategoricalDtype() == dtype
assert is_dtype_equal(CategoricalDtype(), dtype)

assert dtype != "foo"
assert not is_dtype_equal(dtype, "foo")
assert "foo" != dtype
assert not is_dtype_equal("foo", dtype)

def test_construction_from_string(self, dtype):
result = CategoricalDtype.construct_from_string("category")
Expand Down Expand Up @@ -834,10 +845,27 @@ def test_categorical_equality(self, ordered1, ordered2):
c1 = CategoricalDtype(list("abc"), ordered1)
c2 = CategoricalDtype(None, ordered2)
c3 = CategoricalDtype(None, ordered1)
assert c1 == c2
assert c2 == c1
assert c1 != c2
assert c2 != c1
assert c2 == c3

def test_categorical_dtype_equality_requires_categories(self):
# CategoricalDtype with categories=None is *not* equal to
# any fully-initialized CategoricalDtype
first = CategoricalDtype(["a", "b"])
second = CategoricalDtype()
third = CategoricalDtype(ordered=True)

assert second == second
assert third == third

assert first != second
assert second != first
assert first != third
assert third != first
assert second == third
assert third == second

@pytest.mark.parametrize("categories", [list("abc"), None])
@pytest.mark.parametrize("other", ["category", "not a category"])
def test_categorical_equality_strings(self, categories, ordered, other):
Expand Down
14 changes: 11 additions & 3 deletions pandas/tests/reshape/merge/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,7 +1622,7 @@ def test_identical(self, left):
merged = pd.merge(left, left, on="X")
result = merged.dtypes.sort_index()
expected = Series(
[CategoricalDtype(), np.dtype("O"), np.dtype("O")],
[CategoricalDtype(categories=["foo", "bar"]), np.dtype("O"), np.dtype("O")],
index=["X", "Y_x", "Y_y"],
)
tm.assert_series_equal(result, expected)
Expand All @@ -1633,7 +1633,11 @@ def test_basic(self, left, right):
merged = pd.merge(left, right, on="X")
result = merged.dtypes.sort_index()
expected = Series(
[CategoricalDtype(), np.dtype("O"), np.dtype("int64")],
[
CategoricalDtype(categories=["foo", "bar"]),
np.dtype("O"),
np.dtype("int64"),
],
index=["X", "Y", "Z"],
)
tm.assert_series_equal(result, expected)
Expand Down Expand Up @@ -1713,7 +1717,11 @@ def test_other_columns(self, left, right):
merged = pd.merge(left, right, on="X")
result = merged.dtypes.sort_index()
expected = Series(
[CategoricalDtype(), np.dtype("O"), CategoricalDtype()],
[
CategoricalDtype(categories=["foo", "bar"]),
np.dtype("O"),
CategoricalDtype(categories=[1, 2]),
],
index=["X", "Y", "Z"],
)
tm.assert_series_equal(result, expected)
Expand Down