Skip to content

REF/CLN: Standardized matplotlib imports #58937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/plotting/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1598,7 +1598,7 @@ def area(

See Also
--------
DataFrame.plot : Make plots of DataFrame using matplotlib / pylab.
DataFrame.plot : Make plots of DataFrame using matplotlib.

Examples
--------
Expand Down
2 changes: 1 addition & 1 deletion pandas/plotting/_matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def plot(data, kind, **kwargs):
kwargs["ax"] = getattr(ax, "left_ax", ax)
plot_obj = PLOT_CLASSES[kind](data, **kwargs)
plot_obj.generate()
plot_obj.draw()
plt.draw_if_interactive()
return plot_obj.result


Expand Down
12 changes: 6 additions & 6 deletions pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
import warnings

from matplotlib.artist import setp
import matplotlib as mpl
import numpy as np

from pandas._libs import lib
Expand Down Expand Up @@ -274,13 +274,13 @@ def maybe_color_bp(bp, color_tup, **kwds) -> None:
# GH#30346, when users specifying those arguments explicitly, our defaults
# for these four kwargs should be overridden; if not, use Pandas settings
if not kwds.get("boxprops"):
setp(bp["boxes"], color=color_tup[0], alpha=1)
mpl.artist.setp(bp["boxes"], color=color_tup[0], alpha=1)
if not kwds.get("whiskerprops"):
setp(bp["whiskers"], color=color_tup[1], alpha=1)
mpl.artist.setp(bp["whiskers"], color=color_tup[1], alpha=1)
if not kwds.get("medianprops"):
setp(bp["medians"], color=color_tup[2], alpha=1)
mpl.artist.setp(bp["medians"], color=color_tup[2], alpha=1)
if not kwds.get("capprops"):
setp(bp["caps"], color=color_tup[3], alpha=1)
mpl.artist.setp(bp["caps"], color=color_tup[3], alpha=1)


def _grouped_plot_by_column(
Expand Down Expand Up @@ -455,7 +455,7 @@ def plot_group(keys, values, ax: Axes, **kwds):

if ax is None:
rc = {"figure.figsize": figsize} if figsize is not None else {}
with plt.rc_context(rc):
with mpl.rc_context(rc):
ax = plt.gca()
data = data._get_numeric_data()
naxes = len(data.columns)
Expand Down
19 changes: 7 additions & 12 deletions pandas/plotting/_matplotlib/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,8 @@
)
import warnings

import matplotlib as mpl
import matplotlib.dates as mdates
from matplotlib.ticker import (
AutoLocator,
Formatter,
Locator,
)
from matplotlib.transforms import nonsingular
import matplotlib.units as munits
import numpy as np

Expand Down Expand Up @@ -174,7 +169,7 @@ def axisinfo(unit, axis) -> munits.AxisInfo | None:
if unit != "time":
return None

majloc = AutoLocator()
majloc = mpl.ticker.AutoLocator() # pyright: ignore[reportAttributeAccessIssue]
majfmt = TimeFormatter(majloc)
return munits.AxisInfo(majloc=majloc, majfmt=majfmt, label="time")

Expand All @@ -184,7 +179,7 @@ def default_units(x, axis) -> str:


# time formatter
class TimeFormatter(Formatter):
class TimeFormatter(mpl.ticker.Formatter): # pyright: ignore[reportAttributeAccessIssue]
def __init__(self, locs) -> None:
self.locs = locs

Expand Down Expand Up @@ -917,7 +912,7 @@ def get_finder(freq: BaseOffset):
raise NotImplementedError(f"Unsupported frequency: {dtype_code}")


class TimeSeries_DateLocator(Locator):
class TimeSeries_DateLocator(mpl.ticker.Locator): # pyright: ignore[reportAttributeAccessIssue]
"""
Locates the ticks along an axis controlled by a :class:`Series`.

Expand Down Expand Up @@ -998,15 +993,15 @@ def autoscale(self):
if vmin == vmax:
vmin -= 1
vmax += 1
return nonsingular(vmin, vmax)
return mpl.transforms.nonsingular(vmin, vmax)


# -------------------------------------------------------------------------
# --- Formatter ---
# -------------------------------------------------------------------------


class TimeSeries_DateFormatter(Formatter):
class TimeSeries_DateFormatter(mpl.ticker.Formatter): # pyright: ignore[reportAttributeAccessIssue]
"""
Formats the ticks along an axis controlled by a :class:`PeriodIndex`.

Expand Down Expand Up @@ -1082,7 +1077,7 @@ def __call__(self, x, pos: int | None = 0) -> str:
return period.strftime(fmt)


class TimeSeries_TimedeltaFormatter(Formatter):
class TimeSeries_TimedeltaFormatter(mpl.ticker.Formatter): # pyright: ignore[reportAttributeAccessIssue]
"""
Formats the ticks along an axis controlled by a :class:`TimedeltaIndex`.
"""
Expand Down
39 changes: 11 additions & 28 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ def _color_in_style(style: str) -> bool:
"""
Check if there is a color letter in the style string.
"""
from matplotlib.colors import BASE_COLORS

