@@ -2059,6 +2059,7 @@ def post_processor(vals: np.ndarray, inference: Optional[Type]) -> np.ndarray:
2059
2059
return self ._get_cythonized_result (
2060
2060
"group_quantile" ,
2061
2061
aggregate = True ,
2062
+ needs_counts = True ,
2062
2063
needs_values = True ,
2063
2064
needs_mask = True ,
2064
2065
cython_dtype = np .dtype (np .float64 ),
@@ -2072,6 +2073,7 @@ def post_processor(vals: np.ndarray, inference: Optional[Type]) -> np.ndarray:
2072
2073
self ._get_cythonized_result (
2073
2074
"group_quantile" ,
2074
2075
aggregate = True ,
2076
+ needs_counts = True ,
2075
2077
needs_values = True ,
2076
2078
needs_mask = True ,
2077
2079
cython_dtype = np .dtype (np .float64 ),
@@ -2348,9 +2350,10 @@ def _get_cythonized_result(
2348
2350
how : str ,
2349
2351
cython_dtype : np .dtype ,
2350
2352
aggregate : bool = False ,
2353
+ needs_counts : bool = False ,
2351
2354
needs_values : bool = False ,
2355
+ min_count : Optional [int ] = None ,
2352
2356
needs_mask : bool = False ,
2353
- needs_ngroups : bool = False ,
2354
2357
result_is_index : bool = False ,
2355
2358
pre_processing = None ,
2356
2359
post_processing = None ,
@@ -2367,14 +2370,16 @@ def _get_cythonized_result(
2367
2370
aggregate : bool, default False
2368
2371
Whether the result should be aggregated to match the number of
2369
2372
groups
2373
+ needs_counts : bool, default False
2374
+ Whether the counts should be a part of the Cython call
2370
2375
needs_values : bool, default False
2371
2376
Whether the values should be a part of the Cython call
2372
2377
signature
2378
+ min_count : int, default None
2379
+ When not None, min_count for the Cython call
2373
2380
needs_mask : bool, default False
2374
2381
Whether boolean mask needs to be part of the Cython call
2375
2382
signature
2376
- needs_ngroups : bool, default False
2377
- Whether number of groups is part of the Cython call signature
2378
2383
result_is_index : bool, default False
2379
2384
Whether the result of the Cython operation is an index of
2380
2385
values to be retrieved, instead of the actual values themselves
@@ -2414,74 +2419,63 @@ def _get_cythonized_result(
2414
2419
labels , _ , ngroups = grouper .group_info
2415
2420
output : Dict [base .OutputKey , np .ndarray ] = {}
2416
2421
base_func = getattr (libgroupby , how )
2422
+ inferences = None
2417
2423
2418
- if how == "group_quantile" :
2419
- values = self ._obj_with_exclusions ._values
2420
- result_sz = ngroups if aggregate else len (values )
2424
+ values = self ._obj_with_exclusions ._values
2425
+ result_sz = ngroups if aggregate else len (values )
2426
+ if self ._obj_with_exclusions .ndim == 1 :
2427
+ width = 1
2428
+ else :
2429
+ width = len (self ._obj_with_exclusions .columns )
2430
+ result = np .zeros ((result_sz , width ), dtype = cython_dtype )
2431
+ func = partial (base_func , result )
2421
2432
2422
- vals , inferences = pre_processing (values )
2423
- if self ._obj_with_exclusions .ndim == 1 :
2424
- width = 1
2425
- vals = np .reshape (vals , (- 1 , 1 ))
2426
- else :
2427
- width = len (self ._obj_with_exclusions .columns )
2428
- result = np .zeros ((result_sz , width ), dtype = cython_dtype )
2433
+ if needs_counts :
2429
2434
counts = np .zeros (self .ngroups , dtype = np .int64 )
2430
- mask = isna (vals ).view (np .uint8 )
2431
-
2432
- func = partial (base_func , result , counts , vals , labels , - 1 , mask )
2433
- func (** kwargs ) # Call func to modify indexer values in place
2434
- result = post_processing (result , inferences )
2435
+ func = partial (func , counts )
2435
2436
2437
+ if needs_values :
2438
+ vals = values
2439
+ if pre_processing :
2440
+ vals , inferences = pre_processing (vals )
2436
2441
if self ._obj_with_exclusions .ndim == 1 :
2437
- key = base .OutputKey (label = self ._obj_with_exclusions .name , position = 0 )
2438
- output [key ] = result [:, 0 ]
2439
- else :
2440
- for idx , name in enumerate (self ._obj_with_exclusions .columns ):
2441
- key = base .OutputKey (label = name , position = idx )
2442
- output [key ] = result [:, idx ]
2442
+ vals = np .reshape (vals , (- 1 , 1 ))
2443
+ func = partial (func , vals )
2443
2444
2444
- if aggregate :
2445
- return self ._wrap_aggregated_output (output )
2446
- else :
2447
- return self ._wrap_transformed_output (output )
2445
+ # Groupby always needs labels
2446
+ func = partial (func , labels )
2448
2447
2449
- for idx , obj in enumerate (self ._iterate_slices ()):
2450
- name = obj .name
2451
- values = obj ._values
2448
+ if min_count is not None :
2449
+ func = partial (func , min_count )
2452
2450
2453
- if aggregate :
2454
- result_sz = ngroups
2451
+ if needs_mask :
2452
+ if self ._obj_with_exclusions .ndim == 1 :
2453
+ # If needs_values is True, don't need to reshape again
2454
+ if needs_values :
2455
+ mask = isna (vals ).view (np .uint8 )
2456
+ else :
2457
+ mask = isna (np .reshape (values , (- 1 , 1 ))).view (np .uint8 )
2455
2458
else :
2456
- result_sz = len (values )
2457
-
2458
- result = np .zeros (result_sz , dtype = cython_dtype )
2459
- func = partial (base_func , result , labels )
2460
- inferences = None
2461
-
2462
- if needs_values :
2463
- vals = values
2464
- if pre_processing :
2465
- vals , inferences = pre_processing (vals )
2466
- func = partial (func , vals )
2467
-
2468
- if needs_mask :
2469
2459
mask = isna (values ).view (np .uint8 )
2470
- func = partial (func , mask )
2471
-
2472
- if needs_ngroups :
2473
- func = partial (func , ngroups )
2460
+ func = partial (func , mask )
2474
2461
2475
- func (** kwargs ) # Call func to modify indexer values in place
2462
+ func (** kwargs ) # Call func to modify indexer values in place
2476
2463
2477
- if result_is_index :
2478
- result = algorithms .take_nd (values , result )
2464
+ # TODO: Probably not correct
2465
+ if result_is_index :
2466
+ result = algorithms .take_nd (values , result )
2479
2467
2480
- if post_processing :
2481
- result = post_processing (result , inferences )
2468
+ if post_processing :
2469
+ result = post_processing (result , inferences )
2482
2470
2483
- key = base .OutputKey (label = name , position = idx )
2484
- output [key ] = result
2471
+ # TODO: Perhaps there is a better way to get result into output
2472
+ if self ._obj_with_exclusions .ndim == 1 :
2473
+ key = base .OutputKey (label = self ._obj_with_exclusions .name , position = 0 )
2474
+ output [key ] = result [:, 0 ]
2475
+ else :
2476
+ for idx , name in enumerate (self ._obj_with_exclusions .columns ):
2477
+ key = base .OutputKey (label = name , position = idx )
2478
+ output [key ] = result [:, idx ]
2485
2479
2486
2480
if aggregate :
2487
2481
return self ._wrap_aggregated_output (output )
0 commit comments