Skip to content

Commit f2bb7a9

Browse files
committed
Refactored _get_cythonized_result
1 parent 53ae9d6 commit f2bb7a9

File tree

2 files changed

+51
-58
lines changed

2 files changed

+51
-58
lines changed

pandas/_libs/groupby.pyx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,6 @@ def group_quantile(floating[:, :] out,
718718
int64_t[:] counts,
719719
floating[:, :] values,
720720
const int64_t[:] labels,
721-
Py_ssize_t min_count,
722721
const uint8_t[:, :] mask,
723722
float64_t q,
724723
object interpolation):

pandas/core/groupby/groupby.py

Lines changed: 51 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2059,6 +2059,7 @@ def post_processor(vals: np.ndarray, inference: Optional[Type]) -> np.ndarray:
20592059
return self._get_cythonized_result(
20602060
"group_quantile",
20612061
aggregate=True,
2062+
needs_counts=True,
20622063
needs_values=True,
20632064
needs_mask=True,
20642065
cython_dtype=np.dtype(np.float64),
@@ -2072,6 +2073,7 @@ def post_processor(vals: np.ndarray, inference: Optional[Type]) -> np.ndarray:
20722073
self._get_cythonized_result(
20732074
"group_quantile",
20742075
aggregate=True,
2076+
needs_counts=True,
20752077
needs_values=True,
20762078
needs_mask=True,
20772079
cython_dtype=np.dtype(np.float64),
@@ -2348,9 +2350,10 @@ def _get_cythonized_result(
23482350
how: str,
23492351
cython_dtype: np.dtype,
23502352
aggregate: bool = False,
2353+
needs_counts: bool = False,
23512354
needs_values: bool = False,
2355+
min_count: Optional[int] = None,
23522356
needs_mask: bool = False,
2353-
needs_ngroups: bool = False,
23542357
result_is_index: bool = False,
23552358
pre_processing=None,
23562359
post_processing=None,
@@ -2367,14 +2370,16 @@ def _get_cythonized_result(
23672370
aggregate : bool, default False
23682371
Whether the result should be aggregated to match the number of
23692372
groups
2373+
needs_counts : bool, default False
2374+
Whether the counts should be a part of the Cython call
23702375
needs_values : bool, default False
23712376
Whether the values should be a part of the Cython call
23722377
signature
2378+
min_count : int, default None
2379+
When not None, min_count for the Cython call
23732380
needs_mask : bool, default False
23742381
Whether boolean mask needs to be part of the Cython call
23752382
signature
2376-
needs_ngroups : bool, default False
2377-
Whether number of groups is part of the Cython call signature
23782383
result_is_index : bool, default False
23792384
Whether the result of the Cython operation is an index of
23802385
values to be retrieved, instead of the actual values themselves
@@ -2414,74 +2419,63 @@ def _get_cythonized_result(
24142419
labels, _, ngroups = grouper.group_info
24152420
output: Dict[base.OutputKey, np.ndarray] = {}
24162421
base_func = getattr(libgroupby, how)
2422+
inferences = None
24172423

2418-
if how == "group_quantile":
2419-
values = self._obj_with_exclusions._values
2420-
result_sz = ngroups if aggregate else len(values)
2424+
values = self._obj_with_exclusions._values
2425+
result_sz = ngroups if aggregate else len(values)
2426+
if self._obj_with_exclusions.ndim == 1:
2427+
width = 1
2428+
else:
2429+
width = len(self._obj_with_exclusions.columns)
2430+
result = np.zeros((result_sz, width), dtype=cython_dtype)
2431+
func = partial(base_func, result)
24212432

2422-
vals, inferences = pre_processing(values)
2423-
if self._obj_with_exclusions.ndim == 1:
2424-
width = 1
2425-
vals = np.reshape(vals, (-1, 1))
2426-
else:
2427-
width = len(self._obj_with_exclusions.columns)
2428-
result = np.zeros((result_sz, width), dtype=cython_dtype)
2433+
if needs_counts:
24292434
counts = np.zeros(self.ngroups, dtype=np.int64)
2430-
mask = isna(vals).view(np.uint8)
2431-
2432-
func = partial(base_func, result, counts, vals, labels, -1, mask)
2433-
func(**kwargs) # Call func to modify indexer values in place
2434-
result = post_processing(result, inferences)
2435+
func = partial(func, counts)
24352436

2437+
if needs_values:
2438+
vals = values
2439+
if pre_processing:
2440+
vals, inferences = pre_processing(vals)
24362441
if self._obj_with_exclusions.ndim == 1:
2437-
key = base.OutputKey(label=self._obj_with_exclusions.name, position=0)
2438-
output[key] = result[:, 0]
2439-
else:
2440-
for idx, name in enumerate(self._obj_with_exclusions.columns):
2441-
key = base.OutputKey(label=name, position=idx)
2442-
output[key] = result[:, idx]
2442+
vals = np.reshape(vals, (-1, 1))
2443+
func = partial(func, vals)
24432444

2444-
if aggregate:
2445-
return self._wrap_aggregated_output(output)
2446-
else:
2447-
return self._wrap_transformed_output(output)
2445+
# Groupby always needs labels
2446+
func = partial(func, labels)
24482447

2449-
for idx, obj in enumerate(self._iterate_slices()):
2450-
name = obj.name
2451-
values = obj._values
2448+
if min_count is not None:
2449+
func = partial(func, min_count)
24522450

2453-
if aggregate:
2454-
result_sz = ngroups
2451+
if needs_mask:
2452+
if self._obj_with_exclusions.ndim == 1:
2453+
# If needs_values is True, don't need to reshape again
2454+
if needs_values:
2455+
mask = isna(vals).view(np.uint8)
2456+
else:
2457+
mask = isna(np.reshape(values, (-1, 1))).view(np.uint8)
24552458
else:
2456-
result_sz = len(values)
2457-
2458-
result = np.zeros(result_sz, dtype=cython_dtype)
2459-
func = partial(base_func, result, labels)
2460-
inferences = None
2461-
2462-
if needs_values:
2463-
vals = values
2464-
if pre_processing:
2465-
vals, inferences = pre_processing(vals)
2466-
func = partial(func, vals)
2467-
2468-
if needs_mask:
24692459
mask = isna(values).view(np.uint8)
2470-
func = partial(func, mask)
2471-
2472-
if needs_ngroups:
2473-
func = partial(func, ngroups)
2460+
func = partial(func, mask)
24742461

2475-
func(**kwargs) # Call func to modify indexer values in place
2462+
func(**kwargs) # Call func to modify indexer values in place
24762463

2477-
if result_is_index:
2478-
result = algorithms.take_nd(values, result)
2464+
# TODO: Probably not correct
2465+
if result_is_index:
2466+
result = algorithms.take_nd(values, result)
24792467

2480-
if post_processing:
2481-
result = post_processing(result, inferences)
2468+
if post_processing:
2469+
result = post_processing(result, inferences)
24822470

2483-
key = base.OutputKey(label=name, position=idx)
2484-
output[key] = result
2471+
# TODO: Perhaps there is a better way to get result into output
2472+
if self._obj_with_exclusions.ndim == 1:
2473+
key = base.OutputKey(label=self._obj_with_exclusions.name, position=0)
2474+
output[key] = result[:, 0]
2475+
else:
2476+
for idx, name in enumerate(self._obj_with_exclusions.columns):
2477+
key = base.OutputKey(label=name, position=idx)
2478+
output[key] = result[:, idx]
24852479

24862480
if aggregate:
24872481
return self._wrap_aggregated_output(output)

0 commit comments

Comments
 (0)