return not set(BASE_COLORS).isdisjoint(style)
return not set(mpl.colors.BASE_COLORS).isdisjoint(style)


class MPLPlot(ABC):
Expand Down Expand Up @@ -176,8 +174,6 @@ def __init__(
style=None,
**kwds,
) -> None:
import matplotlib.pyplot as plt

# if users assign an empty list or tuple, raise `ValueError`
# similar to current `df.box` and `df.hist` APIs.
if by in ([], ()):
Expand Down Expand Up @@ -238,7 +234,7 @@ def __init__(
self.rot = self._default_rot

if grid is None:
grid = False if secondary_y else plt.rcParams["axes.grid"]
grid = False if secondary_y else mpl.rcParams["axes.grid"]

self.grid = grid
self.legend = legend
Expand Down Expand Up @@ -497,10 +493,6 @@ def _get_nseries(self, data: Series | DataFrame) -> int:
def nseries(self) -> int:
return self._get_nseries(self.data)

@final
def draw(self) -> None:
self.plt.draw_if_interactive()

@final
def generate(self) -> None:
self._compute_plot_data()
Expand Down Expand Up @@ -570,6 +562,8 @@ def axes(self) -> Sequence[Axes]:
@final
@cache_readonly
def _axes_and_fig(self) -> tuple[Sequence[Axes], Figure]:
import matplotlib.pyplot as plt

if self.subplots:
naxes = (
self.nseries if isinstance(self.subplots, bool) else len(self.subplots)
Expand All @@ -584,7 +578,7 @@ def _axes_and_fig(self) -> tuple[Sequence[Axes], Figure]:
layout_type=self._layout_type,
)
elif self.ax is None:
fig = self.plt.figure(figsize=self.figsize)
fig = plt.figure(figsize=self.figsize)
axes = fig.add_subplot(111)
else:
fig = self.ax.get_figure()
Expand Down Expand Up @@ -918,13 +912,6 @@ def _get_ax_legend(ax: Axes):
ax = other_ax
return ax, leg

@final
@cache_readonly
def plt(self):
import matplotlib.pyplot as plt

return plt

_need_to_set_index = False

@final
Expand Down Expand Up @@ -1219,9 +1206,9 @@ def _get_errorbars(
@final
def _get_subplots(self, fig: Figure) -> list[Axes]:
if Version(mpl.__version__) < Version("3.8"):
from matplotlib.axes import Subplot as Klass
Klass = mpl.axes.Subplot
else:
from matplotlib.axes import Axes as Klass
Klass = mpl.axes.Axes

return [
ax
Expand Down Expand Up @@ -1386,7 +1373,7 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool):
if c is not None and color is not None:
raise TypeError("Specify exactly one of `c` and `color`")
if c is None and color is None:
c_values = self.plt.rcParams["patch.facecolor"]
c_values = mpl.rcParams["patch.facecolor"]
elif color is not None:
c_values = color
elif color_by_categorical:
Expand All @@ -1411,12 +1398,10 @@ def _get_norm_and_cmap(self, c_values, color_by_categorical: bool):
cmap = None

if color_by_categorical and cmap is not None:
from matplotlib import colors

n_cats = len(self.data[c].cat.categories)
cmap = colors.ListedColormap([cmap(i) for i in range(cmap.N)])
cmap = mpl.colors.ListedColormap([cmap(i) for i in range(cmap.N)])
bounds = np.linspace(0, n_cats, n_cats + 1)
norm = colors.BoundaryNorm(bounds, cmap.N)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
# TODO: warn that we are ignoring self.norm if user specified it?
# Doesn't happen in any tests 2023-11-09
else:
Expand Down Expand Up @@ -1676,8 +1661,6 @@ def _update_stacker(cls, ax: Axes, stacking_id: int | None, values) -> None:
ax._stacker_neg_prior[stacking_id] += values # type: ignore[attr-defined]

def _post_plot_logic(self, ax: Axes, data) -> None:
from matplotlib.ticker import FixedLocator

def get_label(i):
if is_float(i) and i.is_integer():
i = int(i)
Expand All @@ -1691,7 +1674,7 @@ def get_label(i):
xticklabels = [get_label(x) for x in xticks]
# error: Argument 1 to "FixedLocator" has incompatible type "ndarray[Any,
# Any]"; expected "Sequence[float]"
ax.xaxis.set_major_locator(FixedLocator(xticks)) # type: ignore[arg-type]
ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(xticks)) # type: ignore[arg-type]
ax.set_xticklabels(xticklabels)

# If the index is an irregular time series, then by default
Expand Down
9 changes: 4 additions & 5 deletions pandas/plotting/_matplotlib/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import random
from typing import TYPE_CHECKING

from matplotlib import patches
import matplotlib.lines as mlines
import matplotlib as mpl
import numpy as np

from pandas.core.dtypes.missing import notna
Expand Down Expand Up @@ -129,7 +128,7 @@ def scatter_matrix(


def _get_marker_compat(marker):
if marker not in mlines.lineMarkers:
if marker not in mpl.lines.lineMarkers:
return "o"
return marker

Expand Down Expand Up @@ -190,10 +189,10 @@ def normalize(series):
)
ax.legend()

ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor="none"))
ax.add_patch(mpl.patches.Circle((0.0, 0.0), radius=1.0, facecolor="none"))

