Skip to content

ENH: Add numba engine to df.apply #55104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1fa802c
ENH: Add numba engine to df.apply
lithomas1 Sep 12, 2023
c6af7c9
Merge branch 'main' of github.com:pandas-dev/pandas into numba-apply
lithomas1 Sep 12, 2023
0ac544d
complete?
lithomas1 Sep 14, 2023
31b9e20
wip: pass tests
lithomas1 Sep 19, 2023
6190772
Merge branch 'main' of github.com:pandas-dev/pandas into numba-apply
lithomas1 Sep 24, 2023
55df7ad
fix existing tests
lithomas1 Sep 24, 2023
3c89b0f
go for green
lithomas1 Sep 25, 2023
1418d3e
fix checks?
lithomas1 Sep 25, 2023
c143c67
fix pyright
lithomas1 Sep 25, 2023
0d827c4
update docs
lithomas1 Sep 28, 2023
7129ee8
Merge branch 'main' of github.com:pandas-dev/pandas into numba-apply
lithomas1 Sep 28, 2023
b0ba283
Merge branch 'main' into numba-apply
lithomas1 Sep 29, 2023
f4e80a6
eliminate a blank line
lithomas1 Sep 29, 2023
21e2186
update from code review + more tests
lithomas1 Oct 7, 2023
b60bef8
Merge branch 'main' of github.com:pandas-dev/pandas into numba-apply
lithomas1 Oct 9, 2023
ba1d0e0
fix failing tests
lithomas1 Oct 10, 2023
088d27f
Simplify w/ context manager
lithomas1 Oct 12, 2023
60539a1
skip if no numba
lithomas1 Oct 12, 2023
76538d6
simplify more
lithomas1 Oct 12, 2023
cca34f9
specify dtypes
lithomas1 Oct 12, 2023
8b423bf
Merge branch 'main' of github.com:pandas-dev/pandas into numba-apply
lithomas1 Oct 15, 2023
f135def
Merge branch 'main' into numba-apply
lithomas1 Oct 16, 2023
b2e50d2
Merge branch 'numba-apply' of github.com:lithomas1/pandas into numba-…
lithomas1 Oct 16, 2023
f86024f
address code review
lithomas1 Oct 16, 2023
a15293d
add errors for invalid columns
lithomas1 Oct 19, 2023
8fe5d89
adjust message
lithomas1 Oct 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions pandas/core/_numba/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
import operator

