Skip to content

Commit 4e83066

Browse files
authored
ENH: Numba groupby support multiple labels (#53556)
* ENH: Numba groupby support multiple labels * update regex
1 parent eca28a3 commit 4e83066

File tree

4 files changed

+87
-6
lines changed

4 files changed

+87
-6
lines changed

doc/source/whatsnew/v2.1.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ Other enhancements
9999
- :meth:`Categorical.from_codes` has gotten a ``validate`` parameter (:issue:`50975`)
100100
- :meth:`DataFrame.stack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
101101
- :meth:`DataFrame.unstack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
102+
- :meth:`DataFrameGroupby.agg` and :meth:`DataFrameGroupby.transform` now support grouping by multiple keys when the index is not a :class:`MultiIndex` for ``engine="numba"`` (:issue:`53486`)
102103
- :meth:`SeriesGroupby.agg` and :meth:`DataFrameGroupby.agg` now support passing in multiple functions for ``engine="numba"`` (:issue:`53486`)
103104
- Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`)
104105
- Added a new parameter ``by_row`` to :meth:`Series.apply`. When set to ``False`` the supplied callables will always operate on the whole Series (:issue:`53400`).

pandas/core/groupby/groupby.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,13 +1453,14 @@ def _numba_prep(self, data: DataFrame):
14531453
sorted_ids = self.grouper._sorted_ids
14541454

14551455
sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
1456-
if len(self.grouper.groupings) > 1:
1457-
raise NotImplementedError(
1458-
"More than 1 grouping labels are not supported with engine='numba'"
1459-
)
14601456
# GH 46867
14611457
index_data = data.index
14621458
if isinstance(index_data, MultiIndex):
1459+
if len(self.grouper.groupings) > 1:
1460+
raise NotImplementedError(
1461+
"Grouping with more than 1 grouping labels and "
1462+
"a MultiIndex is not supported with engine='numba'"
1463+
)
14631464
group_key = self.grouper.groupings[0].name
14641465
index_data = index_data.get_level_values(group_key)
14651466
sorted_index_data = index_data.take(sorted_index).to_numpy()

pandas/tests/groupby/aggregate/test_numba.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,44 @@ def numba_func(values, index):
339339

340340
df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
341341
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
342-
with pytest.raises(NotImplementedError, match="More than 1 grouping labels"):
342+
with pytest.raises(NotImplementedError, match="more than 1 grouping labels"):
343343
df.groupby(["A", "B"]).agg(
344344
numba_func, engine="numba", engine_kwargs=engine_kwargs
345345
)
346+
347+
348+
@td.skip_if_no("numba")
349+
def test_multilabel_numba_vs_cython(numba_supported_reductions):
350+
reduction, kwargs = numba_supported_reductions
351+
df = DataFrame(
352+
{
353+
"A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"],
354+
"B": ["one", "one", "two", "three", "two", "two", "one", "three"],
355+
"C": np.random.randn(8),
356+
"D": np.random.randn(8),
357+
}
358+
)
359+
gb = df.groupby(["A", "B"])
360+
res_agg = gb.agg(reduction, engine="numba", **kwargs)
361+
expected_agg = gb.agg(reduction, engine="cython", **kwargs)
362+
tm.assert_frame_equal(res_agg, expected_agg)
363+
# Test that calling the aggregation directly also works
364+
direct_res = getattr(gb, reduction)(engine="numba", **kwargs)
365+
direct_expected = getattr(gb, reduction)(engine="cython", **kwargs)
366+
tm.assert_frame_equal(direct_res, direct_expected)
367+
368+
369+
@td.skip_if_no("numba")
370+
def test_multilabel_udf_numba_vs_cython():
371+
df = DataFrame(
372+
{
373+
"A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"],
374+
"B": ["one", "one", "two", "three", "two", "two", "one", "three"],
375+
"C": np.random.randn(8),
376+
"D": np.random.randn(8),
377+
}
378+
)
379+
gb = df.groupby(["A", "B"])
380+
result = gb.agg(lambda values, index: values.min(), engine="numba")
381+
expected = gb.agg(lambda x: x.min(), engine="cython")
382+
tm.assert_frame_equal(result, expected)

pandas/tests/groupby/transform/test_numba.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import pytest
23

34
from pandas.errors import NumbaUtilError
@@ -224,7 +225,48 @@ def numba_func(values, index):
224225

225226
df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
226227
engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
227-
with pytest.raises(NotImplementedError, match="More than 1 grouping labels"):
228+
with pytest.raises(NotImplementedError, match="more than 1 grouping labels"):
228229
df.groupby(["A", "B"]).transform(
229230
numba_func, engine="numba", engine_kwargs=engine_kwargs
230231
)
232+
233+
234+
@td.skip_if_no("numba")
235+
@pytest.mark.xfail(
236+
reason="Groupby transform doesn't support strings as function inputs yet with numba"
237+
)
238+
def test_multilabel_numba_vs_cython(numba_supported_reductions):
239+
reduction, kwargs = numba_supported_reductions
240+
df = DataFrame(
241+
{
242+
"A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"],
243+
"B": ["one", "one", "two", "three", "two", "two", "one", "three"],
244+
"C": np.random.randn(8),
245+
"D": np.random.randn(8),
246+
}
247+
)
248+
gb = df.groupby(["A", "B"])
249+
res_agg = gb.transform(reduction, engine="numba", **kwargs)
250+
expected_agg = gb.transform(reduction, engine="cython", **kwargs)
251+
tm.assert_frame_equal(res_agg, expected_agg)
252+
253+
254+
@td.skip_if_no("numba")
255+
def test_multilabel_udf_numba_vs_cython():
256+
df = DataFrame(
257+
{
258+
"A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"],
259+
"B": ["one", "one", "two", "three", "two", "two", "one", "three"],
260+
"C": np.random.randn(8),
261+
"D": np.random.randn(8),
262+
}
263+
)
264+
gb = df.groupby(["A", "B"])
265+
result = gb.transform(
266+
lambda values, index: (values - values.min()) / (values.max() - values.min()),
267+
engine="numba",
268+
)
269+
expected = gb.transform(
270+
lambda x: (x - x.min()) / (x.max() - x.min()), engine="cython"
271+
)
272+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)