Skip to content

Commit 38b5e4a

Browse files
PERF: cache _get_cython_function in groupby ops (#40178)
1 parent 6945116 commit 38b5e4a

File tree

1 file changed

+61
-61
lines changed

1 file changed

+61
-61
lines changed

pandas/core/groupby/ops.py

Lines changed: 61 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
import collections
11+
import functools
1112
from typing import (
1213
Dict,
1314
Generic,
@@ -95,6 +96,64 @@
9596
get_indexer_dict,
9697
)
9798

99+
_CYTHON_FUNCTIONS = {
100+
"aggregate": {
101+
"add": "group_add",
102+
"prod": "group_prod",
103+
"min": "group_min",
104+
"max": "group_max",
105+
"mean": "group_mean",
106+
"median": "group_median",
107+
"var": "group_var",
108+
"first": "group_nth",
109+
"last": "group_last",
110+
"ohlc": "group_ohlc",
111+
},
112+
"transform": {
113+
"cumprod": "group_cumprod",
114+
"cumsum": "group_cumsum",
115+
"cummin": "group_cummin",
116+
"cummax": "group_cummax",
117+
"rank": "group_rank",
118+
},
119+
}
120+
121+
122+
@functools.lru_cache(maxsize=None)
123+
def _get_cython_function(kind: str, how: str, dtype: np.dtype, is_numeric: bool):
124+
125+
dtype_str = dtype.name
126+
ftype = _CYTHON_FUNCTIONS[kind][how]
127+
128+
# see if there is a fused-type version of function
129+
# only valid for numeric
130+
f = getattr(libgroupby, ftype, None)
131+
if f is not None and is_numeric:
132+
return f
133+
134+
# otherwise find dtype-specific version, falling back to object
135+
for dt in [dtype_str, "object"]:
136+
f2 = getattr(libgroupby, f"{ftype}_{dt}", None)
137+
if f2 is not None:
138+
return f2
139+
140+
if hasattr(f, "__signatures__"):
141+
# inspect what fused types are implemented
142+
if dtype_str == "object" and "object" not in f.__signatures__:
143+
# disallow this function so we get a NotImplementedError below
144+
# instead of a TypeError at runtime
145+
f = None
146+
147+
func = f
148+
149+
if func is None:
150+
raise NotImplementedError(
151+
f"function is not implemented for this dtype: "
152+
f"[how->{how},dtype->{dtype_str}]"
153+
)
154+
155+
return func
156+
98157

99158
class BaseGrouper:
100159
"""
@@ -385,28 +444,6 @@ def get_group_levels(self) -> List[Index]:
385444
# ------------------------------------------------------------
386445
# Aggregation functions
387446

388-
_cython_functions = {
389-
"aggregate": {
390-
"add": "group_add",
391-
"prod": "group_prod",
392-
"min": "group_min",
393-
"max": "group_max",
394-
"mean": "group_mean",
395-
"median": "group_median",
396-
"var": "group_var",
397-
"first": "group_nth",
398-
"last": "group_last",
399-
"ohlc": "group_ohlc",
400-
},
401-
"transform": {
402-
"cumprod": "group_cumprod",
403-
"cumsum": "group_cumsum",
404-
"cummin": "group_cummin",
405-
"cummax": "group_cummax",
406-
"rank": "group_rank",
407-
},
408-
}
409-
410447
_cython_arity = {"ohlc": 4} # OHLC
411448

412449
@final
@@ -417,43 +454,6 @@ def _is_builtin_func(self, arg):
417454
"""
418455
return SelectionMixin._builtin_table.get(arg, arg)
419456

420-
@final
421-
def _get_cython_function(
422-
self, kind: str, how: str, values: np.ndarray, is_numeric: bool
423-
):
424-
425-
dtype_str = values.dtype.name
426-
ftype = self._cython_functions[kind][how]
427-
428-
# see if there is a fused-type version of function
429-
# only valid for numeric
430-
f = getattr(libgroupby, ftype, None)
431-
if f is not None and is_numeric:
432-
return f
433-
434-
# otherwise find dtype-specific version, falling back to object
435-
for dt in [dtype_str, "object"]:
436-
f2 = getattr(libgroupby, f"{ftype}_{dt}", None)
437-
if f2 is not None:
438-
return f2
439-
440-
if hasattr(f, "__signatures__"):
441-
# inspect what fused types are implemented
442-
if dtype_str == "object" and "object" not in f.__signatures__:
443-
# disallow this function so we get a NotImplementedError below
444-
# instead of a TypeError at runtime
445-
f = None
446-
447-
func = f
448-
449-
if func is None:
450-
raise NotImplementedError(
451-
f"function is not implemented for this dtype: "
452-
f"[how->{how},dtype->{dtype_str}]"
453-
)
454-
455-
return func
456-
457457
@final
458458
def _get_cython_func_and_vals(
459459
self, kind: str, how: str, values: np.ndarray, is_numeric: bool
@@ -474,7 +474,7 @@ def _get_cython_func_and_vals(
474474
values : np.ndarray
475475
"""
476476
try:
477-
func = self._get_cython_function(kind, how, values, is_numeric)
477+
func = _get_cython_function(kind, how, values.dtype, is_numeric)
478478
except NotImplementedError:
479479
if is_numeric:
480480
try:
@@ -484,7 +484,7 @@ def _get_cython_func_and_vals(
484484
values = values.astype(complex)
485485
else:
486486
raise
487-
func = self._get_cython_function(kind, how, values, is_numeric)
487+
func = _get_cython_function(kind, how, values.dtype, is_numeric)
488488
else:
489489
raise
490490
return func, values

0 commit comments

Comments
 (0)