Skip to content

Fix for DataFrame.hist() with by- and weights-keyword #11028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2700,7 +2700,7 @@ def plot_group(group, ax):
return fig


def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
def hist_frame(data, column=None, weights=None, by=None, grid=True, xlabelsize=None,
xrot=None, ylabelsize=None, yrot=None, ax=None, sharex=False,
sharey=False, figsize=None, layout=None, bins=10, **kwds):
"""
Expand All @@ -2711,6 +2711,8 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
data : DataFrame
column : string or sequence
If passed, will be used to limit data to a subset of columns
weights : string or sequence
If passed, will be used to weight the data
by : object, optional
If passed, then used to form histograms for separate groups
grid : boolean, default True
Expand Down Expand Up @@ -2742,7 +2744,7 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
"""

if by is not None:
axes = grouped_hist(data, column=column, by=by, ax=ax, grid=grid, figsize=figsize,
axes = grouped_hist(data, column=column, weights=weights, by=by, ax=ax, grid=grid, figsize=figsize,
sharex=sharex, sharey=sharey, layout=layout, bins=bins,
xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot,
**kwds)
Expand Down Expand Up @@ -2846,7 +2848,7 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None,
return axes


def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None,
def grouped_hist(data, column=None, weights=None, by=None, ax=None, bins=50, figsize=None,
layout=None, sharex=False, sharey=False, rot=90, grid=True,
xlabelsize=None, xrot=None, ylabelsize=None, yrot=None,
**kwargs):
Expand All @@ -2857,6 +2859,7 @@ def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None,
----------
data: Series/DataFrame
column: object, optional
weights: object, optional
by: object, optional
ax: axes, optional
bins: int, default 50
Expand All @@ -2872,12 +2875,14 @@ def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None,
-------
axes: collection of Matplotlib Axes
"""
def plot_group(group, ax):
ax.hist(group.dropna().values, bins=bins, **kwargs)
def plot_group(group, ax, weights=None):
if weights is not None:
weights=weights.dropna().values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this will fail if the missing values in the weights column don't align perfectly with the missing values in the group column. It might be cleaner to refactor this to drop rows missing either group or weight earlier on, so that plot_group only has to deal with valid observations.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this because on the next line, group drops na values as well?
ax.hist(group.dropna().values

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, issues when na is different on weights and group arrays, did not think of this, will think a little bit more

ax.hist(group.dropna().values, weights=weights, bins=bins, **kwargs)

xrot = xrot or rot

fig, axes = _grouped_plot(plot_group, data, column=column,
fig, axes = _grouped_plot(plot_group, data, column=column, weights=weights,
by=by, sharex=sharex, sharey=sharey, ax=ax,
figsize=figsize, layout=layout, rot=rot)

Expand Down Expand Up @@ -2964,7 +2969,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
return ret


def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
def _grouped_plot(plotf, data, column=None, weights=None, by=None, numeric_only=True,
figsize=None, sharex=True, sharey=True, layout=None,
rot=0, ax=None, **kwargs):
from pandas import DataFrame
Expand All @@ -2977,6 +2982,8 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,

grouped = data.groupby(by)
if column is not None:
if weights is not None:
weights = grouped[weights]
grouped = grouped[column]

naxes = len(grouped)
Expand All @@ -2986,11 +2993,20 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,

_axes = _flatten(axes)

weight = None
for i, (key, group) in enumerate(grouped):
ax = _axes[i]
if weights is not None:
weight = weights.get_group(key)
if numeric_only and isinstance(group, DataFrame):
group = group._get_numeric_data()
plotf(group, ax, **kwargs)
if weight is not None:
weight = weight._get_numeric_data()
if weight is not None:
plotf(group, ax, weight, **kwargs)
else:
# scatterplot etc has not the weight implemented in plotf
plotf(group, ax, **kwargs)
ax.set_title(com.pprint_thing(key))

return fig, axes
Expand Down