Skip to content

Commit abbacb1

Browse files
committed
Simplify implementation
1 parent c65f8f8 commit abbacb1

File tree

1 file changed

+54
-55
lines changed

1 file changed

+54
-55
lines changed

pandas/core/common.py

Lines changed: 54 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,12 @@
1414
from pandas._libs import lib, tslibs
1515
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
1616
from pandas import compat
17-
from pandas.compat import iteritems, PY2, PY36, OrderedDict
17+
from pandas.compat import iteritems, PY36, OrderedDict
1818
from pandas.core.dtypes.generic import ABCSeries, ABCIndex, ABCIndexClass
19-
from pandas.core.dtypes.common import (is_integer, is_integer_dtype,
20-
is_bool_dtype, is_extension_array_dtype,
21-
is_array_like,
22-
is_float_dtype, is_object_dtype,
23-
is_categorical_dtype, is_numeric_dtype,
24-
is_scalar, ensure_platform_int)
19+
from pandas.core.dtypes.common import (
20+
is_integer, is_integer_dtype, is_bool_dtype,
21+
is_extension_array_dtype, is_array_like, is_object_dtype,
22+
is_categorical_dtype, is_numeric_dtype, is_scalar, ensure_platform_int)
2523
from pandas.core.dtypes.inference import _iterable_not_string
2624
from pandas.core.dtypes.missing import isna, isnull, notnull # noqa
2725

@@ -486,58 +484,47 @@ def f(x):
486484
return f
487485

488486

489-
def ensure_integer_dtype(arr, value):
487+
def searchsorted_integer(arr, value, side="left", sorter=None):
490488
"""
491-
Ensure optimal dtype for :func:`searchsorted_integer` is returned.
489+
searchsorted implementation for searching integer arrays.
490+
491+
We get a speedup if we ensure the dtype of arr and value are the same
492+
(if possible) before searchingm as numpy implicitly converts the dtypes
493+
if they're different, which would cause a slowdown.
494+
495+
See :func:`searchsorted` for a more general searchsorted implementation.
492496
493497
Parameters
494498
----------
495-
arr : a numpy integer array
496-
value : a number or array of numbers
499+
arr : numpy.array
500+
a numpy array of integers
501+
value : int or numpy.array
502+
an integer or an array of integers that we want to find the
503+
location(s) for in `arr`
504+
side : str
505+
One of {'left', 'right'}
506+
sorter : numpy.array, optional
497507
498508
Returns
499509
-------
500-
dtype : an numpy integer dtype
501-
502-
Raises
503-
------
504-
TypeError : if value is not a number
505-
"""
506-
value_arr = np.array([value]) if is_scalar(value) else np.array(value)
507-
508-
if PY2 and not is_numeric_dtype(value_arr):
509-
# python 2 allows "a" < 1, avoid such nonsense
510-
msg = "value must be numeric, was type {}"
511-
raise TypeError(msg.format(value))
512-
513-
iinfo = np.iinfo(arr.dtype)
514-
if not ((value_arr < iinfo.min).any() or (value_arr > iinfo.max).any()):
515-
return arr.dtype
516-
else:
517-
return value_arr.dtype
518-
519-
520-
def searchsorted_integer(arr, value, side="left", sorter=None):
521-
"""
522-
searchsorted implementation, but only for integer arrays.
523-
524-
We get a speedup if the dtype of arr and value is the same.
525-
526-
See :func:`searchsorted` for a more general searchsorted implementation.
510+
int or numpy.array
511+
The locations(s) of `value` in `arr`.
527512
"""
528513
if sorter is not None:
529514
sorter = ensure_platform_int(sorter)
530515

531-
dtype = ensure_integer_dtype(arr, value)
532-
533-
if is_integer(value) or is_integer_dtype(value):
534-
value = np.asarray(value, dtype=dtype)
535-
elif hasattr(value, 'is_integer') and value.is_integer():
536-
# float 2.0 can be converted to int 2 for better speed,
537-
# but float 2.2 should *not* be converted to int 2
538-
value = np.asarray(value, dtype=dtype)
516+
# below we try to give `value` the same dtype as `arr`, while guarding
517+
# against integer overflows. If the value of `value` is outside of the
518+
# bound of `arr`, `arr` would be recast by numpy, causing a slower search.
519+
value_arr = np.array([value]) if is_scalar(value) else np.array(value)
520+
iinfo = np.iinfo(arr.dtype)
521+
if (value_arr >= iinfo.min).all() and (value_arr <= iinfo.max).all():
522+
dtype = arr.dtype
523+
else:
524+
dtype = value_arr.dtype
525+
value = np.asarray(value, dtype=dtype)
539526

540-
return np.searchsorted(arr, value, side=side, sorter=sorter)
527+
return arr.searchsorted(value, side=side, sorter=sorter)
541528

542529

543530
def searchsorted(arr, value, side="left", sorter=None):
@@ -549,18 +536,30 @@ def searchsorted(arr, value, side="left", sorter=None):
549536
the order of `arr` would be preserved.
550537
551538
See :class:`IndexOpsMixin.searchsorted` for more details and examples.
539+
540+
Parameters
541+
----------
542+
arr : numpy.array or ExtensionArray
543+
value : scalar or numpy.array
544+
side : str
545+
One of {'left', 'right'}
546+
sorter : numpy.array, optional
547+
548+
Returns
549+
-------
550+
int or numpy.array
551+
The locations(s) of `value` in `arr`.
552552
"""
553553
if sorter is not None:
554554
sorter = ensure_platform_int(sorter)
555555

556-
if is_integer_dtype(arr):
556+
if is_integer_dtype(arr) and (
557+
is_integer(value) or is_integer_dtype(value)):
557558
return searchsorted_integer(arr, value, side=side, sorter=sorter)
558-
elif (is_object_dtype(arr) or is_float_dtype(arr) or
559-
is_categorical_dtype(arr)):
560-
return arr.searchsorted(value, side=side, sorter=sorter)
561-
else:
562-
# fallback solution. E.g. arr is an array with dtype='datetime64[ns]'
563-
# and value is a pd.Timestamp, need to convert value
559+
if not (is_object_dtype(arr) or is_numeric_dtype(arr) or
560+
is_categorical_dtype(arr)):
561+
# E.g. if `arr` is an array with dtype='datetime64[ns]'
562+
# and `value` is a pd.Timestamp, we may need to convert value
564563
from pandas.core.series import Series
565564
value = Series(value)._values
566-
return arr.searchsorted(value, side=side, sorter=sorter)
565+
return arr.searchsorted(value, side=side, sorter=sorter)

0 commit comments

Comments
 (0)