Skip to content

Commit 20a02c3

Browse files
committed
Fix for DataFrame.hist() with by- and weights-keyword
will make the following work import pandas as pd d = {'one' : ['A', 'A', 'B', 'C'], 'two' : [4., 3., 2., 1.], 'three' : [10., 8., 5., 7.]} df = pd.DataFrame(d) df.hist('two', by='one', weights='three', bins=range(0, 10))
1 parent 2dd4335 commit 20a02c3

File tree

1 file changed

+38
-15
lines changed

1 file changed

+38
-15
lines changed

pandas/tools/plotting.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2770,9 +2770,10 @@ def plot_group(group, ax):
27702770
return fig
27712771

27722772

2773-
def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
2774-
xrot=None, ylabelsize=None, yrot=None, ax=None, sharex=False,
2775-
sharey=False, figsize=None, layout=None, bins=10, **kwds):
2773+
def hist_frame(data, column=None, weights=None, by=None, grid=True,
2774+
xlabelsize=None, xrot=None, ylabelsize=None, yrot=None, ax=None,
2775+
sharex=False, sharey=False, figsize=None, layout=None, bins=10,
2776+
**kwds):
27762777
"""
27772778
Draw histogram of the DataFrame's series using matplotlib / pylab.
27782779
@@ -2781,6 +2782,8 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
27812782
data : DataFrame
27822783
column : string or sequence
27832784
If passed, will be used to limit data to a subset of columns
2785+
weights : string or sequence
2786+
If passed, will be used to weight the data
27842787
by : object, optional
27852788
If passed, then used to form histograms for separate groups
27862789
grid : boolean, default True
@@ -2812,7 +2815,7 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
28122815
"""
28132816

28142817
if by is not None:
2815-
axes = grouped_hist(data, column=column, by=by, ax=ax, grid=grid, figsize=figsize,
2818+
axes = grouped_hist(data, column=column, weights=weights, by=by, ax=ax, grid=grid, figsize=figsize,
28162819
sharex=sharex, sharey=sharey, layout=layout, bins=bins,
28172820
xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot,
28182821
**kwds)
@@ -2916,17 +2919,18 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None,
29162919
return axes
29172920

29182921

2919-
def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None,
2920-
layout=None, sharex=False, sharey=False, rot=90, grid=True,
2921-
xlabelsize=None, xrot=None, ylabelsize=None, yrot=None,
2922-
**kwargs):
2922+
def grouped_hist(data, column=None, weights=None, by=None, ax=None, bins=50,
2923+
figsize=None, layout=None, sharex=False, sharey=False, rot=90,
2924+
grid=True, xlabelsize=None, xrot=None, ylabelsize=None,
2925+
yrot=None, **kwargs):
29232926
"""
29242927
Grouped histogram
29252928
29262929
Parameters
29272930
----------
29282931
data: Series/DataFrame
29292932
column: object, optional
2933+
weights: object, optional
29302934
by: object, optional
29312935
ax: axes, optional
29322936
bins: int, default 50
@@ -2942,12 +2946,20 @@ def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None,
29422946
-------
29432947
axes: collection of Matplotlib Axes
29442948
"""
2945-
def plot_group(group, ax):
2946-
ax.hist(group.dropna().values, bins=bins, **kwargs)
2949+
def plot_group(group, ax, weights=None):
2950+
if weights is not None:
2951+
# remove fields where we have nan in weights OR in group
2952+
# for both data sets
2953+
inx_na = (np.isnan(weights)) | (np.isnan(group))
2954+
weights = weights.ix[~inx_na]
2955+
group = group.ix[~inx_na]
2956+
else:
2957+
group = group.dropna()
2958+
ax.hist(group.values, weights=weights.values, bins=bins, **kwargs)
29472959

29482960
xrot = xrot or rot
29492961

2950-
fig, axes = _grouped_plot(plot_group, data, column=column,
2962+
fig, axes = _grouped_plot(plot_group, data, column=column, weights=weights,
29512963
by=by, sharex=sharex, sharey=sharey, ax=ax,
29522964
figsize=figsize, layout=layout, rot=rot)
29532965

@@ -3034,9 +3046,9 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
30343046
return ret
30353047

30363048

3037-
def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
3038-
figsize=None, sharex=True, sharey=True, layout=None,
3039-
rot=0, ax=None, **kwargs):
3049+
def _grouped_plot(plotf, data, column=None, weights=None, by=None,
3050+
numeric_only=True, figsize=None, sharex=True, sharey=True,
3051+
layout=None, rot=0, ax=None, **kwargs):
30403052
from pandas import DataFrame
30413053

30423054
if figsize == 'default':
@@ -3047,6 +3059,8 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
30473059

30483060
grouped = data.groupby(by)
30493061
if column is not None:
3062+
if weights is not None:
3063+
weights = grouped[weights]
30503064
grouped = grouped[column]
30513065

30523066
naxes = len(grouped)
@@ -3056,11 +3070,20 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
30563070

30573071
_axes = _flatten(axes)
30583072

3073+
weight = None
30593074
for i, (key, group) in enumerate(grouped):
30603075
ax = _axes[i]
3076+
if weights is not None:
3077+
weight = weights.get_group(key)
30613078
if numeric_only and isinstance(group, DataFrame):
30623079
group = group._get_numeric_data()
3063-
plotf(group, ax, **kwargs)
3080+
if weight is not None:
3081+
weight = weight._get_numeric_data()
3082+
if weight is not None:
3083+
plotf(group, ax, weight, **kwargs)
3084+
else:
3085+
# scatterplot etc has not the weight implemented in plotf
3086+
plotf(group, ax, **kwargs)
30643087
ax.set_title(com.pprint_thing(key))
30653088

30663089
return fig, axes

0 commit comments

Comments
 (0)