Skip to content

Commit d42689b

Browse files
committed
collect into one function
1 parent 0dad72c commit d42689b

File tree

6 files changed

+128
-93
lines changed

6 files changed

+128
-93
lines changed

pandas/core/arrays/base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -555,17 +555,17 @@ def searchsorted(self, value, side="left", sorter=None):
555555
.. versionadded:: 0.24.0
556556
557557
Find the indices into a sorted array `self` (a) such that, if the
558-
corresponding elements in `v` were inserted before the indices, the
559-
order of `self` would be preserved.
558+
corresponding elements in `value` were inserted before the indices,
559+
the order of `self` would be preserved.
560560
561-
Assuming that `a` is sorted:
561+
Assuming that `self` is sorted:
562562
563-
====== ============================
563+
====== ================================
564564
`side` returned index `i` satisfies
565-
====== ============================
566-
left ``self[i-1] < v <= self[i]``
567-
right ``self[i-1] <= v < self[i]``
568-
====== ============================
565+
====== ================================
566+
left ``self[i-1] < value <= self[i]``
567+
right ``self[i-1] <= value < self[i]``
568+
====== ================================
569569
570570
Parameters
571571
----------
@@ -581,7 +581,7 @@ def searchsorted(self, value, side="left", sorter=None):
581581
582582
Returns
583583
-------
584-
indices : array of ints
584+
array of ints
585585
Array of insertion points with the same shape as `value`.
586586
587587
See Also

pandas/core/arrays/numpy_.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
from pandas._libs import lib
66
from pandas.compat.numpy import function as nv
7+
from pandas.util._decorators import Appender
78
from pandas.util._validators import validate_fillna_kwargs
89

910
from pandas.core.dtypes.dtypes import ExtensionDtype
1011
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries
1112
from pandas.core.dtypes.inference import is_array_like, is_list_like
1213

1314
from pandas import compat
14-
from pandas.core import nanops
15+
from pandas.core import common as com, nanops
1516
from pandas.core.missing import backfill_1d, pad_1d
1617

1718
from .base import ExtensionArray, ExtensionOpsMixin
@@ -423,6 +424,11 @@ def to_numpy(self, dtype=None, copy=False):
423424

424425
return result
425426

427+
@Appender(ExtensionArray.searchsorted.__doc__)
428+
def searchsorted(self, value, side='left', sorter=None):
429+
return com.searchsorted(self.to_numpy(), value,
430+
side=side, sorter=sorter)
431+
426432
# ------------------------------------------------------------------------
427433
# Ops
428434

