Skip to content

Commit f570ba2

Browse files
authored
TYP: SelectionMixin (#41384)
1 parent e3a9618 commit f570ba2

File tree

6 files changed

+33
-55
lines changed

6 files changed

+33
-55
lines changed

pandas/_typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from pandas.core.generic import NDFrame
5757
from pandas.core.groupby.generic import (
5858
DataFrameGroupBy,
59+
GroupBy,
5960
SeriesGroupBy,
6061
)
6162
from pandas.core.indexes.base import Index
@@ -158,6 +159,7 @@
158159
AggObjType = Union[
159160
"Series",
160161
"DataFrame",
162+
"GroupBy",
161163
"SeriesGroupBy",
162164
"DataFrameGroupBy",
163165
"BaseWindow",

pandas/core/apply.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
AggFuncTypeDict,
2525
AggObjType,
2626
Axis,
27+
FrameOrSeries,
2728
FrameOrSeriesUnion,
2829
)
2930
from pandas.util._decorators import cache_readonly
@@ -60,10 +61,7 @@
6061
Index,
6162
Series,
6263
)
63-
from pandas.core.groupby import (
64-
DataFrameGroupBy,
65-
SeriesGroupBy,
66-
)
64+
from pandas.core.groupby import GroupBy
6765
from pandas.core.resample import Resampler
6866
from pandas.core.window.rolling import BaseWindow
6967

@@ -1089,11 +1087,9 @@ def apply_standard(self) -> FrameOrSeriesUnion:
10891087

10901088

