Skip to content

Commit 8ecfcbd

Browse files
topper-123Terji Petersen
authored and
Terji Petersen
committed
API: ensure IntervalIndex.left/right are 64bit if numeric part II
1 parent a0ee90a commit 8ecfcbd

File tree

3 files changed

+31
-22
lines changed

3 files changed

+31
-22
lines changed

pandas/core/arrays/interval.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
Dtype,
3636
IntervalClosedType,
3737
NpDtype,
38+
NumpyIndexT,
3839
PositionalIndexer,
3940
ScalarIndexer,
4041
SequenceIndexer,
@@ -55,7 +56,9 @@
5556
is_list_like,
5657
is_object_dtype,
5758
is_scalar,
59+
is_signed_integer_dtype,
5860
is_string_dtype,
61+
is_unsigned_integer_dtype,
5962
needs_i8_conversion,
6063
pandas_dtype,
6164
)
@@ -177,6 +180,21 @@
177180
"""
178181

179182

183+
def maybe_convert_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT:
184+
# IntervalTree only supports 64 bit numpy array
185+
dtype = arr.dtype
186+
if not np.issubclass_(dtype.type, np.number):
187+
return arr
188+
elif is_signed_integer_dtype(dtype) and dtype != np.int64:
189+
return arr.astype(np.int64)
190+
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
191+
return arr.astype(np.uint64)
192+
elif is_float_dtype(dtype) and dtype != np.float64:
193+
return arr.astype(np.float64)
194+
else:
195+
return arr
196+
197+
180198
@Appender(
181199
_interval_shared_docs["class"]
182200
% {
@@ -248,6 +266,7 @@ def __new__(
248266

249267
# might need to convert empty or purely na data
250268
data = _maybe_convert_platform_interval(data)
269+
data = maybe_convert_numeric_to_64bit(data)
251270
left, right, infer_closed = intervals_to_interval_bounds(
252271
data, validate_closed=closed is None
253272
)

pandas/core/indexes/interval.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@
5959
is_number,
6060
is_object_dtype,
6161
is_scalar,
62-
is_signed_integer_dtype,
63-
is_unsigned_integer_dtype,
6462
)
6563
from pandas.core.dtypes.dtypes import IntervalDtype
6664
from pandas.core.dtypes.missing import is_valid_na_for_dtype
@@ -69,6 +67,7 @@
6967
from pandas.core.arrays.interval import (
7068
IntervalArray,
7169
_interval_shared_docs,
70+
maybe_convert_numeric_to_64bit,
7271
)
7372
import pandas.core.common as com
7473
from pandas.core.indexers import is_valid_positional_slice
@@ -520,13 +519,12 @@ def _maybe_convert_i8(self, key):
520519
The original key if no conversion occurred, int if converted scalar,
521520
Int64Index if converted list-like.
522521
"""
523-
original = key
524522
if is_list_like(key):
525523
key = ensure_index(key)
526-
key = self._maybe_convert_numeric_to_64bit(key)
524+
key = maybe_convert_numeric_to_64bit(key)
527525

528526
if not self._needs_i8_conversion(key):
529-
return original
527+
return key
530528

531529
scalar = is_scalar(key)
532530
if is_interval_dtype(key) or isinstance(key, Interval):
@@ -569,20 +567,6 @@ def _maybe_convert_i8(self, key):
569567

570568
return key_i8
571569

572-
def _maybe_convert_numeric_to_64bit(self, idx: Index) -> Index:
573-
# IntervalTree only supports 64 bit numpy array
574-
dtype = idx.dtype
575-
if np.issubclass_(dtype.type, np.number):
576-
return idx
577-
elif is_signed_integer_dtype(dtype) and dtype != np.int64:
578-
return idx.astype(np.int64)
579-
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
580-
return idx.astype(np.uint64)
581-
elif is_float_dtype(dtype) and dtype != np.float64:
582-
return idx.astype(np.float64)
583-
else:
584-
return idx
585-
586570
def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"):
587571
if not self.is_non_overlapping_monotonic:
588572
raise KeyError(

pandas/tests/indexes/interval/test_interval.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
timedelta_range,
1919
)
2020
import pandas._testing as tm
21-
from pandas.core.api import Float64Index
21+
from pandas.core.api import (
22+
Float64Index,
23+
NumericIndex,
24+
)
2225
import pandas.core.common as com
2326

2427

@@ -435,9 +438,12 @@ def test_maybe_convert_i8_numeric(self, breaks, make_key):
435438
index = IntervalIndex.from_breaks(breaks)
436439
key = make_key(breaks)
437440

438-
# no conversion occurs for numeric
439441
result = index._maybe_convert_i8(key)
440-
assert result is key
442+
if not isinstance(result, NumericIndex):
443+
assert result is key
444+
else:
445+
expected = NumericIndex(key)
446+
tm.assert_index_equal(result, expected)
441447

442448
@pytest.mark.parametrize(
443449
"breaks1, breaks2",

0 commit comments

Comments
 (0)