Skip to content

Commit bab7491

Browse files
committed
BUG: groupby.hist legend should use group keys
1 parent 89e44e6 commit bab7491

File tree

4 files changed

+102
-3
lines changed

4 files changed

+102
-3
lines changed

pandas/plotting/_core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def hist_series(
2121
yrot=None,
2222
figsize=None,
2323
bins=10,
24+
legend=False,
2425
backend=None,
2526
**kwargs,
2627
):
@@ -50,6 +51,8 @@ def hist_series(
5051
bin edges are calculated and returned. If bins is a sequence, gives
5152
bin edges, including left edge of first bin and right edge of last
5253
bin. In this case, bins is returned unmodified.
54+
legend : bool, default False
55+
Whether to show the legend.
5356
backend : str, default None
5457
Backend to use instead of the backend specified in the option
5558
``plotting.backend``. For instance, 'matplotlib'. Alternatively, to
@@ -82,6 +85,7 @@ def hist_series(
8285
yrot=yrot,
8386
figsize=figsize,
8487
bins=bins,
88+
legend=legend,
8589
**kwargs,
8690
)
8791

@@ -101,6 +105,7 @@ def hist_frame(
101105
figsize=None,
102106
layout=None,
103107
bins=10,
108+
legend=False,
104109
backend=None,
105110
**kwargs,
106111
):
@@ -154,6 +159,8 @@ def hist_frame(
154159
bin edges are calculated and returned. If bins is a sequence, gives
155160
bin edges, including left edge of first bin and right edge of last
156161
bin. In this case, bins is returned unmodified.
162+
legend : bool, default False
163+
Whether to show the legend.
157164
backend : str, default None
158165
Backend to use instead of the backend specified in the option
159166
``plotting.backend``. For instance, 'matplotlib'. Alternatively, to
@@ -203,6 +210,7 @@ def hist_frame(
203210
sharey=sharey,
204211
figsize=figsize,
205212
layout=layout,
213+
legend=legend,
206214
bins=bins,
207215
**kwargs,
208216
)

pandas/plotting/_matplotlib/hist.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def _grouped_hist(
225225
xrot=None,
226226
ylabelsize=None,
227227
yrot=None,
228+
legend=False,
228229
**kwargs,
229230
):
230231
"""
@@ -243,15 +244,27 @@ def _grouped_hist(
243244
sharey : bool, default False
244245
rot : int, default 90
245246
grid : bool, default True
247+
legend: : bool, default False
246248
kwargs : dict, keyword arguments passed to matplotlib.Axes.hist
247249
248250
Returns
249251
-------
250252
collection of Matplotlib Axes
251253
"""
252254

255+
if legend and "label" not in kwargs:
256+
if isinstance(data, ABCDataFrame):
257+
if column is None:
258+
kwargs["label"] = data.columns
259+
else:
260+
kwargs["label"] = column
261+
else:
262+
kwargs["label"] = data.name
263+
253264
def plot_group(group, ax):
254265
ax.hist(group.dropna().values, bins=bins, **kwargs)
266+
if legend:
267+
ax.legend()
255268

256269
if xrot is None:
257270
xrot = rot
@@ -290,6 +303,7 @@ def hist_series(
290303
yrot=None,
291304
figsize=None,
292305
bins=10,
306+
legend=False,
293307
**kwds,
294308
):
295309
import matplotlib.pyplot as plt
@@ -308,8 +322,11 @@ def hist_series(
308322
elif ax.get_figure() != fig:
309323
raise AssertionError("passed axis not bound to passed figure")
310324
values = self.dropna().values
311-
325+
if legend and "label" not in kwds:
326+
kwds["label"] = self.name
312327
ax.hist(values, bins=bins, **kwds)
328+
if legend:
329+
ax.legend()
313330
ax.grid(grid)
314331
axes = np.array([ax])
315332

@@ -334,6 +351,7 @@ def hist_series(
334351
xrot=xrot,
335352
ylabelsize=ylabelsize,
336353
yrot=yrot,
354+
legend=legend,
337355
**kwds,
338356
)
339357

@@ -358,6 +376,7 @@ def hist_frame(
358376
figsize=None,
359377
layout=None,
360378
bins=10,
379+
legend=False,
361380
**kwds,
362381
):
363382
if by is not None:
@@ -376,6 +395,7 @@ def hist_frame(
376395
xrot=xrot,
377396
ylabelsize=ylabelsize,
378397
yrot=yrot,
398+
legend=legend,
379399
**kwds,
380400
)
381401
return axes
@@ -401,11 +421,17 @@ def hist_frame(
401421
)
402422
_axes = _flatten(axes)
403423

424+
can_set_label = "label" not in kwds
425+
404426
for i, col in enumerate(data.columns):
405427
ax = _axes[i]
428+
if legend and can_set_label:
429+
kwds["label"] = col
406430
ax.hist(data[col].dropna().values, bins=bins, **kwds)
407431
ax.set_title(col)
408432
ax.grid(grid)
433+
if legend:
434+
ax.legend()
409435

410436
_set_ticks_props(
411437
axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot

pandas/tests/plotting/test_groupby.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33

44
import numpy as np
5+
import pytest
56

67
import pandas.util._test_decorators as td
78

8-
from pandas import DataFrame, Series
9+
from pandas import DataFrame, Index, Series
910
import pandas._testing as tm
1011
from pandas.tests.plotting.common import TestPlotBase
1112

@@ -65,3 +66,18 @@ def test_plot_kwargs(self):
6566

6667
res = df.groupby("z").plot.scatter(x="x", y="y")
6768
assert len(res["a"].collections) == 1
69+
70+
71+
@td.skip_if_no_mpl
72+
@pytest.mark.parametrize("column", [None, "b"])
73+
@pytest.mark.parametrize("label", [None, "d"])
74+
def test_hist_with_legend(column, label):
75+
index = Index(15 * [1] + 15 * [2], name="c")
76+
df = DataFrame(np.random.randn(30, 2), index=index, columns=["a", "b"])
77+
g = df.groupby("c")
78+
79+
g.hist(column=column, label=label, legend=True)
80+
tm.close()
81+
if column != "b":
82+
g["a"].hist(label=label, legend=True)
83+
tm.close()

pandas/tests/plotting/test_hist_method.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pandas.util._test_decorators as td
88

9-
from pandas import DataFrame, Series
9+
from pandas import DataFrame, Index, Series
1010
import pandas._testing as tm
1111
from pandas.tests.plotting.common import TestPlotBase, _check_plot_works
1212

@@ -293,6 +293,28 @@ def test_hist_column_order_unchanged(self, column, expected):
293293

294294
assert result == expected
295295

296+
@pytest.mark.slow
297+
@pytest.mark.parametrize("by", [None, "b"])
298+
@pytest.mark.parametrize("label", [None, "c"])
299+
def test_hist_with_legend(self, by, label):
300+
expected_labels = label or "a"
301+
expected_axes_num = 1 if by is None else 2
302+
expected_layout = (1, 1) if by is None else (1, 2)
303+
304+
index = 15 * [1] + 15 * [2]
305+
s = Series(np.random.randn(30), index=index, name="a")
306+
s.index.name = "b"
307+
308+
kwargs = {"legend": True, "by": by}
309+
if label is not None:
310+
# Behavior differs if kwargs contains "label": None
311+
kwargs["label"] = label
312+
313+
_check_plot_works(s.hist, **kwargs)
314+
axes = s.hist(**kwargs)
315+
self._check_axes_shape(axes, axes_num=expected_axes_num, layout=expected_layout)
316+
self._check_legend_labels(axes, expected_labels)
317+
296318

297319
@td.skip_if_no_mpl
298320
class TestDataFrameGroupByPlots(TestPlotBase):
@@ -484,3 +506,30 @@ def test_axis_share_xy(self):
484506

485507
assert ax1._shared_y_axes.joined(ax1, ax2)
486508
assert ax2._shared_y_axes.joined(ax1, ax2)
509+
510+
@pytest.mark.slow
511+
@pytest.mark.parametrize("by", [None, "c"])
512+
@pytest.mark.parametrize("column", [None, "b"])
513+
@pytest.mark.parametrize("label", [None, "d"])
514+
def test_hist_with_legend(self, by, column, label):
515+
expected_axes_num = 1 if by is None and column is not None else 2
516+
expected_layout = (1, expected_axes_num)
517+
expected_labels = label or column or ["a", "b"]
518+
if by is not None:
519+
expected_labels = [expected_labels] * 2
520+
521+
index = Index(15 * [1] + 15 * [2], name="c")
522+
df = DataFrame(np.random.randn(30, 2), index=index, columns=["a", "b"])
523+
524+
kwargs = {"legend": True, "by": by, "column": column}
525+
if label is not None:
526+
# Behavior differs if kwargs contains "label": None
527+
kwargs["label"] = label
528+
529+
_check_plot_works(df.hist, **kwargs)
530+
axes = df.hist(**kwargs)
531+
self._check_axes_shape(axes, axes_num=expected_axes_num, layout=expected_layout)
532+
if by is None:
533+
axes = axes[0]
534+
for expected_label, ax in zip(expected_labels, axes):
535+
self._check_legend_labels(ax, expected_label)

0 commit comments

Comments
 (0)