Skip to content

Commit ac5587c

Browse files
authored
ENH: Add numba engine to df.apply (#55104)
* ENH: Add numba engine to df.apply * complete? * wip: pass tests * fix existing tests * go for green * fix checks? * fix pyright * update docs * eliminate a blank line * update from code review + more tests * fix failing tests * Simplify w/ context manager * skip if no numba * simplify more * specify dtypes * address code review * add errors for invalid columns * adjust message
1 parent 206f981 commit ac5587c

9 files changed

+948
-67
lines changed

pandas/core/_numba/extensions.py

+575
Large diffs are not rendered by default.

pandas/core/apply.py

+175-9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import abc
44
from collections import defaultdict
5+
import functools
56
from functools import partial
67
import inspect
78
from typing import (
@@ -29,14 +30,17 @@
2930
NDFrameT,
3031
npt,
3132
)
33+
from pandas.compat._optional import import_optional_dependency
3234
from pandas.errors import SpecificationError
3335
from pandas.util._decorators import cache_readonly
3436
from pandas.util._exceptions import find_stack_level
3537

3638
from pandas.core.dtypes.cast import is_nested_object
3739
from pandas.core.dtypes.common import (
3840
is_dict_like,
41+
is_extension_array_dtype,
3942
is_list_like,
43+
is_numeric_dtype,
4044
is_sequence,
4145
)
4246
from pandas.core.dtypes.dtypes import (
@@ -121,6 +125,8 @@ def __init__(
121125
result_type: str | None,
122126
*,
123127
by_row: Literal[False, "compat", "_compat"] = "compat",
128+
engine: str = "python",
129+
engine_kwargs: dict[str, bool] | None = None,
124130
args,
125131
kwargs,
126132
) -> None:
@@ -133,6 +139,9 @@ def __init__(
133139
self.args = args or ()
134140
self.kwargs = kwargs or {}
135141

142+
self.engine = engine
143+
self.engine_kwargs = {} if engine_kwargs is None else engine_kwargs
144+
136145
if result_type not in [None, "reduce", "broadcast", "expand"]:
137146
raise ValueError(
138147
"invalid value for result_type, must be one "
@@ -601,6 +610,13 @@ def apply_list_or_dict_like(self) -> DataFrame | Series:
601610
result: Series, DataFrame, or None
602611
Result when self.func is a list-like or dict-like, None otherwise.
603612
"""
613+
614+
if self.engine == "numba":
615+
raise NotImplementedError(
616+
"The 'numba' engine doesn't support list-like/"
617+
"dict likes of callables yet."
618+
)
619+
604620
if self.axis == 1 and isinstance(self.obj, ABCDataFrame):
605621
return self.obj.T.apply(self.func, 0, args=self.args, **self.kwargs).T
606622

@@ -768,10 +784,16 @@ def __init__(
768784
) -> None:
769785
if by_row is not False and by_row != "compat":
770786
raise ValueError(f"by_row={by_row} not allowed")
771-
self.engine = engine
772-
self.engine_kwargs = engine_kwargs
773787
super().__init__(
774-
obj, func, raw, result_type, by_row=by_row, args=args, kwargs=kwargs
788+
obj,
789+
func,
790+
raw,
791+
result_type,
792+
by_row=by_row,
793+
engine=engine,
794+
engine_kwargs=engine_kwargs,
795+
args=args,
796+
kwargs=kwargs,
775797
)
776798

777799
# ---------------------------------------------------------------
@@ -792,6 +814,32 @@ def result_columns(self) -> Index:
792814
def series_generator(self) -> Generator[Series, None, None]:
793815
pass
794816

817+
@staticmethod
818+
@functools.cache
819+
@abc.abstractmethod
820+
def generate_numba_apply_func(
821+
func, nogil=True, nopython=True, parallel=False
822+
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
823+
pass
824+
825+
@abc.abstractmethod
826+
def apply_with_numba(self):
827+
pass
828+
829+
def validate_values_for_numba(self):
830+
# Validate column dtyps all OK
831+
for colname, dtype in self.obj.dtypes.items():
832+
if not is_numeric_dtype(dtype):
833+
raise ValueError(
834+
f"Column {colname} must have a numeric dtype. "
835+
f"Found '{dtype}' instead"
836+
)
837+
if is_extension_array_dtype(dtype):
838+
raise ValueError(
839+
f"Column {colname} is backed by an extension array, "
840+
f"which is not supported by the numba engine."
841+
)
842+
795843
@abc.abstractmethod
796844
def wrap_results_for_axis(
797845
self, results: ResType, res_index: Index
@@ -815,13 +863,12 @@ def values(self):
815863
def apply(self) -> DataFrame | Series:
816864
"""compute the results"""
817865

818-
if self.engine == "numba" and not self.raw:
819-
raise ValueError(
820-
"The numba engine in DataFrame.apply can only be used when raw=True"
821-
)
822-
823866
# dispatch to handle list-like or dict-like
824867
if is_list_like(self.func):
868+
if self.engine == "numba":
869+
raise NotImplementedError(
870+
"the 'numba' engine doesn't support lists of callables yet"
871+
)
825872
return self.apply_list_or_dict_like()
826873

827874
# all empty
@@ -830,17 +877,31 @@ def apply(self) -> DataFrame | Series:
830877

831878
# string dispatch
832879
if isinstance(self.func, str):
880+
if self.engine == "numba":
881+
raise NotImplementedError(
882+
"the 'numba' engine doesn't support using "
883+
"a string as the callable function"
884+
)
833885
return self.apply_str()
834886

835887
# ufunc
836888
elif isinstance(self.func, np.ufunc):
889+
if self.engine == "numba":
890+
raise NotImplementedError(
891+
"the 'numba' engine doesn't support "
892+
"using a numpy ufunc as the callable function"
893+
)
837894
with np.errstate(all="ignore"):
838895
results = self.obj._mgr.apply("apply", func=self.func)
839896
# _constructor will retain self.index and self.columns
840897
return self.obj._constructor_from_mgr(results, axes=results.axes)
841898

842899
# broadcasting
843900
if self.result_type == "broadcast":
901+
if self.engine == "numba":
902+
raise NotImplementedError(
903+
"the 'numba' engine doesn't support result_type='broadcast'"
904+
)
844905
return self.apply_broadcast(self.obj)
845906

846907
# one axis empty
@@ -997,7 +1058,10 @@ def apply_broadcast(self, target: DataFrame) -> DataFrame:
9971058
return result
9981059

9991060
def apply_standard(self):
1000-
results, res_index = self.apply_series_generator()
1061+
if self.engine == "python":
1062+
results, res_index = self.apply_series_generator()
1063+
else:
1064+
results, res_index = self.apply_series_numba()
10011065

10021066
# wrap results
10031067
return self.wrap_results(results, res_index)
@@ -1021,6 +1085,19 @@ def apply_series_generator(self) -> tuple[ResType, Index]:
10211085

10221086
return results, res_index
10231087

1088+
def apply_series_numba(self):
1089+
if self.engine_kwargs.get("parallel", False):
1090+
raise NotImplementedError(
1091+
"Parallel apply is not supported when raw=False and engine='numba'"
1092+
)
1093+
if not self.obj.index.is_unique or not self.columns.is_unique:
1094+
raise NotImplementedError(
1095+
"The index/columns must be unique when raw=False and engine='numba'"
1096+
)
1097+
self.validate_values_for_numba()
1098+
results = self.apply_with_numba()
1099+
return results, self.result_index
1100+
10241101
def wrap_results(self, results: ResType, res_index: Index) -> DataFrame | Series:
10251102
from pandas import Series
10261103

@@ -1060,6 +1137,49 @@ class FrameRowApply(FrameApply):
10601137
def series_generator(self) -> Generator[Series, None, None]:
10611138
return (self.obj._ixs(i, axis=1) for i in range(len(self.columns)))
10621139

1140+
@staticmethod
1141+
@functools.cache
1142+
def generate_numba_apply_func(
1143+
func, nogil=True, nopython=True, parallel=False
1144+
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
1145+
numba = import_optional_dependency("numba")
1146+
from pandas import Series
1147+
1148+
# Import helper from extensions to cast string object -> np strings
1149+
# Note: This also has the side effect of loading our numba extensions
1150+
from pandas.core._numba.extensions import maybe_cast_str
1151+
1152+
jitted_udf = numba.extending.register_jitable(func)
1153+
1154+
# Currently the parallel argument doesn't get passed through here
1155+
# (it's disabled) since the dicts in numba aren't thread-safe.
1156+
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
1157+
def numba_func(values, col_names, df_index):
1158+
results = {}
1159+
for j in range(values.shape[1]):
1160+
# Create the series
1161+
ser = Series(
1162+
values[:, j], index=df_index, name=maybe_cast_str(col_names[j])
1163+
)
1164+
results[j] = jitted_udf(ser)
1165+
return results
1166+
1167+
return numba_func
1168+
1169+
def apply_with_numba(self) -> dict[int, Any]:
1170+
nb_func = self.generate_numba_apply_func(
1171+
cast(Callable, self.func), **self.engine_kwargs
1172+
)
1173+
from pandas.core._numba.extensions import set_numba_data
1174+
1175+
# Convert from numba dict to regular dict
1176+
# Our isinstance checks in the df constructor don't pass for numbas typed dict
1177+
with set_numba_data(self.obj.index) as index, set_numba_data(
1178+
self.columns
1179+
) as columns:
1180+
res = dict(nb_func(self.values, columns, index))
1181+
return res
1182+
10631183
@property
10641184
def result_index(self) -> Index:
10651185
return self.columns
@@ -1143,6 +1263,52 @@ def series_generator(self) -> Generator[Series, None, None]:
11431263
object.__setattr__(ser, "_name", name)
11441264
yield ser
11451265

1266+
@staticmethod
1267+
@functools.cache
1268+
def generate_numba_apply_func(
1269+
func, nogil=True, nopython=True, parallel=False
1270+
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
1271+
numba = import_optional_dependency("numba")
1272+
from pandas import Series
1273+
from pandas.core._numba.extensions import maybe_cast_str
1274+
1275+
jitted_udf = numba.extending.register_jitable(func)
1276+
1277+
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
1278+
def numba_func(values, col_names_index, index):
1279+
results = {}
1280+
# Currently the parallel argument doesn't get passed through here
1281+
# (it's disabled) since the dicts in numba aren't thread-safe.
1282+
for i in range(values.shape[0]):
1283+
# Create the series
1284+
# TODO: values corrupted without the copy
1285+
ser = Series(
1286+
values[i].copy(),
1287+
index=col_names_index,
1288+
name=maybe_cast_str(index[i]),
1289+
)
1290+
results[i] = jitted_udf(ser)
1291+
1292+
return results
1293+
1294+
return numba_func
1295+
1296+
def apply_with_numba(self) -> dict[int, Any]:
1297+
nb_func = self.generate_numba_apply_func(
1298+
cast(Callable, self.func), **self.engine_kwargs
1299+
)
1300+
1301+
from pandas.core._numba.extensions import set_numba_data
1302+
1303+
# Convert from numba dict to regular dict
1304+
# Our isinstance checks in the df constructor don't pass for numbas typed dict
1305+
with set_numba_data(self.obj.index) as index, set_numba_data(
1306+
self.columns
1307+
) as columns:
1308+
res = dict(nb_func(self.values, columns, index))
1309+
1310+
return res
1311+
11461312
@property
11471313
def result_index(self) -> Index:
11481314
return self.index

pandas/core/frame.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -10090,6 +10090,9 @@ def apply(
1009010090
- nogil (release the GIL inside the JIT compiled function)
1009110091
- parallel (try to apply the function in parallel over the DataFrame)
1009210092
10093+
Note: Due to limitations within numba/how pandas interfaces with numba,
10094+
you should only use this if raw=True
10095+
1009310096
Note: The numba compiler only supports a subset of
1009410097
valid Python/numpy operations.
1009510098
@@ -10099,8 +10102,6 @@ def apply(
1009910102
<https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html>`_
1010010103
in numba to learn what you can or cannot use in the passed function.
1010110104
10102-
As of right now, the numba engine can only be used with raw=True.
10103-
1010410105
.. versionadded:: 2.2.0
1010510106
1010610107
engine_kwargs : dict

pandas/tests/apply/conftest.py

+12
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,15 @@ def int_frame_const_col():
1616
columns=["A", "B", "C"],
1717
)
1818
return df
19+
20+
21+
@pytest.fixture(params=["python", "numba"])
22+
def engine(request):
23+
if request.param == "numba":
24+
pytest.importorskip("numba")
25+
return request.param
26+
27+
28+
@pytest.fixture(params=[0, 1])
29+
def apply_axis(request):
30+
return request.param

0 commit comments

Comments
 (0)