Skip to content

Commit 03396ef

Browse files
Add @ operator type hints for Series (#1047)
* Add @ operator type hints for Series * Fix test * Formatting * Fix test
1 parent 92bd9cb commit 03396ef

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

pandas-stubs/core/series.pyi

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -800,8 +800,18 @@ class Series(IndexOpsMixin[S1], NDFrame):
800800
def dot(
801801
self, other: ArrayLike | dict[_str, np.ndarray] | Sequence[S1] | Index[S1]
802802
) -> np.ndarray: ...
803-
def __matmul__(self, other): ...
804-
def __rmatmul__(self, other): ...
803+
@overload
804+
def __matmul__(self, other: Series) -> Scalar: ...
805+
@overload
806+
def __matmul__(self, other: DataFrame) -> Series: ...
807+
@overload
808+
def __matmul__(self, other: np.ndarray) -> np.ndarray: ...
809+
@overload
810+
def __rmatmul__(self, other: Series) -> Scalar: ...
811+
@overload
812+
def __rmatmul__(self, other: DataFrame) -> Series: ...
813+
@overload
814+
def __rmatmul__(self, other: np.ndarray) -> np.ndarray: ...
805815
@overload
806816
def searchsorted(
807817
self,

tests/test_series.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,16 +1238,17 @@ def test_types_as_type() -> None:
12381238

12391239

12401240
def test_types_dot() -> None:
1241+
"""Test typing of multiplication methods (dot and @) for Series."""
12411242
s1 = pd.Series([0, 1, 2, 3])
12421243
s2 = pd.Series([-1, 2, -3, 4])
12431244
df1 = pd.DataFrame([[0, 1], [-2, 3], [4, -5], [6, 7]])
12441245
n1 = np.array([[0, 1], [1, 2], [-1, -1], [2, 0]])
1245-
sc1: Scalar = s1.dot(s2)
1246-
sc2: Scalar = s1 @ s2
1247-
s3: pd.Series = s1.dot(df1)
1248-
s4: pd.Series = s1 @ df1
1249-
n2: np.ndarray = s1.dot(n1)
1250-
n3: np.ndarray = s1 @ n1
1246+
check(assert_type(s1.dot(s2), Scalar), np.int64)
1247+
check(assert_type(s1 @ s2, Scalar), np.int64)
1248+
check(assert_type(s1.dot(df1), "pd.Series[int]"), pd.Series, np.int64)
1249+
check(assert_type(s1 @ df1, pd.Series), pd.Series)
1250+
check(assert_type(s1.dot(n1), np.ndarray), np.ndarray)
1251+
check(assert_type(s1 @ n1, np.ndarray), np.ndarray)
12511252

12521253

12531254
def test_series_loc_setitem() -> None:

0 commit comments

Comments
 (0)