Skip to content

CLN: plotting cleanups for groupby plotting #10717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 31, 2015
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 94 additions & 109 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,12 @@ class MPLPlot(object):
data :

"""
_kind = 'base'

@property
def _kind(self):
"""Specify kind str. Must be overridden in child class"""
raise NotImplementedError

_layout_type = 'vertical'
_default_rot = 0
orientation = None
Expand Down Expand Up @@ -938,7 +943,10 @@ def generate(self):
self._make_plot()
self._add_table()
self._make_legend()
self._post_plot_logic()

for ax in self.axes:
self._post_plot_logic_common(ax, self.data)
self._post_plot_logic(ax, self.data)
self._adorn_subplots()

def _args_adjust(self):
Expand Down Expand Up @@ -1055,12 +1063,34 @@ def _add_table(self):
ax = self._get_ax(0)
table(ax, data)

def _post_plot_logic(self):
def _post_plot_logic_common(self, ax, data):
"""Common post process for each axes"""
labels = [com.pprint_thing(key) for key in data.index]
labels = dict(zip(range(len(data.index)), labels))

if self.orientation == 'vertical' or self.orientation is None:
if self._need_to_set_index:
xticklabels = [labels.get(x, '') for x in ax.get_xticks()]
ax.set_xticklabels(xticklabels)
self._apply_axis_properties(ax.xaxis, rot=self.rot,
fontsize=self.fontsize)
self._apply_axis_properties(ax.yaxis, fontsize=self.fontsize)
elif self.orientation == 'horizontal':
if self._need_to_set_index:
yticklabels = [labels.get(y, '') for y in ax.get_yticks()]
ax.set_yticklabels(yticklabels)
self._apply_axis_properties(ax.yaxis, rot=self.rot,
fontsize=self.fontsize)
self._apply_axis_properties(ax.xaxis, fontsize=self.fontsize)
else: # pragma no cover
raise ValueError

def _post_plot_logic(self, ax, data):
"""Post process for each axes. Overridden in child classes"""
pass

def _adorn_subplots(self):
to_adorn = self.axes

"""Common post process unrelated to data"""
if len(self.axes) > 0:
all_axes = self._get_axes()
nrows, ncols = self._get_axes_layout()
Expand All @@ -1069,7 +1099,7 @@ def _adorn_subplots(self):
ncols=ncols, sharex=self.sharex,
sharey=self.sharey)

for ax in to_adorn:
for ax in self.axes:
if self.yticks is not None:
ax.set_yticks(self.yticks)

Expand All @@ -1090,25 +1120,6 @@ def _adorn_subplots(self):
else:
self.axes[0].set_title(self.title)

labels = [com.pprint_thing(key) for key in self.data.index]
labels = dict(zip(range(len(self.data.index)), labels))

for ax in self.axes:
if self.orientation == 'vertical' or self.orientation is None:
if self._need_to_set_index:
xticklabels = [labels.get(x, '') for x in ax.get_xticks()]
ax.set_xticklabels(xticklabels)
self._apply_axis_properties(ax.xaxis, rot=self.rot,
fontsize=self.fontsize)
self._apply_axis_properties(ax.yaxis, fontsize=self.fontsize)
elif self.orientation == 'horizontal':
if self._need_to_set_index:
yticklabels = [labels.get(y, '') for y in ax.get_yticks()]
ax.set_yticklabels(yticklabels)
self._apply_axis_properties(ax.yaxis, rot=self.rot,
fontsize=self.fontsize)
self._apply_axis_properties(ax.xaxis, fontsize=self.fontsize)

def _apply_axis_properties(self, axis, rot=None, fontsize=None):
labels = axis.get_majorticklabels() + axis.get_minorticklabels()
for label in labels:
Expand Down Expand Up @@ -1419,34 +1430,48 @@ def _get_axes_layout(self):
y_set.add(points[0][1])
return (len(y_set), len(x_set))

class ScatterPlot(MPLPlot):
_kind = 'scatter'

class PlanePlot(MPLPlot):
"""
Abstract class for plotting on plane, currently scatter and hexbin.
"""

_layout_type = 'single'

def __init__(self, data, x, y, c=None, **kwargs):
def __init__(self, data, x, y, **kwargs):
MPLPlot.__init__(self, data, **kwargs)
if x is None or y is None:
raise ValueError( 'scatter requires and x and y column')
raise ValueError(self._kind + ' requires and x and y column')
if com.is_integer(x) and not self.data.columns.holds_integer():
x = self.data.columns[x]
if com.is_integer(y) and not self.data.columns.holds_integer():
y = self.data.columns[y]
if com.is_integer(c) and not self.data.columns.holds_integer():
c = self.data.columns[c]
self.x = x
self.y = y
self.c = c

@property
def nseries(self):
return 1

def _post_plot_logic(self, ax, data):
x, y = self.x, self.y
ax.set_ylabel(com.pprint_thing(y))
ax.set_xlabel(com.pprint_thing(x))


class ScatterPlot(PlanePlot):
_kind = 'scatter'

def __init__(self, data, x, y, c=None, **kwargs):
super(ScatterPlot, self).__init__(data, x, y, **kwargs)
if com.is_integer(c) and not self.data.columns.holds_integer():
c = self.data.columns[c]
self.c = c

def _make_plot(self):
import matplotlib as mpl
mpl_ge_1_3_1 = str(mpl.__version__) >= LooseVersion('1.3.1')

import matplotlib.pyplot as plt

x, y, c, data = self.x, self.y, self.c, self.data
ax = self.axes[0]

Expand All @@ -1457,7 +1482,7 @@ def _make_plot(self):

# pandas uses colormap, matplotlib uses cmap.
cmap = self.colormap or 'Greys'
cmap = plt.cm.get_cmap(cmap)
cmap = self.plt.cm.get_cmap(cmap)

if c is None:
c_values = self.plt.rcParams['patch.facecolor']
Expand Down Expand Up @@ -1491,46 +1516,22 @@ def _make_plot(self):
err_kwds['ecolor'] = scatter.get_facecolor()[0]
ax.errorbar(data[x].values, data[y].values, linestyle='none', **err_kwds)

def _post_plot_logic(self):
ax = self.axes[0]
x, y = self.x, self.y
ax.set_ylabel(com.pprint_thing(y))
ax.set_xlabel(com.pprint_thing(x))


class HexBinPlot(MPLPlot):
class HexBinPlot(PlanePlot):
_kind = 'hexbin'
_layout_type = 'single'

def __init__(self, data, x, y, C=None, **kwargs):
MPLPlot.__init__(self, data, **kwargs)

if x is None or y is None:
raise ValueError('hexbin requires and x and y column')
if com.is_integer(x) and not self.data.columns.holds_integer():
x = self.data.columns[x]
if com.is_integer(y) and not self.data.columns.holds_integer():
y = self.data.columns[y]

super(HexBinPlot, self).__init__(data, x, y, **kwargs)
if com.is_integer(C) and not self.data.columns.holds_integer():
C = self.data.columns[C]

self.x = x
self.y = y
self.C = C

@property
def nseries(self):
return 1

def _make_plot(self):
import matplotlib.pyplot as plt

x, y, data, C = self.x, self.y, self.data, self.C
ax = self.axes[0]
# pandas uses colormap, matplotlib uses cmap.
cmap = self.colormap or 'BuGn'
cmap = plt.cm.get_cmap(cmap)
cmap = self.plt.cm.get_cmap(cmap)
cb = self.kwds.pop('colorbar', True)

if C is None:
Expand All @@ -1547,12 +1548,6 @@ def _make_plot(self):
def _make_legend(self):
pass

def _post_plot_logic(self):
ax = self.axes[0]
x, y = self.x, self.y
ax.set_ylabel(com.pprint_thing(y))
ax.set_xlabel(com.pprint_thing(x))


class LinePlot(MPLPlot):
_kind = 'line'
Expand Down Expand Up @@ -1685,26 +1680,23 @@ def _update_stacker(cls, ax, stacking_id, values):
elif (values <= 0).all():
ax._stacker_neg_prior[stacking_id] += values

def _post_plot_logic(self):
df = self.data

def _post_plot_logic(self, ax, data):
condition = (not self._use_dynamic_x()
and df.index.is_all_dates
and data.index.is_all_dates
and not self.subplots
or (self.subplots and self.sharex))

index_name = self._get_index_name()

for ax in self.axes:
if condition:
# irregular TS rotated 30 deg. by default
# probably a better place to check / set this.
if not self._rot_set:
self.rot = 30
format_date_labels(ax, rot=self.rot)
if condition:
# irregular TS rotated 30 deg. by default
# probably a better place to check / set this.
if not self._rot_set:
self.rot = 30
format_date_labels(ax, rot=self.rot)

if index_name is not None and self.use_index:
ax.set_xlabel(index_name)
if index_name is not None and self.use_index:
ax.set_xlabel(index_name)


class AreaPlot(LinePlot):
Expand Down Expand Up @@ -1758,16 +1750,14 @@ def _add_legend_handle(self, handle, label, index=None):
handle = Rectangle((0, 0), 1, 1, fc=handle.get_color(), alpha=alpha)
LinePlot._add_legend_handle(self, handle, label, index=index)

def _post_plot_logic(self):
LinePlot._post_plot_logic(self)
def _post_plot_logic(self, ax, data):
LinePlot._post_plot_logic(self, ax, data)

if self.ylim is None:
if (self.data >= 0).all().all():
for ax in self.axes:
ax.set_ylim(0, None)
elif (self.data <= 0).all().all():
for ax in self.axes:
ax.set_ylim(None, 0)
if (data >= 0).all().all():
ax.set_ylim(0, None)
elif (data <= 0).all().all():
ax.set_ylim(None, 0)


class BarPlot(MPLPlot):
Expand Down Expand Up @@ -1865,19 +1855,17 @@ def _make_plot(self):
start=start, label=label, log=self.log, **kwds)
self._add_legend_handle(rect, label, index=i)

def _post_plot_logic(self):
for ax in self.axes:
if self.use_index:
str_index = [com.pprint_thing(key) for key in self.data.index]
else:
str_index = [com.pprint_thing(key) for key in
range(self.data.shape[0])]
name = self._get_index_name()
def _post_plot_logic(self, ax, data):
if self.use_index:
str_index = [com.pprint_thing(key) for key in data.index]
else:
str_index = [com.pprint_thing(key) for key in range(data.shape[0])]
name = self._get_index_name()

s_edge = self.ax_pos[0] - 0.25 + self.lim_offset
e_edge = self.ax_pos[-1] + 0.25 + self.bar_width + self.lim_offset
s_edge = self.ax_pos[0] - 0.25 + self.lim_offset
e_edge = self.ax_pos[-1] + 0.25 + self.bar_width + self.lim_offset

self._decorate_ticks(ax, name, str_index, s_edge, e_edge)
self._decorate_ticks(ax, name, str_index, s_edge, e_edge)

def _decorate_ticks(self, ax, name, ticklabels, start_edge, end_edge):
ax.set_xlim((start_edge, end_edge))
Expand Down Expand Up @@ -1975,13 +1963,11 @@ def _make_plot_keywords(self, kwds, y):
kwds['bins'] = self.bins
return kwds

def _post_plot_logic(self):
def _post_plot_logic(self, ax, data):
if self.orientation == 'horizontal':
for ax in self.axes:
ax.set_xlabel('Frequency')
ax.set_xlabel('Frequency')
else:
for ax in self.axes:
ax.set_ylabel('Frequency')
ax.set_ylabel('Frequency')

@property
def orientation(self):
Expand Down Expand Up @@ -2038,9 +2024,8 @@ def _make_plot_keywords(self, kwds, y):
kwds['ind'] = self._get_ind(y)
return kwds

def _post_plot_logic(self):
for ax in self.axes:
ax.set_ylabel('Density')
def _post_plot_logic(self, ax, data):
ax.set_ylabel('Density')


class PiePlot(MPLPlot):
Expand Down Expand Up @@ -2242,7 +2227,7 @@ def _set_ticklabels(self, ax, labels):
def _make_legend(self):
pass

def _post_plot_logic(self):
def _post_plot_logic(self, ax, data):
pass

@property
Expand Down