Skip to content

Commit 4747609

Browse files
committed
solved the mypy type checking error, and implemented searchsorted under MultiIndex()
1 parent 0e0b9b5 commit 4747609

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

pandas/core/indexes/multi.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3780,7 +3780,7 @@ def _reorder_indexer(
37803780

37813781
def searchsorted(
37823782
self,
3783-
value: Union[Tuple[Hashable, ...], Sequence[Tuple[Hashable, ...]]],
3783+
value: Any,
37843784
side: Literal["left", "right"] = "left",
37853785
sorter: npt.NDArray[np.intp] | None = None,
37863786
) -> npt.NDArray[np.intp]:
@@ -3789,7 +3789,7 @@ def searchsorted(
37893789
37903790
Parameters
37913791
----------
3792-
value : tuple
3792+
value : Any
37933793
The value(s) to search for in the MultiIndex.
37943794
side : {'left', 'right'}, default 'left'
37953795
If 'left', the index of the first suitable location found is given.
@@ -3800,7 +3800,7 @@ def searchsorted(
38003800
38013801
Returns
38023802
-------
3803-
numpy.ndarray
3803+
npt.NDArray[np.intp]
38043804
Array of insertion points.
38053805
38063806
See Also
@@ -3813,18 +3813,19 @@ def searchsorted(
38133813
>>> mi.searchsorted(("b", "y"))
38143814
1
38153815
"""
3816+
3817+
if not value:
3818+
raise ValueError("searchsorted requires a non-empty value")
3819+
38163820
if not isinstance(value, (tuple, list)):
38173821
raise TypeError("value must be a tuple or list")
38183822

38193823
if isinstance(value, tuple):
3820-
values = [value]
3824+
value = [value]
3825+
38213826
if side not in ["left", "right"]:
38223827
raise ValueError("side must be either 'left' or 'right'")
38233828

3824-
if not value:
3825-
raise ValueError("searchsorted requires a non-empty value")
3826-
3827-
38283829
indexer = self.get_indexer(value)
38293830
result = []
38303831

@@ -3834,20 +3835,20 @@ def searchsorted(
38343835
else:
38353836
dtype = np.dtype(
38363837
[
3837-
(f"level_{i}", level.dtype)
3838+
(f"level_{i}", np.asarray(level).dtype)
38383839
for i, level in enumerate(self.levels)
38393840
]
38403841
)
3841-
3842-
val_array = np.array(values, dtype=dtype)
3842+
3843+
val_array = np.array([v], dtype=dtype)
38433844

38443845
pos = np.searchsorted(
38453846
np.asarray(self.values, dtype=dtype),
38463847
val_array,
38473848
side=side,
38483849
sorter=sorter,
38493850
)
3850-
result.append(pos)
3851+
result.append(int(pos[0]))
38513852

38523853
return np.array(result, dtype=np.intp)
38533854

0 commit comments

Comments
 (0)