Skip to content

Commit 4a5ebea

Browse files
committed
more tests & fixes for non-unique / overlaps
rename _is_contained_in -> contains add sorting test
1 parent 340c98b commit 4a5ebea

File tree

10 files changed

+284
-87
lines changed

10 files changed

+284
-87
lines changed

pandas/core/indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,7 +1429,7 @@ def error():
14291429

14301430
try:
14311431
key = self._convert_scalar_indexer(key, axis)
1432-
if not ax._is_contained_in(key):
1432+
if not ax.contains(key):
14331433
error()
14341434
except TypeError as e:
14351435

@@ -1897,7 +1897,7 @@ def convert_to_index_sliceable(obj, key):
18971897
elif isinstance(key, compat.string_types):
18981898

18991899
# we are an actual column
1900-
if obj._data.items._is_contained_in(key):
1900+
if obj._data.items.contains(key):
19011901
return None
19021902

19031903
# We might have a datetimelike string that we can translate to a

pandas/indexes/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,7 +1585,7 @@ def __contains__(self, key):
15851585
except TypeError:
15861586
return False
15871587

1588-
_index_shared_docs['_is_contained_in'] = """
1588+
_index_shared_docs['contains'] = """
15891589
return a boolean if this key is IN the index
15901590
15911591
Parameters
@@ -1597,8 +1597,8 @@ def __contains__(self, key):
15971597
boolean
15981598
"""
15991599

1600-
@Appender(_index_shared_docs['_is_contained_in'] % _index_doc_kwargs)
1601-
def _is_contained_in(self, key):
1600+
@Appender(_index_shared_docs['contains'] % _index_doc_kwargs)
1601+
def contains(self, key):
16021602
hash(key)
16031603
try:
16041604
return key in self._engine

pandas/indexes/category.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,12 @@ def __contains__(self, key):
271271

272272
return key in self.values
273273

274-
@Appender(_index_shared_docs['_is_contained_in'] % _index_doc_kwargs)
275-
def _is_contained_in(self, key):
274+
@Appender(_index_shared_docs['contains'] % _index_doc_kwargs)
275+
def contains(self, key):
276276
hash(key)
277277

278278
if self.categories._defer_to_indexing:
279-
return self.categories._is_contained_in(key)
279+
return self.categories.contains(key)
280280

281281
return key in self.values
282282

pandas/indexes/interval.py

Lines changed: 93 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def __contains__(self, key):
263263
except KeyError:
264264
return False
265265

266-
def _is_contained_in(self, key):
266+
def contains(self, key):
267267
"""
268268
return a boolean if this key is IN the index
269269
@@ -566,20 +566,31 @@ def _convert_list_indexer(self, keyarr, kind=None):
566566
indexer for matching intervals.
567567
"""
568568
locs = self.get_indexer_for(keyarr)
569-
check = locs == -1
570-
locs = locs[~check]
569+
570+
# we have missing values
571+
if (locs == -1).any():
572+
raise KeyError
573+
571574
return locs
572575

573576
def _maybe_cast_indexed(self, key):
574577
"""
575578
we need to cast the key, which could be a scalar
576579
or an array-like to the type of our subtype
577580
"""
578-
if is_float_dtype(self.dtype.subtype):
581+
if isinstance(key, IntervalIndex):
582+
return key
583+
584+
subtype = self.dtype.subtype
585+
if is_float_dtype(subtype):
579586
if is_integer(key):
580587
key = float(key)
581588
elif isinstance(key, (np.ndarray, Index)):
582589
key = key.astype('float64')
590+
elif is_integer_dtype(subtype):
591+
if is_integer(key):
592+
key = int(key)
593+
583594
return key
584595

585596
def _check_method(self, method):
@@ -616,6 +627,11 @@ def _searchsorted_monotonic(self, label, side, exclude_label=False):
616627

617628
def _get_loc_only_exact_matches(self, key):
618629
if isinstance(key, Interval):
630+
631+
if not self.is_unique:
632+
raise ValueError("cannot index with a slice Interval"
633+
" and a non-unique index")
634+
619635
# TODO: this expands to a tuple index, see if we can
620636
# do better
621637
return Index(self._multiindex.values).get_loc(key)
@@ -685,12 +701,28 @@ def get_value(self, series, key):
685701
loc = key
686702
elif is_list_like(key):
687703
loc = self.get_indexer(key)
704+
elif isinstance(key, slice):
705+
706+
if not (key.step is None or key.step == 1):
707+
raise ValueError("cannot support not-default "
708+
"step in a slice")
709+
710+
try:
711+
loc = self.get_loc(key)
712+
except TypeError:
713+
714+
# we didn't find exact intervals
715+
# or are non-unique
716+
raise ValueError("unable to slice with "
717+
"this key: {}".format(key))
718+
688719
else:
689720
loc = self.get_loc(key)
690721
return series.iloc[loc]
691722

692723
@Appender(_index_shared_docs['get_indexer'] % _index_doc_kwargs)
693724
def get_indexer(self, target, method=None, limit=None, tolerance=None):
725+
694726
self._check_method(method)
695727
target = _ensure_index(target)
696728
target = self._maybe_cast_indexed(target)
@@ -706,7 +738,22 @@ def get_indexer(self, target, method=None, limit=None, tolerance=None):
706738
return np.where(start_plus_one == stop, start, -1)
707739

708740
if not self.is_unique:
709-
raise ValueError("get_indexer cannot handle non-unique indices")
741+
raise ValueError("cannot handle non-unique indices")
742+
743+
# IntervalIndex
744+
if isinstance(target, IntervalIndex):
745+
indexer = self._get_reindexer(target)
746+
747+
# non IntervalIndex
748+
else:
749+
indexer = np.concatenate([self.get_loc(i) for i in target])
750+
751+
return _ensure_platform_int(indexer)
752+
753+
def _get_reindexer(self, target):
754+
"""
755+
Return an indexer for a target IntervalIndex with self
756+
"""
710757

711758
# find the left and right indexers
712759
lindexer = self._engine.get_indexer(target.left.values)
@@ -720,27 +767,59 @@ def get_indexer(self, target, method=None, limit=None, tolerance=None):
720767
indexer = []
721768
n = len(self)
722769

723-
for l, r in zip(lindexer, rindexer):
770+
for i, (l, r) in enumerate(zip(lindexer, rindexer)):
771+
772+
target_value = target[i]
773+
774+
# matching on the lhs bound
775+
if (l != -1 and
776+
self.closed == 'right' and
777+
target_value.left == self[l].right):
778+
l += 1
779+
780+
# matching on the lhs bound
781+
if (r != -1 and
782+
self.closed == 'left' and
783+
target_value.right == self[r].left):
784+
r -= 1
724785

725786
# not found
726787
if l == -1 and r == -1:
727788
indexer.append(np.array([-1]))
728789

729790
elif r == -1:
791+
730792
indexer.append(np.arange(l, n))
731793

732794
elif l == -1:
733-
if r == 0:
734-
indexer.append(np.array([-1]))
735-
else:
736-
indexer.append(np.arange(0, r + 1))
737795

738-
else:
739-
indexer.append(np.arange(l, r))
796+
# care about left/right closed here
797+
value = self[i]
740798

741-
indexer = np.concatenate(indexer)
799+
# target.closed same as self.closed
800+
if self.closed == target.closed:
801+
if target_value.left < value.left:
802+
indexer.append(np.array([-1]))
803+
continue
742804

743-
return _ensure_platform_int(indexer)
805+
# target.closed == 'left'
806+
elif self.closed == 'right':
807+
if target_value.left <= value.left:
808+
indexer.append(np.array([-1]))
809+
continue
810+
811+
# target.closed == 'right'
812+
elif self.closed == 'left':
813+
if target_value.left <= value.left:
814+
indexer.append(np.array([-1]))
815+
continue
816+
817+
indexer.append(np.arange(0, r + 1))
818+
819+
else:
820+
indexer.append(np.arange(l, r + 1))
821+
822+
return np.concatenate(indexer)
744823

745824
@Appender(_index_shared_docs['get_indexer_non_unique'] % _index_doc_kwargs)
746825
def get_indexer_non_unique(self, target):

pandas/indexes/multi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1327,7 +1327,7 @@ def __contains__(self, key):
13271327
except LookupError:
13281328
return False
13291329

1330-
_is_contained_in = __contains__
1330+
contains = __contains__
13311331

13321332
def __reduce__(self):
13331333
"""Necessary for making this object picklable"""

pandas/tests/indexes/test_interval.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -430,13 +430,12 @@ def test_get_indexer(self):
430430
self.assert_numpy_array_equal(actual, expected)
431431

432432
actual = self.index.get_indexer(index)
433-
expected = np.array([-1, 0], dtype='int64')
433+
expected = np.array([-1, 1], dtype='int64')
434434
self.assert_numpy_array_equal(actual, expected)
435435

436-
@pytest.mark.xfail(reason="what to return for overlaps")
437436
def test_get_indexer_subintervals(self):
438-
# TODO
439437

438+
# TODO: is this right?
440439
# return indexers for wholly contained subintervals
441440
target = IntervalIndex.from_breaks(np.linspace(0, 2, 5))
442441
actual = self.index.get_indexer(target)
@@ -445,7 +444,7 @@ def test_get_indexer_subintervals(self):
445444

446445
target = IntervalIndex.from_breaks([0, 0.67, 1.33, 2])
447446
actual = self.index.get_indexer(target)
448-
expected = np.array([-1, 0, 1], dtype='int64')
447+
expected = np.array([0, 0, 1, 1], dtype='int64')
449448
self.assert_numpy_array_equal(actual, expected)
450449

451450
actual = self.index.get_indexer(target[[0, -1]])
@@ -473,22 +472,22 @@ def test_contains(self):
473472
self.assertNotIn(Interval(3, 5), i)
474473
self.assertNotIn(Interval(-1, 0, closed='left'), i)
475474

476-
def test_is_contained_in(self):
475+
def testcontains(self):
477476
# can select values that are IN the range of a value
478477
i = IntervalIndex.from_arrays([0, 1], [1, 2])
479478

480-
assert i._is_contained_in(0.1)
481-
assert i._is_contained_in(0.5)
482-
assert i._is_contained_in(1)
483-
assert i._is_contained_in(Interval(0, 1))
484-
assert i._is_contained_in(Interval(0, 2))
479+
assert i.contains(0.1)
480+
assert i.contains(0.5)
481+
assert i.contains(1)
482+
assert i.contains(Interval(0, 1))
483+
assert i.contains(Interval(0, 2))
485484

486485
# these overlaps completely
487-
assert i._is_contained_in(Interval(0, 3))
488-
assert i._is_contained_in(Interval(1, 3))
486+
assert i.contains(Interval(0, 3))
487+
assert i.contains(Interval(1, 3))
489488

490-
assert not i._is_contained_in(20)
491-
assert not i._is_contained_in(-20)
489+
assert not i.contains(20)
490+
assert not i.contains(-20)
492491

493492
def test_dropna(self):
494493

0 commit comments

Comments
 (0)