import numba
from numba.core import (
cgutils,
types,
)
from numba import types
from numba.core import cgutils
from numba.core.datamodel import models
from numba.core.extending import (
NativeValue,
Expand All @@ -40,7 +38,7 @@


# TODO: Range index support
# (not passing an index to series constructor doesn't work)
# (this currently lowers OK, but does not round-trip)
class IndexType(types.Type):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block just defines the types for Index and Series, there isn't much to see here.

It is pretty boilerplate and standard.

"""
The type class for Index objects.
Expand Down Expand Up @@ -149,6 +147,7 @@ def typer(data, hashmap=None):
@register_model(IndexType)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This defines the numba representations of index/series.

Only interesting thing here is that Index has a pointer to the original index object, so we can avoid calling the index constructor and then just return that object.

Also, we add a hashmap to the index for indexing purposes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the hashmap support duplicate values like a pandas Index would?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I will update and add some tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, it's probably easier for now to disallow duplicate indexes.

I don't most frames have duplicate columns/indexes.

class IndexModel(models.StructModel):
def __init__(self, dmm, fe_type) -> None:
# We don't want the numpy string scalar type in our hashmap
members = [
("data", fe_type.as_array),
# This is an attempt to emulate our hashtable code with a numba
Expand Down Expand Up @@ -240,6 +239,25 @@ def index_impl(data):
return context.compile_internal(builder, index_impl, sig, args)


# Helper to convert the unicodecharseq (numpy string scalar) into a unicode_type
# (regular string)


def maybe_cast_str(x):
# Dummy function that numba can overload
pass


@overload(maybe_cast_str)
def maybe_cast_str_impl(x):
"""Converts numba UnicodeCharSeq (numpy string scalar) -> unicode type (string).
Is a no-op for other types."""
if isinstance(x, types.UnicodeCharSeq):
return lambda x: str(x)
else:
return lambda x: x


@unbox(IndexType)
def unbox_index(typ, obj, c):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code that transforms Series/Index -> numba representations and back.

There's a lot of C API stuff here, it's maybe worth a closer look if you want.

"""
Expand Down Expand Up @@ -426,8 +444,12 @@ def series_binop_impl(series1, value):
series_reductions = [
("sum", np.sum),
("mean", np.mean),
("std", np.std),
("var", np.var),
# Disabled due to discrepancies between numba std. dev
# and pandas std. dev (no way to specify dof)
# ("std", np.std),
# ("var", np.var),
("min", np.min),
("max", np.max),
]
for reduction, reduction_method in series_reductions:
generate_series_reduction(reduction, reduction_method)
Expand Down
86 changes: 67 additions & 19 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,14 @@ def apply_series_generator(self) -> tuple[ResType, Index]:
return results, res_index

def apply_series_numba(self):
if self.engine_kwargs.get("parallel", False):
raise NotImplementedError(
"Parallel apply is not supported when raw=False and engine='numba'"
)
if not self.obj.index.is_unique or not self.columns.is_unique:
raise NotImplementedError(
"The index/columns must be unique when raw=False and engine='numba'"
)
results = self.apply_with_numba()
return results, self.result_index

Expand Down Expand Up @@ -1128,6 +1136,7 @@ def generate_numba_apply_func(
# This isn't an entrypoint since we don't want users
# using Series/DF in numba code outside of apply
from pandas.core._numba.extensions import SeriesType # noqa: F401
from pandas.core._numba.extensions import maybe_cast_str

numba = import_optional_dependency("numba")

Expand All @@ -1138,7 +1147,9 @@ def numba_func(values, col_names, df_index):
results = {}
for j in range(values.shape[1]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for j in range(values.shape[1]):
for j in numba.prange(values.shape[1]):

? (and below)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out!

I think for now it'll probably make better sense to disable parallel mode for now, since the dict in numba isn't thread-safe.

The overhead from the boxing/unboxing is also really high (99% of the time spent is there), so I doubt parallel will give a good speedup, at least for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK makes sense. Would be good to put a TODO: comment explaining why we shouldn't use prange for now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment.

# Create the series
ser = Series(values[:, j], index=df_index, name=str(col_names[j]))
ser = Series(
values[:, j], index=df_index, name=maybe_cast_str(col_names[j])
)
results[j] = jitted_udf(ser)
return results

Expand All @@ -1148,23 +1159,40 @@ def apply_with_numba(self) -> dict[int, Any]:
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func), **self.engine_kwargs
)
orig_values = self.columns.to_numpy()
fixed_cols = False
if orig_values.dtype == object:
if not lib.is_string_array(orig_values):
# Since numpy/numba doesn't support object array of stringswell
# we'll do a sketchy thing where if index._data is object
# we convert to string and directly set index._data to that,
# setting it back after we call the function
fixed_obj_colnames = False
orig_cols = self.columns.to_numpy()
if self.columns._data.dtype == object:
if not lib.is_string_array(orig_cols):
raise ValueError(
"The numba engine only supports "
"using string or numeric column names"
)
col_names_values = orig_values.astype("U")
# Remember to set this back!
self.columns._data = col_names_values
fixed_cols = True
# Remember to set this back!!!
self.columns._data = orig_cols.astype("U")
fixed_obj_colnames = True

fixed_obj_index = False
orig_index = self.index.to_numpy()
if self.obj.index._data.dtype == object:
if not lib.is_string_array(orig_index):
raise ValueError(
"The numba engine only supports "
"using string or numeric index values"
)
# Remember to set this back!!!
self.obj.index._data = orig_index.astype("U")
fixed_obj_index = True
df_index = self.obj.index

res = dict(nb_func(self.values, self.columns, df_index))
if fixed_cols:
self.columns._data = orig_values
if fixed_obj_colnames:
self.columns._data = orig_cols
if fixed_obj_index:
self.obj.index._data = orig_index
return res

@property
Expand Down Expand Up @@ -1260,6 +1288,7 @@ def generate_numba_apply_func(
# using Series/DF in numba code outside of apply
from pandas import Series
from pandas.core._numba.extensions import SeriesType # noqa: F401
from pandas.core._numba.extensions import maybe_cast_str

numba = import_optional_dependency("numba")

Expand All @@ -1271,7 +1300,11 @@ def numba_func(values, col_names_index, index):
for i in range(values.shape[0]):
# Create the series
# TODO: values corrupted without the copy
ser = Series(values[i].copy(), index=col_names_index, name=index[i])
ser = Series(
values[i].copy(),
index=col_names_index,
name=maybe_cast_str(index[i]),
)
results[i] = jitted_udf(ser)

return results
Expand All @@ -1287,24 +1320,39 @@ def apply_with_numba(self) -> dict[int, Any]:
# we'll do a sketchy thing where if index._data is object
# we convert to string and directly set index._data to that,
# setting it back after we call the function
fixed_obj_dtype = False
orig_data = self.columns.to_numpy()
fixed_obj_colnames = False
orig_cols = self.columns.to_numpy()
if self.columns._data.dtype == object:
if not lib.is_string_array(orig_data):
if not lib.is_string_array(orig_cols):
raise ValueError(
"The numba engine only supports "
"using string or numeric column names"
)
# Remember to set this back!!!
self.columns._data = orig_data.astype("U")
fixed_obj_dtype = True
self.columns._data = orig_cols.astype("U")
fixed_obj_colnames = True

fixed_obj_index = False
orig_index = self.index.to_numpy()
if self.obj.index._data.dtype == object:
if not lib.is_string_array(orig_index):
raise ValueError(
"The numba engine only supports "
"using string or numeric index values"
)
# Remember to set this back!!!
self.obj.index._data = orig_index.astype("U")
fixed_obj_index = True

# Convert from numba dict to regular dict
# Our isinstance checks in the df constructor don't pass for numbas typed dict
res = dict(nb_func(self.values, self.columns, self.obj.index))

if fixed_obj_dtype:
self.columns._data = orig_data
if fixed_obj_colnames:
self.columns._data = orig_cols

if fixed_obj_index:
self.obj.index._data = orig_index

return res

Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/apply/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,15 @@ def int_frame_const_col():
columns=["A", "B", "C"],
)
return df


@pytest.fixture(params=["python", "numba"])
def engine(request):
if request.param == "numba":
pytest.importorskip("numba")
return request.param


@pytest.fixture(params=[0, 1])
def apply_axis(request):
return request.param
22 changes: 11 additions & 11 deletions pandas/tests/apply/test_frame_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,6 @@
from pandas.tests.frame.common import zip_frames


@pytest.fixture(params=["python", "numba"])
def engine(request):
if request.param == "numba":
pytest.importorskip("numba")
return request.param


def test_apply(float_frame, engine, request):
if engine == "numba":
mark = pytest.mark.xfail(reason="numba engine not supporting numpy ufunc yet")
Expand Down Expand Up @@ -102,7 +95,7 @@ def test_apply_mixed_datetimelike():


@pytest.mark.parametrize("func", [np.sqrt, np.mean])
def test_apply_empty(func, engine=engine):
def test_apply_empty(func, engine):
# empty
empty_frame = DataFrame()

Expand Down Expand Up @@ -983,15 +976,17 @@ def test_result_type_shorter_list(int_frame_const_col):
tm.assert_frame_equal(result, expected)


def test_result_type_broadcast(int_frame_const_col, request):
def test_result_type_broadcast(int_frame_const_col, request, engine):
# result_type should be consistent no matter which
# path we take in the code
if engine == "numba":
mark = pytest.mark.xfail(reason="numba engine doesn't support list return")
request.node.add_marker(mark)
df = int_frame_const_col
# broadcast result
result = df.apply(lambda x: [1, 2, 3], axis=1, result_type="broadcast")
result = df.apply(
lambda x: [1, 2, 3], axis=1, result_type="broadcast", engine=engine
)
expected = df.copy()
tm.assert_frame_equal(result, expected)

Expand Down Expand Up @@ -1550,8 +1545,13 @@ def sum_div2(s):
tm.assert_frame_equal(result, expected)


def test_apply_getitem_axis_1(engine):
def test_apply_getitem_axis_1(engine, request):
# GH 13427
if engine == "numba":
mark = pytest.mark.xfail(
reason="numba engine not supporting duplicate index values"
)
request.node.add_marker(mark)
df = DataFrame({"a": [0, 1, 2], "b": [1, 2, 3]})
result = df[["a", "a"]].apply(
lambda x: x.iloc[0] + x.iloc[1], axis=1, engine=engine
Expand Down
64 changes: 64 additions & 0 deletions pandas/tests/apply/test_numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import numpy as np
import pytest

from pandas import DataFrame
import pandas._testing as tm


def test_numba_vs_python_noop(float_frame, apply_axis):
func = lambda x: x
result = float_frame.apply(func, engine="numba", axis=apply_axis)
expected = float_frame.apply(func, engine="python", axis=apply_axis)
tm.assert_frame_equal(result, expected)


def test_numba_vs_python_indexing(float_frame):
row_func = lambda x: x["A"]
result = float_frame.apply(row_func, engine="numba", axis=1)
expected = float_frame.apply(row_func, engine="python", axis=1)
tm.assert_series_equal(result, expected)

row_func = lambda x: x["ZqgszYBfuL"] # This is a label in the index
result = float_frame.apply(row_func, engine="numba", axis=0)
expected = float_frame.apply(row_func, engine="python", axis=0)
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(
"reduction",
[lambda x: x.mean(), lambda x: x.min(), lambda x: x.max(), lambda x: x.sum()],
)
def test_numba_vs_python_reductions(float_frame, reduction, apply_axis):
result = float_frame.apply(reduction, engine="numba", axis=apply_axis)
expected = float_frame.apply(reduction, engine="python", axis=apply_axis)
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("colnames", [[1, 2, 3], [1.0, 2.0, 3.0]])
def test_numba_numeric_colnames(colnames):
# Check that numeric column names lower properly and can be indxed on
df = DataFrame(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), columns=colnames)
first_col = colnames[0]
f = lambda x: x[first_col] # Get the first column
result = df.apply(f, engine="numba", axis=1)
expected = df.apply(f, engine="python", axis=1)
tm.assert_series_equal(result, expected)


def test_numba_parallel_unsupported(float_frame):
f = lambda x: x
with pytest.raises(
NotImplementedError,
match="Parallel apply is not supported when raw=False and engine='numba'",
):
float_frame.apply(f, engine="numba", engine_kwargs={"parallel": True})


def test_numba_nonunique_unsupported():
f = lambda x: x
df = DataFrame({"a": [1, 2], "b": [1, 2]})
with pytest.raises(
NotImplementedError,
match="The index/columns must be unique when raw=False and engine='numba'",
):
df.apply(f, engine="numba", engine_kwargs={"parallel": True})