pandas/core/base.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,15 +1494,11 @@ def factorize(self, sort=False, na_sentinel=-1):
14941494
array([3])
14951495
""")
14961496

1497-
@Substitution(klass='IndexOpsMixin')
1497+
@Substitution(klass='Index')
14981498
@Appender(_shared_docs['searchsorted'])
14991499
def searchsorted(self, value, side='left', sorter=None):
1500-
result = com.searchsorted(self._values, value,
1501-
side=side, sorter=sorter)
1502-
1503-
if is_scalar(value):
1504-
return result if is_scalar(result) else result[0]
1505-
return result
1500+
return com.searchsorted(self._values, value,
1501+
side=side, sorter=sorter)
15061502

15071503
def drop_duplicates(self, keep='first', inplace=False):
15081504
inplace = validate_bool_kwarg(inplace, 'inplace')

pandas/core/common.py

Lines changed: 64 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,19 @@
1212
import numpy as np
1313

1414
from pandas._libs import lib, tslibs
15+
from pandas.compat import PY36, OrderedDict, iteritems
16+
1517
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
16-
from pandas import compat
17-
from pandas.compat import iteritems, PY36, OrderedDict
18-
from pandas.core.dtypes.generic import ABCSeries, ABCIndex, ABCIndexClass
1918
from pandas.core.dtypes.common import (
20-
is_integer, is_integer_dtype, is_bool_dtype,
21-
is_extension_array_dtype, is_array_like, is_object_dtype,
22-
is_categorical_dtype, is_numeric_dtype, is_scalar, ensure_platform_int)
19+
ensure_platform_int, is_array_like, is_bool_dtype, is_categorical_dtype,
20+
is_extension_array_dtype, is_integer, is_integer_dtype, is_numeric_dtype,
21+
is_object_dtype, is_scalar)
22+
from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries
2323
from pandas.core.dtypes.inference import _iterable_not_string
2424
from pandas.core.dtypes.missing import isna, isnull, notnull # noqa
2525

26+
from pandas import compat
27+
2628

2729
class SettingWithCopyError(ValueError):
2830
pass
@@ -484,87 +486,79 @@ def f(x):
484486
return f
485487

486488

487-
def searchsorted_integer(arr, value, side="left", sorter=None):
488-
"""
489-
searchsorted implementation for searching integer arrays.
490-
491-
We get a speedup if we ensure the dtype of arr and value are the same
492-
(if possible) before searchingm as numpy implicitly converts the dtypes
493-
if they're different, which would cause a slowdown.
494-
495-
See :func:`searchsorted` for a more general searchsorted implementation.
496-
497-
Parameters
498-
----------
499-
arr : numpy.array
500-
a numpy array of integers
501-
value : int or numpy.array
502-
an integer or an array of integers that we want to find the
503-
location(s) for in `arr`
504-
side : str
505-
One of {'left', 'right'}
506-
sorter : numpy.array, optional
507-
508-
Returns
509-
-------
510-
int or numpy.array
511-
The locations(s) of `value` in `arr`.
512-
"""
513-
from .arrays.array_ import array
514-
if sorter is not None:
515-
sorter = ensure_platform_int(sorter)
516-
517-
# below we try to give `value` the same dtype as `arr`, while guarding
518-
# against integer overflows. If the value of `value` is outside of the
519-
# bound of `arr`, `arr` would be recast by numpy, causing a slower search.
520-
value_arr = np.array([value]) if is_scalar(value) else np.array(value)
521-
iinfo = np.iinfo(arr.dtype.type)
522-
if (value_arr >= iinfo.min).all() and (value_arr <= iinfo.max).all():
523-
dtype = arr.dtype
524-
else:
525-
dtype = value_arr.dtype
526-
527-
if is_scalar(value):
528-
value = dtype.type(value)
529-
else:
530-
value = array(value, dtype=dtype)
531-
532-
return arr.searchsorted(value, side=side, sorter=sorter)
533-
534-
535489
def searchsorted(arr, value, side="left", sorter=None):
536490
"""
537491
Find indices where elements should be inserted to maintain order.
538492
539-
Find the indices into a sorted array-like `arr` such that, if the
493+
.. versionadded:: 0.25.0
494+
495+
Find the indices into a sorted array `self` (a) such that, if the
540496
corresponding elements in `value` were inserted before the indices,
541-
the order of `arr` would be preserved.
497+
the order of `self` would be preserved.
498+
499+
Assuming that `self` is sorted:
542500
543-
See :class:`IndexOpsMixin.searchsorted` for more details and examples.
501+
====== ================================
502+
`side` returned index `i` satisfies
503+
====== ================================
504+
left ``self[i-1] < value <= self[i]``
505+
right ``self[i-1] <= value < self[i]``
506+
====== ================================
544507
545508
Parameters
546509
----------
547-
arr : numpy.array or ExtensionArray
548-
value : scalar or numpy.array
549-
side : str
550-
One of {'left', 'right'}
551-
sorter : numpy.array, optional
510+
arr: numpy.array or ExtensionArray
511+
array to search in. Cannot be Index, Series or PandasArray, as that
512+
would cause a RecursionError.
513+
value : array_like
514+
Values to insert into `arr`.
515+
side : {'left', 'right'}, optional
516+
If 'left', the index of the first suitable location found is given.
517+
If 'right', return the last such index. If there is no suitable
518+
index, return either 0 or N (where N is the length of `self`).
519+
sorter : 1-D array_like, optional
520+
Optional array of integer indices that sort array a into ascending
521+
order. They are typically the result of argsort.
552522
553523
Returns
554524
-------
555-
int or numpy.array
556-
The locations(s) of `value` in `arr`.
525+
array of ints
526+
Array of insertion points with the same shape as `value`.
527+
528+
See Also
529+
--------
530+
numpy.searchsorted : Similar method from NumPy.
557531
"""
558532
if sorter is not None:
559533
sorter = ensure_platform_int(sorter)
560534

561535
if is_integer_dtype(arr) and (
562536
is_integer(value) or is_integer_dtype(value)):
563-
return searchsorted_integer(arr, value, side=side, sorter=sorter)
564-
if not (is_object_dtype(arr) or is_numeric_dtype(arr) or
565-
is_categorical_dtype(arr)):
537+
from .arrays.array_ import array
538+
# if `arr` and `value` have different dtypes, `arr` would be
539+
# recast by numpy, causing a slow search.
540+
# Before searching below, we therefore try to give `value` the
541+
# same dtype as `arr`, while guarding against integer overflows.
542+
iinfo = np.iinfo(arr.dtype.type)
543+
value_arr = np.array([value]) if is_scalar(value) else np.array(value)
544+
if (value_arr >= iinfo.min).all() and (value_arr <= iinfo.max).all():
545+
# value within bounds, so no overflow, so can convert value dtype
546+
# to dtype of arr
547+
dtype = arr.dtype
548+
else:
549+
dtype = value_arr.dtype
550+
551+
if is_scalar(value):
552+
value = dtype.type(value)
553+
else:
554+
value = array(value, dtype=dtype)
555+
elif not (is_object_dtype(arr) or is_numeric_dtype(arr) or
556+
is_categorical_dtype(arr)):
557+
from pandas.core.series import Series
566558
# E.g. if `arr` is an array with dtype='datetime64[ns]'
567559
# and `value` is a pd.Timestamp, we may need to convert value
568-
from pandas.core.series import Series
569-
value = Series(value)._values
570-
return arr.searchsorted(value, side=side, sorter=sorter)
560+
value_ser = Series(value)._values
561+
value = value_ser[0] if is_scalar(value) else value_ser
562+
563+
result = arr.searchsorted(value, side=side, sorter=sorter)
564+
return result

pandas/core/series.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,12 +2331,8 @@ def __rmatmul__(self, other):
23312331
@Substitution(klass='Series')
23322332
@Appender(base._shared_docs['searchsorted'])
23332333
def searchsorted(self, value, side='left', sorter=None):
2334-
result = com.searchsorted(self._values, value,
2335-
side=side, sorter=sorter)
2336-
2337-
if is_scalar(value):
2338-
return result if is_scalar(result) else result[0]
2339-
return result
2334+
return com.searchsorted(self._values, value,
2335+
side=side, sorter=sorter)
23402336

23412337
# -------------------------------------------------------------------
23422338
# Combination

pandas/tests/arrays/test_array.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import pandas as pd
1111
from pandas.api.extensions import register_extension_dtype
12+
from pandas.api.types import is_scalar
1213
from pandas.core.arrays import PandasArray, integer_array, period_array
1314
from pandas.tests.extension.decimal import (
1415
DecimalArray, DecimalDtype, to_decimal)
@@ -254,3 +255,45 @@ def test_array_not_registered(registry_without_decimal):
254255
result = pd.array(data, dtype=DecimalDtype)
255256
expected = DecimalArray._from_sequence(data)
256257
tm.assert_equal(result, expected)
258+
259+
260+
class TestArrayAnalytics(object):
261+
def test_searchsorted(self, string_dtype):
262+
arr = pd.array(['a', 'b', 'c'], dtype=string_dtype)
263+
264+
result = arr.searchsorted('a', side='left')
265+
assert is_scalar(result)
266+
assert result == 0
267+
268+
result = arr.searchsorted('a', side='right')
269+
assert is_scalar(result)
270+
assert result == 1
271+
272+
def test_searchsorted_numeric_dtypes_scalar(self, any_real_dtype):
273+
arr = pd.array([1, 3, 90], dtype=any_real_dtype)
274+
result = arr.searchsorted(30)
275+
assert is_scalar(result)
276+
assert result == 2
277+
278+
result = arr.searchsorted([30])
279+
expected = np.array([2], dtype=np.intp)
280+
tm.assert_numpy_array_equal(result, expected)
281+
282+
def test_searchsorted_numeric_dtypes_vector(self, any_real_dtype):
283+
arr = pd.array([1, 3, 90], dtype=any_real_dtype)
284+
result = arr.searchsorted([2, 30])
285+
expected = np.array([1, 2], dtype=np.intp)
286+
tm.assert_numpy_array_equal(result, expected)
287+
288+
def test_search_sorted_datetime64_scalar(self):
289+
arr = pd.array(pd.date_range('20120101', periods=10, freq='2D'))
290+
val = pd.Timestamp('20120102')
291+
result = arr.searchsorted(val)
292+
assert is_scalar(result)
293+
assert result == 1
294+
295+
def test_searchsorted_sorter(self, any_real_dtype):
296+
arr = pd.array([3, 1, 2], dtype=any_real_dtype)
297+
result = arr.searchsorted([0, 3], sorter=np.argsort(arr))
298+
expected = np.array([0, 2], dtype=np.intp)
299+
tm.assert_numpy_array_equal(result, expected)

0 commit comments

Comments
 (0)