Skip to content

Commit 650ab62

Browse files
Dr-Irvtwoertwein
authored andcommitted
add operators for Index (pandas-dev#504)
* add operators for Index * Fix OpsMixin using TypeVar * detect Never as argument
1 parent 8c3be26 commit 650ab62

File tree

7 files changed

+133
-75
lines changed

7 files changed

+133
-75
lines changed

pandas-stubs/core/arraylike.pyi

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,40 @@
11
from typing import (
22
Any,
3-
Protocol,
3+
TypeVar,
44
)
55

6-
from pandas import DataFrame
7-
8-
class OpsMixinProtocol(Protocol): ...
6+
_OpsMixinT = TypeVar("_OpsMixinT", bound=OpsMixin)
97

108
class OpsMixin:
11-
def __eq__(self: OpsMixinProtocol, other: object) -> DataFrame: ... # type: ignore[override]
12-
def __ne__(self: OpsMixinProtocol, other: object) -> DataFrame: ... # type: ignore[override]
13-
def __lt__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
14-
def __le__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
15-
def __gt__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
16-
def __ge__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
9+
def __eq__(self: _OpsMixinT, other: object) -> _OpsMixinT: ... # type: ignore[override]
10+
def __ne__(self: _OpsMixinT, other: object) -> _OpsMixinT: ... # type: ignore[override]
11+
def __lt__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
12+
def __le__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
13+
def __gt__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
14+
def __ge__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
1715
# -------------------------------------------------------------
1816
# Logical Methods
19-
def __and__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
20-
def __rand__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
21-
def __or__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
22-
def __ror__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
23-
def __xor__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
24-
def __rxor__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
17+
def __and__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
18+
def __rand__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
19+
def __or__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
20+
def __ror__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
21+
def __xor__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
22+
def __rxor__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
2523
# -------------------------------------------------------------
2624
# Arithmetic Methods
27-
def __add__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
28-
def __radd__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
29-
def __sub__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
30-
def __rsub__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
31-
def __mul__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
32-
def __rmul__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
33-
def __truediv__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
34-
def __rtruediv__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
35-
def __floordiv__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
36-
def __rfloordiv__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
37-
def __mod__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
38-
def __rmod__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
39-
def __divmod__(
40-
self: OpsMixinProtocol, other: DataFrame
41-
) -> tuple[DataFrame, DataFrame]: ...
42-
def __rdivmod__(
43-
self: OpsMixinProtocol, other: DataFrame
44-
) -> tuple[DataFrame, DataFrame]: ...
45-
def __pow__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
46-
def __rpow__(self: OpsMixinProtocol, other: Any) -> DataFrame: ...
25+
def __add__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
26+
def __radd__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
27+
def __sub__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
28+
def __rsub__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
29+
def __mul__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
30+
def __rmul__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
31+
def __truediv__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
32+
def __rtruediv__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
33+
def __floordiv__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
34+
def __rfloordiv__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
35+
def __mod__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
36+
def __rmod__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
37+
def __divmod__(self: _OpsMixinT, other: Any) -> tuple[_OpsMixinT, _OpsMixinT]: ...
38+
def __rdivmod__(self: _OpsMixinT, other: Any) -> tuple[_OpsMixinT, _OpsMixinT]: ...
39+
def __pow__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...
40+
def __rpow__(self: _OpsMixinT, other: Any) -> _OpsMixinT: ...

pandas-stubs/core/base.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ from typing import (
66
import numpy as np
77
from pandas import Index
88
from pandas.core.accessor import DirNamesMixin
9+
from pandas.core.arraylike import OpsMixin
910
from pandas.core.arrays import ExtensionArray
1011
from pandas.core.arrays.categorical import Categorical
1112

@@ -27,7 +28,7 @@ class SelectionMixin(Generic[NDFrameT]):
2728
def ndim(self) -> int: ...
2829
def __getitem__(self, key): ...
2930

30-
class IndexOpsMixin:
31+
class IndexOpsMixin(OpsMixin):
3132
__array_priority__: int = ...
3233
def transpose(self, *args, **kwargs) -> IndexOpsMixin: ...
3334
@property

pandas-stubs/core/indexes/base.pyi

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ from pandas.core.base import (
2424
)
2525
from pandas.core.indexes.numeric import NumericIndex
2626
from pandas.core.strings import StringMethods
27+
from typing_extensions import Never
2728

2829
from pandas._typing import (
2930
T1,
@@ -147,14 +148,12 @@ class Index(IndexOpsMixin, PandasObject):
147148
def duplicated(
148149
self, keep: Literal["first", "last", False] = ...
149150
) -> np_ndarray_bool: ...
150-
def __add__(self, other) -> Index: ...
151-
def __radd__(self, other) -> Index: ...
152-
def __iadd__(self, other) -> Index: ...
153-
def __sub__(self, other) -> Index: ...
154-
def __rsub__(self, other) -> Index: ...
155-
def __and__(self, other) -> Index: ...
156-
def __or__(self, other) -> Index: ...
157-
def __xor__(self, other) -> Index: ...
151+
def __and__(self, other: Never) -> Never: ...
152+
def __rand__(self, other: Never) -> Never: ...
153+
def __or__(self, other: Never) -> Never: ...
154+
def __ror__(self, other: Never) -> Never: ...
155+
def __xor__(self, other: Never) -> Never: ...
156+
def __rxor__(self, other: Never) -> Never: ...
158157
def __neg__(self: IndexT) -> IndexT: ...
159158
def __nonzero__(self) -> None: ...
160159
__bool__ = ...
@@ -226,10 +225,10 @@ class Index(IndexOpsMixin, PandasObject):
226225
def __eq__(self, other: object) -> np_ndarray_bool: ... # type: ignore[override]
227226
def __iter__(self) -> Iterator[IndexIterScalar | tuple[Hashable, ...]]: ...
228227
def __ne__(self, other: object) -> np_ndarray_bool: ... # type: ignore[override]
229-
def __le__(self, other: Index | Scalar) -> np_ndarray_bool: ...
230-
def __ge__(self, other: Index | Scalar) -> np_ndarray_bool: ...
231-
def __lt__(self, other: Index | Scalar) -> np_ndarray_bool: ...
232-
def __gt__(self, other: Index | Scalar) -> np_ndarray_bool: ...
228+
def __le__(self, other: Index | Scalar) -> np_ndarray_bool: ... # type: ignore[override]
229+
def __ge__(self, other: Index | Scalar) -> np_ndarray_bool: ... # type: ignore[override]
230+
def __lt__(self, other: Index | Scalar) -> np_ndarray_bool: ... # type: ignore[override]
231+
def __gt__(self, other: Index | Scalar) -> np_ndarray_bool: ... # type: ignore[override]
233232

234233
def ensure_index_from_sequences(
235234
sequences: Sequence[Sequence[Dtype]], names: list[str] = ...

pandas-stubs/core/series.pyi

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
12781278
self, other: num | _str | Timedelta | _ListLike | Series[S1] | np.timedelta64
12791279
) -> Series: ...
12801280
# ignore needed for mypy as we want different results based on the arguments
1281-
@overload
1281+
@overload # type: ignore[override]
12821282
def __and__( # type: ignore[misc]
12831283
self, other: bool | list[bool] | np_ndarray_bool | Series[bool]
12841284
) -> Series[bool]: ...
@@ -1289,29 +1289,17 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
12891289
# def __array__(self, dtype: Optional[_bool] = ...) -> _np_ndarray
12901290
def __div__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
12911291
def __eq__(self, other: object) -> Series[_bool]: ... # type: ignore[override]
1292-
def __floordiv__(self, other: num | _ListLike | Series[S1]) -> Series[int]: ...
1293-
def __ge__(
1292+
def __floordiv__(self, other: num | _ListLike | Series[S1]) -> Series[int]: ... # type: ignore[override]
1293+
def __ge__( # type: ignore[override]
12941294
self, other: S1 | _ListLike | Series[S1] | datetime | timedelta
12951295
) -> Series[_bool]: ...
1296-
def __gt__(
1296+
def __gt__( # type: ignore[override]
12971297
self, other: S1 | _ListLike | Series[S1] | datetime | timedelta
12981298
) -> Series[_bool]: ...
1299-
# def __iadd__(self, other: S1) -> Series[S1]: ...
1300-
# def __iand__(self, other: S1) -> Series[_bool]: ...
1301-
# def __idiv__(self, other: S1) -> Series[S1]: ...
1302-
# def __ifloordiv__(self, other: S1) -> Series[S1]: ...
1303-
# def __imod__(self, other: S1) -> Series[S1]: ...
1304-
# def __imul__(self, other: S1) -> Series[S1]: ...
1305-
# def __ior__(self, other: S1) -> Series[_bool]: ...
1306-
# def __ipow__(self, other: S1) -> Series[S1]: ...
1307-
# def __isub__(self, other: S1) -> Series[S1]: ...
1308-
# def __itruediv__(self, other: S1) -> Series[S1]: ...
1309-
# def __itruediv__(self, other) -> None: ...
1310-
# def __ixor__(self, other: S1) -> Series[_bool]: ...
1311-
def __le__(
1299+
def __le__( # type: ignore[override]
13121300
self, other: S1 | _ListLike | Series[S1] | datetime | timedelta
13131301
) -> Series[_bool]: ...
1314-
def __lt__(
1302+
def __lt__( # type: ignore[override]
13151303
self, other: S1 | _ListLike | Series[S1] | datetime | timedelta
13161304
) -> Series[_bool]: ...
13171305
@overload
@@ -1324,7 +1312,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
13241312
def __ne__(self, other: object) -> Series[_bool]: ... # type: ignore[override]
13251313
def __pow__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
13261314
# ignore needed for mypy as we want different results based on the arguments
1327-
@overload
1315+
@overload # type: ignore[override]
13281316
def __or__( # type: ignore[misc]
13291317
self, other: bool | list[bool] | np_ndarray_bool | Series[bool]
13301318
) -> Series[bool]: ...
@@ -1334,7 +1322,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
13341322
) -> Series[int]: ...
13351323
def __radd__(self, other: num | _str | _ListLike | Series[S1]) -> Series[S1]: ...
13361324
# ignore needed for mypy as we want different results based on the arguments
1337-
@overload
1325+
@overload # type: ignore[override]
13381326
def __rand__( # type: ignore[misc]
13391327
self, other: bool | list[bool] | np_ndarray_bool | Series[bool]
13401328
) -> Series[bool]: ...
@@ -1343,14 +1331,14 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
13431331
self, other: int | list[int] | np_ndarray_anyint | Series[int]
13441332
) -> Series[int]: ...
13451333
def __rdiv__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
1346-
def __rdivmod__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
1334+
def __rdivmod__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ... # type: ignore[override]
13471335
def __rfloordiv__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
13481336
def __rmod__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
13491337
def __rmul__(self, other: num | _ListLike | Series) -> Series: ...
13501338
def __rnatmul__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
13511339
def __rpow__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
13521340
# ignore needed for mypy as we want different results based on the arguments
1353-
@overload
1341+
@overload # type: ignore[override]
13541342
def __ror__( # type: ignore[misc]
13551343
self, other: bool | list[bool] | np_ndarray_bool | Series[bool]
13561344
) -> Series[bool]: ...
@@ -1364,7 +1352,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
13641352
@overload
13651353
def __rtruediv__(self, other: num | _ListLike | Series[S1]) -> Series: ...
13661354
# ignore needed for mypy as we want different results based on the arguments
1367-
@overload
1355+
@overload # type: ignore[override]
13681356
def __rxor__( # type: ignore[misc]
13691357
self, other: bool | list[bool] | np_ndarray_bool | Series[bool]
13701358
) -> Series[bool]: ...
@@ -1390,7 +1378,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
13901378
def __sub__(self, other: num | _ListLike | Series) -> Series: ...
13911379
def __truediv__(self, other: num | _ListLike | Series[S1]) -> Series: ...
13921380
# ignore needed for mypy as we want different results based on the arguments
1393-
@overload
1381+
@overload # type: ignore[override]
13941382
def __xor__( # type: ignore[misc]
13951383
self, other: bool | list[bool] | np_ndarray_bool | Series[bool]
13961384
) -> Series[bool]: ...

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ poethepoet = ">=0.16.5"
4343
loguru = ">=0.6.0"
4444
pandas = "1.5.3"
4545
numpy = ">=1.24.1"
46-
typing-extensions = ">=4.2.0"
46+
typing-extensions = ">=4.4.0"
4747
matplotlib = ">=3.5.1"
4848
pre-commit = ">=2.19.0"
4949
black = ">=22.12.0"

tests/test_indexes.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,18 @@
1313
from numpy import typing as npt
1414
import pandas as pd
1515
from pandas.core.indexes.numeric import NumericIndex
16-
from typing_extensions import assert_type
16+
from typing_extensions import (
17+
Never,
18+
assert_type,
19+
)
1720

1821
from pandas._typing import Scalar
1922

2023
if TYPE_CHECKING:
2124
from pandas._typing import IndexIterScalar
2225

2326
from tests import (
27+
TYPE_CHECKING_INVALID_USAGE,
2428
check,
2529
pytest_warns_bounded,
2630
)
@@ -684,3 +688,72 @@ def test_sorted_and_list() -> None:
684688
),
685689
list,
686690
)
691+
692+
693+
def test_index_operators() -> None:
694+
# GH 405
695+
i1 = pd.Index([1, 2, 3])
696+
i2 = pd.Index([4, 5, 6])
697+
698+
check(assert_type(i1 + i2, pd.Index), pd.Index)
699+
check(assert_type(i1 + 10, pd.Index), pd.Index)
700+
check(assert_type(10 + i1, pd.Index), pd.Index)
701+
check(assert_type(i1 - i2, pd.Index), pd.Index)
702+
check(assert_type(i1 - 10, pd.Index), pd.Index)
703+
check(assert_type(10 - i1, pd.Index), pd.Index)
704+
check(assert_type(i1 * i2, pd.Index), pd.Index)
705+
check(assert_type(i1 * 10, pd.Index), pd.Index)
706+
check(assert_type(10 * i1, pd.Index), pd.Index)
707+
check(assert_type(i1 / i2, pd.Index), pd.Index)
708+
check(assert_type(i1 / 10, pd.Index), pd.Index)
709+
check(assert_type(10 / i1, pd.Index), pd.Index)
710+
check(assert_type(i1 // i2, pd.Index), pd.Index)
711+
check(assert_type(i1 // 10, pd.Index), pd.Index)
712+
check(assert_type(10 // i1, pd.Index), pd.Index)
713+
check(assert_type(i1**i2, pd.Index), pd.Index)
714+
check(assert_type(i1**2, pd.Index), pd.Index)
715+
check(assert_type(2**i1, pd.Index), pd.Index)
716+
check(assert_type(i1 % i2, pd.Index), pd.Index)
717+
check(assert_type(i1 % 10, pd.Index), pd.Index)
718+
check(assert_type(10 % i1, pd.Index), pd.Index)
719+
check(assert_type(divmod(i1, i2), Tuple[pd.Index, pd.Index]), tuple)
720+
check(assert_type(divmod(i1, 10), Tuple[pd.Index, pd.Index]), tuple)
721+
check(assert_type(divmod(10, i1), Tuple[pd.Index, pd.Index]), tuple)
722+
723+
if TYPE_CHECKING_INVALID_USAGE:
724+
assert_type(
725+
i1 & i2, # type:ignore[operator] # pyright: ignore[reportGeneralTypeIssues]
726+
Never,
727+
)
728+
assert_type( # type: ignore[assert-type]
729+
i1 & 10, # type:ignore[operator] # pyright: ignore[reportGeneralTypeIssues]
730+
Never,
731+
)
732+
assert_type( # type: ignore[assert-type]
733+
10 & i1, # type:ignore[operator] # pyright: ignore[reportGeneralTypeIssues]
734+
Never,
735+
)
736+
assert_type(
737+
i1 | i2, # type:ignore[operator] # pyright: ignore[reportGeneralTypeIssues]
738+
Never,
739+
)
740+
assert_type( # type: ignore[assert-type]
741+
i1 | 10, # type:ignore[operator] # pyright: ignore[reportGeneralTypeIssues]
742+
Never,
743+
)
744+
assert_type( # type: ignore[assert-type]
745+
10 | i1, # type:ignore[operator] # pyright: ignore[reportGeneralTypeIssues]
746+
Never,
747+
)
748+
assert_type(
749+
i1 ^ i2, # type:ignore[operator] # pyright: ignore[reportGeneralTypeIssues]
750+
Never,
751+
)
752+
assert_type( # type: ignore[assert-type]
753+
i1 ^ 10, # type:ignore[operator] # pyright: ignore[reportGeneralTypeIssues]
754+
Never,
755+
)
756+
assert_type( # type: ignore[assert-type]
757+
10 ^ i1, # type:ignore[operator] # pyright: ignore[reportGeneralTypeIssues]
758+
Never,
759+
)

tests/test_series.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Iterator,
1313
List,
1414
Sequence,
15+
Tuple,
1516
TypeVar,
1617
cast,
1718
)
@@ -452,6 +453,8 @@ def test_types_element_wise_arithmetic() -> None:
452453
res_pow: pd.Series = s ** s2.abs()
453454
res_pow2: pd.Series = s.pow(s2.abs(), fill_value=0)
454455

456+
check(assert_type(divmod(s, s2), Tuple[pd.Series, pd.Series]), tuple)
457+
455458

456459
def test_types_scalar_arithmetic() -> None:
457460
s = pd.Series([0, 1, -10])

0 commit comments

Comments
 (0)