-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
Fix for DataFrame.hist() with by- and weights-keyword #11441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
20a02c3
5bbba11
787577e
b6bcb5c
cbe68ec
868fbf0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2770,9 +2770,10 @@ def plot_group(group, ax): | |
return fig | ||
|
||
|
||
def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, | ||
xrot=None, ylabelsize=None, yrot=None, ax=None, sharex=False, | ||
sharey=False, figsize=None, layout=None, bins=10, **kwds): | ||
def hist_frame(data, column=None, weights=None, by=None, grid=True, | ||
xlabelsize=None, xrot=None, ylabelsize=None, yrot=None, ax=None, | ||
sharex=False, sharey=False, figsize=None, layout=None, bins=10, | ||
**kwds): | ||
""" | ||
Draw histogram of the DataFrame's series using matplotlib / pylab. | ||
|
||
|
@@ -2781,6 +2782,8 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, | |
data : DataFrame | ||
column : string or sequence | ||
If passed, will be used to limit data to a subset of columns | ||
weights : string or sequence | ||
If passed, will be used to weight the data | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
by : object, optional | ||
If passed, then used to form histograms for separate groups | ||
grid : boolean, default True | ||
|
@@ -2810,19 +2813,35 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, | |
kwds : other plotting keyword arguments | ||
To be passed to hist function | ||
""" | ||
|
||
if by is not None: | ||
axes = grouped_hist(data, column=column, by=by, ax=ax, grid=grid, figsize=figsize, | ||
axes = grouped_hist(data, column=column, weights=weights, by=by, ax=ax, grid=grid, figsize=figsize, | ||
sharex=sharex, sharey=sharey, layout=layout, bins=bins, | ||
xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot, | ||
**kwds) | ||
return axes | ||
|
||
inx_na = np.zeros(len(data), dtype=bool) | ||
if weights is not None: | ||
# first figure out if given my column name, or by an array | ||
if isinstance(weights, str): | ||
weights = data[weights] | ||
if isinstance(weights, np.ndarray) == False: | ||
weights = weights.values | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this code is pretty close to |
||
# remove fields where we have nan in weights OR in group | ||
# for both data sets | ||
inx_na = (np.isnan(weights)) | ||
|
||
if column is not None: | ||
if not isinstance(column, (list, np.ndarray, Index)): | ||
column = [column] | ||
data = data[column] | ||
data = data._get_numeric_data() | ||
inx_na |= np.isnan(data.T.values)[0] | ||
|
||
data = data.ix[~inx_na] | ||
if weights is not None: | ||
weights = weights[~inx_na] | ||
|
||
naxes = len(data.columns) | ||
|
||
fig, axes = _subplots(naxes=naxes, ax=ax, squeeze=False, | ||
|
@@ -2832,7 +2851,7 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, | |
|
||
for i, col in enumerate(com._try_sort(data.columns)): | ||
ax = _axes[i] | ||
ax.hist(data[col].dropna().values, bins=bins, **kwds) | ||
ax.hist(data[col].values, bins=bins, weights=weights, **kwds) | ||
ax.set_title(col) | ||
ax.grid(grid) | ||
|
||
|
@@ -2916,17 +2935,18 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None, | |
return axes | ||
|
||
|
||
def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None, | ||
layout=None, sharex=False, sharey=False, rot=90, grid=True, | ||
xlabelsize=None, xrot=None, ylabelsize=None, yrot=None, | ||
**kwargs): | ||
def grouped_hist(data, column=None, weights=None, by=None, ax=None, bins=50, | ||
figsize=None, layout=None, sharex=False, sharey=False, rot=90, | ||
grid=True, xlabelsize=None, xrot=None, ylabelsize=None, | ||
yrot=None, **kwargs): | ||
""" | ||
Grouped histogram | ||
|
||
Parameters | ||
---------- | ||
data: Series/DataFrame | ||
column: object, optional | ||
weights: object, optional | ||
by: object, optional | ||
ax: axes, optional | ||
bins: int, default 50 | ||
|
@@ -2942,12 +2962,26 @@ def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None, | |
------- | ||
axes: collection of Matplotlib Axes | ||
""" | ||
def plot_group(group, ax): | ||
ax.hist(group.dropna().values, bins=bins, **kwargs) | ||
def plot_group(group, ax, weights=None): | ||
if isinstance(group, np.ndarray) == False: | ||
group = group.values | ||
inx_na = np.isnan(group) | ||
if weights is not None: | ||
# remove fields where we have nan in weights OR in group | ||
# for both data sets | ||
if isinstance(weights, np.ndarray) == False: | ||
weights = weights.values | ||
inx_na |= (np.isnan(weights)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here this is pretty duplicative |
||
weights = weights[~inx_na] | ||
group = group[~inx_na] | ||
if len(group) > 0: | ||
# if length is less than 0, we had only NaN's for this group | ||
# nothing to print! | ||
ax.hist(group, weights=weights, bins=bins, **kwargs) | ||
|
||
xrot = xrot or rot | ||
|
||
fig, axes = _grouped_plot(plot_group, data, column=column, | ||
fig, axes = _grouped_plot(plot_group, data, column=column, weights=weights, | ||
by=by, sharex=sharex, sharey=sharey, ax=ax, | ||
figsize=figsize, layout=layout, rot=rot) | ||
|
||
|
@@ -3034,9 +3068,9 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None, | |
return ret | ||
|
||
|
||
def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, | ||
figsize=None, sharex=True, sharey=True, layout=None, | ||
rot=0, ax=None, **kwargs): | ||
def _grouped_plot(plotf, data, column=None, weights=None, by=None, | ||
numeric_only=True, figsize=None, sharex=True, sharey=True, | ||
layout=None, rot=0, ax=None, **kwargs): | ||
from pandas import DataFrame | ||
|
||
if figsize == 'default': | ||
|
@@ -3045,22 +3079,43 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, | |
"size by tuple instead", FutureWarning, stacklevel=4) | ||
figsize = None | ||
|
||
added_weights_dummy_column = False | ||
if isinstance(weights, np.ndarray): | ||
# weights supplied as an array instead of a part of the dataframe | ||
data['weights'] = weights | ||
weights = 'weights' | ||
added_weights_dummy_column = True | ||
|
||
grouped = data.groupby(by) | ||
if column is not None: | ||
if weights is not None: | ||
weights = grouped[weights] | ||
grouped = grouped[column] | ||
|
||
if added_weights_dummy_column: | ||
data = data.drop('weights', axis=1) | ||
|
||
naxes = len(grouped) | ||
fig, axes = _subplots(naxes=naxes, figsize=figsize, | ||
sharex=sharex, sharey=sharey, ax=ax, | ||
layout=layout) | ||
|
||
_axes = _flatten(axes) | ||
|
||
weight = None | ||
for i, (key, group) in enumerate(grouped): | ||
ax = _axes[i] | ||
if weights is not None: | ||
weight = weights.get_group(key) | ||
if numeric_only and isinstance(group, DataFrame): | ||
group = group._get_numeric_data() | ||
plotf(group, ax, **kwargs) | ||
if weight is not None: | ||
weight = weight._get_numeric_data() | ||
if weight is not None: | ||
plotf(group, ax, weight, **kwargs) | ||
else: | ||
# scatterplot etc has not the weight implemented in plotf | ||
plotf(group, ax, **kwargs) | ||
ax.set_title(com.pprint_thing(key)) | ||
|
||
return fig, axes | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't change the order of parameters, just add at th eend