10911089
class GroupByApply(Apply):
1092-
obj: SeriesGroupBy | DataFrameGroupBy
1093-
10941090
def __init__(
10951091
self,
1096-
obj: SeriesGroupBy | DataFrameGroupBy,
1092+
obj: GroupBy[FrameOrSeries],
10971093
func: AggFuncType,
10981094
args,
10991095
kwargs,

pandas/core/base.py

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from typing import (
99
TYPE_CHECKING,
1010
Any,
11+
Generic,
12+
Hashable,
1113
TypeVar,
1214
cast,
1315
)
@@ -19,6 +21,7 @@
1921
ArrayLike,
2022
Dtype,
2123
DtypeObj,
24+
FrameOrSeries,
2225
IndexLabel,
2326
Shape,
2427
final,
@@ -165,13 +168,15 @@ class SpecificationError(Exception):
165168
pass
166169

167170

168-
class SelectionMixin:
171+
class SelectionMixin(Generic[FrameOrSeries]):
169172
"""
170173
mixin implementing the selection & aggregation interface on a group-like
171174
object sub-classes need to define: obj, exclusions
172175
"""
173176

177+
obj: FrameOrSeries
174178
_selection: IndexLabel | None = None
179+
exclusions: frozenset[Hashable]
175180
_internal_names = ["_cache", "__setstate__"]
176181
_internal_names_set = set(_internal_names)
177182

@@ -196,15 +201,10 @@ def _selection_list(self):
196201

197202
@cache_readonly
198203
def _selected_obj(self):
199-
# error: "SelectionMixin" has no attribute "obj"
200-
if self._selection is None or isinstance(
201-
self.obj, ABCSeries # type: ignore[attr-defined]
202-
):
203-
# error: "SelectionMixin" has no attribute "obj"
204-
return self.obj # type: ignore[attr-defined]
204+
if self._selection is None or isinstance(self.obj, ABCSeries):
205+
return self.obj
205206
else:
206-
# error: "SelectionMixin" has no attribute "obj"
207-
return self.obj[self._selection] # type: ignore[attr-defined]
207+
return self.obj[self._selection]
208208

209209
@cache_readonly
210210
def ndim(self) -> int:
@@ -213,49 +213,31 @@ def ndim(self) -> int:
213213
@final
214214
@cache_readonly
215215
def _obj_with_exclusions(self):
216-
# error: "SelectionMixin" has no attribute "obj"
217-
if self._selection is not None and isinstance(
218-
self.obj, ABCDataFrame # type: ignore[attr-defined]
219-
):
220-
# error: "SelectionMixin" has no attribute "obj"
221-
return self.obj.reindex( # type: ignore[attr-defined]
222-
columns=self._selection_list
223-
)
216+
if self._selection is not None and isinstance(self.obj, ABCDataFrame):
217+
return self.obj.reindex(columns=self._selection_list)
224218

225-
# error: "SelectionMixin" has no attribute "exclusions"
226-
if len(self.exclusions) > 0: # type: ignore[attr-defined]
227-
# error: "SelectionMixin" has no attribute "obj"
228-
# error: "SelectionMixin" has no attribute "exclusions"
229-
return self.obj.drop(self.exclusions, axis=1) # type: ignore[attr-defined]
219+
if len(self.exclusions) > 0:
220+
return self.obj.drop(self.exclusions, axis=1)
230221
else:
231-
# error: "SelectionMixin" has no attribute "obj"
232-
return self.obj # type: ignore[attr-defined]
222+
return self.obj
233223

234224
def __getitem__(self, key):
235225
if self._selection is not None:
236226
raise IndexError(f"Column(s) {self._selection} already selected")
237227

238228
if isinstance(key, (list, tuple, ABCSeries, ABCIndex, np.ndarray)):
239-
# error: "SelectionMixin" has no attribute "obj"
240-
if len(
241-
self.obj.columns.intersection(key) # type: ignore[attr-defined]
242-
) != len(key):
243-
# error: "SelectionMixin" has no attribute "obj"
244-
bad_keys = list(
245-
set(key).difference(self.obj.columns) # type: ignore[attr-defined]
246-
)
229+
if len(self.obj.columns.intersection(key)) != len(key):
230+
bad_keys = list(set(key).difference(self.obj.columns))
247231
raise KeyError(f"Columns not found: {str(bad_keys)[1:-1]}")
248232
return self._gotitem(list(key), ndim=2)
249233

250234
elif not getattr(self, "as_index", False):
251-
# error: "SelectionMixin" has no attribute "obj"
252-
if key not in self.obj.columns: # type: ignore[attr-defined]
235+
if key not in self.obj.columns:
253236
raise KeyError(f"Column not found: {key}")
254237
return self._gotitem(key, ndim=2)
255238

256239
else:
257-
# error: "SelectionMixin" has no attribute "obj"
258-
if key not in self.obj: # type: ignore[attr-defined]
240+
if key not in self.obj:
259241
raise KeyError(f"Column not found: {key}")
260242
return self._gotitem(key, ndim=1)
261243

pandas/core/groupby/groupby.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ class providing the base-class of operations.
2020
from typing import (
2121
TYPE_CHECKING,
2222
Callable,
23-
Generic,
2423
Hashable,
2524
Iterable,
2625
Iterator,
@@ -567,7 +566,7 @@ def group_selection_context(groupby: GroupBy) -> Iterator[GroupBy]:
567566
]
568567

569568

570-
class BaseGroupBy(PandasObject, SelectionMixin, Generic[FrameOrSeries]):
569+
class BaseGroupBy(PandasObject, SelectionMixin[FrameOrSeries]):
571570
_group_selection: IndexLabel | None = None
572571
_apply_allowlist: frozenset[str] = frozenset()
573572
_hidden_attrs = PandasObject._hidden_attrs | {
@@ -588,7 +587,6 @@ class BaseGroupBy(PandasObject, SelectionMixin, Generic[FrameOrSeries]):
588587

589588
axis: int
590589
grouper: ops.BaseGrouper
591-
obj: FrameOrSeries
592590
group_keys: bool
593591

594592
@final
@@ -840,7 +838,6 @@ class GroupBy(BaseGroupBy[FrameOrSeries]):
840838
more
841839
"""
842840

843-
obj: FrameOrSeries
844841
grouper: ops.BaseGrouper
845842
as_index: bool
846843

@@ -852,7 +849,7 @@ def __init__(
852849
axis: int = 0,
853850
level: IndexLabel | None = None,
854851
grouper: ops.BaseGrouper | None = None,
855-
exclusions: set[Hashable] | None = None,
852+
exclusions: frozenset[Hashable] | None = None,
856853
selection: IndexLabel | None = None,
857854
as_index: bool = True,
858855
sort: bool = True,
@@ -901,7 +898,7 @@ def __init__(
901898
self.obj = obj
902899
self.axis = obj._get_axis_number(axis)
903900
self.grouper = grouper
904-
self.exclusions = exclusions or set()
901+
self.exclusions = frozenset(exclusions) if exclusions else frozenset()
905902

906903
def __getattr__(self, attr: str):
907904
if attr in self._internal_names_set:

pandas/core/groupby/grouper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def get_grouper(
652652
mutated: bool = False,
653653
validate: bool = True,
654654
dropna: bool = True,
655-
) -> tuple[ops.BaseGrouper, set[Hashable], FrameOrSeries]:
655+
) -> tuple[ops.BaseGrouper, frozenset[Hashable], FrameOrSeries]:
656656
"""
657657
Create and return a BaseGrouper, which is an internal
658658
mapping of how to create the grouper indexers.
@@ -728,13 +728,13 @@ def get_grouper(
728728
if isinstance(key, Grouper):
729729
binner, grouper, obj = key._get_grouper(obj, validate=False)
730730
if key.key is None:
731-
return grouper, set(), obj
731+
return grouper, frozenset(), obj
732732
else:
733-
return grouper, {key.key}, obj
733+
return grouper, frozenset({key.key}), obj
734734

735735
# already have a BaseGrouper, just return it
736736
elif isinstance(key, ops.BaseGrouper):
737-
return key, set(), obj
737+
return key, frozenset(), obj
738738

739739
if not isinstance(key, list):
740740
keys = [key]
@@ -861,7 +861,7 @@ def is_in_obj(gpr) -> bool:
861861
grouper = ops.BaseGrouper(
862862
group_axis, groupings, sort=sort, mutated=mutated, dropna=dropna
863863
)
864-
return grouper, exclusions, obj
864+
return grouper, frozenset(exclusions), obj
865865

866866

867867
def _is_label_like(val) -> bool:

pandas/core/window/rolling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
TYPE_CHECKING,
1414
Any,
1515
Callable,
16+
Hashable,
1617
)
1718
import warnings
1819

@@ -109,7 +110,7 @@ class BaseWindow(SelectionMixin):
109110
"""Provides utilities for performing windowing operations."""
110111

111112
_attributes: list[str] = []
112-
exclusions: set[str] = set()
113+
exclusions: frozenset[Hashable] = frozenset()
113114

114115
def __init__(
115116
self,

0 commit comments

Comments
 (0)