-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
BUG: made behavior of operator equal for CategoricalIndex consistent,… #10637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -396,6 +396,66 @@ def test_symmetric_diff(self): | |
with tm.assertRaisesRegexp(TypeError, msg): | ||
result = first.sym_diff([1, 2, 3]) | ||
|
||
def test_equals_op(self): | ||
# GH9947, GH10637 | ||
index_a = self.create_index() | ||
if isinstance(index_a, PeriodIndex): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, another way is to override this method in the TestPeriodIndex, just FYI (not sure which is better though) |
||
return | ||
|
||
n = len(index_a) | ||
index_b = index_a[0:-1] | ||
index_c = index_a[0:-1].append(index_a[-2:-1]) | ||
index_d = index_a[0:1] | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
index_a == index_b | ||
expected1 = np.array([True] * n) | ||
expected2 = np.array([True] * (n - 1) + [False]) | ||
assert_numpy_array_equivalent(index_a == index_a, expected1) | ||
assert_numpy_array_equivalent(index_a == index_c, expected2) | ||
|
||
# test comparisons with numpy arrays | ||
array_a = np.array(index_a) | ||
array_b = np.array(index_a[0:-1]) | ||
array_c = np.array(index_a[0:-1].append(index_a[-2:-1])) | ||
array_d = np.array(index_a[0:1]) | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
index_a == array_b | ||
assert_numpy_array_equivalent(index_a == array_a, expected1) | ||
assert_numpy_array_equivalent(index_a == array_c, expected2) | ||
|
||
# test comparisons with Series | ||
series_a = Series(array_a) | ||
series_b = Series(array_b) | ||
series_c = Series(array_c) | ||
series_d = Series(array_d) | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
index_a == series_b | ||
assert_numpy_array_equivalent(index_a == series_a, expected1) | ||
assert_numpy_array_equivalent(index_a == series_c, expected2) | ||
|
||
# cases where length is 1 for one of them | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
index_a == index_d | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
index_a == series_d | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
index_a == array_d | ||
with tm.assertRaisesRegexp(ValueError, "Series lengths must match"): | ||
series_a == series_d | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
series_a == array_d | ||
|
||
# comparing with a scalar should broadcast; note that we are excluding | ||
# MultiIndex because in this case each item in the index is a tuple of | ||
# length 2, and therefore is considered an array of length 2 in the | ||
# comparison instead of a scalar | ||
if not isinstance(index_a, MultiIndex): | ||
expected3 = np.array([False] * (len(index_a) - 2) + [True, False]) | ||
# assuming the 2nd to last item is unique in the data | ||
item = index_a[-2] | ||
assert_numpy_array_equivalent(index_a == item, expected3) | ||
assert_numpy_array_equivalent(series_a == item, expected3) | ||
|
||
|
||
class TestIndex(Base, tm.TestCase): | ||
_holder = Index | ||
|
@@ -1548,54 +1608,7 @@ def test_groupby(self): | |
exp = {1: [0, 1], 2: [2, 3, 4]} | ||
tm.assert_dict_equal(groups, exp) | ||
|
||
def test_equals_op(self): | ||
# GH9947 | ||
index_a = Index(['foo', 'bar', 'baz']) | ||
index_b = Index(['foo', 'bar', 'baz', 'qux']) | ||
index_c = Index(['foo', 'bar', 'qux']) | ||
index_d = Index(['foo']) | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
index_a == index_b | ||
assert_numpy_array_equivalent(index_a == index_a, np.array([True, True, True])) | ||
assert_numpy_array_equivalent(index_a == index_c, np.array([True, True, False])) | ||
|
||
# test comparisons with numpy arrays | ||
array_a = np.array(['foo', 'bar', 'baz']) | ||
array_b = np.array(['foo', 'bar', 'baz', 'qux']) | ||
array_c = np.array(['foo', 'bar', 'qux']) | ||
array_d = np.array(['foo']) | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
index_a == array_b | ||
assert_numpy_array_equivalent(index_a == array_a, np.array([True, True, True])) | ||
assert_numpy_array_equivalent(index_a == array_c, np.array([True, True, False])) | ||
|
||
# test comparisons with Series | ||
series_a = Series(['foo', 'bar', 'baz']) | ||
series_b = Series(['foo', 'bar', 'baz', 'qux']) | ||
series_c = Series(['foo', 'bar', 'qux']) | ||
series_d = Series(['foo']) | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
index_a == series_b | ||
assert_numpy_array_equivalent(index_a == series_a, np.array([True, True, True])) | ||
assert_numpy_array_equivalent(index_a == series_c, np.array([True, True, False])) | ||
|
||
# cases where length is 1 for one of them | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
index_a == index_d | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
index_a == series_d | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
index_a == array_d | ||
with tm.assertRaisesRegexp(ValueError, "Series lengths must match"): | ||
series_a == series_d | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
series_a == array_d | ||
|
||
# comparing with scalar should broadcast | ||
assert_numpy_array_equivalent(index_a == 'foo', np.array([True, False, False])) | ||
assert_numpy_array_equivalent(series_a == 'foo', np.array([True, False, False])) | ||
assert_numpy_array_equivalent(array_a == 'foo', np.array([True, False, False])) | ||
|
||
def test_equals_op_multiindex(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe this should move to the above? (e.g. compariing a random index to a multi-index)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah maybe that'd be easier than I expected, I'll give that a shot There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jreback I'm able to include everything works the same way except for one wrinkle, when testing the "broadcast when comparing with a scalar" rule:
Note that instead of broadcasting it thinks that That's why for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead include multiindex and special case that test (u can just to an isinstance MultiIndex to detect) and put your comment from above there There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure sounds good, it's updated |
||
# GH9785 | ||
# test comparisons of multiindex | ||
from pandas.compat import StringIO | ||
|
@@ -1609,6 +1622,8 @@ def test_equals_op(self): | |
mi3 = MultiIndex.from_tuples([(1, 2), (4, 5), (8, 9)]) | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
df.index == mi3 | ||
|
||
index_a = Index(['foo', 'bar', 'baz']) | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
df.index == index_a | ||
assert_numpy_array_equivalent(index_a == mi3, np.array([False, False, False])) | ||
|
@@ -1966,7 +1981,8 @@ def test_equals(self): | |
self.assertTrue((ci1 == ci1.values).all()) | ||
|
||
# invalid comparisons | ||
self.assertRaises(TypeError, lambda : ci1 == Index(['a','b','c'])) | ||
with tm.assertRaisesRegexp(ValueError, "Lengths must match"): | ||
ci1 == Index(['a','b','c']) | ||
self.assertRaises(TypeError, lambda : ci1 == ci2) | ||
self.assertRaises(TypeError, lambda : ci1 == Categorical(ci1.values, ordered=False)) | ||
self.assertRaises(TypeError, lambda : ci1 == Categorical(ci1.values, categories=list('abc'))) | ||
|
@@ -2082,7 +2098,7 @@ def setUp(self): | |
self.setup_indices() | ||
|
||
def create_index(self): | ||
return Float64Index(np.arange(5,dtype='float64')) | ||
return Float64Index(np.arange(5, dtype='float64')) | ||
|
||
def test_repr_roundtrip(self): | ||
for ind in (self.mixed, self.float): | ||
|
@@ -2253,7 +2269,7 @@ def setUp(self): | |
self.setup_indices() | ||
|
||
def create_index(self): | ||
return Int64Index(np.arange(5,dtype='int64')) | ||
return Int64Index(np.arange(5, dtype='int64')) | ||
|
||
def test_too_many_names(self): | ||
def testit(): | ||
|
@@ -2743,7 +2759,7 @@ def setUp(self): | |
self.setup_indices() | ||
|
||
def create_index(self): | ||
return date_range('20130101',periods=5) | ||
return date_range('20130101', periods=5) | ||
|
||
def test_pickle_compat_construction(self): | ||
pass | ||
|
@@ -2936,7 +2952,7 @@ def setUp(self): | |
self.setup_indices() | ||
|
||
def create_index(self): | ||
return pd.to_timedelta(range(5),unit='d') + pd.offsets.Hour(1) | ||
return pd.to_timedelta(range(5), unit='d') + pd.offsets.Hour(1) | ||
|
||
def test_get_loc(self): | ||
idx = pd.to_timedelta(['0 days', '1 days', '2 days']) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can show what hits this case now and that would fail this PR (but pass now)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jreback the unit test at line 1980 below hits this case.
Currently (without this PR) we have this:
With this PR we have:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh ok, so it effectively defers to the regular Index checking for the length comparison. ok. better error message then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah definitely a better error message, also note that it was
TypeError
and now it isValueError
like with other index typesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right
ValueError
is fine (thought heis_dtype_equal
will raise aTypeError
for invalid dtypes, which is correct as well)