Skip to content

Commit c1dcd54

Browse files
authored
PERF/CLN: Avoid ravel in plotting (#58973)
* Avoid ravel in plotting * Use reshape instead of ravel * Add type ignore
1 parent 3bcc95f commit c1dcd54

File tree

4 files changed

+24
-30
lines changed

4 files changed

+24
-30
lines changed

pandas/plotting/_matplotlib/boxplot.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,6 @@ def _grouped_plot_by_column(
311311
layout=layout,
312312
)
313313

314-
_axes = flatten_axes(axes)
315-
316314
# GH 45465: move the "by" label based on "vert"
317315
xlabel, ylabel = kwargs.pop("xlabel", None), kwargs.pop("ylabel", None)
318316
if kwargs.get("vert", True):
@@ -322,8 +320,7 @@ def _grouped_plot_by_column(
322320

323321
ax_values = []
324322

325-
for i, col in enumerate(columns):
326-
ax = _axes[i]
323+
for ax, col in zip(flatten_axes(axes), columns):
327324
gp_col = grouped[col]
328325
keys, values = zip(*gp_col)
329326
re_plotf = plotf(keys, values, ax, xlabel=xlabel, ylabel=ylabel, **kwargs)
@@ -531,10 +528,8 @@ def boxplot_frame_groupby(
531528
figsize=figsize,
532529
layout=layout,
533530
)
534-
axes = flatten_axes(axes)
535-
536531
data = {}
537-
for (key, group), ax in zip(grouped, axes):
532+
for (key, group), ax in zip(grouped, flatten_axes(axes)):
538533
d = group.boxplot(
539534
ax=ax, column=column, fontsize=fontsize, rot=rot, grid=grid, **kwds
540535
)

pandas/plotting/_matplotlib/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def _axes_and_fig(self) -> tuple[Sequence[Axes], Figure]:
586586
fig.set_size_inches(self.figsize)
587587
axes = self.ax
588588

589-
axes = flatten_axes(axes)
589+
axes = np.fromiter(flatten_axes(axes), dtype=object)
590590

591591
if self.logx is True or self.loglog is True:
592592
[a.set_xscale("log") for a in axes]

pandas/plotting/_matplotlib/hist.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,12 @@ def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]):
9595
def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray:
9696
"""Calculate bins given data"""
9797
nd_values = data.infer_objects()._get_numeric_data()
98-
values = np.ravel(nd_values)
98+
values = nd_values.values
99+
if nd_values.ndim == 2:
100+
values = values.reshape(-1)
99101
values = values[~isna(values)]
100102

101-
hist, bins = np.histogram(values, bins=bins, range=self._bin_range)
102-
return bins
103+
return np.histogram_bin_edges(values, bins=bins, range=self._bin_range)
103104

104105
# error: Signature of "_plot" incompatible with supertype "LinePlot"
105106
@classmethod
@@ -322,10 +323,7 @@ def _grouped_plot(
322323
naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout
323324
)
324325

325-
_axes = flatten_axes(axes)
326-
327-
for i, (key, group) in enumerate(grouped):
328-
ax = _axes[i]
326+
for ax, (key, group) in zip(flatten_axes(axes), grouped):
329327
if numeric_only and isinstance(group, ABCDataFrame):
330328
group = group._get_numeric_data()
331329
plotf(group, ax, **kwargs)
@@ -557,12 +555,9 @@ def hist_frame(
557555
figsize=figsize,
558556
layout=layout,
559557
)
560-
_axes = flatten_axes(axes)
561-
562558
can_set_label = "label" not in kwds
563559

564-
for i, col in enumerate(data.columns):
565-
ax = _axes[i]
560+
for ax, col in zip(flatten_axes(axes), data.columns):
566561
if legend and can_set_label:
567562
kwds["label"] = col
568563
ax.hist(data[col].dropna().values, bins=bins, **kwds)

pandas/plotting/_matplotlib/tools.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
)
1919

2020
if TYPE_CHECKING:
21-
from collections.abc import Iterable
21+
from collections.abc import (
22+
Generator,
23+
Iterable,
24+
)
2225

2326
from matplotlib.axes import Axes
2427
from matplotlib.axis import Axis
@@ -231,7 +234,7 @@ def create_subplots(
231234
else:
232235
if is_list_like(ax):
233236
if squeeze:
234-
ax = flatten_axes(ax)
237+
ax = np.fromiter(flatten_axes(ax), dtype=object)
235238
if layout is not None:
236239
warnings.warn(
237240
"When passing multiple axes, layout keyword is ignored.",
@@ -260,7 +263,7 @@ def create_subplots(
260263
if squeeze:
261264
return fig, ax
262265
else:
263-
return fig, flatten_axes(ax)
266+
return fig, np.fromiter(flatten_axes(ax), dtype=object)
264267
else:
265268
warnings.warn(
266269
"To output multiple subplots, the figure containing "
@@ -439,12 +442,13 @@ def handle_shared_axes(
439442
_remove_labels_from_axis(ax.yaxis)
440443

441444

442-
def flatten_axes(axes: Axes | Iterable[Axes]) -> np.ndarray:
445+
def flatten_axes(axes: Axes | Iterable[Axes]) -> Generator[Axes, None, None]:
443446
if not is_list_like(axes):
444-
return np.array([axes])
447+
yield axes # type: ignore[misc]
445448
elif isinstance(axes, (np.ndarray, ABCIndex)):
446-
return np.asarray(axes).ravel()
447-
return np.array(axes)
449+
yield from np.asarray(axes).reshape(-1)
450+
else:
451+
yield from axes # type: ignore[misc]
448452

449453

450454
def set_ticks_props(
@@ -456,13 +460,13 @@ def set_ticks_props(
456460
):
457461
for ax in flatten_axes(axes):
458462
if xlabelsize is not None:
459-
mpl.artist.setp(ax.get_xticklabels(), fontsize=xlabelsize)
463+
mpl.artist.setp(ax.get_xticklabels(), fontsize=xlabelsize) # type: ignore[arg-type]
460464
if xrot is not None:
461-
mpl.artist.setp(ax.get_xticklabels(), rotation=xrot)
465+
mpl.artist.setp(ax.get_xticklabels(), rotation=xrot) # type: ignore[arg-type]
462466
if ylabelsize is not None:
463-
mpl.artist.setp(ax.get_yticklabels(), fontsize=ylabelsize)
467+
mpl.artist.setp(ax.get_yticklabels(), fontsize=ylabelsize) # type: ignore[arg-type]
464468
if yrot is not None:
465-
mpl.artist.setp(ax.get_yticklabels(), rotation=yrot)
469+
mpl.artist.setp(ax.get_yticklabels(), rotation=yrot) # type: ignore[arg-type]
466470
return axes
467471

468472

0 commit comments

Comments
 (0)