Skip to content

API: CategoricalDtype.__eq__ with categories=None stricter #38516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 22, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,12 +354,10 @@ def __eq__(self, other: Any) -> bool:
elif not (hasattr(other, "ordered") and hasattr(other, "categories")):
return False
elif self.categories is None or other.categories is None:
# We're forced into a suboptimal corner thanks to math and
# backwards compatibility. We require that `CDT(...) == 'category'`
# for all CDTs **including** `CDT(None, ...)`. Therefore, *all*
# CDT(., .) = CDT(None, False) and *all*
# CDT(., .) = CDT(None, True).
return True
# For non-fully-initialized dtypes, these are only equal to
# - the string "categorical" (handled above)
# - other CategoricalDtype with categories=None
return self.categories is other.categories
elif self.ordered or other.ordered:
# At least one has ordered=True; equal if both have ordered=True
# and the same values for categories in the same order.
Expand Down
5 changes: 3 additions & 2 deletions pandas/tests/dtypes/cast/test_infer_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pandas import (
Categorical,
CategoricalDtype,
Interval,
Period,
Series,
Expand Down Expand Up @@ -153,8 +154,8 @@ def test_infer_dtype_from_scalar_errors():
(np.array([[1.0, 2.0]]), np.float_, False),
(Categorical(list("aabc")), np.object_, False),
(Categorical([1, 2, 3]), np.int64, False),
(Categorical(list("aabc")), "category", True),
(Categorical([1, 2, 3]), "category", True),
(Categorical(list("aabc")), CategoricalDtype(categories=["a", "b", "c"]), True),
(Categorical([1, 2, 3]), CategoricalDtype(categories=[1, 2, 3]), True),
(Timestamp("20160101"), np.object_, False),
(np.datetime64("2016-01-01"), np.dtype("=M8[D]"), False),
(date_range("20160101", periods=3), np.dtype("=M8[ns]"), False),
Expand Down
1 change: 0 additions & 1 deletion pandas/tests/dtypes/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,6 @@ def test_is_complex_dtype():
(pd.CategoricalIndex(["a", "b"]).dtype, CategoricalDtype(["a", "b"])),
(pd.CategoricalIndex(["a", "b"]), CategoricalDtype(["a", "b"])),
(CategoricalDtype(), CategoricalDtype()),
(CategoricalDtype(["a", "b"]), CategoricalDtype()),
(pd.DatetimeIndex([1, 2]), np.dtype("=M8[ns]")),
(pd.DatetimeIndex([1, 2]).dtype, np.dtype("=M8[ns]")),
("<M8[ns]", np.dtype("<M8[ns]")),
Expand Down
21 changes: 19 additions & 2 deletions pandas/tests/dtypes/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,10 +834,27 @@ def test_categorical_equality(self, ordered1, ordered2):
c1 = CategoricalDtype(list("abc"), ordered1)
c2 = CategoricalDtype(None, ordered2)
c3 = CategoricalDtype(None, ordered1)
assert c1 == c2
assert c2 == c1
assert c1 != c2
assert c2 != c1
assert c2 == c3

def test_categorical_dtype_equality_requires_categories(self):
# CategoricalDtype with categories=None is *not* equal to
# any fully-initialized CategoricalDtype
first = CategoricalDtype(["a", "b"])
second = CategoricalDtype()
third = CategoricalDtype(ordered=True)

assert second == second
assert third == third

assert first != second
assert second != first
assert first != third
assert third != first
assert second == third
assert third == second

@pytest.mark.parametrize("categories", [list("abc"), None])
@pytest.mark.parametrize("other", ["category", "not a category"])
def test_categorical_equality_strings(self, categories, ordered, other):
Expand Down
14 changes: 11 additions & 3 deletions pandas/tests/reshape/merge/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,7 +1622,7 @@ def test_identical(self, left):
merged = pd.merge(left, left, on="X")
result = merged.dtypes.sort_index()
expected = Series(
[CategoricalDtype(), np.dtype("O"), np.dtype("O")],
[CategoricalDtype(categories=["foo", "bar"]), np.dtype("O"), np.dtype("O")],
index=["X", "Y_x", "Y_y"],
)
tm.assert_series_equal(result, expected)
Expand All @@ -1633,7 +1633,11 @@ def test_basic(self, left, right):
merged = pd.merge(left, right, on="X")
result = merged.dtypes.sort_index()
expected = Series(
[CategoricalDtype(), np.dtype("O"), np.dtype("int64")],
[
CategoricalDtype(categories=["foo", "bar"]),
np.dtype("O"),
np.dtype("int64"),
],
index=["X", "Y", "Z"],
)
tm.assert_series_equal(result, expected)
Expand Down Expand Up @@ -1713,7 +1717,11 @@ def test_other_columns(self, left, right):
merged = pd.merge(left, right, on="X")
result = merged.dtypes.sort_index()
expected = Series(
[CategoricalDtype(), np.dtype("O"), CategoricalDtype()],
[
CategoricalDtype(categories=["foo", "bar"]),
np.dtype("O"),
CategoricalDtype(categories=[1, 2]),
],
index=["X", "Y", "Z"],
)
tm.assert_series_equal(result, expected)
Expand Down