Skip to content

Commit a30bb6f

Browse files
authored
REF/CLN: Standardized matplotlib imports (#58937)
* Use standard import matplotlib as mpl * Standardaize more matplotlib imports * Fix matplotlib units * Reduce diff a little more * Import matplotlib dates * satisfy pyright
1 parent c95716a commit a30bb6f

File tree

14 files changed

+73
-125
lines changed

14 files changed

+73
-125
lines changed

pandas/plotting/_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1598,7 +1598,7 @@ def area(
15981598
15991599
See Also
16001600
--------
1601-
DataFrame.plot : Make plots of DataFrame using matplotlib / pylab.
1601+
DataFrame.plot : Make plots of DataFrame using matplotlib.
16021602
16031603
Examples
16041604
--------

pandas/plotting/_matplotlib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def plot(data, kind, **kwargs):
6969
kwargs["ax"] = getattr(ax, "left_ax", ax)
7070
plot_obj = PLOT_CLASSES[kind](data, **kwargs)
7171
plot_obj.generate()
72-
plot_obj.draw()
72+
plt.draw_if_interactive()
7373
return plot_obj.result
7474

7575

pandas/plotting/_matplotlib/boxplot.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
)
88
import warnings
99

10-
from matplotlib.artist import setp
10+
import matplotlib as mpl
1111
import numpy as np
1212

1313
from pandas._libs import lib
@@ -274,13 +274,13 @@ def maybe_color_bp(bp, color_tup, **kwds) -> None:
274274
# GH#30346, when users specifying those arguments explicitly, our defaults
275275
# for these four kwargs should be overridden; if not, use Pandas settings
276276
if not kwds.get("boxprops"):
277-
setp(bp["boxes"], color=color_tup[0], alpha=1)
277+
mpl.artist.setp(bp["boxes"], color=color_tup[0], alpha=1)
278278
if not kwds.get("whiskerprops"):
279-
setp(bp["whiskers"], color=color_tup[1], alpha=1)
279+
mpl.artist.setp(bp["whiskers"], color=color_tup[1], alpha=1)
280280
if not kwds.get("medianprops"):
281-
setp(bp["medians"], color=color_tup[2], alpha=1)
281+
mpl.artist.setp(bp["medians"], color=color_tup[2], alpha=1)
282282
if not kwds.get("capprops"):
283-
setp(bp["caps"], color=color_tup[3], alpha=1)
283+
mpl.artist.setp(bp["caps"], color=color_tup[3], alpha=1)
284284

285285

286286
def _grouped_plot_by_column(
@@ -455,7 +455,7 @@ def plot_group(keys, values, ax: Axes, **kwds):
455455

456456
if ax is None:
457457
rc = {"figure.figsize": figsize} if figsize is not None else {}
458-
with plt.rc_context(rc):
458+
with mpl.rc_context(rc):
459459
ax = plt.gca()
460460
data = data._get_numeric_data()
461461
naxes = len(data.columns)

pandas/plotting/_matplotlib/converter.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,8 @@
1414
)
1515
import warnings
1616

17+
import matplotlib as mpl
1718
import matplotlib.dates as mdates
18-
from matplotlib.ticker import (
19-
AutoLocator,
20-
Formatter,
21-
Locator,
22-
)
23-
from matplotlib.transforms import nonsingular
2419
import matplotlib.units as munits
2520
import numpy as np
2621

@@ -174,7 +169,7 @@ def axisinfo(unit, axis) -> munits.AxisInfo | None:
174169
if unit != "time":
175170
return None
176171

177-
majloc = AutoLocator()
172+
majloc = mpl.ticker.AutoLocator() # pyright: ignore[reportAttributeAccessIssue]
178173
majfmt = TimeFormatter(majloc)
179174
return munits.AxisInfo(majloc=majloc, majfmt=majfmt, label="time")
180175

@@ -184,7 +179,7 @@ def default_units(x, axis) -> str:
184179

185180

186181
# time formatter
187-
class TimeFormatter(Formatter):
182+
class TimeFormatter(mpl.ticker.Formatter): # pyright: ignore[reportAttributeAccessIssue]
188183
def __init__(self, locs) -> None:
189184
self.locs = locs
190185

@@ -917,7 +912,7 @@ def get_finder(freq: BaseOffset):
917912
raise NotImplementedError(f"Unsupported frequency: {dtype_code}")
918913

919914

920-
class TimeSeries_DateLocator(Locator):
915+
class TimeSeries_DateLocator(mpl.ticker.Locator): # pyright: ignore[reportAttributeAccessIssue]
921916
"""
922917
Locates the ticks along an axis controlled by a :class:`Series`.
923918
@@ -998,15 +993,15 @@ def autoscale(self):
998993
if vmin == vmax:
999994
vmin -= 1
1000995
vmax += 1
1001-
return nonsingular(vmin, vmax)
996+
return mpl.transforms.nonsingular(vmin, vmax)
1002997

1003998

1004999
# -------------------------------------------------------------------------
10051000
# --- Formatter ---
10061001
# -------------------------------------------------------------------------
10071002

10081003

1009-
class TimeSeries_DateFormatter(Formatter):
1004+
class TimeSeries_DateFormatter(mpl.ticker.Formatter): # pyright: ignore[reportAttributeAccessIssue]
10101005
"""
10111006
Formats the ticks along an axis controlled by a :class:`PeriodIndex`.
10121007
@@ -1082,7 +1077,7 @@ def __call__(self, x, pos: int | None = 0) -> str:
10821077
return period.strftime(fmt)
10831078

10841079

1085-
class TimeSeries_TimedeltaFormatter(Formatter):
1080+
class TimeSeries_TimedeltaFormatter(mpl.ticker.Formatter): # pyright: ignore[reportAttributeAccessIssue]
10861081
"""
10871082
Formats the ticks along an axis controlled by a :class:`TimedeltaIndex`.
10881083
"""

pandas/plotting/_matplotlib/core.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ def _color_in_style(style: str) -> bool:
107107
"""
108108
Check if there is a color letter in the style string.
109109
"""
110-
from matplotlib.colors import BASE_COLORS
111-
112-
return not set(BASE_COLORS).isdisjoint(style)
110+
return not set(mpl.colors.BASE_COLORS).isdisjoint(style)
113111

114112

115113
class MPLPlot(ABC):
@@ -176,8 +174,6 @@ def __init__(
176174
style=None,
177175
**kwds,
178176
) -> None:
179-
import matplotlib.pyplot as plt
180-
181177
# if users assign an empty list or tuple, raise `ValueError`
182178
# similar to current `df.box` and `df.hist` APIs.
183179
if by in ([], ()):
@@ -238,7 +234,7 @@ def __init__(
238234
self.rot = self._default_rot
239235

240236
if grid is None:
241-
grid = False if secondary_y else plt.rcParams["axes.grid"]
237+
grid = False if secondary_y else mpl.rcParams["axes.grid"]
242238

243239
self.grid = grid
244240
self.legend = legend
@@ -497,10 +493,6 @@ def _get_nseries(self, data: Series | DataFrame) -> int:
497493
def nseries(self) -> int:
498494
return self._get_nseries(self.data)
499495

500-
@final
501-
def draw(self) -> None:
502-
self.plt.draw_if_interactive()
503-
504496
@final
505497
def generate(self) -> None:
506498
self._compute_plot_data()
@@ -570,6 +562,8 @@ def axes(self) -> Sequence[Axes]:
570562
@final
571563
@cache_readonly
572564
def _axes_and_fig(self) -> tuple[Sequence[Axes], Figure]:
565+
import matplotlib.pyplot as plt
566+
573567
if self.subplots:
574568
naxes = (
575569
self.nseries if isinstance(self.subplots, bool) else len(self.subplots)
@@ -584,7 +578,7 @@ def _axes_and_fig(self) -> tuple[Sequence[Axes], Figure]:
584578
layout_type=self._layout_type,
585579
)
586580
elif self.ax is None:
587-
fig = self.plt.figure(figsize=self.figsize)
581+
fig = plt.figure(figsize=self.figsize)
588582
axes = fig.add_subplot(111)
589583
else:
590584
fig = self.ax.get_figure()
@@ -918,13 +912,6 @@ def _get_ax_legend(ax: Axes):
918912
ax = other_ax
919913
return ax, leg
920914

921-
@final
922-
@cache_readonly
923-
def plt(self):
924-
import matplotlib.pyplot as plt
925-
926-
return plt
927-
928915
_need_to_set_index = False
929916

930917
@final
@@ -1219,9 +1206,9 @@ def _get_errorbars(
12191206
@final
12201207
def _get_subplots(self, fig: Figure) -> list[Axes]:
12211208
if Version(mpl.__version__) < Version("3.8"):
1222-
from matplotlib.axes import Subplot as Klass
1209+
Klass = mpl.axes.Subplot
12231210
else:
1224-
from matplotlib.axes import Axes as Klass
1211+
Klass = mpl.axes.Axes
12251212

12261213
return [
12271214
ax
@@ -1386,7 +1373,7 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool):
13861373
if c is not None and color is not None:
13871374
raise TypeError("Specify exactly one of `c` and `color`")
13881375
if c is None and color is None:
1389-
c_values = self.plt.rcParams["patch.facecolor"]
1376+
c_values = mpl.rcParams["patch.facecolor"]
13901377
elif color is not None:
13911378
c_values = color
13921379
elif color_by_categorical:
@@ -1411,12 +1398,10 @@ def _get_norm_and_cmap(self, c_values, color_by_categorical: bool):
14111398
cmap = None
14121399

14131400
if color_by_categorical and cmap is not None:
1414-
from matplotlib import colors
1415-
14161401
n_cats = len(self.data[c].cat.categories)
1417-
cmap = colors.ListedColormap([cmap(i) for i in range(cmap.N)])
1402+
cmap = mpl.colors.ListedColormap([cmap(i) for i in range(cmap.N)])
14181403
bounds = np.linspace(0, n_cats, n_cats + 1)
1419-
norm = colors.BoundaryNorm(bounds, cmap.N)
1404+
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
14201405
# TODO: warn that we are ignoring self.norm if user specified it?
14211406
# Doesn't happen in any tests 2023-11-09
14221407
else:
@@ -1676,8 +1661,6 @@ def _update_stacker(cls, ax: Axes, stacking_id: int | None, values) -> None:
16761661
ax._stacker_neg_prior[stacking_id] += values # type: ignore[attr-defined]
16771662

16781663
def _post_plot_logic(self, ax: Axes, data) -> None:
1679-
from matplotlib.ticker import FixedLocator
1680-
16811664
def get_label(i):
16821665
if is_float(i) and i.is_integer():
16831666
i = int(i)
@@ -1691,7 +1674,7 @@ def get_label(i):
16911674
xticklabels = [get_label(x) for x in xticks]
16921675
# error: Argument 1 to "FixedLocator" has incompatible type "ndarray[Any,
16931676
# Any]"; expected "Sequence[float]"
1694-
ax.xaxis.set_major_locator(FixedLocator(xticks)) # type: ignore[arg-type]
1677+
ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(xticks)) # type: ignore[arg-type]
16951678
ax.set_xticklabels(xticklabels)
16961679

16971680
# If the index is an irregular time series, then by default

pandas/plotting/_matplotlib/misc.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import random
44
from typing import TYPE_CHECKING
55

6-
from matplotlib import patches
7-
import matplotlib.lines as mlines
6+
import matplotlib as mpl
87
import numpy as np
98

109
from pandas.core.dtypes.missing import notna
@@ -129,7 +128,7 @@ def scatter_matrix(
129128

130129

131130
def _get_marker_compat(marker):
132-
if marker not in mlines.lineMarkers:
131+
if marker not in mpl.lines.lineMarkers:
133132
return "o"
134133
return marker
135134

@@ -190,10 +189,10 @@ def normalize(series):
190189
)
191190
ax.legend()
192191

193-
ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor="none"))
192+
ax.add_patch(mpl.patches.Circle((0.0, 0.0), radius=1.0, facecolor="none"))
194193

