Skip to content

Commit 4da554f

Browse files
jbrockmendeljreback
authored andcommitted
REF: simplify index.pyx (#31168)
1 parent 754dc4c commit 4da554f

File tree

2 files changed

+37
-41
lines changed

2 files changed

+37
-41
lines changed

pandas/_libs/index.pyx

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,13 @@ from pandas._libs import algos, hashtable as _hash
2626
from pandas._libs.tslibs import Timestamp, Timedelta, period as periodlib
2727
from pandas._libs.missing import checknull
2828

29-
cdef int64_t NPY_NAT = util.get_nat()
30-
3129

3230
cdef inline bint is_definitely_invalid_key(object val):
33-
if isinstance(val, tuple):
34-
try:
35-
hash(val)
36-
except TypeError:
37-
return True
38-
39-
# we have a _data, means we are a NDFrame
40-
return (isinstance(val, slice) or util.is_array(val)
41-
or isinstance(val, list) or hasattr(val, '_data'))
31+
try:
32+
hash(val)
33+
except TypeError:
34+
return True
35+
return False
4236

4337

4438
cpdef get_value_at(ndarray arr, object loc, object tz=None):
@@ -168,6 +162,15 @@ cdef class IndexEngine:
168162
int count
169163

170164
indexer = self._get_index_values() == val
165+
return self._unpack_bool_indexer(indexer, val)
166+
167+
cdef _unpack_bool_indexer(self,
168+
ndarray[uint8_t, ndim=1, cast=True] indexer,
169+
object val):
170+
cdef:
171+
ndarray[intp_t, ndim=1] found
172+
int count
173+
171174
found = np.where(indexer)[0]
172175
count = len(found)
173176

@@ -446,7 +449,7 @@ cdef class DatetimeEngine(Int64Engine):
446449
cdef:
447450
int64_t loc
448451
if is_definitely_invalid_key(val):
449-
raise TypeError
452+
raise TypeError(f"'{val}' is an invalid key")
450453

451454
try:
452455
conv = self._unbox_scalar(val)
@@ -651,7 +654,10 @@ cdef class BaseMultiIndexCodesEngine:
651654
# integers representing labels: we will use its get_loc and get_indexer
652655
self._base.__init__(self, lambda: lab_ints, len(lab_ints))
653656

654-
def _extract_level_codes(self, object target, object method=None):
657+
def _codes_to_ints(self, codes):
658+
raise NotImplementedError("Implemented by subclass")
659+
660+
def _extract_level_codes(self, object target):
655661
"""
656662
Map the requested list of (tuple) keys to their integer representations
657663
for searching in the underlying integer index.

pandas/_libs/index_class_helper.pxi.in

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,26 @@ WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
1010

1111
{{py:
1212

13-
# name, dtype, ctype, hashtable_name, hashtable_dtype
14-
dtypes = [('Float64', 'float64', 'float64_t', 'Float64', 'float64'),
15-
('Float32', 'float32', 'float32_t', 'Float64', 'float64'),
16-
('Int64', 'int64', 'int64_t', 'Int64', 'int64'),
17-
('Int32', 'int32', 'int32_t', 'Int64', 'int64'),
18-
('Int16', 'int16', 'int16_t', 'Int64', 'int64'),
19-
('Int8', 'int8', 'int8_t', 'Int64', 'int64'),
20-
('UInt64', 'uint64', 'uint64_t', 'UInt64', 'uint64'),
21-
('UInt32', 'uint32', 'uint32_t', 'UInt64', 'uint64'),
22-
('UInt16', 'uint16', 'uint16_t', 'UInt64', 'uint64'),
23-
('UInt8', 'uint8', 'uint8_t', 'UInt64', 'uint64'),
13+
# name, dtype, hashtable_name
14+
dtypes = [('Float64', 'float64', 'Float64'),
15+
('Float32', 'float32', 'Float64'),
16+
('Int64', 'int64', 'Int64'),
17+
('Int32', 'int32', 'Int64'),
18+
('Int16', 'int16', 'Int64'),
19+
('Int8', 'int8', 'Int64'),
20+
('UInt64', 'uint64', 'UInt64'),
21+
('UInt32', 'uint32', 'UInt64'),
22+
('UInt16', 'uint16', 'UInt64'),
23+
('UInt8', 'uint8', 'UInt64'),
2424
]
2525
}}
2626

27-
{{for name, dtype, ctype, hashtable_name, hashtable_dtype in dtypes}}
27+
{{for name, dtype, hashtable_name in dtypes}}
2828

2929

3030
cdef class {{name}}Engine(IndexEngine):
31+
# constructor-caller is responsible for ensuring that vgetter()
32+
# returns an ndarray with dtype {{dtype}}_t
3133

3234
cdef _make_hash_table(self, Py_ssize_t n):
3335
return _hash.{{hashtable_name}}HashTable(n)
@@ -41,22 +43,18 @@ cdef class {{name}}Engine(IndexEngine):
4143
cdef void _call_map_locations(self, values):
4244
# self.mapping is of type {{hashtable_name}}HashTable,
4345
# so convert dtype of values
44-
self.mapping.map_locations(algos.ensure_{{hashtable_dtype}}(values))
45-
46-
cdef _get_index_values(self):
47-
return algos.ensure_{{dtype}}(self.vgetter())
46+
self.mapping.map_locations(algos.ensure_{{hashtable_name.lower()}}(values))
4847

4948
cdef _maybe_get_bool_indexer(self, object val):
5049
cdef:
5150
ndarray[uint8_t, ndim=1, cast=True] indexer
5251
ndarray[intp_t, ndim=1] found
53-
ndarray[{{ctype}}] values
52+
ndarray[{{dtype}}_t, ndim=1] values
5453
int count = 0
5554

5655
self._check_type(val)
5756

58-
# A view is needed for some subclasses, such as PeriodEngine:
59-
values = self._get_index_values().view('{{dtype}}')
57+
values = self._get_index_values()
6058
try:
6159
with warnings.catch_warnings():
6260
# e.g. if values is float64 and `val` is a str, suppress warning
@@ -67,14 +65,6 @@ cdef class {{name}}Engine(IndexEngine):
6765
# when trying to cast it to ndarray
6866
raise KeyError(val)
6967

70-
found = np.where(indexer)[0]
71-
count = len(found)
72-
73-
if count > 1:
74-
return indexer
75-
if count == 1:
76-
return int(found[0])
77-
78-
raise KeyError(val)
68+
return self._unpack_bool_indexer(indexer, val)
7969

8070
{{endfor}}

0 commit comments

Comments
 (0)