Skip to content

Commit 54fed1c

Browse files
committed
ENH: Groupby.plot enhancement
1 parent b32f218 commit 54fed1c

File tree

4 files changed

+664
-259
lines changed

4 files changed

+664
-259
lines changed

pandas/core/groupby.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3297,9 +3297,13 @@ def _apply_to_column_groupbys(self, func):
32973297
in self._iterate_column_groupbys()),
32983298
keys=self._selected_obj.columns, axis=1)
32993299

3300-
from pandas.tools.plotting import boxplot_frame_groupby
3300+
3301+
from pandas.tools.plotting import boxplot_frame_groupby, plot_grouped
33013302
DataFrameGroupBy.boxplot = boxplot_frame_groupby
33023303

3304+
SeriesGroupBy.plot = plot_grouped
3305+
DataFrameGroupBy.plot = plot_grouped
3306+
33033307

33043308
class PanelGroupBy(NDFrameGroupBy):
33053309

pandas/tests/test_graphics.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,6 +1648,7 @@ def test_line_lim(self):
16481648
self.assertFalse(hasattr(ax, 'right_ax'))
16491649
xmin, xmax = ax.get_xlim()
16501650
lines = ax.get_lines()
1651+
self.assertTrue(hasattr(ax, 'left_ax'))
16511652
self.assertEqual(xmin, lines[0].get_data()[0][0])
16521653
self.assertEqual(xmax, lines[0].get_data()[0][-1])
16531654

@@ -3085,7 +3086,6 @@ def test_hexbin_cmap(self):
30853086
@slow
30863087
def test_no_color_bar(self):
30873088
df = self.hexbin_df
3088-
30893089
ax = df.plot(kind='hexbin', x='A', y='B', colorbar=None)
30903090
self.assertIs(ax.collections[0].colorbar, None)
30913091

@@ -3448,9 +3448,8 @@ def test_grouped_plot_fignums(self):
34483448
df = DataFrame({'height': height, 'weight': weight, 'gender': gender})
34493449
gb = df.groupby('gender')
34503450

3451-
res = gb.plot()
3452-
self.assertEqual(len(self.plt.get_fignums()), 2)
3453-
self.assertEqual(len(res), 2)
3451+
with tm.assertRaisesRegexp(ValueError, "To plot DataFrameGroupBy, specify 'suplots=True'"):
3452+
res = gb.plot()
34543453
tm.close()
34553454

34563455
res = gb.boxplot(return_type='axes')
@@ -3829,6 +3828,97 @@ def test_plotting_with_float_index_works(self):
38293828
df.groupby('def')['val'].apply(lambda x: x.plot())
38303829
tm.close()
38313830

3831+
def test_line_groupby(self):
3832+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3833+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3834+
grouped = df.groupby(by='by')
3835+
3836+
# SeriesGroupBy
3837+
sgb = grouped['A']
3838+
ax = _check_plot_works(sgb.plot, colors=['r', 'g', 'b'])
3839+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3840+
self._check_colors(ax.get_lines(), linecolors=['r', 'g', 'b'])
3841+
3842+
axes = _check_plot_works(sgb.plot, subplots=True)
3843+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3844+
self._check_legend_labels(axes[0], labels=['Group 0'])
3845+
self._check_legend_labels(axes[1], labels=['Group 1'])
3846+
self._check_legend_labels(axes[2], labels=['Group 2'])
3847+
3848+
# DataFrameGroupBy
3849+
axes = _check_plot_works(grouped.plot, subplots=True)
3850+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3851+
for ax in axes:
3852+
self._check_legend_labels(ax, labels=['A', 'B', 'C', 'D', 'E'])
3853+
self._check_colors(ax.get_lines(), linecolors=['k', 'k', 'k', 'k', 'k'])
3854+
3855+
axes = _check_plot_works(grouped.plot, subplots=True, axis=1)
3856+
self._check_axes_shape(axes, axes_num=5, layout=(5, 1))
3857+
for ax in axes:
3858+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3859+
self._check_colors(ax.get_lines(), linecolors=['k', 'k', 'k'])
3860+
3861+
def test_hist_groupby(self):
3862+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3863+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3864+
grouped = df.groupby(by='by')
3865+
3866+
# SeriesGroupBy
3867+
sgb = grouped['A']
3868+
ax = sgb.plot(kind='hist', colors=['r', 'g', 'b'])
3869+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3870+
self._check_colors(ax.patches[::10], facecolors=['r', 'g', 'b'])
3871+
3872+
axes = sgb.plot(kind='hist', subplots=True)
3873+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3874+
self._check_legend_labels(axes[0], labels=['Group 0'])
3875+
self._check_legend_labels(axes[1], labels=['Group 1'])
3876+
self._check_legend_labels(axes[2], labels=['Group 2'])
3877+
self._check_colors([axes[0].patches[0]], facecolors=['b'])
3878+
self._check_colors([axes[1].patches[0]], facecolors=['b'])
3879+
self._check_colors([axes[2].patches[0]], facecolors=['b'])
3880+
3881+
# DataFrameGroupBy
3882+
axes = grouped.plot(kind='hist', subplots=True)
3883+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3884+
for ax in axes:
3885+
self._check_legend_labels(ax, labels=['A', 'B', 'C', 'D', 'E'])
3886+
# self._check_colors(axes[0].patches(), facecolors=['k', 'k', 'k', 'k', 'k'])
3887+
# self._check_colors(axes[1].patches(), facecolors=['k', 'k', 'k', 'k', 'k'])
3888+
# self._check_colors(axes[2].patches(), facecolors=['k', 'k', 'k', 'k', 'k'])
3889+
3890+
axes = grouped.plot(kind='hist', subplots=True, axis=1)
3891+
self._check_axes_shape(axes, axes_num=5, layout=(5, 1))
3892+
for ax in axes:
3893+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3894+
self._check_colors(ax.patches[::10], facecolors=['b', 'g', 'r'])
3895+
3896+
def test_scatter_groupby(self):
3897+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3898+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3899+
grouped = df.groupby(by='by')
3900+
3901+
ax = _check_plot_works(grouped.plot, kind='scatter', x='A', y='B', subplots=False)
3902+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3903+
3904+
axes = _check_plot_works(grouped.plot, kind='scatter', x='A', y='B', subplots=True)
3905+
self._check_axes_shape(axes, axes_num=3, layout=(1, 3))
3906+
self._check_legend_labels(axes[0], labels=['Group 0'])
3907+
self._check_legend_labels(axes[1], labels=['Group 1'])
3908+
self._check_legend_labels(axes[2], labels=['Group 2'])
3909+
3910+
def test_hexbin_groupby(self):
3911+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3912+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3913+
grouped = df.groupby(by='by')
3914+
3915+
msg = "To plot DataFrameGroupBy, specify 'suplots=True'"
3916+
with tm.assertRaisesRegexp(ValueError, msg):
3917+
grouped.plot(kind='hexbin', x='A', y='B', subplots=False)
3918+
3919+
axes = _check_plot_works(grouped.plot, kind='hexbin', x='A', y='B', subplots=True)
3920+
self._check_axes_shape(axes, axes_num=3, layout=(1, 3))
3921+
38323922

38333923
def assert_is_valid_plot_return_object(objs):
38343924
import matplotlib.pyplot as plt

0 commit comments

Comments
 (0)