Skip to content

Commit 2a1ef8c

Browse files
jschendeljreback
authored andcommitted
BUG: Perform i8 conversion for datetimelike IntervalTree queries (#22988)
1 parent af271ba commit 2a1ef8c

File tree

3 files changed

+221
-7
lines changed

3 files changed

+221
-7
lines changed

doc/source/whatsnew/v0.24.0.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ Interval
757757
- Bug in the :class:`IntervalIndex` constructor where the ``closed`` parameter did not always override the inferred ``closed`` (:issue:`19370`)
758758
- Bug in the ``IntervalIndex`` repr where a trailing comma was missing after the list of intervals (:issue:`20611`)
759759
- Bug in :class:`Interval` where scalar arithmetic operations did not retain the ``closed`` value (:issue:`22313`)
760-
-
760+
- Bug in :class:`IntervalIndex` where indexing with datetime-like values raised a ``KeyError`` (:issue:`20636`)
761761

762762
Indexing
763763
^^^^^^^^

pandas/core/indexes/interval.py

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
from pandas.compat import add_metaclass
88
from pandas.core.dtypes.missing import isna
9-
from pandas.core.dtypes.cast import find_common_type, maybe_downcast_to_dtype
9+
from pandas.core.dtypes.cast import (
10+
find_common_type, maybe_downcast_to_dtype, infer_dtype_from_scalar)
1011
from pandas.core.dtypes.common import (
1112
ensure_platform_int,
1213
is_list_like,
1314
is_datetime_or_timedelta_dtype,
1415
is_datetime64tz_dtype,
16+
is_dtype_equal,
1517
is_integer_dtype,
1618
is_float_dtype,
1719
is_interval_dtype,
@@ -29,8 +31,8 @@
2931
Interval, IntervalMixin, IntervalTree,
3032
)
3133

32-
from pandas.core.indexes.datetimes import date_range
33-
from pandas.core.indexes.timedeltas import timedelta_range
34+
from pandas.core.indexes.datetimes import date_range, DatetimeIndex
35+
from pandas.core.indexes.timedeltas import timedelta_range, TimedeltaIndex
3436
from pandas.core.indexes.multi import MultiIndex
3537
import pandas.core.common as com
3638
from pandas.util._decorators import cache_readonly, Appender
@@ -192,7 +194,9 @@ def _isnan(self):
192194

193195
@cache_readonly
194196
def _engine(self):
195-
return IntervalTree(self.left, self.right, closed=self.closed)
197+
left = self._maybe_convert_i8(self.left)
198+
right = self._maybe_convert_i8(self.right)
199+
return IntervalTree(left, right, closed=self.closed)
196200

197201
def __contains__(self, key):
198202
"""
@@ -514,6 +518,78 @@ def _maybe_cast_indexed(self, key):
514518

515519
return key
516520

521+
def _needs_i8_conversion(self, key):
522+
"""
523+
Check if a given key needs i8 conversion. Conversion is necessary for
524+
Timestamp, Timedelta, DatetimeIndex, and TimedeltaIndex keys. An
525+
Interval-like requires conversion if it's endpoints are one of the
526+
aforementioned types.
527+
528+
Assumes that any list-like data has already been cast to an Index.
529+
530+
Parameters
531+
----------
532+
key : scalar or Index-like
533+
The key that should be checked for i8 conversion
534+
535+
Returns
536+
-------
537+
boolean
538+
"""
539+
if is_interval_dtype(key) or isinstance(key, Interval):
540+
return self._needs_i8_conversion(key.left)
541+
542+
i8_types = (Timestamp, Timedelta, DatetimeIndex, TimedeltaIndex)
543+
return isinstance(key, i8_types)
544+
545+
def _maybe_convert_i8(self, key):
546+
"""
547+
Maybe convert a given key to it's equivalent i8 value(s). Used as a
548+
preprocessing step prior to IntervalTree queries (self._engine), which
549+
expects numeric data.
550+
551+
Parameters
552+
----------
553+
key : scalar or list-like
554+
The key that should maybe be converted to i8.
555+
556+
Returns
557+
-------
558+
key: scalar or list-like
559+
The original key if no conversion occured, int if converted scalar,
560+
Int64Index if converted list-like.
561+
"""
562+
original = key
563+
if is_list_like(key):
564+
key = ensure_index(key)
565+
566+
if not self._needs_i8_conversion(key):
567+
return original
568+
569+
scalar = is_scalar(key)
570+
if is_interval_dtype(key) or isinstance(key, Interval):
571+
# convert left/right and reconstruct
572+
left = self._maybe_convert_i8(key.left)
573+
right = self._maybe_convert_i8(key.right)
574+
constructor = Interval if scalar else IntervalIndex.from_arrays
575+
return constructor(left, right, closed=self.closed)
576+
577+
if scalar:
578+
# Timestamp/Timedelta
579+
key_dtype, key_i8 = infer_dtype_from_scalar(key, pandas_dtype=True)
580+
else:
581+
# DatetimeIndex/TimedeltaIndex
582+
key_dtype, key_i8 = key.dtype, Index(key.asi8)
583+
584+
# ensure consistency with IntervalIndex subtype
585+
subtype = self.dtype.subtype
586+
msg = ('Cannot index an IntervalIndex of subtype {subtype} with '
587+
'values of dtype {other}')
588+
if not is_dtype_equal(subtype, key_dtype):
589+
raise ValueError(msg.format(subtype=subtype, other=key_dtype))
590+
591+
return key_i8
592+
517593
def _check_method(self, method):
518594
if method is None:
519595
return
@@ -648,6 +724,7 @@ def get_loc(self, key, method=None):
648724

649725
else:
650726
# use the interval tree
727+
key = self._maybe_convert_i8(key)
651728
if isinstance(key, Interval):
652729
left, right = _get_interval_closed_bounds(key)
653730
return self._engine.get_loc_interval(left, right)
@@ -711,8 +788,10 @@ def _get_reindexer(self, target):
711788
"""
712789

713790
# find the left and right indexers
714-
lindexer = self._engine.get_indexer(target.left.values)
715-
rindexer = self._engine.get_indexer(target.right.values)
791+
left = self._maybe_convert_i8(target.left)
792+
right = self._maybe_convert_i8(target.right)
793+
lindexer = self._engine.get_indexer(left.values)
794+
rindexer = self._engine.get_indexer(right.values)
716795

717796
# we want to return an indexer on the intervals
718797
# however, our keys could provide overlapping of multiple

pandas/tests/indexes/interval/test_interval.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import division
22

3+
from itertools import permutations
34
import pytest
45
import numpy as np
6+
import re
57
from pandas import (
68
Interval, IntervalIndex, Index, isna, notna, interval_range, Timestamp,
79
Timedelta, date_range, timedelta_range)
@@ -498,6 +500,48 @@ def test_get_loc_length_one(self, item, closed):
498500
result = index.get_loc(item)
499501
assert result == 0
500502

503+
# Make consistent with test_interval_new.py (see #16316, #16386)
504+
@pytest.mark.parametrize('breaks', [
505+
date_range('20180101', periods=4),
506+
date_range('20180101', periods=4, tz='US/Eastern'),
507+
timedelta_range('0 days', periods=4)], ids=lambda x: str(x.dtype))
508+
def test_get_loc_datetimelike_nonoverlapping(self, breaks):
509+
# GH 20636
510+
# nonoverlapping = IntervalIndex method and no i8 conversion
511+
index = IntervalIndex.from_breaks(breaks)
512+
513+
value = index[0].mid
514+
result = index.get_loc(value)
515+
expected = 0
516+
assert result == expected
517+
518+
interval = Interval(index[0].left, index[1].right)
519+
result = index.get_loc(interval)
520+
expected = slice(0, 2)
521+
assert result == expected
522+
523+
# Make consistent with test_interval_new.py (see #16316, #16386)
524+
@pytest.mark.parametrize('arrays', [
525+
(date_range('20180101', periods=4), date_range('20180103', periods=4)),
526+
(date_range('20180101', periods=4, tz='US/Eastern'),
527+
date_range('20180103', periods=4, tz='US/Eastern')),
528+
(timedelta_range('0 days', periods=4),
529+
timedelta_range('2 days', periods=4))], ids=lambda x: str(x[0].dtype))
530+
def test_get_loc_datetimelike_overlapping(self, arrays):
531+
# GH 20636
532+
# overlapping = IntervalTree method with i8 conversion
533+
index = IntervalIndex.from_arrays(*arrays)
534+
535+
value = index[0].mid + Timedelta('12 hours')
536+
result = np.sort(index.get_loc(value))
537+
expected = np.array([0, 1], dtype='int64')
538+
assert tm.assert_numpy_array_equal(result, expected)
539+
540+
interval = Interval(index[0].left, index[1].right)
541+
result = np.sort(index.get_loc(interval))
542+
expected = np.array([0, 1, 2], dtype='int64')
543+
assert tm.assert_numpy_array_equal(result, expected)
544+
501545
# To be removed, replaced by test_interval_new.py (see #16316, #16386)
502546
def test_get_indexer(self):
503547
actual = self.index.get_indexer([-1, 0, 0.5, 1, 1.5, 2, 3])
@@ -555,6 +599,97 @@ def test_get_indexer_length_one(self, item, closed):
555599
expected = np.array([0] * len(item), dtype='intp')
556600
tm.assert_numpy_array_equal(result, expected)
557601

602+
# Make consistent with test_interval_new.py (see #16316, #16386)
603+
@pytest.mark.parametrize('arrays', [
604+
(date_range('20180101', periods=4), date_range('20180103', periods=4)),
605+
(date_range('20180101', periods=4, tz='US/Eastern'),
606+
date_range('20180103', periods=4, tz='US/Eastern')),
607+
(timedelta_range('0 days', periods=4),
608+
timedelta_range('2 days', periods=4))], ids=lambda x: str(x[0].dtype))
609+
def test_get_reindexer_datetimelike(self, arrays):
610+
# GH 20636
611+
index = IntervalIndex.from_arrays(*arrays)
612+
tuples = [(index[0].left, index[0].left + pd.Timedelta('12H')),
613+
(index[-1].right - pd.Timedelta('12H'), index[-1].right)]
614+
target = IntervalIndex.from_tuples(tuples)
615+
616+
result = index._get_reindexer(target)
617+
expected = np.array([0, 3], dtype='int64')
618+
tm.assert_numpy_array_equal(result, expected)
619+
620+
@pytest.mark.parametrize('breaks', [
621+
date_range('20180101', periods=4),
622+
date_range('20180101', periods=4, tz='US/Eastern'),
623+
timedelta_range('0 days', periods=4)], ids=lambda x: str(x.dtype))
624+
def test_maybe_convert_i8(self, breaks):
625+
# GH 20636
626+
index = IntervalIndex.from_breaks(breaks)
627+
628+
# intervalindex
629+
result = index._maybe_convert_i8(index)
630+
expected = IntervalIndex.from_breaks(breaks.asi8)
631+
tm.assert_index_equal(result, expected)
632+
633+
# interval
634+
interval = Interval(breaks[0], breaks[1])
635+
result = index._maybe_convert_i8(interval)
636+
expected = Interval(breaks[0].value, breaks[1].value)
637+
assert result == expected
638+
639+
# datetimelike index
640+
result = index._maybe_convert_i8(breaks)
641+
expected = Index(breaks.asi8)
642+
tm.assert_index_equal(result, expected)
643+
644+
# datetimelike scalar
645+
result = index._maybe_convert_i8(breaks[0])
646+
expected = breaks[0].value
647+
assert result == expected
648+
649+
# list-like of datetimelike scalars
650+
result = index._maybe_convert_i8(list(breaks))
651+
expected = Index(breaks.asi8)
652+
tm.assert_index_equal(result, expected)
653+
654+
@pytest.mark.parametrize('breaks', [
655+
np.arange(5, dtype='int64'),
656+
np.arange(5, dtype='float64')], ids=lambda x: str(x.dtype))
657+
@pytest.mark.parametrize('make_key', [
658+
IntervalIndex.from_breaks,
659+
lambda breaks: Interval(breaks[0], breaks[1]),
660+
lambda breaks: breaks,
661+
lambda breaks: breaks[0],
662+
list], ids=['IntervalIndex', 'Interval', 'Index', 'scalar', 'list'])
663+
def test_maybe_convert_i8_numeric(self, breaks, make_key):
664+
# GH 20636
665+
index = IntervalIndex.from_breaks(breaks)
666+
key = make_key(breaks)
667+
668+
# no conversion occurs for numeric
669+
result = index._maybe_convert_i8(key)
670+
assert result is key
671+
672+
@pytest.mark.parametrize('breaks1, breaks2', permutations([
673+
date_range('20180101', periods=4),
674+
date_range('20180101', periods=4, tz='US/Eastern'),
675+
timedelta_range('0 days', periods=4)], 2), ids=lambda x: str(x.dtype))
676+
@pytest.mark.parametrize('make_key', [
677+
IntervalIndex.from_breaks,
678+
lambda breaks: Interval(breaks[0], breaks[1]),
679+
lambda breaks: breaks,
680+
lambda breaks: breaks[0],
681+
list], ids=['IntervalIndex', 'Interval', 'Index', 'scalar', 'list'])
682+
def test_maybe_convert_i8_errors(self, breaks1, breaks2, make_key):
683+
# GH 20636
684+
index = IntervalIndex.from_breaks(breaks1)
685+
key = make_key(breaks2)
686+
687+
msg = ('Cannot index an IntervalIndex of subtype {dtype1} with '
688+
'values of dtype {dtype2}')
689+
msg = re.escape(msg.format(dtype1=breaks1.dtype, dtype2=breaks2.dtype))
690+
with tm.assert_raises_regex(ValueError, msg):
691+
index._maybe_convert_i8(key)
692+
558693
# To be removed, replaced by test_interval_new.py (see #16316, #16386)
559694
def test_contains(self):
560695
# Only endpoints are valid.

0 commit comments

Comments
 (0)