Skip to content

Commit bb636d9

Browse files
authored
REF: raise more selectively in libreduction (#41298)
1 parent 53b7fff commit bb636d9

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

pandas/_libs/reduction.pyx

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,18 @@ from pandas._libs.util cimport (
2424
from pandas._libs.lib import is_scalar
2525

2626

27-
cpdef check_result_array(object obj):
27+
cdef cnp.dtype _dtype_obj = np.dtype("object")
2828

29-
if (is_array(obj) or
30-
(isinstance(obj, list) and len(obj) == 0) or
31-
getattr(obj, 'shape', None) == (0,)):
32-
raise ValueError('Must produce aggregated value')
29+
30+
cpdef check_result_array(object obj, object dtype):
31+
# Our operation is supposed to be an aggregation/reduction. If
32+
# it returns an ndarray, this likely means an invalid operation has
33+
# been passed. See test_apply_without_aggregation, test_agg_must_agg
34+
if is_array(obj):
35+
if dtype != _dtype_obj:
36+
# If it is object dtype, the function can be a reduction/aggregation
37+
# and still return an ndarray e.g. test_agg_over_numpy_arrays
38+
raise ValueError("Must produce aggregated value")
3339

3440

3541
cdef class _BaseGrouper:
@@ -86,7 +92,7 @@ cdef class _BaseGrouper:
8692
# On the first pass, we check the output shape to see
8793
# if this looks like a reduction.
8894
initialized = True
89-
check_result_array(res)
95+
check_result_array(res, cached_series.dtype)
9096

9197
return res, initialized
9298

pandas/core/groupby/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ def _aggregate_named(self, func, *args, **kwargs):
517517
output = libreduction.extract_result(output)
518518
if not initialized:
519519
# We only do this validation on the first iteration
520-
libreduction.check_result_array(output)
520+
libreduction.check_result_array(output, group.dtype)
521521
initialized = True
522522
result[name] = output
523523

pandas/core/groupby/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,7 @@ def _aggregate_series_pure_python(self, obj: Series, func: F):
10401040

10411041
if not initialized:
10421042
# We only do this validation on the first iteration
1043-
libreduction.check_result_array(res)
1043+
libreduction.check_result_array(res, group.dtype)
10441044
initialized = True
10451045

10461046
counts[i] = group.shape[0]

0 commit comments

Comments
 (0)