for xy, name in zip(s, df.columns):
ax.add_patch(patches.Circle(xy, radius=0.025, facecolor="gray"))
ax.add_patch(mpl.patches.Circle(xy, radius=0.025, facecolor="gray"))

if xy[0] < 0.0 and xy[1] < 0.0:
ax.text(
Expand Down
4 changes: 1 addition & 3 deletions pandas/plotting/_matplotlib/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,7 @@ def _get_colors_from_color_type(color_type: str, num_colors: int) -> list[Color]

def _get_default_colors(num_colors: int) -> list[Color]:
"""Get `num_colors` of default colors from matplotlib rc params."""
import matplotlib.pyplot as plt

colors = [c["color"] for c in plt.rcParams["axes.prop_cycle"]]
colors = [c["color"] for c in mpl.rcParams["axes.prop_cycle"]]
return colors[0:num_colors]


Expand Down
4 changes: 2 additions & 2 deletions pandas/plotting/_matplotlib/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def format_dateaxis(
default, changing the limits of the x axis will intelligently change
the positions of the ticks.
"""
from matplotlib import pylab
import matplotlib.pyplot as plt

# handle index specific formatting
# Note: DatetimeIndex does not use this
Expand Down Expand Up @@ -365,4 +365,4 @@ def format_dateaxis(
else:
raise TypeError("index type not supported")

pylab.draw_if_interactive()
plt.draw_if_interactive()
23 changes: 10 additions & 13 deletions pandas/plotting/_matplotlib/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from typing import TYPE_CHECKING
import warnings

from matplotlib import ticker
import matplotlib.table
import matplotlib as mpl
import numpy as np

from pandas.util._exceptions import find_stack_level
Expand Down Expand Up @@ -77,7 +76,7 @@ def table(

# error: Argument "cellText" to "table" has incompatible type "ndarray[Any,
# Any]"; expected "Sequence[Sequence[str]] | None"
return matplotlib.table.table(
return mpl.table.table(
ax,
cellText=cellText, # type: ignore[arg-type]
rowLabels=rowLabels,
Expand Down Expand Up @@ -327,10 +326,10 @@ def _remove_labels_from_axis(axis: Axis) -> None:

# set_visible will not be effective if
# minor axis has NullLocator and NullFormatter (default)
if isinstance(axis.get_minor_locator(), ticker.NullLocator):
axis.set_minor_locator(ticker.AutoLocator())
if isinstance(axis.get_minor_formatter(), ticker.NullFormatter):
axis.set_minor_formatter(ticker.FormatStrFormatter(""))
if isinstance(axis.get_minor_locator(), mpl.ticker.NullLocator):
axis.set_minor_locator(mpl.ticker.AutoLocator())
if isinstance(axis.get_minor_formatter(), mpl.ticker.NullFormatter):
axis.set_minor_formatter(mpl.ticker.FormatStrFormatter(""))
for t in axis.get_minorticklabels():
t.set_visible(False)

Expand Down Expand Up @@ -455,17 +454,15 @@ def set_ticks_props(
ylabelsize: int | None = None,
yrot=None,
):
import matplotlib.pyplot as plt

for ax in flatten_axes(axes):
if xlabelsize is not None:
plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
mpl.artist.setp(ax.get_xticklabels(), fontsize=xlabelsize)
if xrot is not None:
plt.setp(ax.get_xticklabels(), rotation=xrot)
mpl.artist.setp(ax.get_xticklabels(), rotation=xrot)
if ylabelsize is not None:
plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
mpl.artist.setp(ax.get_yticklabels(), fontsize=ylabelsize)
if yrot is not None:
plt.setp(ax.get_yticklabels(), rotation=yrot)
mpl.artist.setp(ax.get_yticklabels(), rotation=yrot)
return axes


Expand Down
4 changes: 1 addition & 3 deletions pandas/tests/io/formats/style/test_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
Series,
)

pytest.importorskip("matplotlib")
mpl = pytest.importorskip("matplotlib")
pytest.importorskip("jinja2")

import matplotlib as mpl

from pandas.io.formats.style import Styler

pytestmark = pytest.mark.usefixtures("mpl_cleanup")
Expand Down
Loading