-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: Add numba engine to groupby.transform #32854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
db1b3aa
e82ede7
b0f4faa
1316662
8d7343e
3e12a51
6e6891f
73be08d
e85dcd5
0760c44
52fff03
156a2b4
6e9d2bf
2a57a14
e1e5f73
ce3b2b3
195c35f
d8ea389
a9ece86
367dc12
8256a0a
2c5543d
80e0ddc
9c4fa56
1c71c9b
6c0a573
538f4fa
083ead7
3132a51
1de2cf1
d4d58f9
f63e3e7
e984283
909e92e
e2f2a54
cd7a0be
930466a
145ca50
fc0654d
6d5c63d
9dbded0
5909abb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,6 +75,13 @@ | |
import pandas.core.indexes.base as ibase | ||
from pandas.core.internals import BlockManager, make_block | ||
from pandas.core.series import Series | ||
from pandas.core.util.numba_ import ( | ||
check_kwargs_and_nopython, | ||
get_jit_arguments, | ||
jit_user_function, | ||
split_for_numba, | ||
validate_udf, | ||
) | ||
|
||
from pandas.plotting import boxplot_frame_groupby | ||
|
||
|
@@ -154,6 +161,8 @@ def pinner(cls): | |
class SeriesGroupBy(GroupBy[Series]): | ||
_apply_whitelist = base.series_apply_whitelist | ||
|
||
_numba_func_cache: Dict[Callable, Callable] = {} | ||
|
||
def _iterate_slices(self) -> Iterable[Series]: | ||
yield self._selected_obj | ||
|
||
|
@@ -463,11 +472,13 @@ def _aggregate_named(self, func, *args, **kwargs): | |
|
||
@Substitution(klass="Series", selected="A.") | ||
@Appender(_transform_template) | ||
def transform(self, func, *args, **kwargs): | ||
def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs): | ||
func = self._get_cython_func(func) or func | ||
|
||
if not isinstance(func, str): | ||
return self._transform_general(func, *args, **kwargs) | ||
return self._transform_general( | ||
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs | ||
) | ||
|
||
elif func not in base.transform_kernel_whitelist: | ||
msg = f"'{func}' is not a valid function name for transform(name)" | ||
|
@@ -482,16 +493,33 @@ def transform(self, func, *args, **kwargs): | |
result = getattr(self, func)(*args, **kwargs) | ||
return self._transform_fast(result, func) | ||
|
||
def _transform_general(self, func, *args, **kwargs): | ||
def _transform_general( | ||
self, func, *args, engine="cython", engine_kwargs=None, **kwargs | ||
): | ||
""" | ||
Transform with a non-str `func`. | ||
""" | ||
|
||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if engine == "numba": | ||
nopython, nogil, parallel = get_jit_arguments(engine_kwargs) | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
check_kwargs_and_nopython(kwargs, nopython) | ||
validate_udf(func) | ||
numba_func = self._numba_func_cache.get( | ||
func, jit_user_function(func, nopython, nogil, parallel) | ||
) | ||
|
||
klass = type(self._selected_obj) | ||
|
||
results = [] | ||
for name, group in self: | ||
object.__setattr__(group, "name", name) | ||
res = func(group, *args, **kwargs) | ||
if engine == "numba": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. like to see this as
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
values, index = split_for_numba(group) | ||
res = numba_func(values, index, *args) | ||
if func not in self._numba_func_cache: | ||
self._numba_func_cache[func] = numba_func | ||
else: | ||
res = func(group, *args, **kwargs) | ||
|
||
if isinstance(res, (ABCDataFrame, ABCSeries)): | ||
res = res._values | ||
|
@@ -819,6 +847,8 @@ class DataFrameGroupBy(GroupBy[DataFrame]): | |
|
||
_apply_whitelist = base.dataframe_apply_whitelist | ||
|
||
_numba_func_cache: Dict[Callable, Callable] = {} | ||
|
||
_agg_see_also_doc = dedent( | ||
""" | ||
See Also | ||
|
@@ -1355,19 +1385,35 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False): | |
# Handle cases like BinGrouper | ||
return self._concat_objects(keys, values, not_indexed_same=not_indexed_same) | ||
|
||
def _transform_general(self, func, *args, **kwargs): | ||
def _transform_general( | ||
self, func, *args, engine="cython", engine_kwargs=None, **kwargs | ||
): | ||
from pandas.core.reshape.concat import concat | ||
|
||
applied = [] | ||
obj = self._obj_with_exclusions | ||
gen = self.grouper.get_iterator(obj, axis=self.axis) | ||
fast_path, slow_path = self._define_paths(func, *args, **kwargs) | ||
if engine == "numba": | ||
nopython, nogil, parallel = get_jit_arguments(engine_kwargs) | ||
check_kwargs_and_nopython(kwargs, nopython) | ||
validate_udf(func) | ||
numba_func = self._numba_func_cache.get( | ||
func, jit_user_function(func, nopython, nogil, parallel) | ||
) | ||
else: | ||
fast_path, slow_path = self._define_paths(func, *args, **kwargs) | ||
|
||
path = None | ||
for name, group in gen: | ||
object.__setattr__(group, "name", name) | ||
|
||
if path is None: | ||
if engine == "numba": | ||
values, index = split_for_numba(group) | ||
res = numba_func(values, index, *args) | ||
if func not in self._numba_func_cache: | ||
self._numba_func_cache[func] = numba_func | ||
# Return the result as a DataFrame for concatenation later | ||
res = DataFrame(res, index=group.index, columns=group.columns) | ||
else: | ||
# Try slow path and fast path. | ||
try: | ||
path, res = self._choose_path(fast_path, slow_path, group) | ||
|
@@ -1376,8 +1422,6 @@ def _transform_general(self, func, *args, **kwargs): | |
except ValueError as err: | ||
msg = "transform must return a scalar value for each group" | ||
raise ValueError(msg) from err | ||
else: | ||
res = path(group) | ||
|
||
if isinstance(res, Series): | ||
|
||
|
@@ -1411,13 +1455,15 @@ def _transform_general(self, func, *args, **kwargs): | |
|
||
@Substitution(klass="DataFrame", selected="") | ||
@Appender(_transform_template) | ||
def transform(self, func, *args, **kwargs): | ||
def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs): | ||
|
||
# optimized transforms | ||
func = self._get_cython_func(func) or func | ||
|
||
if not isinstance(func, str): | ||
return self._transform_general(func, *args, **kwargs) | ||
return self._transform_general( | ||
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs | ||
) | ||
|
||
elif func not in base.transform_kernel_whitelist: | ||
msg = f"'{func}' is not a valid function name for transform(name)" | ||
|
@@ -1439,7 +1485,9 @@ def transform(self, func, *args, **kwargs): | |
): | ||
return self._transform_fast(result, func) | ||
|
||
return self._transform_general(func, *args, **kwargs) | ||
return self._transform_general( | ||
func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs | ||
) | ||
|
||
def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame: | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,58 @@ | ||
"""Common utilities for Numba operations""" | ||
import inspect | ||
import types | ||
from typing import Callable, Dict, Optional | ||
from typing import Callable, Dict, Optional, Tuple | ||
|
||
import numpy as np | ||
|
||
from pandas._typing import FrameOrSeries | ||
from pandas.compat._optional import import_optional_dependency | ||
|
||
|
||
def check_kwargs_and_nopython( | ||
kwargs: Optional[Dict] = None, nopython: Optional[bool] = None | ||
): | ||
) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add a doc-string |
||
""" | ||
Validate that **kwargs and nopython=True was passed | ||
https://github.com/numba/numba/issues/2916 | ||
|
||
Parameters | ||
---------- | ||
kwargs : dict, default None | ||
user passed keyword arguments to pass into the JITed function | ||
nopython : bool, default None | ||
nopython parameter | ||
|
||
Returns | ||
------- | ||
None | ||
|
||
Raises | ||
------ | ||
ValueError | ||
""" | ||
if kwargs and nopython: | ||
raise ValueError( | ||
"numba does not support kwargs with nopython=True: " | ||
"https://github.com/numba/numba/issues/2916" | ||
) | ||
|
||
|
||
def get_jit_arguments(engine_kwargs: Optional[Dict[str, bool]] = None): | ||
def get_jit_arguments( | ||
engine_kwargs: Optional[Dict[str, bool]] = None | ||
) -> Tuple[bool, bool, bool]: | ||
""" | ||
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings. | ||
|
||
Parameters | ||
---------- | ||
engine_kwargs : dict, default None | ||
user passed keyword arguments for numba.JIT | ||
|
||
Returns | ||
------- | ||
(bool, bool, bool) | ||
nopython, nogil, parallel | ||
""" | ||
if engine_kwargs is None: | ||
engine_kwargs = {} | ||
|
@@ -30,9 +63,28 @@ def get_jit_arguments(engine_kwargs: Optional[Dict[str, bool]] = None): | |
return nopython, nogil, parallel | ||
|
||
|
||
def jit_user_function(func: Callable, nopython: bool, nogil: bool, parallel: bool): | ||
def jit_user_function( | ||
func: Callable, nopython: bool, nogil: bool, parallel: bool | ||
) -> Callable: | ||
""" | ||
JIT the user's function given the configurable arguments. | ||
|
||
Parameters | ||
---------- | ||
func : function | ||
user defined function | ||
|
||
nopython : bool | ||
nopython parameter for numba.JIT | ||
nogil : bool | ||
nogil parameter for numba.JIT | ||
parallel : bool | ||
parallel parameter for numba.JIT | ||
|
||
Returns | ||
------- | ||
function | ||
Numba JITed function | ||
""" | ||
numba = import_optional_dependency("numba") | ||
|
||
|
@@ -56,3 +108,50 @@ def impl(data, *_args): | |
return impl | ||
|
||
return numba_func | ||
|
||
|
||
def split_for_numba(arg: FrameOrSeries) -> Tuple[np.ndarray, np.ndarray]: | ||
""" | ||
Split pandas object into its components as numpy arrays for numba functions. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add Parameters / Returns section |
||
|
||
Parameters | ||
---------- | ||
arg : Series or DataFrame | ||
|
||
Returns | ||
------- | ||
(ndarray, ndarray) | ||
values, index | ||
""" | ||
return arg.to_numpy(), arg.index.to_numpy() | ||
|
||
|
||
def validate_udf(func: Callable) -> None: | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
Validate user defined function for ops when using Numba. | ||
|
||
The first signature arguments should include: | ||
|
||
def f(values, index, ...): | ||
... | ||
|
||
Parameters | ||
---------- | ||
func : function, default False | ||
user defined function | ||
|
||
Returns | ||
------- | ||
None | ||
""" | ||
udf_signature = list(inspect.signature(func).parameters.keys()) | ||
expected_args = ["values", "index"] | ||
min_number_args = len(expected_args) | ||
if ( | ||
len(udf_signature) < min_number_args | ||
or udf_signature[:min_number_args] != expected_args | ||
): | ||
raise ValueError( | ||
f"The first {min_number_args} arguments to {func.__name__} must be " | ||
f"{expected_args}" | ||
) |
Uh oh!
There was an error while loading. Please reload this page.