-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: Add numba engine to df.apply #55104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
1fa802c
c6af7c9
0ac544d
31b9e20
6190772
55df7ad
3c89b0f
1418d3e
c143c67
0d827c4
7129ee8
b0ba283
f4e80a6
21e2186
b60bef8
ba1d0e0
088d27f
60539a1
76538d6
cca34f9
8b423bf
f135def
b2e50d2
f86024f
a15293d
8fe5d89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,10 +13,8 @@ | |
import operator | ||
|
||
import numba | ||
from numba.core import ( | ||
cgutils, | ||
types, | ||
) | ||
from numba import types | ||
from numba.core import cgutils | ||
from numba.core.datamodel import models | ||
from numba.core.extending import ( | ||
NativeValue, | ||
|
@@ -40,7 +38,7 @@ | |
|
||
|
||
# TODO: Range index support | ||
# (not passing an index to series constructor doesn't work) | ||
# (this currently lowers OK, but does not round-trip) | ||
class IndexType(types.Type): | ||
""" | ||
The type class for Index objects. | ||
|
@@ -149,6 +147,7 @@ def typer(data, hashmap=None): | |
@register_model(IndexType) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This defines the numba representations of index/series. Only interesting thing here is that Index has a pointer to the original index object, so we can avoid calling the index constructor and then just return that object. Also, we add a hashmap to the index for indexing purposes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the hashmap support duplicate values like a pandas Index would? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, I will update and add some tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On second thought, it's probably easier for now to disallow duplicate indexes. I don't most frames have duplicate columns/indexes. |
||
class IndexModel(models.StructModel): | ||
def __init__(self, dmm, fe_type) -> None: | ||
# We don't want the numpy string scalar type in our hashmap | ||
members = [ | ||
("data", fe_type.as_array), | ||
# This is an attempt to emulate our hashtable code with a numba | ||
|
@@ -240,6 +239,25 @@ def index_impl(data): | |
return context.compile_internal(builder, index_impl, sig, args) | ||
|
||
|
||
# Helper to convert the unicodecharseq (numpy string scalar) into a unicode_type | ||
# (regular string) | ||
|
||
|
||
def maybe_cast_str(x): | ||
# Dummy function that numba can overload | ||
pass | ||
|
||
|
||
@overload(maybe_cast_str) | ||
def maybe_cast_str_impl(x): | ||
"""Converts numba UnicodeCharSeq (numpy string scalar) -> unicode type (string). | ||
Is a no-op for other types.""" | ||
if isinstance(x, types.UnicodeCharSeq): | ||
return lambda x: str(x) | ||
else: | ||
return lambda x: x | ||
|
||
|
||
@unbox(IndexType) | ||
def unbox_index(typ, obj, c): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code that transforms Series/Index -> numba representations and back. There's a lot of C API stuff here, it's maybe worth a closer look if you want. |
||
""" | ||
|
@@ -426,8 +444,12 @@ def series_binop_impl(series1, value): | |
series_reductions = [ | ||
("sum", np.sum), | ||
("mean", np.mean), | ||
("std", np.std), | ||
("var", np.var), | ||
# Disabled due to discrepancies between numba std. dev | ||
# and pandas std. dev (no way to specify dof) | ||
# ("std", np.std), | ||
# ("var", np.var), | ||
("min", np.min), | ||
("max", np.max), | ||
] | ||
for reduction, reduction_method in series_reductions: | ||
generate_series_reduction(reduction, reduction_method) | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -1075,6 +1075,14 @@ def apply_series_generator(self) -> tuple[ResType, Index]: | |||||
return results, res_index | ||||||
|
||||||
def apply_series_numba(self): | ||||||
if self.engine_kwargs.get("parallel", False): | ||||||
raise NotImplementedError( | ||||||
"Parallel apply is not supported when raw=False and engine='numba'" | ||||||
) | ||||||
if not self.obj.index.is_unique or not self.columns.is_unique: | ||||||
raise NotImplementedError( | ||||||
"The index/columns must be unique when raw=False and engine='numba'" | ||||||
) | ||||||
results = self.apply_with_numba() | ||||||
return results, self.result_index | ||||||
|
||||||
|
@@ -1128,6 +1136,7 @@ def generate_numba_apply_func( | |||||
# This isn't an entrypoint since we don't want users | ||||||
# using Series/DF in numba code outside of apply | ||||||
from pandas.core._numba.extensions import SeriesType # noqa: F401 | ||||||
from pandas.core._numba.extensions import maybe_cast_str | ||||||
|
||||||
numba = import_optional_dependency("numba") | ||||||
|
||||||
|
@@ -1138,7 +1147,9 @@ def numba_func(values, col_names, df_index): | |||||
results = {} | ||||||
for j in range(values.shape[1]): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
? (and below) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing this out! I think for now it'll probably make better sense to disable parallel mode for now, since the dict in numba isn't thread-safe. The overhead from the boxing/unboxing is also really high (99% of the time spent is there), so I doubt parallel will give a good speedup, at least for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK makes sense. Would be good to put a TODO: comment explaining why we shouldn't use prange for now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added a comment. |
||||||
# Create the series | ||||||
ser = Series(values[:, j], index=df_index, name=str(col_names[j])) | ||||||
ser = Series( | ||||||
values[:, j], index=df_index, name=maybe_cast_str(col_names[j]) | ||||||
) | ||||||
results[j] = jitted_udf(ser) | ||||||
return results | ||||||
|
||||||
|
@@ -1148,23 +1159,40 @@ def apply_with_numba(self) -> dict[int, Any]: | |||||
nb_func = self.generate_numba_apply_func( | ||||||
cast(Callable, self.func), **self.engine_kwargs | ||||||
) | ||||||
orig_values = self.columns.to_numpy() | ||||||
fixed_cols = False | ||||||
if orig_values.dtype == object: | ||||||
if not lib.is_string_array(orig_values): | ||||||
# Since numpy/numba doesn't support object array of stringswell | ||||||
# we'll do a sketchy thing where if index._data is object | ||||||
# we convert to string and directly set index._data to that, | ||||||
# setting it back after we call the function | ||||||
fixed_obj_colnames = False | ||||||
orig_cols = self.columns.to_numpy() | ||||||
if self.columns._data.dtype == object: | ||||||
if not lib.is_string_array(orig_cols): | ||||||
raise ValueError( | ||||||
"The numba engine only supports " | ||||||
"using string or numeric column names" | ||||||
) | ||||||
col_names_values = orig_values.astype("U") | ||||||
# Remember to set this back! | ||||||
self.columns._data = col_names_values | ||||||
fixed_cols = True | ||||||
# Remember to set this back!!! | ||||||
self.columns._data = orig_cols.astype("U") | ||||||
fixed_obj_colnames = True | ||||||
|
||||||
fixed_obj_index = False | ||||||
orig_index = self.index.to_numpy() | ||||||
if self.obj.index._data.dtype == object: | ||||||
if not lib.is_string_array(orig_index): | ||||||
raise ValueError( | ||||||
"The numba engine only supports " | ||||||
"using string or numeric index values" | ||||||
) | ||||||
# Remember to set this back!!! | ||||||
self.obj.index._data = orig_index.astype("U") | ||||||
fixed_obj_index = True | ||||||
df_index = self.obj.index | ||||||
|
||||||
res = dict(nb_func(self.values, self.columns, df_index)) | ||||||
if fixed_cols: | ||||||
self.columns._data = orig_values | ||||||
if fixed_obj_colnames: | ||||||
self.columns._data = orig_cols | ||||||
if fixed_obj_index: | ||||||
self.obj.index._data = orig_index | ||||||
return res | ||||||
|
||||||
@property | ||||||
|
@@ -1260,6 +1288,7 @@ def generate_numba_apply_func( | |||||
# using Series/DF in numba code outside of apply | ||||||
from pandas import Series | ||||||
from pandas.core._numba.extensions import SeriesType # noqa: F401 | ||||||
from pandas.core._numba.extensions import maybe_cast_str | ||||||
|
||||||
numba = import_optional_dependency("numba") | ||||||
|
||||||
|
@@ -1271,7 +1300,11 @@ def numba_func(values, col_names_index, index): | |||||
for i in range(values.shape[0]): | ||||||
# Create the series | ||||||
# TODO: values corrupted without the copy | ||||||
ser = Series(values[i].copy(), index=col_names_index, name=index[i]) | ||||||
ser = Series( | ||||||
values[i].copy(), | ||||||
index=col_names_index, | ||||||
name=maybe_cast_str(index[i]), | ||||||
) | ||||||
results[i] = jitted_udf(ser) | ||||||
|
||||||
return results | ||||||
|
@@ -1287,24 +1320,39 @@ def apply_with_numba(self) -> dict[int, Any]: | |||||
# we'll do a sketchy thing where if index._data is object | ||||||
# we convert to string and directly set index._data to that, | ||||||
# setting it back after we call the function | ||||||
fixed_obj_dtype = False | ||||||
orig_data = self.columns.to_numpy() | ||||||
fixed_obj_colnames = False | ||||||
orig_cols = self.columns.to_numpy() | ||||||
if self.columns._data.dtype == object: | ||||||
if not lib.is_string_array(orig_data): | ||||||
if not lib.is_string_array(orig_cols): | ||||||
raise ValueError( | ||||||
"The numba engine only supports " | ||||||
"using string or numeric column names" | ||||||
) | ||||||
# Remember to set this back!!! | ||||||
self.columns._data = orig_data.astype("U") | ||||||
fixed_obj_dtype = True | ||||||
self.columns._data = orig_cols.astype("U") | ||||||
fixed_obj_colnames = True | ||||||
|
||||||
fixed_obj_index = False | ||||||
orig_index = self.index.to_numpy() | ||||||
if self.obj.index._data.dtype == object: | ||||||
if not lib.is_string_array(orig_index): | ||||||
raise ValueError( | ||||||
"The numba engine only supports " | ||||||
"using string or numeric index values" | ||||||
) | ||||||
# Remember to set this back!!! | ||||||
self.obj.index._data = orig_index.astype("U") | ||||||
fixed_obj_index = True | ||||||
|
||||||
# Convert from numba dict to regular dict | ||||||
# Our isinstance checks in the df constructor don't pass for numbas typed dict | ||||||
res = dict(nb_func(self.values, self.columns, self.obj.index)) | ||||||
|
||||||
if fixed_obj_dtype: | ||||||
self.columns._data = orig_data | ||||||
if fixed_obj_colnames: | ||||||
self.columns._data = orig_cols | ||||||
|
||||||
if fixed_obj_index: | ||||||
self.obj.index._data = orig_index | ||||||
|
||||||
return res | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from pandas import DataFrame | ||
import pandas._testing as tm | ||
|
||
|
||
def test_numba_vs_python_noop(float_frame, apply_axis): | ||
func = lambda x: x | ||
result = float_frame.apply(func, engine="numba", axis=apply_axis) | ||
expected = float_frame.apply(func, engine="python", axis=apply_axis) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
|
||
def test_numba_vs_python_indexing(float_frame): | ||
row_func = lambda x: x["A"] | ||
result = float_frame.apply(row_func, engine="numba", axis=1) | ||
expected = float_frame.apply(row_func, engine="python", axis=1) | ||
tm.assert_series_equal(result, expected) | ||
|
||
row_func = lambda x: x["ZqgszYBfuL"] # This is a label in the index | ||
mroeschke marked this conversation as resolved.
Show resolved
Hide resolved
|
||
result = float_frame.apply(row_func, engine="numba", axis=0) | ||
expected = float_frame.apply(row_func, engine="python", axis=0) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"reduction", | ||
[lambda x: x.mean(), lambda x: x.min(), lambda x: x.max(), lambda x: x.sum()], | ||
) | ||
def test_numba_vs_python_reductions(float_frame, reduction, apply_axis): | ||
result = float_frame.apply(reduction, engine="numba", axis=apply_axis) | ||
expected = float_frame.apply(reduction, engine="python", axis=apply_axis) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
@pytest.mark.parametrize("colnames", [[1, 2, 3], [1.0, 2.0, 3.0]]) | ||
def test_numba_numeric_colnames(colnames): | ||
# Check that numeric column names lower properly and can be indxed on | ||
df = DataFrame(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), columns=colnames) | ||
first_col = colnames[0] | ||
f = lambda x: x[first_col] # Get the first column | ||
result = df.apply(f, engine="numba", axis=1) | ||
expected = df.apply(f, engine="python", axis=1) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
def test_numba_parallel_unsupported(float_frame): | ||
f = lambda x: x | ||
with pytest.raises( | ||
NotImplementedError, | ||
match="Parallel apply is not supported when raw=False and engine='numba'", | ||
): | ||
float_frame.apply(f, engine="numba", engine_kwargs={"parallel": True}) | ||
|
||
|
||
def test_numba_nonunique_unsupported(): | ||
f = lambda x: x | ||
df = DataFrame({"a": [1, 2], "b": [1, 2]}) | ||
mroeschke marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with pytest.raises( | ||
NotImplementedError, | ||
match="The index/columns must be unique when raw=False and engine='numba'", | ||
): | ||
df.apply(f, engine="numba", engine_kwargs={"parallel": True}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block just defines the types for Index and Series, there isn't much to see here.
It is pretty boilerplate and standard.