|
12 | 12 | import numpy as np
|
13 | 13 |
|
14 | 14 | from pandas._libs import lib, tslibs
|
15 |
| -import pandas.compat as compat |
16 |
| -from pandas.compat import PY36, OrderedDict, iteritems |
17 |
| - |
18 | 15 | from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
|
19 |
| -from pandas.core.dtypes.common import ( |
20 |
| - is_array_like, is_bool_dtype, is_extension_array_dtype, is_integer) |
21 |
| -from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries |
| 16 | +from pandas import compat |
| 17 | +from pandas.compat import iteritems, PY2, PY36, OrderedDict |
| 18 | +from pandas.core.dtypes.generic import ABCSeries, ABCIndex, ABCIndexClass |
| 19 | +from pandas.core.dtypes.common import (is_integer, is_integer_dtype, |
| 20 | + is_bool_dtype, is_extension_array_dtype, |
| 21 | + is_array_like, |
| 22 | + is_float_dtype, is_object_dtype, |
| 23 | + is_categorical_dtype, is_numeric_dtype, |
| 24 | + is_scalar, ensure_platform_int) |
22 | 25 | from pandas.core.dtypes.inference import _iterable_not_string
|
23 | 26 | from pandas.core.dtypes.missing import isna, isnull, notnull # noqa
|
24 | 27 |
|
@@ -481,3 +484,83 @@ def f(x):
|
481 | 484 | f = mapper
|
482 | 485 |
|
483 | 486 | return f
|
| 487 | + |
| 488 | + |
| 489 | +def ensure_integer_dtype(arr, value): |
| 490 | + """ |
| 491 | + Ensure optimal dtype for :func:`searchsorted_integer` is returned. |
| 492 | +
|
| 493 | + Parameters |
| 494 | + ---------- |
| 495 | + arr : a numpy integer array |
| 496 | + value : a number or array of numbers |
| 497 | +
|
| 498 | + Returns |
| 499 | + ------- |
| 500 | + dtype : an numpy integer dtype |
| 501 | +
|
| 502 | + Raises |
| 503 | + ------ |
| 504 | + TypeError : if value is not a number |
| 505 | + """ |
| 506 | + value_arr = np.array([value]) if is_scalar(value) else np.array(value) |
| 507 | + |
| 508 | + if PY2 and not is_numeric_dtype(value_arr): |
| 509 | + # python 2 allows "a" < 1, avoid such nonsense |
| 510 | + msg = "value must be numeric, was type {}" |
| 511 | + raise TypeError(msg.format(value)) |
| 512 | + |
| 513 | + iinfo = np.iinfo(arr.dtype) |
| 514 | + if not ((value_arr < iinfo.min).any() or (value_arr > iinfo.max).any()): |
| 515 | + return arr.dtype |
| 516 | + else: |
| 517 | + return value_arr.dtype |
| 518 | + |
| 519 | + |
| 520 | +def searchsorted_integer(arr, value, side="left", sorter=None): |
| 521 | + """ |
| 522 | + searchsorted implementation, but only for integer arrays. |
| 523 | +
|
| 524 | + We get a speedup if the dtype of arr and value is the same. |
| 525 | +
|
| 526 | + See :func:`searchsorted` for a more general searchsorted implementation. |
| 527 | + """ |
| 528 | + if sorter is not None: |
| 529 | + sorter = ensure_platform_int(sorter) |
| 530 | + |
| 531 | + dtype = ensure_integer_dtype(arr, value) |
| 532 | + |
| 533 | + if is_integer(value) or is_integer_dtype(value): |
| 534 | + value = np.asarray(value, dtype=dtype) |
| 535 | + elif hasattr(value, 'is_integer') and value.is_integer(): |
| 536 | + # float 2.0 can be converted to int 2 for better speed, |
| 537 | + # but float 2.2 should *not* be converted to int 2 |
| 538 | + value = np.asarray(value, dtype=dtype) |
| 539 | + |
| 540 | + return np.searchsorted(arr, value, side=side, sorter=sorter) |
| 541 | + |
| 542 | + |
| 543 | +def searchsorted(arr, value, side="left", sorter=None): |
| 544 | + """ |
| 545 | + Find indices where elements should be inserted to maintain order. |
| 546 | +
|
| 547 | + Find the indices into a sorted array-like `arr` such that, if the |
| 548 | + corresponding elements in `value` were inserted before the indices, |
| 549 | + the order of `arr` would be preserved. |
| 550 | +
|
| 551 | + See :class:`IndexOpsMixin.searchsorted` for more details and examples. |
| 552 | + """ |
| 553 | + if sorter is not None: |
| 554 | + sorter = ensure_platform_int(sorter) |
| 555 | + |
| 556 | + if is_integer_dtype(arr): |
| 557 | + return searchsorted_integer(arr, value, side=side, sorter=sorter) |
| 558 | + elif (is_object_dtype(arr) or is_float_dtype(arr) or |
| 559 | + is_categorical_dtype(arr)): |
| 560 | + return arr.searchsorted(value, side=side, sorter=sorter) |
| 561 | + else: |
| 562 | + # fallback solution. E.g. arr is an array with dtype='datetime64[ns]' |
| 563 | + # and value is a pd.Timestamp, need to convert value |
| 564 | + from pandas.core.series import Series |
| 565 | + value = Series(value)._values |
| 566 | + return arr.searchsorted(value, side=side, sorter=sorter) |
0 commit comments