Skip to content

Commit b6bcb5c

Browse files
committed
Fixed bug where weights were not used, since not included in **kwds anymore
Also made the logic better when using weights without group by by aligning NaN's, and finding out if weights is supplied by column or by array.
1 parent 787577e commit b6bcb5c

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

pandas/tools/plotting.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2813,19 +2813,35 @@ def hist_frame(data, column=None, weights=None, by=None, grid=True,
28132813
kwds : other plotting keyword arguments
28142814
To be passed to hist function
28152815
"""
2816-
28172816
if by is not None:
28182817
axes = grouped_hist(data, column=column, weights=weights, by=by, ax=ax, grid=grid, figsize=figsize,
28192818
sharex=sharex, sharey=sharey, layout=layout, bins=bins,
28202819
xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot,
28212820
**kwds)
28222821
return axes
28232822

2823+
inx_na = np.zeros(len(data), dtype=bool)
2824+
if weights is not None:
2825+
# first figure out if given my column name, or by an array
2826+
if isinstance(weights, str):
2827+
weights = data[weights]
2828+
if isinstance(weights, np.ndarray) == False:
2829+
weights = weights.values
2830+
# remove fields where we have nan in weights OR in group
2831+
# for both data sets
2832+
inx_na = (np.isnan(weights))
2833+
28242834
if column is not None:
28252835
if not isinstance(column, (list, np.ndarray, Index)):
28262836
column = [column]
28272837
data = data[column]
28282838
data = data._get_numeric_data()
2839+
inx_na |= np.isnan(data.T.values)[0]
2840+
2841+
data = data.ix[~inx_na]
2842+
if weights is not None:
2843+
weights = weights[~inx_na]
2844+
28292845
naxes = len(data.columns)
28302846

28312847
fig, axes = _subplots(naxes=naxes, ax=ax, squeeze=False,
@@ -2835,7 +2851,7 @@ def hist_frame(data, column=None, weights=None, by=None, grid=True,
28352851

28362852
for i, col in enumerate(com._try_sort(data.columns)):
28372853
ax = _axes[i]
2838-
ax.hist(data[col].dropna().values, bins=bins, **kwds)
2854+
ax.hist(data[col].values, bins=bins, weights=weights, **kwds)
28392855
ax.set_title(col)
28402856
ax.grid(grid)
28412857

@@ -3063,17 +3079,22 @@ def _grouped_plot(plotf, data, column=None, weights=None, by=None,
30633079
"size by tuple instead", FutureWarning, stacklevel=4)
30643080
figsize = None
30653081

3082+
added_weights_dummy_column = False
30663083
if isinstance(weights, np.ndarray):
30673084
# weights supplied as an array instead of a part of the dataframe
30683085
data['weights'] = weights
30693086
weights = 'weights'
3087+
added_weights_dummy_column = True
30703088

30713089
grouped = data.groupby(by)
30723090
if column is not None:
30733091
if weights is not None:
30743092
weights = grouped[weights]
30753093
grouped = grouped[column]
30763094

3095+
if added_weights_dummy_column:
3096+
data = data.drop('weights', axis=1)
3097+
30773098
naxes = len(grouped)
30783099
fig, axes = _subplots(naxes=naxes, figsize=figsize,
30793100
sharex=sharex, sharey=sharey, ax=ax,

0 commit comments

Comments
 (0)