-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: Fix by
in DataFrame.plot.hist and DataFrame.plot.box
#28373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 71 commits
7e461a1
1314059
8bcb313
e36592c
b2f45a6
8b6e00a
dc0c2ec
d803938
33dd762
d59d642
66eb06c
ea267ad
8095224
31decc1
e4bdbd0
4033159
57a3bdf
8060223
d666334
3216d59
45f4b7f
2b0785b
d79dba3
eca597b
321fbd2
a7b9ae5
d2d13fd
5abedb6
d73115a
d7998bb
1bbf7ea
a279f45
2b793ea
04de066
4adc324
d0103a4
f94dbb4
0415cb0
525200b
1ab4310
c005880
a1fabc5
70453f1
b6579a5
e99f3dc
99d6d67
8e2fcf6
d02f4ac
947189c
6b5203d
27d0d21
48ff521
f39d948
5d1705c
90471aa
46a8031
57a96e6
29127f0
61bb97f
62fb9e6
638174b
02de005
5adb25d
7051432
adbde9f
5dfad18
abd10f3
c20d81a
07112c0
fb0b87c
a6a8e57
7f77f48
c09bb19
a120d27
f87afee
82711ee
60f7298
071488b
f2a0210
bb07e15
867094a
b0f06b2
111e89c
6472053
83ec868
d6c8566
6a0ac8d
49d0791
2bfbe78
c5d7518
03356ce
7abc47d
db832b4
be99a97
9ae5987
10c2ad1
ce8cfd4
627cc02
4bfbf03
ee8972d
12ff785
163f920
0839be2
142ee53
ef65137
d793703
2710cf2
f76d2cb
a5ecbd7
439be51
5fd420e
4b4832f
7425dff
8ab4b90
b06e454
79294ed
9523bb9
cd59370
050ba95
add406f
bb22c53
25214e6
aaa5c95
77e46f4
9de9c61
af68d2e
b75015a
b90303d
898fa9b
2ac32f5
f7bcdb7
aeb32e5
dc17959
4aee3e0
4eb466f
5160224
e2de0d3
b2b33ac
1199a93
c4a5842
6556414
826f277
891dc55
4c4a158
ea7e5b1
006588e
4f0a1dc
e1579e2
f2c141f
5f96abd
06483af
08f534d
f1c3a1f
e6e96d3
52e47f1
bc2f282
a43d3bb
ceeb3c5
4fea841
3ea2603
9f48139
b1094e3
97bde59
444a964
b66dad0
982f562
c76ad67
2c1aa33
6896546
3c54302
2d20178
a169dfd
d0b56ff
dec313c
143f286
283286f
f2a0736
f1aeee0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -190,6 +190,8 @@ I/O | |
Plotting | ||
^^^^^^^^ | ||
|
||
- | ||
- Implement ``by`` argument for :meth:`DataFrame.plot.hist` (:issue:`15079`) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move to other enhancements section There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moved |
||
- :func:`.plot` for line/bar now accepts color by dictonary (:issue:`8193`). | ||
- | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1209,6 +1209,16 @@ def hist(self, by=None, bins=10, **kwargs): | |
... columns = ['one']) | ||
>>> df['two'] = df['one'] + np.random.randint(1, 7, 6000) | ||
>>> ax = df.plot.hist(bins=12, alpha=0.5) | ||
|
||
If `by` is defined, a grouped hist plot is generated: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a versionadded tag 1.1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, added above in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is context on what by can be (str, list of str?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is already defined lines above in |
||
|
||
.. plot:: | ||
:context: close-figs | ||
|
||
>>> np.random.seed(159753) | ||
>>> df = pd.DataFrame(np.random.randn(30, 2), columns=['A', 'B']) | ||
>>> df['C'] = np.random.choice(['a', 'b', 'c'], 30) | ||
>>> ax = df.plot.hist(column=['A', 'B'], by=['C'], figsize=(8, 10)) | ||
""" | ||
return self(kind="hist", by=by, bins=bins, **kwargs) | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -23,7 +23,9 @@ | |||||
) | ||||||
from pandas.core.dtypes.missing import isna, notna | ||||||
|
||||||
from pandas import MultiIndex | ||||||
import pandas.core.common as com | ||||||
from pandas.core.reshape.concat import concat | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can include with the from pandas import There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed! |
||||||
|
||||||
from pandas.io.formats.printing import pprint_thing | ||||||
from pandas.plotting._matplotlib.compat import _mpl_ge_3_0_0 | ||||||
|
@@ -102,13 +104,15 @@ def __init__( | |||||
table=False, | ||||||
layout=None, | ||||||
include_bool=False, | ||||||
column=None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you annotate this? I think should just be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing a push? I don't see the change here |
||||||
**kwds, | ||||||
): | ||||||
|
||||||
import matplotlib.pyplot as plt | ||||||
|
||||||
self.data = data | ||||||
self.by = by | ||||||
self.column = [column] if not isinstance(column, list) else column | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor but I think this should be called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i agree, but here the reason i chose There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with @WillAyd here. I see your point. But When |
||||||
|
||||||
self.kind = kind | ||||||
|
||||||
|
@@ -117,7 +121,9 @@ def __init__( | |||||
self.subplots = subplots | ||||||
|
||||||
if sharex is None: | ||||||
if ax is None: | ||||||
|
||||||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# if by is defined, subplots are used and sharex should be False | ||||||
if ax is None and by is None: | ||||||
self.sharex = True | ||||||
else: | ||||||
# if we get an axis, the users should do the visibility | ||||||
|
@@ -240,18 +246,30 @@ def _iter_data(self, data=None, keep_index=False, fillna=None): | |||||
if fillna is not None: | ||||||
data = data.fillna(fillna) | ||||||
|
||||||
for col, values in data.items(): | ||||||
if keep_index is True: | ||||||
yield col, values | ||||||
else: | ||||||
yield col, values.values | ||||||
if self.by is None: | ||||||
WillAyd marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
for col, values in data.items(): | ||||||
if keep_index is True: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as @datapythonista below |
||||||
yield col, values | ||||||
else: | ||||||
yield col, values.values | ||||||
else: | ||||||
cols = data.columns.get_level_values(0).unique() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we only plot the first level of a multi index then? Is this limitation documented anywhere? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. emm, I would not call it a limitation, this branch will not be reached even dataset is a MI dataframe while There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK I think I understand. Doesn't this require I think this can be more cleary written if you groupby the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @WillAyd thanks very much for your comment, appreciate a lot! this data here in # GH15079 restructure data if by is defined
if self.by is not None:
self.subplots = True
grouped = data.groupby(self.by)
data_list = []
for key, group in grouped:
columns = MultiIndex.from_product([[key], self.column])
sub_group = group[self.column]
sub_group.columns = columns
data_list.append(sub_group)
data = concat(data_list, axis=1) And here, in |
||||||
|
||||||
for col in cols: | ||||||
data_values = data.loc[:, data.columns.get_level_values(0) == col] | ||||||
if keep_index is True: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks still outstanding? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oops, shame on me, i thought i removed :( |
||||||
yield col, data_values | ||||||
else: | ||||||
yield col, data_values.values | ||||||
|
||||||
@property | ||||||
def nseries(self): | ||||||
if self.data.ndim == 1: | ||||||
return 1 | ||||||
else: | ||||||
elif self.by is None: | ||||||
return self.data.shape[1] | ||||||
else: | ||||||
return len(set(self.data.columns.get_level_values(0))) | ||||||
|
||||||
def draw(self): | ||||||
self.plt.draw_if_interactive() | ||||||
|
@@ -378,6 +396,20 @@ def _compute_plot_data(self): | |||||
label = "None" | ||||||
data = data.to_frame(name=label) | ||||||
|
||||||
# GH15079 restructure data if by is defined | ||||||
if self.by is not None: | ||||||
self.subplots = True | ||||||
grouped = data.groupby(self.by) | ||||||
|
||||||
data_list = [] | ||||||
for key, group in grouped: | ||||||
columns = MultiIndex.from_product([[key], self.column]) | ||||||
sub_group = group[self.column] | ||||||
sub_group.columns = columns | ||||||
data_list.append(sub_group) | ||||||
|
||||||
data = concat(data_list, axis=1) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know in all the other transformations inside this function it's done inline, and you're being consistent. But this is somehow complex, I think it should be in a separate function, so it's clear what's the input, what's the output, we can add a small docstring explaining and with an example. And most importantly, it'll be easier to test properly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, okay, I moved it out as an internal function to transform the data. However, i found it hard to test, so I added a couple inline comments and docstrings, please let me know if there is a good way to test this function out! @datapythonista |
||||||
|
||||||
# GH16953, _convert is needed as fallback, for ``Series`` | ||||||
# with ``dtype == object`` | ||||||
data = data._convert(datetime=True, timedelta=True) | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,13 @@ | ||
from typing import Union | ||
|
||
import numpy as np | ||
|
||
from pandas.core.dtypes.common import is_integer, is_list_like | ||
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass | ||
from pandas.core.dtypes.missing import isna, remove_na_arraylike | ||
|
||
import pandas.core.common as com | ||
from pandas.core.series import Series | ||
|
||
from pandas.io.formats.printing import pprint_thing | ||
from pandas.plotting._matplotlib.core import LinePlot, MPLPlot | ||
|
@@ -21,22 +24,38 @@ def __init__(self, data, bins=10, bottom=0, **kwargs): | |
MPLPlot.__init__(self, data, **kwargs) | ||
|
||
def _args_adjust(self): | ||
|
||
# calculate bin number separately in different subplots | ||
# where subplots are created based on by argument | ||
if is_integer(self.bins): | ||
# create common bin edge | ||
values = self.data._convert(datetime=True)._get_numeric_data() | ||
values = np.ravel(values) | ||
values = values[~isna(values)] | ||
|
||
_, self.bins = np.histogram( | ||
values, | ||
bins=self.bins, | ||
range=self.kwds.get("range", None), | ||
weights=self.kwds.get("weights", None), | ||
) | ||
if self.by is None: | ||
self.bins = self._caculcate_bins(self.data) | ||
|
||
else: | ||
grouped = self.data.groupby(self.by)[self.column] | ||
bins_list = [] | ||
for key, group in grouped: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use a list-comprehension There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indeed! |
||
bins_list.append(self._caculcate_bins(group)) | ||
self.bins = bins_list | ||
|
||
if is_list_like(self.bottom): | ||
self.bottom = np.array(self.bottom) | ||
|
||
def _caculcate_bins(self, data: ABCDataFrame) -> np.array: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a typo, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks! corrected! |
||
"""Calculate bins given data""" | ||
|
||
values = data._convert(datetime=True)._get_numeric_data() | ||
values = np.ravel(values) | ||
values = values[~isna(values)] | ||
|
||
hist, bins = np.histogram( | ||
values, | ||
bins=self.bins, | ||
range=self.kwds.get("range", None), | ||
weights=self.kwds.get("weights", None), | ||
) | ||
return bins | ||
|
||
@classmethod | ||
def _plot( | ||
cls, | ||
|
@@ -51,7 +70,6 @@ def _plot( | |
): | ||
if column_num == 0: | ||
cls._initialize_stacker(ax, stacking_id, len(bins) - 1) | ||
y = y[~isna(y)] | ||
|
||
base = np.zeros(len(bins) - 1) | ||
bottom = bottom + cls._get_stacked_values(ax, stacking_id, base, kwds["label"]) | ||
|
@@ -77,9 +95,32 @@ def _make_plot(self): | |
kwds["style"] = style | ||
|
||
kwds = self._make_plot_keywords(kwds, y) | ||
|
||
# the bins is multi-dimension array now and each plot need only 1-d and | ||
# when by is applied, label should be columns that are grouped | ||
if self.by is not None: | ||
kwds["bins"] = kwds["bins"][i] | ||
kwds["label"] = self.column | ||
kwds.pop("color") | ||
|
||
y = self._reformat_y(y) | ||
artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds) | ||
|
||
# when by is applied, show title for subplots to know which group it is | ||
if self.by is not None: | ||
ax.set_title(pprint_thing(label)) | ||
|
||
self._add_legend_handle(artists[0], label, index=i) | ||
|
||
def _reformat_y(self, y: Union[Series, np.array]) -> Union[Series, np.array]: | ||
"""Internal function to reformat y given `by` is applied or not.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you clarify this? I think known from the function name that it is internal to reformat y, but curious what the existence of by should change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, main difference is the dimension of y will be different, if by is None, y is 1-d array, while if by is not None, which means group by will happen, and y here is multi-dimension array. i will change to a more detailed description for this. |
||
if self.by is not None and len(y.shape) > 1: | ||
notna = [col[~isna(col)] for col in y.T] | ||
y = np.array(np.array(notna).T) | ||
else: | ||
y = y[~isna(y)] | ||
return y | ||
|
||
def _make_plot_keywords(self, kwds, y): | ||
"""merge BoxPlot/KdePlot properties to passed kwds""" | ||
# y is required for KdePlot | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
|
||
from datetime import date, datetime | ||
import itertools | ||
import re | ||
import string | ||
import warnings | ||
|
||
|
@@ -25,6 +26,15 @@ | |
import pandas.plotting as plotting | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def test_hist_df(): | ||
np.random.seed(0) | ||
df = DataFrame(np.random.randn(30, 2), columns=["A", "B"]) | ||
df["C"] = np.random.choice(["a", "b", "c"], 30) | ||
df["D"] = np.random.choice(["a", "b", "c"], 30) | ||
return df | ||
|
||
|
||
@td.skip_if_no_mpl | ||
class TestDataFramePlots(TestPlotBase): | ||
def setup_method(self, method): | ||
|
@@ -3256,6 +3266,93 @@ def test_subplots_sharex_false(self): | |
tm.assert_numpy_array_equal(axs[0].get_xticks(), expected_ax1) | ||
tm.assert_numpy_array_equal(axs[1].get_xticks(), expected_ax2) | ||
|
||
@pytest.mark.parametrize("by", ["C", ["C", "D"]]) | ||
@pytest.mark.parametrize("column", ["A", ["A", "B"]]) | ||
def test_hist_plot_by_argument(self, by, column, test_hist_df): | ||
# GH 15079 | ||
_check_plot_works(test_hist_df.plot.hist, column=column, by=by) | ||
|
||
@pytest.mark.slow | ||
@pytest.mark.parametrize( | ||
"by, column, layout, axes_num", | ||
[ | ||
(["C"], "A", (2, 2), 3), | ||
("C", "A", (2, 2), 3), | ||
(["C"], ["A"], (1, 3), 3), | ||
("C", ["A", "B"], (3, 1), 3), | ||
(["C", "D"], "A", (9, 1), 9), | ||
(["C", "D"], "A", (3, 3), 9), | ||
(["C", "D"], ["A"], (5, 2), 9), | ||
(["C", "D"], ["A", "B"], (9, 1), 9), | ||
(["C", "D"], ["A", "B"], (5, 2), 9), | ||
], | ||
) | ||
def test_hist_plot_layout_with_by(self, by, column, layout, axes_num, test_hist_df): | ||
# GH 15079 | ||
# _check_plot_works adds an ax so catch warning. see GH #13188 | ||
with tm.assert_produces_warning(UserWarning): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What warning is this throwing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
axes = _check_plot_works( | ||
test_hist_df.plot.hist, column=column, by=by, layout=layout | ||
) | ||
self._check_axes_shape(axes, axes_num=axes_num, layout=layout) | ||
|
||
def test_hist_plot_invalid_layout_with_by_raises(self, test_hist_df): | ||
# GH 15079, test if error is raised when invalid layout is given | ||
|
||
# layout too small for all 3 plots | ||
msg = "larger than required size" | ||
with pytest.raises(ValueError, match=msg): | ||
test_hist_df.plot.hist(column=["A", "B"], by="C", layout=(1, 1)) | ||
|
||
# invalid format for layout | ||
msg = re.escape("Layout must be a tuple of (rows, columns)") | ||
with pytest.raises(ValueError, match=msg): | ||
test_hist_df.plot.hist(column=["A", "B"], by="C", layout=(1,)) | ||
|
||
msg = "At least one dimension of layout must be positive" | ||
with pytest.raises(ValueError, match=msg): | ||
test_hist_df.plot.hist(column=["A", "B"], by="C", layout=(-1, -1)) | ||
|
||
@pytest.mark.slow | ||
def test_axis_share_x_with_by(self, test_hist_df): | ||
# GH 15079 | ||
ax1, ax2, ax3 = test_hist_df.plot.hist(column="A", by="C", sharex=True) | ||
|
||
# share x | ||
assert ax1._shared_x_axes.joined(ax1, ax2) | ||
assert ax2._shared_x_axes.joined(ax1, ax2) | ||
assert ax3._shared_x_axes.joined(ax1, ax3) | ||
assert ax3._shared_x_axes.joined(ax2, ax3) | ||
|
||
# don't share y | ||
assert not ax1._shared_y_axes.joined(ax1, ax2) | ||
assert not ax2._shared_y_axes.joined(ax1, ax2) | ||
assert not ax3._shared_y_axes.joined(ax1, ax3) | ||
assert not ax3._shared_y_axes.joined(ax2, ax3) | ||
|
||
@pytest.mark.slow | ||
def test_axis_share_y_with_by(self, test_hist_df): | ||
# GH 15079 | ||
ax1, ax2, ax3 = test_hist_df.plot.hist(column="A", by="C", sharey=True) | ||
|
||
# share y | ||
assert ax1._shared_y_axes.joined(ax1, ax2) | ||
assert ax2._shared_y_axes.joined(ax1, ax2) | ||
assert ax3._shared_y_axes.joined(ax1, ax3) | ||
assert ax3._shared_y_axes.joined(ax2, ax3) | ||
|
||
# don't share x | ||
assert not ax1._shared_x_axes.joined(ax1, ax2) | ||
assert not ax2._shared_x_axes.joined(ax1, ax2) | ||
assert not ax3._shared_x_axes.joined(ax1, ax3) | ||
assert not ax3._shared_x_axes.joined(ax2, ax3) | ||
|
||
@pytest.mark.parametrize("figsize", [(12, 8), (20, 10)]) | ||
def test_figure_shape_hist_with_by(self, figsize, test_hist_df): | ||
# GH 15079 | ||
axes = test_hist_df.plot.hist(column="A", by="C", figsize=figsize) | ||
self._check_axes_shape(axes, axes_num=3, figsize=figsize) | ||
|
||
def test_plot_no_rows(self): | ||
# GH 27758 | ||
df = pd.DataFrame(columns=["foo"], dtype=int) | ||
|
Uh oh!
There was an error while loading. Please reload this page.