Skip to content

Commit

Permalink
DEP: drop support for matplotlib<3.5
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Jun 19, 2023
1 parent 0a299d8 commit abe4d17
Show file tree
Hide file tree
Showing 12 changed files with 33 additions and 200 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ dependencies = [
"cmyt>=1.1.2",
"ewah-bool-utils>=1.0.2",
"ipywidgets>=8.0.0",
"matplotlib!=3.4.2,>=3.2", # keep in sync with tests/windows_conda_requirements.txt
"matplotlib>=3.5", # keep in sync with tests/windows_conda_requirements.txt
"more-itertools>=8.4",
"numpy>=1.17.5",
"packaging>=20.9",
Expand Down Expand Up @@ -208,7 +208,7 @@ minimal = [
"cmyt==1.1.2",
"ewah-bool-utils==1.0.2",
"ipywidgets==8.0.0",
"matplotlib==3.2",
"matplotlib==3.5",
"more-itertools==8.4",
"numpy==1.17.5",
"packaging==20.9",
Expand Down
2 changes: 1 addition & 1 deletion tests/windows_conda_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy>=1.19.4
cartopy>=0.21.0
h5py>=3.1.0
matplotlib!=3.4.2,>=3.2 # keep in sync with pyproject.toml
matplotlib>=3.5 # keep in sync with pyproject.toml
scipy>=1.5.0
6 changes: 2 additions & 4 deletions yt/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import numpy as np
from more_itertools import always_iterable, collapse, first
from packaging.version import Version

from yt._maintenance.deprecation import issue_deprecation_warning
from yt.config import ytcfg
Expand Down Expand Up @@ -1002,11 +1001,10 @@ def matplotlib_style_context(style="yt.default", after_reset=False):
"""
# FUTURE: this function should be deprecated in favour of matplotlib.style.context
# after support for matplotlib 3.6 and older versions is dropped.
import matplotlib as mpl
import matplotlib.style

from yt.visualization._commons import MPL_VERSION

if style == "yt.default" and MPL_VERSION < Version("3.7"):
if style == "yt.default" and mpl.__version_info__ < (3, 7):
style = importlib_resources.files("yt") / "default.mplstyle"

return matplotlib.style.context(style, after_reset=after_reset)
Expand Down
81 changes: 5 additions & 76 deletions yt/visualization/_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
import sys
import warnings
from functools import wraps
from importlib.metadata import version
from typing import TYPE_CHECKING, Optional, Type, TypeVar

import matplotlib as mpl
import numpy as np
from matplotlib.ticker import SymmetricalLogLocator
from more_itertools import always_iterable
from packaging.version import Version

from yt.config import ytcfg

Expand All @@ -27,15 +24,13 @@
from matplotlib.backend_bases import FigureCanvasBase


MPL_VERSION = Version(version("matplotlib"))

_yt_style = mpl.rc_params_from_file(
importlib_resources.files("yt") / "default.mplstyle", use_default_template=False
)
DEFAULT_FONT_PROPERTIES = {"family": _yt_style["font.family"][0]}

if MPL_VERSION >= Version("3.4"):
DEFAULT_FONT_PROPERTIES["math_fontfamily"] = _yt_style["mathtext.fontset"]
DEFAULT_FONT_PROPERTIES = {
"family": _yt_style["font.family"][0],
"math_fontfamily": _yt_style["mathtext.fontset"],
}
del _yt_style


Expand Down Expand Up @@ -228,78 +223,12 @@ def _swap_arg_pair_order(*args):
return tuple(new_args)


def get_symlog_majorticks(linthresh: float, vmin: float, vmax: float) -> np.ndarray:
"""calculate positions of major ticks on a log colorbar
Parameters
----------
linthresh : float
the threshold for the linear region
vmin : float
the minimum value in the colorbar
vmax : float
the maximum value in the colorbar
"""
if MPL_VERSION >= Version("3.5"):
raise RuntimeError("get_symlog_majorticks is not needed with matplotlib>=3.5")

if vmin >= 0.0:
yticks = [vmin] + list(
10
** np.arange(
np.rint(np.log10(linthresh)),
np.ceil(np.log10(1.1 * vmax)),
)
)
elif vmax <= 0.0:
if MPL_VERSION >= Version("3.5.0b"):
offset = 0
else:
offset = 1

yticks = list(
-(
10
** np.arange(
np.floor(np.log10(-vmin)),
np.rint(np.log10(linthresh)) - offset,
-1,
)
)
) + [vmax]
else:
yticks = (
list(
-(
10
** np.arange(
np.floor(np.log10(-vmin)),
np.rint(np.log10(linthresh)) - 1,
-1,
)
)
)
+ [0]
+ list(
10
** np.arange(
np.rint(np.log10(linthresh)),
np.ceil(np.log10(1.1 * vmax)),
)
)
)
if yticks[-1] > vmax:
yticks.pop()
return np.array(yticks)


class _MPL38_SymmetricalLogLocator(SymmetricalLogLocator):
# Backporting behaviour from matplotlib 3.8 (in development at the time of writing)
# see https://github.com/matplotlib/matplotlib/pull/25970

def __init__(self, *args, **kwargs):
if MPL_VERSION >= Version("3.8"):
if mpl.__version_info__ >= (3, 8):
raise RuntimeError(
"_MPL38_SymmetricalLogLocator is not needed with matplotlib>=3.8"
)
Expand Down
6 changes: 3 additions & 3 deletions yt/visualization/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from numbers import Real
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union

import matplotlib as mpl
import numpy as np
import unyt as un
from matplotlib.colors import Colormap, LogNorm, Normalize, SymLogNorm
Expand All @@ -10,7 +11,6 @@
from yt._typing import Quantity, Unit
from yt.config import ytcfg
from yt.funcs import get_brewer_cmap, is_sequence, mylog
from yt.visualization.color_maps import _get_cmap


class NormHandler:
Expand Down Expand Up @@ -438,14 +438,14 @@ def draw_minorticks(self, newval) -> None:

@property
def cmap(self) -> Colormap:
return self._cmap or _get_cmap(ytcfg.get("yt", "default_colormap"))
return self._cmap or mpl.colormaps[ytcfg.get("yt", "default_colormap")]

@cmap.setter
def cmap(self, newval) -> None:
if isinstance(newval, Colormap) or newval is None:
self._cmap = newval
elif isinstance(newval, str):
self._cmap = _get_cmap(newval)
self._cmap = mpl.colormaps[newval]
elif is_sequence(newval):
# tuple colormaps are from palettable (or brewer2mpl)
self._cmap = get_brewer_cmap(newval)
Expand Down
59 changes: 3 additions & 56 deletions yt/visualization/base_plot_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,17 @@

import matplotlib
import numpy as np
from matplotlib.colors import LogNorm, Normalize, SymLogNorm
from matplotlib.ticker import LogFormatterMathtext
from packaging.version import Version

from yt.funcs import get_interactivity, is_sequence, matplotlib_style_context, mylog
from yt.visualization._handlers import ColorbarHandler, NormHandler

from ._commons import (
MPL_VERSION,
get_canvas,
get_symlog_majorticks,
validate_image_name,
)

if MPL_VERSION >= Version("3.8"):
if matplotlib.__version_info__ >= (3, 8):
from matplotlib.ticker import SymmetricalLogLocator
else:
from ._commons import _MPL38_SymmetricalLogLocator as SymmetricalLogLocator
Expand Down Expand Up @@ -164,8 +160,6 @@ def save(self, name, mpl_kwargs=None, canvas=None):

if mpl_kwargs is None:
mpl_kwargs = {}
if "papertype" not in mpl_kwargs and MPL_VERSION < Version("3.3.0"):
mpl_kwargs["papertype"] = "auto"

name = validate_image_name(name)

Expand Down Expand Up @@ -312,56 +306,9 @@ def _init_image(self, data, extent, aspect):
interpolation="nearest",
transform=transform,
)
self._set_axes(norm)
self._set_axes()

def _set_axes(self, norm: Normalize) -> None:
if MPL_VERSION >= Version("3.5"):
self._set_axes_mpl_ge35()
else:
self._set_axes_mpl_lt35(norm)

def _set_axes_mpl_lt35(self, norm: Normalize) -> None:
# bug-for-bug backward-compatibility for matplotlib older than 3.5
if isinstance(norm, SymLogNorm):
formatter = LogFormatterMathtext(linthresh=norm.linthresh)
self.cb = self.figure.colorbar(self.image, self.cax, format=formatter)
self.cb.set_ticks(
get_symlog_majorticks(
linthresh=norm.linthresh, vmin=norm.vmin, vmax=norm.vmax
)
)
else:
self.cb = self.figure.colorbar(self.image, self.cax)
self.cax.tick_params(which="both", axis="y", direction="in")

fmt_kwargs = {"style": "scientific", "scilimits": (-2, 3), "useMathText": True}
self.image.axes.ticklabel_format(**fmt_kwargs)
if type(norm) not in (LogNorm, SymLogNorm):
try:
self.cb.ax.ticklabel_format(**fmt_kwargs)
except AttributeError:
warnings.warn(
"Failed to format colorbar ticks. "
"This is expected when using the set_norm method "
"with some matplotlib classes (e.g. TwoSlopeNorm) "
"with matplotlib versions older than 3.5\n"
"Please try upgrading matplotlib to a more recent version. "
"If the problem persists, please file a report to "
"https://github.com/yt-project/yt/issues/new",
stacklevel=2,
)

if self.colorbar_handler.draw_minorticks:
if not isinstance(norm, SymLogNorm):
# no known working method to draw symlog minor ticks
# see https://github.com/yt-project/yt/issues/3535
self.cax.minorticks_on()
else:
self.cax.minorticks_off()

self.image.axes.set_facecolor(self.colorbar_handler.background_color)

def _set_axes_mpl_ge35(self) -> None:
def _set_axes(self) -> None:
fmt_kwargs = {"style": "scientific", "scilimits": (-2, 3), "useMathText": True}
self.image.axes.ticklabel_format(**fmt_kwargs)
self.image.axes.set_facecolor(self.colorbar_handler.background_color)
Expand Down
40 changes: 8 additions & 32 deletions yt/visualization/color_maps.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,14 @@
from copy import deepcopy
from typing import Optional, Tuple, Union
from typing import Tuple, Union

import cmyt # noqa: F401
import matplotlib as mpl
import numpy as np
from matplotlib.colors import Colormap, LinearSegmentedColormap
from packaging.version import Version
from matplotlib.colors import LinearSegmentedColormap

from yt.funcs import get_brewer_cmap
from yt.utilities.logger import ytLogger as mylog

from . import _colormap_data as _cm
from ._commons import MPL_VERSION


# wrap matplotlib.cm API, use non-deprecated API when available
def _get_cmap(name: str) -> Colormap:
if MPL_VERSION >= Version("3.5"):
return mpl.colormaps[name]
else:
# deprecated API
return mpl.cm.get_cmap(name)


def _register_cmap(cmap: Colormap, *, name: Optional[str] = None) -> None:
if MPL_VERSION >= Version("3.5"):
mpl.colormaps.register(cmap, name=name)
else:
# deprecated API
mpl.cm.register_cmap(name=name, cmap=cmap)


yt_colormaps = {}

Expand All @@ -40,7 +19,7 @@ def add_colormap(name, cdict):
"""
# Note: this function modifies the global variable 'yt_colormaps'
yt_colormaps[name] = LinearSegmentedColormap(name, cdict, 256)
_register_cmap(yt_colormaps[name], name=name)
mpl.colormaps.register(yt_colormaps[name], name=name)


# YTEP-0040 backward compatibility layer
Expand Down Expand Up @@ -69,14 +48,11 @@ def register_yt_colormaps_from_cmyt():
"""

for hist_name, alias in _HISTORICAL_ALIASES.items():
if MPL_VERSION >= Version("3.4.0"):
cmap = _get_cmap(alias).copy()
else:
cmap = deepcopy(_get_cmap(alias))
cmap = mpl.colormaps[alias].copy()
cmap.name = hist_name
try:
_register_cmap(cmap=cmap)
_register_cmap(cmap=_get_cmap(hist_name).reversed())
mpl.colormaps.register(cmap=cmap)
mpl.colormaps.register(cmap=mpl.colormaps[hist_name].reversed())
except ValueError:
# Matplotlib 3.4.0 hard-forbids name collisions, but more recent versions
# will emit a warning instead, so we emulate this behaviour regardless.
Expand Down Expand Up @@ -114,7 +90,7 @@ def get_colormap_lut(cmap_id: Union[Tuple[str, str], str]):
if isinstance(cmap_id, tuple) and len(cmap_id) == 2:
cmap = get_brewer_cmap(cmap_id)
elif isinstance(cmap_id, str):
cmap = _get_cmap(cmap_id)
cmap = mpl.colormaps[cmap_id]
else:
raise TypeError(
"Expected a string or a 2-tuple of strings as a colormap id. "
Expand Down Expand Up @@ -201,7 +177,7 @@ def show_colormaps(subset="all", filename=None):
for i, m in enumerate(maps):
plt.subplot(1, l, i + 1)
plt.axis("off")
plt.imshow(a, aspect="auto", cmap=_get_cmap(m), origin="lower")
plt.imshow(a, aspect="auto", cmap=mpl.colormaps[m], origin="lower")
plt.title(m, rotation=90, fontsize=10, verticalalignment="bottom")
if filename is not None:
plt.savefig(filename, dpi=100, facecolor="gray")
Expand Down
2 changes: 1 addition & 1 deletion yt/visualization/profile_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,6 +1555,6 @@ def _init_image(
shading=self._shading,
)

self._set_axes(norm)
self._set_axes()
self.axes.set_xscale(x_scale)
self.axes.set_yscale(y_scale)
Loading

0 comments on commit abe4d17

Please sign in to comment.