Skip to content

Commit b5849bb

Browse files
committed
ENH: Groupby.plot enhancement
1 parent f0ac930 commit b5849bb

File tree

4 files changed

+662
-277
lines changed

4 files changed

+662
-277
lines changed

pandas/core/groupby.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3268,9 +3268,13 @@ def _apply_to_column_groupbys(self, func):
32683268
in self._iterate_column_groupbys()),
32693269
keys=self._selected_obj.columns, axis=1)
32703270

3271-
from pandas.tools.plotting import boxplot_frame_groupby
3271+
3272+
from pandas.tools.plotting import boxplot_frame_groupby, plot_grouped
32723273
DataFrameGroupBy.boxplot = boxplot_frame_groupby
32733274

3275+
SeriesGroupBy.plot = plot_grouped
3276+
DataFrameGroupBy.plot = plot_grouped
3277+
32743278

32753279
class PanelGroupBy(NDFrameGroupBy):
32763280

pandas/tests/test_graphics.py

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,6 +1612,7 @@ def test_line_lim(self):
16121612
for ax in axes:
16131613
xmin, xmax = ax.get_xlim()
16141614
lines = ax.get_lines()
1615+
self.assertTrue(hasattr(ax, 'left_ax'))
16151616
self.assertEqual(xmin, lines[0].get_data()[0][0])
16161617
self.assertEqual(xmax, lines[0].get_data()[0][-1])
16171618

@@ -3049,7 +3050,6 @@ def test_hexbin_cmap(self):
30493050
@slow
30503051
def test_no_color_bar(self):
30513052
df = self.hexbin_df
3052-
30533053
ax = df.plot(kind='hexbin', x='A', y='B', colorbar=None)
30543054
self.assertIs(ax.collections[0].colorbar, None)
30553055

@@ -3412,9 +3412,10 @@ def test_grouped_plot_fignums(self):
34123412
df = DataFrame({'height': height, 'weight': weight, 'gender': gender})
34133413
gb = df.groupby('gender')
34143414

3415-
res = gb.plot()
3416-
self.assertEqual(len(self.plt.get_fignums()), 2)
3417-
self.assertEqual(len(res), 2)
3415+
with tm.assertRaisesRegexp(ValueError, "To plot DataFrameGroupBy, specify 'suplots=True'"):
3416+
res = gb.plot()
3417+
self.assertEqual(len(plt.get_fignums()), 2)
3418+
self.assertEqual(len(res), 2)
34183419
tm.close()
34193420

34203421
res = gb.boxplot(return_type='axes')
@@ -3793,6 +3794,97 @@ def test_plotting_with_float_index_works(self):
37933794
df.groupby('def')['val'].apply(lambda x: x.plot())
37943795
tm.close()
37953796

3797+
def test_line_groupby(self):
3798+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3799+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3800+
grouped = df.groupby(by='by')
3801+
3802+
# SeriesGroupBy
3803+
sgb = grouped['A']
3804+
ax = _check_plot_works(sgb.plot, colors=['r', 'g', 'b'])
3805+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3806+
self._check_colors(ax.get_lines(), linecolors=['r', 'g', 'b'])
3807+
3808+
axes = _check_plot_works(sgb.plot, subplots=True)
3809+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3810+
self._check_legend_labels(axes[0], labels=['Group 0'])
3811+
self._check_legend_labels(axes[1], labels=['Group 1'])
3812+
self._check_legend_labels(axes[2], labels=['Group 2'])
3813+
3814+
# DataFrameGroupBy
3815+
axes = _check_plot_works(grouped.plot, subplots=True)
3816+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3817+
for ax in axes:
3818+
self._check_legend_labels(ax, labels=['A', 'B', 'C', 'D', 'E'])
3819+
self._check_colors(ax.get_lines(), linecolors=['k', 'k', 'k', 'k', 'k'])
3820+
3821+
axes = _check_plot_works(grouped.plot, subplots=True, axis=1)
3822+
self._check_axes_shape(axes, axes_num=5, layout=(5, 1))
3823+
for ax in axes:
3824+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3825+
self._check_colors(ax.get_lines(), linecolors=['k', 'k', 'k'])
3826+
3827+
def test_hist_groupby(self):
3828+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3829+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3830+
grouped = df.groupby(by='by')
3831+
3832+
# SeriesGroupBy
3833+
sgb = grouped['A']
3834+
ax = sgb.plot(kind='hist', colors=['r', 'g', 'b'])
3835+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3836+
self._check_colors(ax.patches[::10], facecolors=['r', 'g', 'b'])
3837+
3838+
axes = sgb.plot(kind='hist', subplots=True)
3839+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3840+
self._check_legend_labels(axes[0], labels=['Group 0'])
3841+
self._check_legend_labels(axes[1], labels=['Group 1'])
3842+
self._check_legend_labels(axes[2], labels=['Group 2'])
3843+
self._check_colors([axes[0].patches[0]], facecolors=['b'])
3844+
self._check_colors([axes[1].patches[0]], facecolors=['b'])
3845+
self._check_colors([axes[2].patches[0]], facecolors=['b'])
3846+
3847+
# DataFrameGroupBy
3848+
axes = grouped.plot(kind='hist', subplots=True)
3849+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3850+
for ax in axes:
3851+
self._check_legend_labels(ax, labels=['A', 'B', 'C', 'D', 'E'])
3852+
# self._check_colors(axes[0].patches(), facecolors=['k', 'k', 'k', 'k', 'k'])
3853+
# self._check_colors(axes[1].patches(), facecolors=['k', 'k', 'k', 'k', 'k'])
3854+
# self._check_colors(axes[2].patches(), facecolors=['k', 'k', 'k', 'k', 'k'])
3855+
3856+
axes = grouped.plot(kind='hist', subplots=True, axis=1)
3857+
self._check_axes_shape(axes, axes_num=5, layout=(5, 1))
3858+
for ax in axes:
3859+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3860+
self._check_colors(ax.patches[::10], facecolors=['b', 'g', 'r'])
3861+
3862+
def test_scatter_groupby(self):
3863+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3864+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3865+
grouped = df.groupby(by='by')
3866+
3867+
ax = _check_plot_works(grouped.plot, kind='scatter', x='A', y='B', subplots=False)
3868+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3869+
3870+
axes = _check_plot_works(grouped.plot, kind='scatter', x='A', y='B', subplots=True)
3871+
self._check_axes_shape(axes, axes_num=3, layout=(1, 3))
3872+
self._check_legend_labels(axes[0], labels=['Group 0'])
3873+
self._check_legend_labels(axes[1], labels=['Group 1'])
3874+
self._check_legend_labels(axes[2], labels=['Group 2'])
3875+
3876+
def test_hexbin_groupby(self):
3877+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3878+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3879+
grouped = df.groupby(by='by')
3880+
3881+
msg = "To plot DataFrameGroupBy, specify 'suplots=True'"
3882+
with tm.assertRaisesRegexp(ValueError, msg):
3883+
grouped.plot(kind='hexbin', x='A', y='B', subplots=False)
3884+
3885+
axes = _check_plot_works(grouped.plot, kind='hexbin', x='A', y='B', subplots=True)
3886+
self._check_axes_shape(axes, axes_num=3, layout=(1, 3))
3887+
37963888

37973889
def assert_is_valid_plot_return_object(objs):
37983890
import matplotlib.pyplot as plt

0 commit comments

Comments
 (0)