-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH/BUG: Use Kleene logic for groupby any/all #40819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
088ca14
2554921
9a8f9c9
5ca9c4b
68fd995
6530491
26146c2
20f475d
924b38e
423f43f
b1408ac
47ef037
4415060
bb04c1c
9c90886
ef3fbe2
f4c8a8a
1c3cb7d
7cbf85b
809b8a4
58fd33a
80a65bb
c9b9d5f
7514568
740ad7b
a116bed
8a428d4
8e3c5be
b627618
23b3b64
a30496c
3051a99
4cd2833
98cd401
a92c637
7c5c8e6
c66d1fd
0950234
c81c1a5
d2b8ad0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -393,9 +393,11 @@ def group_any_all(uint8_t[::1] out, | |
const intp_t[:] labels, | ||
const uint8_t[::1] mask, | ||
str val_test, | ||
bint skipna) -> None: | ||
bint skipna, | ||
bint use_kleene_logic = False) -> None: | ||
""" | ||
Aggregated boolean values to show truthfulness of group elements. | ||
Aggregated boolean values to show truthfulness of group elements. If the | ||
input is a nullable type, Kleene logic will be used. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -412,11 +414,14 @@ def group_any_all(uint8_t[::1] out, | |
String object dictating whether to use any or all truth testing | ||
skipna : bool | ||
Flag to ignore nan values during truth testing | ||
use_kleene_logic : bool, default False | ||
Whether or not to compute the result using Kleene logic | ||
|
||
Notes | ||
----- | ||
This method modifies the `out` parameter rather than returning an object. | ||
The returned values will either be 0 or 1 (False or True, respectively). | ||
The returned values will either be 0, 1 (False or True, respectively), or | ||
2 to signify a masked position in the case of a nullable input. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm is this what we do elsewhere to encode Kleene logic? (e.g. would rather use -1 on nulls) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The other algo is a bit different since it can be vectorized and computes both data and mask (so doesn't need to store something signifying NULL). Agree that -1 is more intuitive, have changed to use -1 as the mask signal There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At some point, it might be interesting to look into directly filling in the mask as well instead of using the temporary -1 (but not for this PR to be clear ;)). Personally, I would maybe leave it at using 2, so you don't need to change the uint8 type There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep I don't think that would be too bad, that's how I had approached it originally, but changes were slightly more invasive, so stuck with this. For the -1 vs 2 question, I think I slightly prefer -1 because it seems more immediately apparent what it represents. Although it forces changing the dtype, still the same number of bytes, so memory layout is unaffected and constructing the view looks equivalent speed-wise:
But not a strong opinion, fine with using 2 as well :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If speed-wise the int8 vs uint8 doesn't matter, then indeed the -1 is fine as well |
||
""" | ||
cdef: | ||
Py_ssize_t i, N = len(labels) | ||
|
@@ -444,6 +449,16 @@ def group_any_all(uint8_t[::1] out, | |
if lab < 0 or (skipna and mask[i]): | ||
continue | ||
|
||
if use_kleene_logic and mask[i]: | ||
# Set the position as masked if `out[lab] != flag_val`, which | ||
# would indicate True/False has not yet been seen for any/all, | ||
# so by Kleene logic the result is currently unknown | ||
if out[lab] != flag_val: | ||
out[lab] = 2 | ||
continue | ||
|
||
# If True and 'any' or False and 'all', the result is | ||
# already determined | ||
if values[i] == flag_val: | ||
out[lab] = flag_val | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,6 +80,8 @@ class providing the base-class of operations. | |
Categorical, | ||
ExtensionArray, | ||
) | ||
from pandas.core.arrays.boolean import BooleanArray | ||
from pandas.core.arrays.masked import BaseMaskedArray | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can import both from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks! |
||
from pandas.core.base import ( | ||
DataError, | ||
PandasObject, | ||
|
@@ -1413,16 +1415,25 @@ def _bool_agg(self, val_test, skipna): | |
Shared func to call any / all Cython GroupBy implementations. | ||
""" | ||
|
||
def objs_to_bool(vals: np.ndarray) -> tuple[np.ndarray, type]: | ||
def objs_to_bool(vals: ArrayLike) -> tuple[np.ndarray, type]: | ||
if is_object_dtype(vals): | ||
vals = np.array([bool(x) for x in vals]) | ||
elif isinstance(vals, BaseMaskedArray): | ||
vals = vals._data.astype(bool, copy=False) | ||
else: | ||
vals = vals.astype(bool) | ||
|
||
return vals.view(np.uint8), bool | ||
|
||
def result_to_bool(result: np.ndarray, inference: type) -> np.ndarray: | ||
return result.astype(inference, copy=False) | ||
def result_to_bool( | ||
result: np.ndarray, | ||
inference: type, | ||
result_is_nullable: bool = False, | ||
) -> ArrayLike: | ||
if result_is_nullable: | ||
return BooleanArray(result.astype(bool, copy=False), result == 2) | ||
else: | ||
return result.astype(inference, copy=False) | ||
|
||
return self._get_cythonized_result( | ||
"group_any_all", | ||
|
@@ -2735,6 +2746,10 @@ def _get_cythonized_result( | |
mask = isna(values).view(np.uint8) | ||
func = partial(func, mask) | ||
|
||
if how == "group_any_all" and isinstance(values, BaseMaskedArray): | ||
post_processing = partial(post_processing, result_is_nullable=True) | ||
func = partial(func, use_kleene_logic=True) | ||
rhshadrach marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if needs_ngroups: | ||
func = partial(func, ngroups) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
import numpy as np | ||
import pytest | ||
|
||
import pandas as pd | ||
from pandas import ( | ||
DataFrame, | ||
Index, | ||
|
@@ -68,3 +69,43 @@ def test_bool_aggs_dup_column_labels(bool_agg_func): | |
|
||
expected = df | ||
tm.assert_frame_equal(result, expected) | ||
|
||
|
||
@pytest.mark.parametrize("bool_agg_func", ["any", "all"]) | ||
@pytest.mark.parametrize("skipna", [True, False]) | ||
@pytest.mark.parametrize( | ||
# expected indexed as [skipna][bool_agg_func == "all"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you maybe also write out what that results in practice? So [skipna=False/any, skipna=False / all], [skipna=True / any, skipna=True / all], if I understand correctly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep have changed, that is a much clearer way to describe it |
||
"data,expected", | ||
[ | ||
([False, False, False], [[False, False], [False, False]]), | ||
([True, True, True], [[True, True], [True, True]]), | ||
([pd.NA, pd.NA, pd.NA], [[pd.NA, pd.NA], [False, True]]), | ||
([False, pd.NA, False], [[pd.NA, False], [False, False]]), | ||
([True, pd.NA, True], [[True, pd.NA], [True, True]]), | ||
([True, pd.NA, False], [[True, False], [True, False]]), | ||
], | ||
) | ||
def test_masked_kleene_logic(bool_agg_func, data, expected, skipna): | ||
# GH#37506 | ||
df = DataFrame(data, dtype="boolean") | ||
expected = DataFrame( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this expected is really hard to parse can you do in multiple steps / make simpler There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressing your comment above makes this simpler, let me know if you think any piece is still confusing |
||
[expected[skipna][bool_agg_func == "all"]], dtype="boolean", index=[1] | ||
) | ||
|
||
result = df.groupby([1, 1, 1]).agg(bool_agg_func, skipna=skipna) | ||
tm.assert_frame_equal(result, expected) | ||
jorisvandenbossche marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@pytest.mark.parametrize("bool_agg_func", ["any", "all"]) | ||
@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"]) | ||
@pytest.mark.parametrize("skipna", [True, False]) | ||
def test_masked_bool_aggs_skipna(bool_agg_func, dtype, skipna, frame_or_series): | ||
# GH#40585 | ||
obj = frame_or_series([pd.NA, 1], dtype=dtype) | ||
expected_res = True | ||
if not skipna and bool_agg_func == "all": | ||
expected_res = pd.NA | ||
expected = frame_or_series([expected_res], index=[1], dtype="boolean") | ||
|
||
result = obj.groupby([1, 1]).agg(bool_agg_func, skipna=skipna) | ||
tm.assert_equal(result, expected) |
Uh oh!
There was an error while loading. Please reload this page.