-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: Allow column grouping in DataFrame.plot #29944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
469eff8
b96a659
c5b187b
99cc049
06da910
49c728a
b370ba5
d1a2ae0
10ac363
9774d42
572de08
4519f22
c6edb9f
73dd7eb
b804e57
d7ee47c
8fb298d
d84f356
ba0f982
fe98b86
f81d295
8255405
f54459e
f962425
dda80b3
a457b38
8d85ba9
bbaeffa
81f2a7f
fdd7610
fe965b0
4478597
ac51202
722c83b
c601c2d
c99abba
ea6c73d
9184e64
34515a8
ae90509
827c5b5
5aace24
b28d3b3
587cbc1
1517ae3
8c61b83
b75d86c
55928ff
b5f47f2
43070d3
e3c0146
a4c2676
71b7b39
4157618
9b44546
2d89f73
8cd6342
ffd4b18
66342bc
48c9f32
aa128dc
69efd18
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
from typing import ( | ||
TYPE_CHECKING, | ||
Hashable, | ||
Iterable, | ||
Sequence, | ||
) | ||
import warnings | ||
|
||
|
@@ -102,7 +104,7 @@ def __init__( | |
data, | ||
kind=None, | ||
by: IndexLabel | None = None, | ||
subplots=False, | ||
subplots: bool | Sequence[Sequence[str]] = False, | ||
sharex=None, | ||
sharey=False, | ||
use_index=True, | ||
|
@@ -166,8 +168,7 @@ def __init__( | |
self.kind = kind | ||
|
||
self.sort_columns = sort_columns | ||
|
||
self.subplots = subplots | ||
self.subplots = self._validate_subplots_kwarg(subplots) | ||
|
||
if sharex is None: | ||
|
||
|
@@ -253,6 +254,112 @@ def __init__( | |
|
||
self._validate_color_args() | ||
|
||
def _validate_subplots_kwarg( | ||
self, subplots: bool | Sequence[Sequence[str]] | ||
) -> bool | list[tuple[int, ...]]: | ||
""" | ||
Validate the subplots parameter | ||
|
||
- check type and content | ||
- check for duplicate columns | ||
- check for invalid column names | ||
- convert column names into indices | ||
- add missing columns in a group of their own | ||
See comments in code below for more details. | ||
|
||
Parameters | ||
---------- | ||
subplots : subplots parameters as passed to PlotAccessor | ||
|
||
Returns | ||
------- | ||
validated subplots : a bool or a list of tuples of column indices. Columns | ||
in the same tuple will be grouped together in the resulting plot. | ||
""" | ||
|
||
if isinstance(subplots, bool): | ||
return subplots | ||
elif not isinstance(subplots, Iterable): | ||
raise ValueError("subplots should be a bool or an iterable") | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
supported_kinds = ( | ||
"line", | ||
"bar", | ||
"barh", | ||
"hist", | ||
"kde", | ||
"density", | ||
"area", | ||
"pie", | ||
) | ||
if self._kind not in supported_kinds: | ||
raise ValueError( | ||
"When subplots is an iterable, kind must be " | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if By the way, this is quite confusing for me (probably regardless of the given PR): inside There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When subplots is boolean we return immediately as we fall back to the current behaviour, as in master. I haven't double checked but I would hope that The validation happening in |
||
f"one of {', '.join(supported_kinds)}. Got {self._kind}." | ||
) | ||
|
||
if isinstance(self.data, ABCSeries): | ||
raise NotImplementedError( | ||
"An iterable subplots for a Series is not supported." | ||
) | ||
|
||
columns = self.data.columns | ||
if isinstance(columns, ABCMultiIndex): | ||
raise NotImplementedError( | ||
"An iterable subplots for a DataFrame with a MultiIndex column " | ||
"is not supported." | ||
) | ||
|
||
if columns.nunique() != len(columns): | ||
raise NotImplementedError( | ||
"An iterable subplots for a DataFrame with non-unique column " | ||
"labels is not supported." | ||
) | ||
|
||
# subplots is a list of tuples where each tuple is a group of | ||
# columns to be grouped together (one ax per group). | ||
# we consolidate the subplots list such that: | ||
# - the tuples contain indices instead of column names | ||
# - the columns that aren't yet in the list are added in a group | ||
# of their own. | ||
# For example with columns from a to g, and | ||
# subplots = [(a, c), (b, f, e)], | ||
# we end up with [(ai, ci), (bi, fi, ei), (di,), (gi,)] | ||
# This way, we can handle self.subplots in a homogeneous manner | ||
# later. | ||
# TODO: also accept indices instead of just names? | ||
|
||
out = [] | ||
seen_columns: set[Hashable] = set() | ||
for group in subplots: | ||
if not is_list_like(group): | ||
raise ValueError( | ||
"When subplots is an iterable, each entry " | ||
"should be a list/tuple of column names." | ||
) | ||
idx_locs = columns.get_indexer_for(group) | ||
if (idx_locs == -1).any(): | ||
bad_labels = np.extract(idx_locs == -1, group) | ||
raise ValueError( | ||
f"Column label(s) {list(bad_labels)} not found in the DataFrame." | ||
) | ||
else: | ||
unique_columns = set(group) | ||
duplicates = seen_columns.intersection(unique_columns) | ||
if duplicates: | ||
raise ValueError( | ||
"Each column should be in only one subplot. " | ||
f"Columns {duplicates} were found in multiple subplots." | ||
) | ||
seen_columns = seen_columns.union(unique_columns) | ||
out.append(tuple(idx_locs)) | ||
|
||
unseen_columns = columns.difference(seen_columns) | ||
for column in unseen_columns: | ||
idx_loc = columns.get_loc(column) | ||
out.append((idx_loc,)) | ||
return out | ||
|
||
def _validate_color_args(self): | ||
if ( | ||
"color" in self.kwds | ||
|
@@ -371,8 +478,11 @@ def _maybe_right_yaxis(self, ax: Axes, axes_num): | |
|
||
def _setup_subplots(self): | ||
if self.subplots: | ||
naxes = ( | ||
self.nseries if isinstance(self.subplots, bool) else len(self.subplots) | ||
) | ||
fig, axes = create_subplots( | ||
naxes=self.nseries, | ||
naxes=naxes, | ||
sharex=self.sharex, | ||
sharey=self.sharey, | ||
figsize=self.figsize, | ||
|
@@ -784,9 +894,23 @@ def _get_ax_layer(cls, ax, primary=True): | |
else: | ||
return getattr(ax, "right_ax", ax) | ||
|
||
def _col_idx_to_axis_idx(self, col_idx: int) -> int: | ||
"""Return the index of the axis where the column at col_idx should be plotted""" | ||
if isinstance(self.subplots, list): | ||
# Subplots is a list: some columns will be grouped together in the same ax | ||
return next( | ||
group_idx | ||
for (group_idx, group) in enumerate(self.subplots) | ||
if col_idx in group | ||
) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's a matter of style. I personally think that this way is less ambiguous when the first return statement is deeply nested with 4 indentation levels: it's easy to miss and to think that the function only returns There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose, that once you added return type,
Or something like this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think mypy is wrong here. Because of the validation done earlier, there must be a valid column. Raising an exception might make mypy happy but it would never be hit (because of the validation), and we wouldn't be able to test against it so this would reduce test coverage. I changed the for loop into a |
||
# subplots is True: one ax per column | ||
return col_idx | ||
|
||
def _get_ax(self, i: int): | ||
# get the twinx ax if appropriate | ||
if self.subplots: | ||
i = self._col_idx_to_axis_idx(i) | ||
ax = self.axes[i] | ||
ax = self._maybe_right_yaxis(ax, i) | ||
self.axes[i] = ax | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2071,6 +2071,90 @@ def test_plot_no_numeric_data(self): | |
with pytest.raises(TypeError, match="no numeric data to plot"): | ||
df.plot() | ||
|
||
@td.skip_if_no_scipy | ||
@pytest.mark.parametrize( | ||
"kind", ("line", "bar", "barh", "hist", "kde", "density", "area", "pie") | ||
) | ||
def test_group_subplot(self, kind): | ||
d = { | ||
"a": np.arange(10), | ||
"b": np.arange(10) + 1, | ||
"c": np.arange(10) + 1, | ||
"d": np.arange(10), | ||
"e": np.arange(10), | ||
} | ||
df = DataFrame(d) | ||
|
||
axes = df.plot(subplots=[("b", "e"), ("c", "d")], kind=kind) | ||
assert len(axes) == 3 # 2 groups + single column a | ||
|
||
expected_labels = (["b", "e"], ["c", "d"], ["a"]) | ||
for ax, labels in zip(axes, expected_labels): | ||
if kind != "pie": | ||
self._check_legend_labels(ax, labels=labels) | ||
if kind == "line": | ||
assert len(ax.lines) == len(labels) | ||
|
||
def test_group_subplot_series_notimplemented(self): | ||
ser = Series(range(1)) | ||
msg = "An iterable subplots for a Series" | ||
with pytest.raises(NotImplementedError, match=msg): | ||
ser.plot(subplots=[("a",)]) | ||
|
||
def test_group_subplot_multiindex_notimplemented(self): | ||
df = DataFrame(np.eye(2), columns=MultiIndex.from_tuples([(0, 1), (1, 2)])) | ||
msg = "An iterable subplots for a DataFrame with a MultiIndex" | ||
with pytest.raises(NotImplementedError, match=msg): | ||
df.plot(subplots=[(0, 1)]) | ||
|
||
def test_group_subplot_nonunique_cols_notimplemented(self): | ||
df = DataFrame(np.eye(2), columns=["a", "a"]) | ||
msg = "An iterable subplots for a DataFrame with non-unique" | ||
with pytest.raises(NotImplementedError, match=msg): | ||
df.plot(subplots=[("a",)]) | ||
|
||
@pytest.mark.parametrize( | ||
"subplots, expected_msg", | ||
[ | ||
(123, "subplots should be a bool or an iterable"), | ||
("a", "each entry should be a list/tuple"), # iterable of non-iterable | ||
((1,), "each entry should be a list/tuple"), # iterable of non-iterable | ||
(("a",), "each entry should be a list/tuple"), # iterable of strings | ||
], | ||
) | ||
def test_group_subplot_bad_input(self, subplots, expected_msg): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good parametrization! |
||
# Make sure error is raised when subplots is not a properly | ||
# formatted iterable. Only iterables of iterables are permitted, and | ||
# entries should not be strings. | ||
d = {"a": np.arange(10), "b": np.arange(10)} | ||
df = DataFrame(d) | ||
|
||
with pytest.raises(ValueError, match=expected_msg): | ||
df.plot(subplots=subplots) | ||
|
||
def test_group_subplot_invalid_column_name(self): | ||
d = {"a": np.arange(10), "b": np.arange(10)} | ||
df = DataFrame(d) | ||
|
||
with pytest.raises(ValueError, match=r"Column label\(s\) \['bad_name'\]"): | ||
df.plot(subplots=[("a", "bad_name")]) | ||
|
||
def test_group_subplot_duplicated_column(self): | ||
d = {"a": np.arange(10), "b": np.arange(10), "c": np.arange(10)} | ||
df = DataFrame(d) | ||
|
||
with pytest.raises(ValueError, match="should be in only one subplot"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it really a desirable behavior (see other comment)? |
||
df.plot(subplots=[("a", "b"), ("a", "c")]) | ||
|
||
@pytest.mark.parametrize("kind", ("box", "scatter", "hexbin")) | ||
def test_group_subplot_invalid_kind(self, kind): | ||
d = {"a": np.arange(10), "b": np.arange(10)} | ||
df = DataFrame(d) | ||
with pytest.raises( | ||
ValueError, match="When subplots is an iterable, kind must be one of" | ||
): | ||
df.plot(subplots=[("a", "b")], kind=kind) | ||
|
||
@pytest.mark.parametrize( | ||
"index_name, old_label, new_label", | ||
[ | ||
|
Uh oh!
There was an error while loading. Please reload this page.