-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
BUG: Fixed groupby quantile for listlike q #27827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
0236046
c1a447f
ca2411a
e060d1d
f66a67b
421c80f
624d33b
a9fb4f6
945648e
dc5147a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1870,6 +1870,7 @@ def quantile(self, q=0.5, interpolation="linear"): | |
a 2.0 | ||
b 3.0 | ||
""" | ||
from pandas import concat | ||
|
||
def pre_processor(vals: np.ndarray) -> Tuple[np.ndarray, Optional[Type]]: | ||
if is_object_dtype(vals): | ||
|
@@ -1897,18 +1898,53 @@ def post_processor(vals: np.ndarray, inference: Optional[Type]) -> np.ndarray: | |
|
||
return vals | ||
|
||
return self._get_cythonized_result( | ||
"group_quantile", | ||
self.grouper, | ||
aggregate=True, | ||
needs_values=True, | ||
needs_mask=True, | ||
cython_dtype=np.float64, | ||
pre_processing=pre_processor, | ||
post_processing=post_processor, | ||
q=q, | ||
interpolation=interpolation, | ||
) | ||
if is_scalar(q): | ||
return self._get_cythonized_result( | ||
"group_quantile", | ||
self.grouper, | ||
aggregate=True, | ||
needs_values=True, | ||
needs_mask=True, | ||
cython_dtype=np.float64, | ||
pre_processing=pre_processor, | ||
post_processing=post_processor, | ||
q=q, | ||
interpolation=interpolation, | ||
) | ||
else: | ||
results = [ | ||
self._get_cythonized_result( | ||
"group_quantile", | ||
self.grouper, | ||
aggregate=True, | ||
needs_values=True, | ||
needs_mask=True, | ||
cython_dtype=np.float64, | ||
pre_processing=pre_processor, | ||
post_processing=post_processor, | ||
q=qi, | ||
interpolation=interpolation, | ||
) | ||
for qi in q | ||
] | ||
# fix levels to place quantiles on the inside | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would welcome improvements to this. Essentially, we have two things to do after concating the results
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yah this isn't fun. just checking: we can have mixed dtypes here? if we had a single dtype then stack/unstack would be cheap and there might be a prettier workaround There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately that sorts the axis. I theory, we should be able to follow it up with a diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py
index c00bd8398..617466b8d 100644
--- a/pandas/core/groupby/groupby.py
+++ b/pandas/core/groupby/groupby.py
@@ -1928,23 +1928,10 @@ class GroupBy(_GroupBy):
for qi in q
]
# fix levels to place quantiles on the inside
- result = concat(results, axis=0, keys=q)
- order = np.roll(list(range(result.index.nlevels)), -1)
- result = result.reorder_levels(order)
- result = result.reindex(q, level=-1)
-
- # fix order.
- hi = len(q) * self.ngroups
- arr = np.arange(0, hi, self.ngroups)
- arrays = []
-
- for i in range(self.ngroups):
- arr = arr + i
- arrays.append(arr)
-
- indices = np.concatenate(arrays)
- assert len(indices) == len(result)
- return result.take(indices)
+ result = concat(results, axis=1, keys=q).stack(0)
+ slices = [slice(None)] * result.index.ndim
+ result = result.loc[tuple(slices), :]
+ return result
@Substitution(name="groupby")
def ngroup(self, ascending=True): But that runs into the (I think known) issue with loc & a MultiIndex In [21]: s = pd.Series(list(range(12)), index=pd.MultiIndex.from_product([['a', 'b', 'c'], [1, 2, 3, 4]]))
In [22]: s.loc[pd.IndexSlice[:, [3, 1, 2, 4]]]
Out[22]:
a 1 0
2 1
3 2
4 3
b 1 4
2 5
3 6
4 7
c 1 8
2 9
3 10
4 11
dtype: int64 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Darn. Can you add a my-idea-doesnt-work case to the tests |
||
result = concat(results, axis=0, keys=q) | ||
order = np.roll(list(range(result.index.nlevels)), -1) | ||
result = result.reorder_levels(order) | ||
result = result.reindex(q, level=-1) | ||
|
||
# fix order. | ||
hi = len(q) * self.ngroups | ||
arr = np.arange(0, hi, self.ngroups) | ||
arrays = [] | ||
|
||
for i in range(self.ngroups): | ||
arr = arr + i | ||
arrays.append(arr) | ||
|
||
indices = np.concatenate(arrays) | ||
assert len(indices) == len(result) | ||
return result.take(indices) | ||
|
||
@Substitution(name="groupby") | ||
def ngroup(self, ascending=True): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: pandas
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Multiple -> list-like