Skip to content

PERF/CLN: Avoid ravel in plotting #58973

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,6 @@ def _grouped_plot_by_column(
layout=layout,
)

_axes = flatten_axes(axes)

# GH 45465: move the "by" label based on "vert"
xlabel, ylabel = kwargs.pop("xlabel", None), kwargs.pop("ylabel", None)
if kwargs.get("vert", True):
Expand All @@ -322,8 +320,7 @@ def _grouped_plot_by_column(

ax_values = []

for i, col in enumerate(columns):
ax = _axes[i]
for ax, col in zip(flatten_axes(axes), columns):
gp_col = grouped[col]
keys, values = zip(*gp_col)
re_plotf = plotf(keys, values, ax, xlabel=xlabel, ylabel=ylabel, **kwargs)
Expand Down Expand Up @@ -531,10 +528,8 @@ def boxplot_frame_groupby(
figsize=figsize,
layout=layout,
)
axes = flatten_axes(axes)

data = {}
for (key, group), ax in zip(grouped, axes):
for (key, group), ax in zip(grouped, flatten_axes(axes)):
d = group.boxplot(
ax=ax, column=column, fontsize=fontsize, rot=rot, grid=grid, **kwds
)
Expand Down
2 changes: 1 addition & 1 deletion pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def _axes_and_fig(self) -> tuple[Sequence[Axes], Figure]:
fig.set_size_inches(self.figsize)
axes = self.ax

axes = flatten_axes(axes)
axes = np.fromiter(flatten_axes(axes), dtype=object)

if self.logx is True or self.loglog is True:
[a.set_xscale("log") for a in axes]
Expand Down
17 changes: 6 additions & 11 deletions pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]):
def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray:
"""Calculate bins given data"""
nd_values = data.infer_objects()._get_numeric_data()
values = np.ravel(nd_values)
values = nd_values.values
if nd_values.ndim == 2:
values = values.reshape(-1)
values = values[~isna(values)]

hist, bins = np.histogram(values, bins=bins, range=self._bin_range)
return bins
return np.histogram_bin_edges(values, bins=bins, range=self._bin_range)

# error: Signature of "_plot" incompatible with supertype "LinePlot"
@classmethod
Expand Down Expand Up @@ -322,10 +323,7 @@ def _grouped_plot(
naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout
)

_axes = flatten_axes(axes)

for i, (key, group) in enumerate(grouped):
ax = _axes[i]
for ax, (key, group) in zip(flatten_axes(axes), grouped):
if numeric_only and isinstance(group, ABCDataFrame):
group = group._get_numeric_data()
plotf(group, ax, **kwargs)
Expand Down Expand Up @@ -557,12 +555,9 @@ def hist_frame(
figsize=figsize,
layout=layout,
)
_axes = flatten_axes(axes)

can_set_label = "label" not in kwds

for i, col in enumerate(data.columns):
ax = _axes[i]
for ax, col in zip(flatten_axes(axes), data.columns):
if legend and can_set_label:
kwds["label"] = col
ax.hist(data[col].dropna().values, bins=bins, **kwds)
Expand Down
26 changes: 15 additions & 11 deletions pandas/plotting/_matplotlib/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
)

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import (
Generator,
Iterable,
)

from matplotlib.axes import Axes
from matplotlib.axis import Axis
Expand Down Expand Up @@ -231,7 +234,7 @@ def create_subplots(
else:
if is_list_like(ax):
if squeeze:
ax = flatten_axes(ax)
ax = np.fromiter(flatten_axes(ax), dtype=object)
if layout is not None:
warnings.warn(
"When passing multiple axes, layout keyword is ignored.",
Expand Down Expand Up @@ -260,7 +263,7 @@ def create_subplots(
if squeeze:
return fig, ax
else:
return fig, flatten_axes(ax)
return fig, np.fromiter(flatten_axes(ax), dtype=object)
else:
warnings.warn(
"To output multiple subplots, the figure containing "
Expand Down Expand Up @@ -439,12 +442,13 @@ def handle_shared_axes(
_remove_labels_from_axis(ax.yaxis)


def flatten_axes(axes: Axes | Iterable[Axes]) -> np.ndarray:
def flatten_axes(axes: Axes | Iterable[Axes]) -> Generator[Axes, None, None]:
if not is_list_like(axes):
return np.array([axes])
yield axes # type: ignore[misc]
elif isinstance(axes, (np.ndarray, ABCIndex)):
return np.asarray(axes).ravel()
return np.array(axes)
yield from np.asarray(axes).reshape(-1)
else:
yield from axes # type: ignore[misc]


def set_ticks_props(
Expand All @@ -456,13 +460,13 @@ def set_ticks_props(
):
for ax in flatten_axes(axes):
if xlabelsize is not None:
mpl.artist.setp(ax.get_xticklabels(), fontsize=xlabelsize)
mpl.artist.setp(ax.get_xticklabels(), fontsize=xlabelsize) # type: ignore[arg-type]
if xrot is not None:
mpl.artist.setp(ax.get_xticklabels(), rotation=xrot)
mpl.artist.setp(ax.get_xticklabels(), rotation=xrot) # type: ignore[arg-type]
if ylabelsize is not None:
mpl.artist.setp(ax.get_yticklabels(), fontsize=ylabelsize)
mpl.artist.setp(ax.get_yticklabels(), fontsize=ylabelsize) # type: ignore[arg-type]
if yrot is not None:
mpl.artist.setp(ax.get_yticklabels(), rotation=yrot)
mpl.artist.setp(ax.get_yticklabels(), rotation=yrot) # type: ignore[arg-type]
return axes


Expand Down