Skip to content

Commit d23c4bb

Browse files
authored
Add overload for DataFrameGroupBy.groupby("size") return Series (#739)
* Add overload for DataFrameGroupBy.groupby("size") return Series * Switch SeriesGroupBy.agg to assignment * Move tests to test_types_groupby_agg * Change error comment to different return types
1 parent 6fe90bb commit d23c4bb

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

pandas-stubs/core/groupby/generic.pyi

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,7 @@ class SeriesGroupBy(GroupBy, Generic[S1, ByT]):
5454
def aggregate(self, func: list[AggFuncTypeBase], *args, **kwargs) -> DataFrame: ...
5555
@overload
5656
def aggregate(self, func: AggFuncTypeBase, *args, **kwargs) -> Series: ...
57-
@overload
58-
def agg(self, func: list[AggFuncTypeBase], *args, **kwargs) -> DataFrame: ...
59-
@overload
60-
def agg(self, func: AggFuncTypeBase, *args, **kwargs) -> Series: ...
57+
agg = aggregate
6158
def transform(self, func: Callable | str, *args, **kwargs) -> Series: ...
6259
def filter(self, func, dropna: bool = ..., *args, **kwargs): ...
6360
def nunique(self, dropna: bool = ...) -> Series: ...
@@ -159,6 +156,10 @@ class DataFrameGroupBy(GroupBy, Generic[ByT]):
159156
def apply( # pyright: ignore[reportOverlappingOverload]
160157
self, func: Callable[[Iterable], float], *args, **kwargs
161158
) -> DataFrame: ...
159+
# error: overload 1 overlaps overload 2 because of different return types
160+
@overload
161+
def aggregate(self, arg: Literal["size"]) -> Series: ... # type: ignore[misc] # pyright: ignore[reportOverlappingOverload]
162+
@overload
162163
def aggregate(self, arg: AggFuncTypeFrame = ..., *args, **kwargs) -> DataFrame: ...
163164
agg = aggregate
164165
def transform(self, func: Callable | str, *args, **kwargs) -> DataFrame: ...

tests/test_frame.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,9 @@ def wrapped_min(x: Any) -> Any:
951951

952952
cols_mixed: list[str | int] = ["col1", 0]
953953
check(assert_type(df.groupby(by=cols_mixed).sum(), pd.DataFrame), pd.DataFrame)
954+
# GH 736
955+
check(assert_type(df.groupby(by="col1").aggregate("size"), pd.Series), pd.Series)
956+
check(assert_type(df.groupby(by="col1").agg("size"), pd.Series), pd.Series)
954957

955958

956959
# This was added in 1.1.0 https://pandas.pydata.org/docs/whatsnew/v1.1.0.html

0 commit comments

Comments
 (0)