Skip to content

Fix for DataFrame.hist() with by- and weights-keyword #11441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 72 additions & 17 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2770,9 +2770,10 @@ def plot_group(group, ax):
return fig


def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
xrot=None, ylabelsize=None, yrot=None, ax=None, sharex=False,
sharey=False, figsize=None, layout=None, bins=10, **kwds):
def hist_frame(data, column=None, weights=None, by=None, grid=True,
xlabelsize=None, xrot=None, ylabelsize=None, yrot=None, ax=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't change the order of parameters, just add at th eend

sharex=False, sharey=False, figsize=None, layout=None, bins=10,
**kwds):
"""
Draw histogram of the DataFrame's series using matplotlib / pylab.

Expand All @@ -2781,6 +2782,8 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
data : DataFrame
column : string or sequence
If passed, will be used to limit data to a subset of columns
weights : string or sequence
If passed, will be used to weight the data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

by : object, optional
If passed, then used to form histograms for separate groups
grid : boolean, default True
Expand Down Expand Up @@ -2810,19 +2813,35 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
kwds : other plotting keyword arguments
To be passed to hist function
"""

if by is not None:
axes = grouped_hist(data, column=column, by=by, ax=ax, grid=grid, figsize=figsize,
axes = grouped_hist(data, column=column, weights=weights, by=by, ax=ax, grid=grid, figsize=figsize,
sharex=sharex, sharey=sharey, layout=layout, bins=bins,
xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot,
**kwds)
return axes

inx_na = np.zeros(len(data), dtype=bool)
if weights is not None:
# first figure out if given my column name, or by an array
if isinstance(weights, str):
weights = data[weights]
if isinstance(weights, np.ndarray) == False:
weights = weights.values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code is pretty close to DataFrame.sample need to consolidate into a function somewhere

# remove fields where we have nan in weights OR in group
# for both data sets
inx_na = (np.isnan(weights))

if column is not None:
if not isinstance(column, (list, np.ndarray, Index)):
column = [column]
data = data[column]
data = data._get_numeric_data()
inx_na |= np.isnan(data.T.values)[0]

data = data.ix[~inx_na]
if weights is not None:
weights = weights[~inx_na]

naxes = len(data.columns)

fig, axes = _subplots(naxes=naxes, ax=ax, squeeze=False,
Expand All @@ -2832,7 +2851,7 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,

for i, col in enumerate(com._try_sort(data.columns)):
ax = _axes[i]
ax.hist(data[col].dropna().values, bins=bins, **kwds)
ax.hist(data[col].values, bins=bins, weights=weights, **kwds)
ax.set_title(col)
ax.grid(grid)

Expand Down Expand Up @@ -2916,17 +2935,18 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None,
return axes


def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None,
layout=None, sharex=False, sharey=False, rot=90, grid=True,
xlabelsize=None, xrot=None, ylabelsize=None, yrot=None,
**kwargs):
def grouped_hist(data, column=None, weights=None, by=None, ax=None, bins=50,
figsize=None, layout=None, sharex=False, sharey=False, rot=90,
grid=True, xlabelsize=None, xrot=None, ylabelsize=None,
yrot=None, **kwargs):
"""
Grouped histogram

Parameters
----------
data: Series/DataFrame
column: object, optional
weights: object, optional
by: object, optional
ax: axes, optional
bins: int, default 50
Expand All @@ -2942,12 +2962,26 @@ def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None,
-------
axes: collection of Matplotlib Axes
"""
def plot_group(group, ax):
ax.hist(group.dropna().values, bins=bins, **kwargs)
def plot_group(group, ax, weights=None):
if isinstance(group, np.ndarray) == False:
group = group.values
inx_na = np.isnan(group)
if weights is not None:
# remove fields where we have nan in weights OR in group
# for both data sets
if isinstance(weights, np.ndarray) == False:
weights = weights.values
inx_na |= (np.isnan(weights))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here this is pretty duplicative

weights = weights[~inx_na]
group = group[~inx_na]
if len(group) > 0:
# if length is less than 0, we had only NaN's for this group
# nothing to print!
ax.hist(group, weights=weights, bins=bins, **kwargs)

xrot = xrot or rot

fig, axes = _grouped_plot(plot_group, data, column=column,
fig, axes = _grouped_plot(plot_group, data, column=column, weights=weights,
by=by, sharex=sharex, sharey=sharey, ax=ax,
figsize=figsize, layout=layout, rot=rot)

Expand Down Expand Up @@ -3034,9 +3068,9 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
return ret


def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
figsize=None, sharex=True, sharey=True, layout=None,
rot=0, ax=None, **kwargs):
def _grouped_plot(plotf, data, column=None, weights=None, by=None,
numeric_only=True, figsize=None, sharex=True, sharey=True,
layout=None, rot=0, ax=None, **kwargs):
from pandas import DataFrame

if figsize == 'default':
Expand All @@ -3045,22 +3079,43 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
"size by tuple instead", FutureWarning, stacklevel=4)
figsize = None

added_weights_dummy_column = False
if isinstance(weights, np.ndarray):
# weights supplied as an array instead of a part of the dataframe
data['weights'] = weights
weights = 'weights'
added_weights_dummy_column = True

grouped = data.groupby(by)
if column is not None:
if weights is not None:
weights = grouped[weights]
grouped = grouped[column]

if added_weights_dummy_column:
data = data.drop('weights', axis=1)

naxes = len(grouped)
fig, axes = _subplots(naxes=naxes, figsize=figsize,
sharex=sharex, sharey=sharey, ax=ax,
layout=layout)

_axes = _flatten(axes)

weight = None
for i, (key, group) in enumerate(grouped):
ax = _axes[i]
if weights is not None:
weight = weights.get_group(key)
if numeric_only and isinstance(group, DataFrame):
group = group._get_numeric_data()
plotf(group, ax, **kwargs)
if weight is not None:
weight = weight._get_numeric_data()
if weight is not None:
plotf(group, ax, weight, **kwargs)
else:
# scatterplot etc has not the weight implemented in plotf
plotf(group, ax, **kwargs)
ax.set_title(com.pprint_thing(key))

return fig, axes
Expand Down