8
8
from __future__ import annotations
9
9
10
10
import collections
11
+ import functools
11
12
from typing import (
12
13
Dict ,
13
14
Generic ,
95
96
get_indexer_dict ,
96
97
)
97
98
99
+ _CYTHON_FUNCTIONS = {
100
+ "aggregate" : {
101
+ "add" : "group_add" ,
102
+ "prod" : "group_prod" ,
103
+ "min" : "group_min" ,
104
+ "max" : "group_max" ,
105
+ "mean" : "group_mean" ,
106
+ "median" : "group_median" ,
107
+ "var" : "group_var" ,
108
+ "first" : "group_nth" ,
109
+ "last" : "group_last" ,
110
+ "ohlc" : "group_ohlc" ,
111
+ },
112
+ "transform" : {
113
+ "cumprod" : "group_cumprod" ,
114
+ "cumsum" : "group_cumsum" ,
115
+ "cummin" : "group_cummin" ,
116
+ "cummax" : "group_cummax" ,
117
+ "rank" : "group_rank" ,
118
+ },
119
+ }
120
+
121
+
122
+ @functools .lru_cache (maxsize = None )
123
+ def _get_cython_function (kind : str , how : str , dtype : np .dtype , is_numeric : bool ):
124
+
125
+ dtype_str = dtype .name
126
+ ftype = _CYTHON_FUNCTIONS [kind ][how ]
127
+
128
+ # see if there is a fused-type version of function
129
+ # only valid for numeric
130
+ f = getattr (libgroupby , ftype , None )
131
+ if f is not None and is_numeric :
132
+ return f
133
+
134
+ # otherwise find dtype-specific version, falling back to object
135
+ for dt in [dtype_str , "object" ]:
136
+ f2 = getattr (libgroupby , f"{ ftype } _{ dt } " , None )
137
+ if f2 is not None :
138
+ return f2
139
+
140
+ if hasattr (f , "__signatures__" ):
141
+ # inspect what fused types are implemented
142
+ if dtype_str == "object" and "object" not in f .__signatures__ :
143
+ # disallow this function so we get a NotImplementedError below
144
+ # instead of a TypeError at runtime
145
+ f = None
146
+
147
+ func = f
148
+
149
+ if func is None :
150
+ raise NotImplementedError (
151
+ f"function is not implemented for this dtype: "
152
+ f"[how->{ how } ,dtype->{ dtype_str } ]"
153
+ )
154
+
155
+ return func
156
+
98
157
99
158
class BaseGrouper :
100
159
"""
@@ -385,28 +444,6 @@ def get_group_levels(self) -> List[Index]:
385
444
# ------------------------------------------------------------
386
445
# Aggregation functions
387
446
388
- _cython_functions = {
389
- "aggregate" : {
390
- "add" : "group_add" ,
391
- "prod" : "group_prod" ,
392
- "min" : "group_min" ,
393
- "max" : "group_max" ,
394
- "mean" : "group_mean" ,
395
- "median" : "group_median" ,
396
- "var" : "group_var" ,
397
- "first" : "group_nth" ,
398
- "last" : "group_last" ,
399
- "ohlc" : "group_ohlc" ,
400
- },
401
- "transform" : {
402
- "cumprod" : "group_cumprod" ,
403
- "cumsum" : "group_cumsum" ,
404
- "cummin" : "group_cummin" ,
405
- "cummax" : "group_cummax" ,
406
- "rank" : "group_rank" ,
407
- },
408
- }
409
-
410
447
_cython_arity = {"ohlc" : 4 } # OHLC
411
448
412
449
@final
@@ -417,43 +454,6 @@ def _is_builtin_func(self, arg):
417
454
"""
418
455
return SelectionMixin ._builtin_table .get (arg , arg )
419
456
420
- @final
421
- def _get_cython_function (
422
- self , kind : str , how : str , values : np .ndarray , is_numeric : bool
423
- ):
424
-
425
- dtype_str = values .dtype .name
426
- ftype = self ._cython_functions [kind ][how ]
427
-
428
- # see if there is a fused-type version of function
429
- # only valid for numeric
430
- f = getattr (libgroupby , ftype , None )
431
- if f is not None and is_numeric :
432
- return f
433
-
434
- # otherwise find dtype-specific version, falling back to object
435
- for dt in [dtype_str , "object" ]:
436
- f2 = getattr (libgroupby , f"{ ftype } _{ dt } " , None )
437
- if f2 is not None :
438
- return f2
439
-
440
- if hasattr (f , "__signatures__" ):
441
- # inspect what fused types are implemented
442
- if dtype_str == "object" and "object" not in f .__signatures__ :
443
- # disallow this function so we get a NotImplementedError below
444
- # instead of a TypeError at runtime
445
- f = None
446
-
447
- func = f
448
-
449
- if func is None :
450
- raise NotImplementedError (
451
- f"function is not implemented for this dtype: "
452
- f"[how->{ how } ,dtype->{ dtype_str } ]"
453
- )
454
-
455
- return func
456
-
457
457
@final
458
458
def _get_cython_func_and_vals (
459
459
self , kind : str , how : str , values : np .ndarray , is_numeric : bool
@@ -474,7 +474,7 @@ def _get_cython_func_and_vals(
474
474
values : np.ndarray
475
475
"""
476
476
try :
477
- func = self . _get_cython_function (kind , how , values , is_numeric )
477
+ func = _get_cython_function (kind , how , values . dtype , is_numeric )
478
478
except NotImplementedError :
479
479
if is_numeric :
480
480
try :
@@ -484,7 +484,7 @@ def _get_cython_func_and_vals(
484
484
values = values .astype (complex )
485
485
else :
486
486
raise
487
- func = self . _get_cython_function (kind , how , values , is_numeric )
487
+ func = _get_cython_function (kind , how , values . dtype , is_numeric )
488
488
else :
489
489
raise
490
490
return func , values
0 commit comments