Skip to content

Commit 70ee340

Browse files
GH1045 Split overload of groupby on as_index for all cases (#1046)
* GH1045 Split overload of groupby on as_index for all cases * GH1045 PR Feedback
1 parent 03396ef commit 70ee340

File tree

2 files changed

+123
-18
lines changed

2 files changed

+123
-18
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 86 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,7 @@ class DataFrame(NDFrame, OpsMixin):
11121112
dropna: _bool = ...,
11131113
) -> DataFrameGroupBy[Timestamp, Literal[True]]: ...
11141114
@overload
1115-
def groupby(
1115+
def groupby( # pyright: ignore reportOverlappingOverload
11161116
self,
11171117
by: DatetimeIndex,
11181118
axis: AxisIndex | NoDefault = ...,
@@ -1124,77 +1124,149 @@ class DataFrame(NDFrame, OpsMixin):
11241124
dropna: _bool = ...,
11251125
) -> DataFrameGroupBy[Timestamp, Literal[False]]: ...
11261126
@overload
1127-
def groupby(
1127+
def groupby( # pyright: ignore reportOverlappingOverload
11281128
self,
11291129
by: TimedeltaIndex,
11301130
axis: AxisIndex | NoDefault = ...,
11311131
level: IndexLabel | None = ...,
1132-
as_index: _bool = ...,
1132+
as_index: Literal[True] = True,
11331133
sort: _bool = ...,
11341134
group_keys: _bool = ...,
11351135
observed: _bool | NoDefault = ...,
11361136
dropna: _bool = ...,
1137-
) -> DataFrameGroupBy[Timedelta, bool]: ...
1137+
) -> DataFrameGroupBy[Timedelta, Literal[True]]: ...
11381138
@overload
11391139
def groupby(
1140+
self,
1141+
by: TimedeltaIndex,
1142+
axis: AxisIndex | NoDefault = ...,
1143+
level: IndexLabel | None = ...,
1144+
as_index: Literal[False] = ...,
1145+
sort: _bool = ...,
1146+
group_keys: _bool = ...,
1147+
observed: _bool | NoDefault = ...,
1148+
dropna: _bool = ...,
1149+
) -> DataFrameGroupBy[Timedelta, Literal[False]]: ...
1150+
@overload
1151+
def groupby( # pyright: ignore reportOverlappingOverload
11401152
self,
11411153
by: PeriodIndex,
11421154
axis: AxisIndex | NoDefault = ...,
11431155
level: IndexLabel | None = ...,
1144-
as_index: _bool = ...,
1156+
as_index: Literal[True] = True,
11451157
sort: _bool = ...,
11461158
group_keys: _bool = ...,
11471159
observed: _bool | NoDefault = ...,
11481160
dropna: _bool = ...,
1149-
) -> DataFrameGroupBy[Period, bool]: ...
1161+
) -> DataFrameGroupBy[Period, Literal[True]]: ...
11501162
@overload
11511163
def groupby(
1164+
self,
1165+
by: PeriodIndex,
1166+
axis: AxisIndex | NoDefault = ...,
1167+
level: IndexLabel | None = ...,
1168+
as_index: Literal[False] = ...,
1169+
sort: _bool = ...,
1170+
group_keys: _bool = ...,
1171+
observed: _bool | NoDefault = ...,
1172+
dropna: _bool = ...,
1173+
) -> DataFrameGroupBy[Period, Literal[False]]: ...
1174+
@overload
1175+
def groupby( # pyright: ignore reportOverlappingOverload
11521176
self,
11531177
by: IntervalIndex[IntervalT],
11541178
axis: AxisIndex | NoDefault = ...,
11551179
level: IndexLabel | None = ...,
1156-
as_index: _bool = ...,
1180+
as_index: Literal[True] = True,
11571181
sort: _bool = ...,
11581182
group_keys: _bool = ...,
11591183
observed: _bool | NoDefault = ...,
11601184
dropna: _bool = ...,
1161-
) -> DataFrameGroupBy[IntervalT, bool]: ...
1185+
) -> DataFrameGroupBy[IntervalT, Literal[True]]: ...
11621186
@overload
11631187
def groupby(
1188+
self,
1189+
by: IntervalIndex[IntervalT],
1190+
axis: AxisIndex | NoDefault = ...,
1191+
level: IndexLabel | None = ...,
1192+
as_index: Literal[False] = ...,
1193+
sort: _bool = ...,
1194+
group_keys: _bool = ...,
1195+
observed: _bool | NoDefault = ...,
1196+
dropna: _bool = ...,
1197+
) -> DataFrameGroupBy[IntervalT, Literal[False]]: ...
1198+
@overload
1199+
def groupby( # type: ignore[overload-overlap] # pyright: ignore reportOverlappingOverload
11641200
self,
11651201
by: MultiIndex | GroupByObjectNonScalar | None = ...,
11661202
axis: AxisIndex | NoDefault = ...,
11671203
level: IndexLabel | None = ...,
1168-
as_index: _bool = ...,
1204+
as_index: Literal[True] = True,
11691205
sort: _bool = ...,
11701206
group_keys: _bool = ...,
11711207
observed: _bool | NoDefault = ...,
11721208
dropna: _bool = ...,
1173-
) -> DataFrameGroupBy[tuple, bool]: ...
1209+
) -> DataFrameGroupBy[tuple, Literal[True]]: ...
1210+
@overload
1211+
def groupby( # type: ignore[overload-overlap]
1212+
self,
1213+
by: MultiIndex | GroupByObjectNonScalar | None = ...,
1214+
axis: AxisIndex | NoDefault = ...,
1215+
level: IndexLabel | None = ...,
1216+
as_index: Literal[False] = ...,
1217+
sort: _bool = ...,
1218+
group_keys: _bool = ...,
1219+
observed: _bool | NoDefault = ...,
1220+
dropna: _bool = ...,
1221+
) -> DataFrameGroupBy[tuple, Literal[False]]: ...
1222+
@overload
1223+
def groupby( # pyright: ignore reportOverlappingOverload
1224+
self,
1225+
by: Series[SeriesByT],
1226+
axis: AxisIndex | NoDefault = ...,
1227+
level: IndexLabel | None = ...,
1228+
as_index: Literal[True] = True,
1229+
sort: _bool = ...,
1230+
group_keys: _bool = ...,
1231+
observed: _bool | NoDefault = ...,
1232+
dropna: _bool = ...,
1233+
) -> DataFrameGroupBy[SeriesByT, Literal[True]]: ...
11741234
@overload
11751235
def groupby(
11761236
self,
11771237
by: Series[SeriesByT],
11781238
axis: AxisIndex | NoDefault = ...,
11791239
level: IndexLabel | None = ...,
1180-
as_index: _bool = ...,
1240+
as_index: Literal[False] = ...,
1241+
sort: _bool = ...,
1242+
group_keys: _bool = ...,
1243+
observed: _bool | NoDefault = ...,
1244+
dropna: _bool = ...,
1245+
) -> DataFrameGroupBy[SeriesByT, Literal[False]]: ...
1246+
@overload
1247+
def groupby(
1248+
self,
1249+
by: CategoricalIndex | Index | Series,
1250+
axis: AxisIndex | NoDefault = ...,
1251+
level: IndexLabel | None = ...,
1252+
as_index: Literal[True] = True,
11811253
sort: _bool = ...,
11821254
group_keys: _bool = ...,
11831255
observed: _bool | NoDefault = ...,
11841256
dropna: _bool = ...,
1185-
) -> DataFrameGroupBy[SeriesByT, bool]: ...
1257+
) -> DataFrameGroupBy[Any, Literal[True]]: ...
11861258
@overload
11871259
def groupby(
11881260
self,
11891261
by: CategoricalIndex | Index | Series,
11901262
axis: AxisIndex | NoDefault = ...,
11911263
level: IndexLabel | None = ...,
1192-
as_index: _bool = ...,
1264+
as_index: Literal[False] = ...,
11931265
sort: _bool = ...,
11941266
group_keys: _bool = ...,
11951267
observed: _bool | NoDefault = ...,
11961268
dropna: _bool = ...,
1197-
) -> DataFrameGroupBy[Any, bool]: ...
1269+
) -> DataFrameGroupBy[Any, Literal[False]]: ...
11981270
def pivot(
11991271
self,
12001272
*,

tests/test_frame.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,8 @@ def test_types_mean() -> None:
504504
s2: pd.Series = df.mean(axis=0)
505505
df2: pd.DataFrame = df.groupby(level=0).mean()
506506
if TYPE_CHECKING_INVALID_USAGE:
507-
df3: pd.DataFrame = df.groupby(axis=1, level=0).mean() # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
508-
df4: pd.DataFrame = df.groupby(axis=1, level=0, dropna=True).mean() # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
507+
df3: pd.DataFrame = df.groupby(axis=1, level=0).mean() # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
508+
df4: pd.DataFrame = df.groupby(axis=1, level=0, dropna=True).mean() # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
509509
s3: pd.Series = df.mean(axis=1, skipna=True, numeric_only=False)
510510

511511

@@ -515,8 +515,8 @@ def test_types_median() -> None:
515515
s2: pd.Series = df.median(axis=0)
516516
df2: pd.DataFrame = df.groupby(level=0).median()
517517
if TYPE_CHECKING_INVALID_USAGE:
518-
df3: pd.DataFrame = df.groupby(axis=1, level=0).median() # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
519-
df4: pd.DataFrame = df.groupby(axis=1, level=0, dropna=True).median() # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
518+
df3: pd.DataFrame = df.groupby(axis=1, level=0).median() # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
519+
df4: pd.DataFrame = df.groupby(axis=1, level=0, dropna=True).median() # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
520520
s3: pd.Series = df.median(axis=1, skipna=True, numeric_only=False)
521521

522522

@@ -1064,6 +1064,39 @@ def test_types_groupby_as_index() -> None:
10641064
),
10651065
pd.Series,
10661066
)
1067+
check(
1068+
assert_type(
1069+
df.groupby("a").size(),
1070+
"pd.Series[int]",
1071+
),
1072+
pd.Series,
1073+
)
1074+
1075+
1076+
def test_types_groupby_as_index_list() -> None:
1077+
"""Test type of groupby.size method depending on list of grouper GH1045."""
1078+
df = pd.DataFrame({"a": [1, 1, 2], "b": [2, 3, 2]})
1079+
check(
1080+
assert_type(
1081+
df.groupby(["a", "b"], as_index=False).size(),
1082+
pd.DataFrame,
1083+
),
1084+
pd.DataFrame,
1085+
)
1086+
check(
1087+
assert_type(
1088+
df.groupby(["a", "b"], as_index=True).size(),
1089+
"pd.Series[int]",
1090+
),
1091+
pd.Series,
1092+
)
1093+
check(
1094+
assert_type(
1095+
df.groupby(["a", "b"]).size(),
1096+
"pd.Series[int]",
1097+
),
1098+
pd.Series,
1099+
)
10671100

10681101

10691102
def test_types_groupby_as_index_value_counts() -> None:

0 commit comments

Comments
 (0)