195194
for xy, name in zip(s, df.columns):
196-
ax.add_patch(patches.Circle(xy, radius=0.025, facecolor="gray"))
195+
ax.add_patch(mpl.patches.Circle(xy, radius=0.025, facecolor="gray"))
197196

198197
if xy[0] < 0.0 and xy[1] < 0.0:
199198
ax.text(

pandas/plotting/_matplotlib/style.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,7 @@ def _get_colors_from_color_type(color_type: str, num_colors: int) -> list[Color]
260260

261261
def _get_default_colors(num_colors: int) -> list[Color]:
262262
"""Get `num_colors` of default colors from matplotlib rc params."""
263-
import matplotlib.pyplot as plt
264-
265-
colors = [c["color"] for c in plt.rcParams["axes.prop_cycle"]]
263+
colors = [c["color"] for c in mpl.rcParams["axes.prop_cycle"]]
266264
return colors[0:num_colors]
267265

268266

pandas/plotting/_matplotlib/timeseries.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def format_dateaxis(
333333
default, changing the limits of the x axis will intelligently change
334334
the positions of the ticks.
335335
"""
336-
from matplotlib import pylab
336+
import matplotlib.pyplot as plt
337337

338338
# handle index specific formatting
339339
# Note: DatetimeIndex does not use this
@@ -365,4 +365,4 @@ def format_dateaxis(
365365
else:
366366
raise TypeError("index type not supported")
367367

368-
pylab.draw_if_interactive()
368+
plt.draw_if_interactive()

pandas/plotting/_matplotlib/tools.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from typing import TYPE_CHECKING
66
import warnings
77

8-
from matplotlib import ticker
9-
import matplotlib.table
8+
import matplotlib as mpl
109
import numpy as np
1110

1211
from pandas.util._exceptions import find_stack_level
@@ -77,7 +76,7 @@ def table(
7776

7877
# error: Argument "cellText" to "table" has incompatible type "ndarray[Any,
7978
# Any]"; expected "Sequence[Sequence[str]] | None"
80-
return matplotlib.table.table(
79+
return mpl.table.table(
8180
ax,
8281
cellText=cellText, # type: ignore[arg-type]
8382
rowLabels=rowLabels,
@@ -327,10 +326,10 @@ def _remove_labels_from_axis(axis: Axis) -> None:
327326

328327
# set_visible will not be effective if
329328
# minor axis has NullLocator and NullFormatter (default)
330-
if isinstance(axis.get_minor_locator(), ticker.NullLocator):
331-
axis.set_minor_locator(ticker.AutoLocator())
332-
if isinstance(axis.get_minor_formatter(), ticker.NullFormatter):
333-
axis.set_minor_formatter(ticker.FormatStrFormatter(""))
329+
if isinstance(axis.get_minor_locator(), mpl.ticker.NullLocator):
330+
axis.set_minor_locator(mpl.ticker.AutoLocator())
331+
if isinstance(axis.get_minor_formatter(), mpl.ticker.NullFormatter):
332+
axis.set_minor_formatter(mpl.ticker.FormatStrFormatter(""))
334333
for t in axis.get_minorticklabels():
335334
t.set_visible(False)
336335

@@ -455,17 +454,15 @@ def set_ticks_props(
455454
ylabelsize: int | None = None,
456455
yrot=None,
457456
):
458-
import matplotlib.pyplot as plt
459-
460457
for ax in flatten_axes(axes):
461458
if xlabelsize is not None:
462-
plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
459+
mpl.artist.setp(ax.get_xticklabels(), fontsize=xlabelsize)
463460
if xrot is not None:
464-
plt.setp(ax.get_xticklabels(), rotation=xrot)
461+
mpl.artist.setp(ax.get_xticklabels(), rotation=xrot)
465462
if ylabelsize is not None:
466-
plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
463+
mpl.artist.setp(ax.get_yticklabels(), fontsize=ylabelsize)
467464
if yrot is not None:
468-
plt.setp(ax.get_yticklabels(), rotation=yrot)
465+
mpl.artist.setp(ax.get_yticklabels(), rotation=yrot)
469466
return axes
470467

471468

pandas/tests/io/formats/style/test_matplotlib.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77
Series,
88
)
99

10-
pytest.importorskip("matplotlib")
10+
mpl = pytest.importorskip("matplotlib")
1111
pytest.importorskip("jinja2")
1212

13-
import matplotlib as mpl
14-
1513
from pandas.io.formats.style import Styler
1614

1715
pytestmark = pytest.mark.usefixtures("mpl_cleanup")

0 commit comments

Comments
 (0)