Skip to content

Commit 2331098

Browse files
authored
REF: Share py_fallback (#41289)
1 parent 5272b56 commit 2331098

File tree

2 files changed

+59
-79
lines changed

2 files changed

+59
-79
lines changed

pandas/core/groupby/generic.py

Lines changed: 10 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
validate_func_kwargs,
7070
)
7171
from pandas.core.apply import GroupByApply
72-
from pandas.core.arrays import Categorical
7372
from pandas.core.base import (
7473
DataError,
7574
SpecificationError,
@@ -84,7 +83,6 @@
8483
_agg_template,
8584
_apply_docs,
8685
_transform_template,
87-
get_groupby,
8886
group_selection_context,
8987
)
9088
from pandas.core.indexes.api import (
@@ -353,6 +351,7 @@ def _cython_agg_general(
353351

354352
obj = self._selected_obj
355353
objvals = obj._values
354+
data = obj._mgr
356355

357356
if numeric_only and not is_numeric_dtype(obj.dtype):
358357
raise DataError("No numeric types to aggregate")
@@ -362,28 +361,15 @@ def _cython_agg_general(
362361
def array_func(values: ArrayLike) -> ArrayLike:
363362
try:
364363
result = self.grouper._cython_operation(
365-
"aggregate", values, how, axis=0, min_count=min_count
364+
"aggregate", values, how, axis=data.ndim - 1, min_count=min_count
366365
)
367366
except NotImplementedError:
368-
ser = Series(values) # equiv 'obj' from outer frame
369-
if self.ngroups > 0:
370-
res_values, _ = self.grouper.agg_series(ser, alt)
371-
else:
372-
# equiv: res_values = self._python_agg_general(alt)
373-
# error: Incompatible types in assignment (expression has
374-
# type "Union[DataFrame, Series]", variable has type
375-
# "Union[ExtensionArray, ndarray]")
376-
res_values = self._python_apply_general( # type: ignore[assignment]
377-
alt, ser
378-
)
367+
# generally if we have numeric_only=False
368+
# and non-applicable functions
369+
# try to python agg
370+
# TODO: shouldn't min_count matter?
371+
result = self._agg_py_fallback(values, ndim=data.ndim, alt=alt)
379372

380-
if isinstance(values, Categorical):
381-
# Because we only get here with known dtype-preserving
382-
# reductions, we cast back to Categorical.
383-
# TODO: if we ever get "rank" working, exclude it here.
384-
result = type(values)._from_sequence(res_values, dtype=values.dtype)
385-
else:
386-
result = res_values
387373
return result
388374

389375
result = array_func(objvals)
@@ -1116,72 +1102,17 @@ def _cython_agg_general(
11161102
if numeric_only:
11171103
data = data.get_numeric_data(copy=False)
11181104

1119-
def cast_agg_result(result: ArrayLike, values: ArrayLike) -> ArrayLike:
1120-
# see if we can cast the values to the desired dtype
1121-
# this may not be the original dtype
1122-
1123-
if isinstance(result.dtype, np.dtype) and result.ndim == 1:
1124-
# We went through a SeriesGroupByPath and need to reshape
1125-
# GH#32223 includes case with IntegerArray values
1126-
# We only get here with values.dtype == object
1127-
result = result.reshape(1, -1)
1128-
# test_groupby_duplicate_columns gets here with
1129-
# result.dtype == int64, values.dtype=object, how="min"
1130-
1131-
return result
1132-
1133-
def py_fallback(values: ArrayLike) -> ArrayLike:
1134-
# if self.grouper.aggregate fails, we fall back to a pure-python
1135-
# solution
1136-
1137-
# We get here with a) EADtypes and b) object dtype
1138-
obj: FrameOrSeriesUnion
1139-
1140-
# call our grouper again with only this block
1141-
if values.ndim == 1:
1142-
# We only get here with ExtensionArray
1143-
1144-
obj = Series(values)
1145-
else:
1146-
# We only get here with values.dtype == object
1147-
# TODO special case not needed with ArrayManager
1148-
df = DataFrame(values.T)
1149-
# bc we split object blocks in grouped_reduce, we have only 1 col
1150-
# otherwise we'd have to worry about block-splitting GH#39329
1151-
assert df.shape[1] == 1
1152-
# Avoid call to self.values that can occur in DataFrame
1153-
# reductions; see GH#28949
1154-
obj = df.iloc[:, 0]
1155-
1156-
# Create SeriesGroupBy with observed=True so that it does
1157-
# not try to add missing categories if grouping over multiple
1158-
# Categoricals. This will done by later self._reindex_output()
1159-
# Doing it here creates an error. See GH#34951
1160-
sgb = get_groupby(obj, self.grouper, observed=True)
1161-
1162-
# Note: bc obj is always a Series here, we can ignore axis and pass
1163-
# `alt` directly instead of `lambda x: alt(x, axis=self.axis)`
1164-
# use _agg_general bc it will go through _cython_agg_general
1165-
# which will correctly cast Categoricals.
1166-
res_ser = sgb._agg_general(
1167-
numeric_only=False, min_count=min_count, alias=how, npfunc=alt
1168-
)
1169-
1170-
# unwrap Series to get array
1171-
res_values = res_ser._mgr.arrays[0]
1172-
return cast_agg_result(res_values, values)
1173-
11741105
def array_func(values: ArrayLike) -> ArrayLike:
1175-
11761106
try:
11771107
result = self.grouper._cython_operation(
1178-
"aggregate", values, how, axis=1, min_count=min_count
1108+
"aggregate", values, how, axis=data.ndim - 1, min_count=min_count
11791109
)
11801110
except NotImplementedError:
11811111
# generally if we have numeric_only=False
11821112
# and non-applicable functions
11831113
# try to python agg
1184-
result = py_fallback(values)
1114+
# TODO: shouldn't min_count matter?
1115+
result = self._agg_py_fallback(values, ndim=data.ndim, alt=alt)
11851116

11861117
return result
11871118

pandas/core/groupby/groupby.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class providing the base-class of operations.
101101
Index,
102102
MultiIndex,
103103
)
104+
from pandas.core.internals.blocks import ensure_block_shape
104105
from pandas.core.series import Series
105106
from pandas.core.sorting import get_group_index_sorter
106107
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
@@ -1317,6 +1318,54 @@ def _agg_general(
13171318
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
13181319
return result.__finalize__(self.obj, method="groupby")
13191320

1321+
def _agg_py_fallback(
1322+
self, values: ArrayLike, ndim: int, alt: Callable
1323+
) -> ArrayLike:
1324+
"""
1325+
Fallback to pure-python aggregation if _cython_operation raises
1326+
NotImplementedError.
1327+
"""
1328+
# We get here with a) EADtypes and b) object dtype
1329+
1330+
if values.ndim == 1:
1331+
# For DataFrameGroupBy we only get here with ExtensionArray
1332+
ser = Series(values)
1333+
else:
1334+
# We only get here with values.dtype == object
1335+
# TODO: special case not needed with ArrayManager
1336+
df = DataFrame(values.T)
1337+
# bc we split object blocks in grouped_reduce, we have only 1 col
1338+
# otherwise we'd have to worry about block-splitting GH#39329
1339+
assert df.shape[1] == 1
1340+
# Avoid call to self.values that can occur in DataFrame
1341+
# reductions; see GH#28949
1342+
ser = df.iloc[:, 0]
1343+
1344+
# Create SeriesGroupBy with observed=True so that it does
1345+
# not try to add missing categories if grouping over multiple
1346+
# Categoricals. This will done by later self._reindex_output()
1347+
# Doing it here creates an error. See GH#34951
1348+
sgb = get_groupby(ser, self.grouper, observed=True)
1349+
# For SeriesGroupBy we could just use self instead of sgb
1350+
1351+
if self.ngroups > 0:
1352+
res_values, _ = self.grouper.agg_series(ser, alt)
1353+
else:
1354+
# equiv: res_values = self._python_agg_general(alt)
1355+
res_values = sgb._python_apply_general(alt, ser)._values
1356+
1357+
if isinstance(values, Categorical):
1358+
# Because we only get here with known dtype-preserving
1359+
# reductions, we cast back to Categorical.
1360+
# TODO: if we ever get "rank" working, exclude it here.
1361+
res_values = type(values)._from_sequence(res_values, dtype=values.dtype)
1362+
1363+
# If we are DataFrameGroupBy and went through a SeriesGroupByPath
1364+
# then we need to reshape
1365+
# GH#32223 includes case with IntegerArray values, ndarray res_values
1366+
# test_groupby_duplicate_columns with object dtype values
1367+
return ensure_block_shape(res_values, ndim=ndim)
1368+
13201369
def _cython_agg_general(
13211370
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
13221371
):

0 commit comments

Comments
 (0)