diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5b78f956..563b30a5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,10 @@ ci: - skip: [flake8] + skip: [flake8, taplo-lint] repos: # Formatters - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-ast - id: check-builtin-literals @@ -16,14 +16,14 @@ repos: args: [--remove] - repo: https://github.com/PyCQA/docformatter - rev: v1.7.5 + rev: 06907d0 hooks: - id: docformatter additional_dependencies: [tomli] args: [--in-place, --config, ./pyproject.toml] - repo: https://github.com/tox-dev/pyproject-fmt - rev: "2.2.4" + rev: "v2.5.0" hooks: - id: pyproject-fmt @@ -36,17 +36,18 @@ repos: rev: v4.0.0-alpha.8 # Use the sha or tag you want to point at hooks: - id: prettier + additional_dependencies: ["prettier@3.3.3"] # Notebook tools - repo: https://github.com/kynan/nbstripout - rev: 0.7.1 + rev: 0.8.1 hooks: - id: nbstripout args: [--drop-empty-cells] - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.6.7 + rev: v0.8.3 hooks: - id: ruff name: "ruff sort imports notebooks" @@ -72,7 +73,7 @@ repos: # Linters - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.13.0 hooks: - id: mypy exclude: ^docs @@ -88,7 +89,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.6.7 + rev: v0.8.3 hooks: - id: ruff name: "ruff sort imports" @@ -127,6 +128,12 @@ repos: args: ["--ignore-words-list=doas"] - repo: https://github.com/rhysd/actionlint - rev: "v1.7.2" + rev: "v1.7.4" hooks: - id: actionlint + + - repo: https://github.com/ComPWA/taplo-pre-commit + rev: v0.9.3 + hooks: + - id: taplo-format + - id: taplo-lint diff --git a/.ruff.toml b/.ruff.toml index 2da1d052..92552c6c 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -36,7 +36,7 @@ select = [ "RSE", # flake8-raise "RET", # flake8-return "SIM", # flake8-simplify - "TCH", # flake8-type-checking + "TC", # flake8-type-checking "ARG", # flake8-unused-arguments "PTH", # flake8-use-pathlib "ERA", # eradicate diff --git a/pyglotaran_extras/__init__.py b/pyglotaran_extras/__init__.py index 7f7f4c7b..336656d4 100644 --- a/pyglotaran_extras/__init__.py +++ b/pyglotaran_extras/__init__.py @@ -36,38 +36,37 @@ from pyglotaran_extras.plotting.utils import add_subplot_labels __all__ = [ + "CONFIG", + "PerFunctionPlotConfig", + "add_subplot_labels", + "create_config_schema", "load_data", - "setup_case_study", "plot_coherent_artifact", "plot_concentrations", + "plot_config_context", + "plot_das", "plot_data_overview", "plot_doas", + "plot_fitted_traces", "plot_guidance", "plot_irf_dispersion_center", - "plot_overview", - "plot_simple_overview", - "plot_residual", - "plot_das", - "plot_norm_das", - "plot_norm_sas", - "plot_sas", - "plot_spectra", "plot_lsv_data", "plot_lsv_residual", + "plot_norm_das", + "plot_norm_sas", + "plot_overview", + "plot_residual", "plot_rsv_data", "plot_rsv_residual", + "plot_sas", + "plot_simple_overview", + "plot_spectra", "plot_sv_data", "plot_sv_residual", "plot_svd", - "plot_fitted_traces", "select_plot_wavelengths", - "add_subplot_labels", - # Config - "PerFunctionPlotConfig", - "plot_config_context", + "setup_case_study", "use_plot_config", - "create_config_schema", - "CONFIG", ] __version__ = "0.7.4.dev0" diff --git a/pyglotaran_extras/io/__init__.py b/pyglotaran_extras/io/__init__.py index 251ae00f..dbb4030c 100644 --- a/pyglotaran_extras/io/__init__.py +++ b/pyglotaran_extras/io/__init__.py @@ -5,4 +5,4 @@ from pyglotaran_extras.io.load_data import load_data from pyglotaran_extras.io.setup_case_study import setup_case_study -__all__ = ["setup_case_study", "load_data"] +__all__ = ["load_data", "setup_case_study"] diff --git a/pyglotaran_extras/plotting/plot_concentrations.py b/pyglotaran_extras/plotting/plot_concentrations.py index dfa2816e..07e5316a 100644 --- a/pyglotaran_extras/plotting/plot_concentrations.py +++ b/pyglotaran_extras/plotting/plot_concentrations.py @@ -13,13 +13,13 @@ if TYPE_CHECKING: import xarray as xr from cycler import Cycler - from matplotlib.axis import Axis + from matplotlib.axes import Axes @use_plot_config(exclude_from_config=("cycler",)) def plot_concentrations( res: xr.Dataset, - ax: Axis, + ax: Axes, center_λ: float | None, linlog: bool = False, linthresh: float = 1, @@ -34,8 +34,8 @@ def plot_concentrations( ---------- res : xr.Dataset Result dataset from a pyglotaran optimization. - ax : Axis - Axis to plot the traces on + ax : Axes + Axes to plot the traces on center_λ : float | None Center wavelength (λ in nm) linlog : bool diff --git a/pyglotaran_extras/plotting/plot_data.py b/pyglotaran_extras/plotting/plot_data.py index 40072f1c..408c388b 100644 --- a/pyglotaran_extras/plotting/plot_data.py +++ b/pyglotaran_extras/plotting/plot_data.py @@ -3,11 +3,11 @@ from __future__ import annotations from typing import TYPE_CHECKING -from typing import cast import matplotlib.pyplot as plt +import numpy as np from glotaran.io.prepare_dataset import add_svd_to_dataset -from matplotlib.axis import Axis +from matplotlib.axes import Axes from pyglotaran_extras.config.plot_config import use_plot_config from pyglotaran_extras.io.load_data import load_data @@ -27,8 +27,8 @@ import xarray as xr from cycler import Cycler from glotaran.project.result import Result + from matplotlib.axes import Axes from matplotlib.figure import Figure - from matplotlib.pyplot import Axes from pyglotaran_extras.types import DatasetConvertible @@ -48,7 +48,7 @@ def plot_data_overview( vmax: float | None = None, svd_cycler: Cycler | None = PlotStyle().cycler, use_svd_number: bool = False, -) -> tuple[Figure, Axes] | tuple[Figure, Axis]: +) -> tuple[Figure, np.ndarray[Axes]] | tuple[Figure, Axes]: """Plot data as filled contour plot and SVD components. Parameters @@ -86,7 +86,7 @@ def plot_data_overview( Returns ------- - tuple[Figure, Axes] | tuple[Figure, Axis] + tuple[Figure, np.ndarray[Axes]] | tuple[Figure, Axes] Figure and axes which can then be refined by the user. """ dataset = load_data(dataset, _stacklevel=3) @@ -103,11 +103,11 @@ def plot_data_overview( ) fig = plt.figure(figsize=figsize) - data_ax = cast(Axis, plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3, fig=fig)) + data_ax = plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3, fig=fig) fig.subplots_adjust(hspace=0.5, wspace=0.25) - lsv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 0), fig=fig)) - sv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 1), fig=fig)) - rsv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 2), fig=fig)) + lsv_ax = plt.subplot2grid((4, 3), (3, 0), fig=fig) + sv_ax = plt.subplot2grid((4, 3), (3, 1), fig=fig) + rsv_ax = plt.subplot2grid((4, 3), (3, 2), fig=fig) if len(data.time) > 1: data.plot(x="time", ax=data_ax, center=False, cmap=cmap, vmin=vmin, vmax=vmax) @@ -146,7 +146,7 @@ def plot_data_overview( if linlog: data_ax.set_xscale("symlog", linthresh=linthresh) data_ax.xaxis.set_minor_locator(MinorSymLogLocator(linthresh)) - return fig, (data_ax, lsv_ax, sv_ax, rsv_ax) + return fig, np.array((data_ax, lsv_ax, sv_ax, rsv_ax)) def _plot_single_trace( @@ -157,7 +157,7 @@ def _plot_single_trace( linlog: bool = False, linthresh: float = 1, figsize: tuple[float, float] = (15, 10), -) -> tuple[Figure, Axis]: +) -> tuple[Figure, Axes]: """Plot single trace data in case ``plot_data_overview`` gets passed ingle trace data. Parameters @@ -178,7 +178,7 @@ def _plot_single_trace( Returns ------- - tuple[Figure, Axis] + tuple[Figure, Axes] Figure and axis which can then be refined by the user. """ fig, ax = plt.subplots(1, 1, figsize=figsize) diff --git a/pyglotaran_extras/plotting/plot_irf_dispersion_center.py b/pyglotaran_extras/plotting/plot_irf_dispersion_center.py index 7f015e0d..dc74d290 100644 --- a/pyglotaran_extras/plotting/plot_irf_dispersion_center.py +++ b/pyglotaran_extras/plotting/plot_irf_dispersion_center.py @@ -7,8 +7,6 @@ import matplotlib.pyplot as plt import xarray as xr -from matplotlib.axis import Axis -from matplotlib.figure import Figure from pyglotaran_extras.config.plot_config import use_plot_config from pyglotaran_extras.io.utils import result_dataset_mapping @@ -20,6 +18,8 @@ from typing import Literal from cycler import Cycler + from matplotlib.axes import Axes + from matplotlib.figure import Figure from pyglotaran_extras.types import ResultLike @@ -27,19 +27,19 @@ @use_plot_config(exclude_from_config=("cycler", "ax")) def plot_irf_dispersion_center( result: ResultLike, - ax: Axis | None = None, + ax: Axes | None = None, figsize: tuple[float, float] = (12, 8), cycler: Cycler | None = PlotStyle().cycler, irf_location: float | None = None, -) -> tuple[Figure, Axis] | None: +) -> tuple[Figure, Axes] | None: """Plot the IRF dispersion center over the spectral dimension for one or multiple datasets. Parameters ---------- result : ResultLike Data structure which can be converted to a mapping. - ax : Axis | None - Axis to plot on. Defaults to None which means that a new figure and axis will be created. + ax : Axes | None + Axes to plot on. Defaults to None which means that a new figure and axis will be created. figsize : tuple[float, float] Size of the figure (N, M) in inches. Defaults to (12, 8). cycler : Cycler | None @@ -50,42 +50,40 @@ def plot_irf_dispersion_center( Returns ------- - tuple[Figure, Axis] | None - Figure object which contains the plots and the Axis, + tuple[Figure, Axes] | None + Figure object which contains the plots and the Axes, if ``ax`` is not None nothing will be returned. """ result_map = result_dataset_mapping(result) if ax is None: - fig, axis = cast(tuple[Figure, Axis], plt.subplots(1, figsize=figsize)) - else: - axis = ax + fig, ax = plt.subplots(1, figsize=figsize) for dataset_name, dataset in result_map.items(): _plot_irf_dispersion_center( dataset, - axis, + ax, spectral_axis="x", cycler=cycler, label=dataset_name, irf_location=irf_location, ) - axis.legend() + ax.legend() if ax is None: fig.suptitle("Instrument Response Functions", fontsize=16) - return fig, axis + return fig, ax return None def _plot_irf_dispersion_center( res: xr.Dataset, - ax: Axis, + ax: Axes, *, spectral_axis: Literal["x", "y"] = "x", cycler: Cycler | None = PlotStyle().cycler, label: str = "IRF", irf_location: float | None = None, ) -> None: - """Plot the IRF dispersion center on an Axis ``ax``. + """Plot the IRF dispersion center on an Axes ``ax``. This is an internal function to be used by higher level functions. @@ -93,8 +91,8 @@ def _plot_irf_dispersion_center( ---------- res : xr.Dataset Dataset containing the IRF data. - ax : Axis - Axis to plot on. + ax : Axes + Axes to plot on. spectral_axis : Literal["x", "y"] Direct of the spectral axis in the plot. Defaults to "x" cycler : Cycler | None diff --git a/pyglotaran_extras/plotting/plot_residual.py b/pyglotaran_extras/plotting/plot_residual.py index afed8cf3..7763af35 100644 --- a/pyglotaran_extras/plotting/plot_residual.py +++ b/pyglotaran_extras/plotting/plot_residual.py @@ -16,13 +16,13 @@ if TYPE_CHECKING: import xarray as xr from cycler import Cycler - from matplotlib.axis import Axis + from matplotlib.axes import Axes @use_plot_config(exclude_from_config=("cycler",)) def plot_residual( res: xr.Dataset, - ax: Axis, + ax: Axes, linlog: bool = False, linthresh: float = 1, show_data: bool | None = False, @@ -36,8 +36,8 @@ def plot_residual( ---------- res : xr.Dataset Result dataset - ax : Axis - Axis to plot on. + ax : Axes + Axes to plot on. linlog : bool Whether to use 'symlog' scale or not. Defaults to False. linthresh : float diff --git a/pyglotaran_extras/plotting/plot_spectra.py b/pyglotaran_extras/plotting/plot_spectra.py index b55b43ac..e249bd27 100644 --- a/pyglotaran_extras/plotting/plot_spectra.py +++ b/pyglotaran_extras/plotting/plot_spectra.py @@ -14,8 +14,7 @@ if TYPE_CHECKING: import xarray as xr from cycler import Cycler - from matplotlib.axis import Axis - from matplotlib.pyplot import Axes + from matplotlib.axes import Axes from pyglotaran_extras.types import UnsetType @@ -23,7 +22,7 @@ @use_plot_config(exclude_from_config=("cycler", "das_cycler")) def plot_spectra( res: xr.Dataset, - axes: Axes, + axes: np.ndarray[(2, 2), Axes], cycler: Cycler | None = PlotStyle().cycler, show_zero_line: bool = True, das_cycler: Cycler | None | UnsetType = Unset, @@ -34,7 +33,7 @@ def plot_spectra( ---------- res : xr.Dataset Result dataset - axes : Axes + axes : np.ndarray[(2, 2), Axes] Axes to plot the spectra on (needs to be at least 2x2). cycler : Cycler | None Plot style cycler to use. Defaults to PlotStyle().cycler. @@ -56,7 +55,7 @@ def plot_spectra( @use_plot_config(exclude_from_config=("cycler",)) def plot_sas( res: xr.Dataset, - ax: Axis, + ax: Axes, title: str = "SAS", cycler: Cycler | None = PlotStyle().cycler, show_zero_line: bool = True, @@ -67,8 +66,8 @@ def plot_sas( ---------- res : xr.Dataset Result dataset - ax : Axis - Axis to plot on. + ax : Axes + Axes to plot on. title : str Title of the plot. Defaults to "SAS". cycler : Cycler | None @@ -92,7 +91,7 @@ def plot_sas( @use_plot_config(exclude_from_config=("cycler",)) def plot_norm_sas( res: xr.Dataset, - ax: Axis, + ax: Axes, title: str = "norm SAS", cycler: Cycler | None = PlotStyle().cycler, show_zero_line: bool = True, @@ -103,8 +102,8 @@ def plot_norm_sas( ---------- res : xr.Dataset Result dataset - ax : Axis - Axis to plot on. + ax : Axes + Axes to plot on. title : str Title of the plot. Defaults to "norm SAS". cycler : Cycler | None @@ -130,7 +129,7 @@ def plot_norm_sas( @use_plot_config(exclude_from_config=("cycler",)) def plot_das( res: xr.Dataset, - ax: Axis, + ax: Axes, title: str = "DAS", cycler: Cycler | None = PlotStyle().cycler, show_zero_line: bool = True, @@ -141,8 +140,8 @@ def plot_das( ---------- res : xr.Dataset Result dataset - ax : Axis - Axis to plot on. + ax : Axes + Axes to plot on. title : str Title of the plot. Defaults to "DAS". cycler : Cycler | None @@ -166,7 +165,7 @@ def plot_das( @use_plot_config(exclude_from_config=("cycler",)) def plot_norm_das( res: xr.Dataset, - ax: Axis, + ax: Axes, title: str = "norm DAS", cycler: Cycler | None = PlotStyle().cycler, show_zero_line: bool = True, @@ -177,8 +176,8 @@ def plot_norm_das( ---------- res : xr.Dataset Result dataset - ax : Axis - Axis to plot on. + ax : Axes + Axes to plot on. title : str Title of the plot. Defaults to "norm DAS". cycler : Cycler | None diff --git a/pyglotaran_extras/plotting/plot_svd.py b/pyglotaran_extras/plotting/plot_svd.py index ada8a441..42f54584 100644 --- a/pyglotaran_extras/plotting/plot_svd.py +++ b/pyglotaran_extras/plotting/plot_svd.py @@ -18,16 +18,16 @@ if TYPE_CHECKING: from collections.abc import Sequence + import numpy as np import xarray as xr from cycler import Cycler - from matplotlib.axis import Axis - from matplotlib.pyplot import Axes + from matplotlib.axes import Axes @use_plot_config(exclude_from_config=("cycler",)) def plot_svd( res: xr.Dataset, - axes: Axes, + axes: np.ndarray[(2, 3), Axes], linlog: bool = False, linthresh: float = 1, cycler: Cycler | None = PlotStyle().cycler, @@ -44,7 +44,7 @@ def plot_svd( ---------- res : xr.Dataset Result dataset - axes : Axes + axes : np.ndarray[(2, 3), Axes] Axes to plot the SVDs on (needs to be at least 2x3). linlog : bool Whether to use 'symlog' scale or not. Defaults to False. @@ -120,7 +120,7 @@ def plot_svd( @use_plot_config(exclude_from_config=("cycler",)) def plot_lsv_data( res: xr.Dataset, - ax: Axis, + ax: Axes, indices: Sequence[int] = tuple(range(4)), linlog: bool = False, linthresh: float = 1, @@ -135,8 +135,8 @@ def plot_lsv_data( ---------- res : xr.Dataset Result dataset - ax : Axis - Axis to plot on. + ax : Axes + Axes to plot on. indices : Sequence[int] Indices of the singular vector to plot. Defaults to tuple(range(4)). linlog : bool @@ -175,7 +175,7 @@ def plot_lsv_data( @use_plot_config(exclude_from_config=("cycler",)) def plot_rsv_data( res: xr.Dataset, - ax: Axis, + ax: Axes, indices: Sequence[int] = tuple(range(4)), cycler: Cycler | None = PlotStyle().cycler, show_legend: bool = True, @@ -188,8 +188,8 @@ def plot_rsv_data( ---------- res : xr.Dataset Result dataset - ax : Axis - Axis to plot on. + ax : Axes + Axes to plot on. indices : Sequence[int] Indices of the singular vector to plot. Defaults to tuple(range(4)). cycler : Cycler | None @@ -220,7 +220,7 @@ def plot_rsv_data( @use_plot_config(exclude_from_config=("cycler",)) def plot_sv_data( res: xr.Dataset, - ax: Axis, + ax: Axes, indices: Sequence[int] = tuple(range(10)), cycler: Cycler | None | UnsetType = Unset, use_svd_number: bool = False, @@ -231,8 +231,8 @@ def plot_sv_data( ---------- res : xr.Dataset Result dataset - ax : Axis - Axis to plot on. + ax : Axes + Axes to plot on. indices : Sequence[int] Indices of the singular vector to plot. Defaults to tuple(range(10)). cycler : Cycler | None | UnsetType @@ -244,7 +244,7 @@ def plot_sv_data( if cycler is not Unset: warn_deprecated( deprecated_qual_name_usage="'cycler' argument in 'plot_sv_data'", - new_qual_name_usage="matplotlib on the axis directly", + new_qual_name_usage="matplotlib on the Axes directly", to_be_removed_in_version="0.9.0", ) dSV = res.data_singular_values # noqa: N806 @@ -263,7 +263,7 @@ def plot_sv_data( @use_plot_config(exclude_from_config=("cycler",)) def plot_lsv_residual( res: xr.Dataset, - ax: Axis, + ax: Axes, indices: Sequence[int] = tuple(range(2)), linlog: bool = False, linthresh: float = 1, @@ -278,8 +278,8 @@ def plot_lsv_residual( ---------- res : xr.Dataset Result dataset - ax : Axis - Axis to plot on. + ax : Axes + Axes to plot on. indices : Sequence[int] Indices of the singular vector to plot. Defaults to tuple(range(4)). linlog : bool @@ -323,7 +323,7 @@ def plot_lsv_residual( @use_plot_config(exclude_from_config=("cycler",)) def plot_rsv_residual( res: xr.Dataset, - ax: Axis, + ax: Axes, indices: Sequence[int] = tuple(range(2)), cycler: Cycler | None = PlotStyle().cycler, show_legend: bool = True, @@ -336,8 +336,8 @@ def plot_rsv_residual( ---------- res : xr.Dataset Result dataset - ax : Axis - Axis to plot on. + ax : Axes + Axes to plot on. indices : Sequence[int] Indices of the singular vector to plot. Defaults to tuple(range(4)). cycler : Cycler | None @@ -372,7 +372,7 @@ def plot_rsv_residual( @use_plot_config(exclude_from_config=("cycler",)) def plot_sv_residual( res: xr.Dataset, - ax: Axis, + ax: Axes, indices: Sequence[int] = tuple(range(10)), cycler: Cycler | None | UnsetType = Unset, use_svd_number: bool = False, @@ -383,8 +383,8 @@ def plot_sv_residual( ---------- res : xr.Dataset Result dataset - ax : Axis - Axis to plot on. + ax : Axes + Axes to plot on. indices : Sequence[int] Indices of the singular vector to plot. Defaults to tuple(range(10)). cycler : Cycler | None | UnsetType @@ -420,7 +420,7 @@ def _plot_svd_vectors( vector_data: xr.DataArray, indices: Sequence[int], sv_index_dim: str, - ax: Axis, + ax: Axes, show_legend: bool, irf_location: float | None, use_svd_number: bool = False, @@ -435,8 +435,8 @@ def _plot_svd_vectors( Indices of the singular vector to plot. sv_index_dim : str Name of the singular value index dimension. - ax : Axis - Axis to plot on. + ax : Axes + Axes to plot on. show_legend : bool Whether or not to show the legend. irf_location : float | None diff --git a/pyglotaran_extras/plotting/plot_traces.py b/pyglotaran_extras/plotting/plot_traces.py index bb215c1e..421b4035 100644 --- a/pyglotaran_extras/plotting/plot_traces.py +++ b/pyglotaran_extras/plotting/plot_traces.py @@ -3,11 +3,13 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import Any from warnings import warn import matplotlib.pyplot as plt from pyglotaran_extras.config.plot_config import use_plot_config +from pyglotaran_extras.deprecation import warn_deprecated from pyglotaran_extras.io.utils import result_dataset_mapping from pyglotaran_extras.plotting.style import PlotStyle from pyglotaran_extras.plotting.utils import MinorSymLogLocator @@ -18,25 +20,27 @@ from pyglotaran_extras.plotting.utils import extract_irf_location from pyglotaran_extras.plotting.utils import get_next_cycler_color from pyglotaran_extras.plotting.utils import select_plot_wavelengths +from pyglotaran_extras.types import Unset +from pyglotaran_extras.types import UnsetType -__all__ = ["select_plot_wavelengths", "plot_fitted_traces"] +__all__ = ["plot_fitted_traces", "select_plot_wavelengths"] if TYPE_CHECKING: from collections.abc import Iterable + import numpy as np from cycler import Cycler - from matplotlib.axis import Axis + from matplotlib.axes import Axes from matplotlib.figure import Figure - from matplotlib.pyplot import Axes from pyglotaran_extras.types import ResultLike -@use_plot_config(exclude_from_config=("cycler",)) +@use_plot_config(exclude_from_config=("cycler", "ax", "axis")) def plot_data_and_fits( result: ResultLike, wavelength: float, - axis: Axis, + ax: Axes | UnsetType = Unset, center_λ: float | None = None, main_irf_nr: int = 0, linlog: bool = False, @@ -46,8 +50,9 @@ def plot_data_and_fits( y_label: str = "a.u.", cycler: Cycler | None = PlotStyle().data_cycler_solid, show_zero_line: bool = True, + axis: UnsetType = Unset, ) -> None: - """Plot data and fits for a given ``wavelength`` on a given ``axis``. + """Plot data and fits for a given ``wavelength`` on a given ``ax``. If the wavelength isn't part of a dataset, that dataset will be skipped. @@ -57,8 +62,8 @@ def plot_data_and_fits( Data structure which can be converted to a mapping. wavelength : float Wavelength to plot data and fits for. - axis : Axis - Axis to plot the data and fits on. + ax : Axes | UnsetType + Axes to plot the data and fits on. Defaults to Unset. center_λ : float | None Center wavelength (λ in nm) main_irf_nr : int @@ -80,13 +85,31 @@ def plot_data_and_fits( Plot style cycler to use. Defaults to PlotStyle().data_cycler_solid. show_zero_line : bool Whether or not to add a horizontal line at zero. Defaults to True. + axis : UnsetType + Deprecated use ``ax`` instead. Defaults to Unset. See Also -------- plot_fit_overview + + Raises + ------ + ValueError + If ``ax`` was not provided, ``ax`` should be a required argument but to facilitate the + deprecation ``axis`` -> ``ax`` it has a default of ``Unset``. """ + if isinstance(ax, UnsetType) and not isinstance(axis, UnsetType): + warn_deprecated( + deprecated_qual_name_usage="axis", + new_qual_name_usage="ax", + to_be_removed_in_version="0.9.0", + ) + ax = axis + if isinstance(ax, UnsetType): + msg = "Required argument ``ax`` wasn't set." + raise ValueError(msg) result_map = result_dataset_mapping(result) - add_cycler_if_not_none(axis, cycler) + add_cycler_if_not_none(ax, cycler) for dataset_name in result_map: if result_map[dataset_name].coords["time"].to_numpy().size == 1: continue @@ -96,18 +119,18 @@ def plot_data_and_fits( scale = extract_dataset_scale(result_data, divide_by_scale) irf_loc = extract_irf_location(result_data, center_λ, main_irf_nr) result_data = result_data.assign_coords(time=result_data.coords["time"] - irf_loc) - (result_data.data / scale).plot(x="time", ax=axis, label=f"{dataset_name}_data") - (result_data.fitted_data / scale).plot(x="time", ax=axis, label=f"{dataset_name}_fit") + (result_data.data / scale).plot(x="time", ax=ax, label=f"{dataset_name}_data") + (result_data.fitted_data / scale).plot(x="time", ax=ax, label=f"{dataset_name}_fit") else: - [get_next_cycler_color(axis) for _ in range(2)] + [get_next_cycler_color(ax) for _ in range(2)] if linlog: - axis.set_xscale("symlog", linthresh=linthresh) - axis.xaxis.set_minor_locator(MinorSymLogLocator(linthresh)) + ax.set_xscale("symlog", linthresh=linthresh) + ax.xaxis.set_minor_locator(MinorSymLogLocator(linthresh)) if show_zero_line is True: - axis.axhline(0, color="k", linewidth=1) - axis.set_ylabel(y_label) + ax.axhline(0, color="k", linewidth=1) + ax.set_ylabel(y_label) if per_axis_legend is True: - axis.legend() + ax.legend() @use_plot_config(exclude_from_config=("cycler",)) @@ -126,7 +149,7 @@ def plot_fitted_traces( y_label: str = "a.u.", cycler: Cycler | None = PlotStyle().data_cycler_solid, show_zero_line: bool = True, -) -> tuple[Figure, Axes]: +) -> tuple[Figure, np.ndarray[Any, Axes]]: """Plot data and their fit in per wavelength plot grid. Parameters @@ -167,7 +190,7 @@ def plot_fitted_traces( Returns ------- - tuple[Figure, Axes] + tuple[Figure, np.ndarray[Any, Axes]] Figure and axes which can then be refined by the user. See Also @@ -192,11 +215,11 @@ def plot_fitted_traces( ), stacklevel=2, ) - for wavelength, axis in zip(wavelengths, axes.flatten(), strict=True): + for wavelength, ax in zip(wavelengths, axes.flatten(), strict=True): plot_data_and_fits( result=result_map, wavelength=wavelength, - axis=axis, + ax=ax, center_λ=center_λ, main_irf_nr=main_irf_nr, linlog=linlog, diff --git a/pyglotaran_extras/plotting/utils.py b/pyglotaran_extras/plotting/utils.py index d222e1a7..5f3dd4c0 100644 --- a/pyglotaran_extras/plotting/utils.py +++ b/pyglotaran_extras/plotting/utils.py @@ -14,20 +14,23 @@ import xarray as xr from matplotlib.ticker import Locator +from pyglotaran_extras.deprecation import warn_deprecated from pyglotaran_extras.inspect.utils import pretty_format_numerical_iterable from pyglotaran_extras.io.utils import result_dataset_mapping +from pyglotaran_extras.types import Unset +from pyglotaran_extras.types import UnsetType if TYPE_CHECKING: from collections.abc import Callable from collections.abc import Hashable from collections.abc import Mapping + from collections.abc import Sequence + from typing import Any from typing import Literal - from collecations.abs import Sequence from cycler import Cycler - from matplotlib.axis import Axis + from matplotlib.axes import Axes from matplotlib.figure import Figure - from matplotlib.pyplot import Axes from pyglotaran_extras.types import BuiltinSubPlotLabelFormatFunctionKey from pyglotaran_extras.types import CyclerColor @@ -187,7 +190,7 @@ def maximum_coordinate_range( return min(minima), max(maxima) -def add_unique_figure_legend(fig: Figure, axes: Axes) -> None: +def add_unique_figure_legend(fig: Figure, axes: Axes | np.ndarray[Any, Axes]) -> None: """Add a legend with unique elements sorted by label to a figure. The handles and labels are extracted from the ``axes`` @@ -196,7 +199,7 @@ def add_unique_figure_legend(fig: Figure, axes: Axes) -> None: ---------- fig : Figure Figure to add the legend to. - axes : Axes + axes : Axes | np.ndarray[Any, Axes] Axes plotted on the figure. See Also @@ -205,7 +208,7 @@ def add_unique_figure_legend(fig: Figure, axes: Axes) -> None: """ handles = [] labels = [] - for ax in axes.flatten(): + for ax in ensure_axes_array(axes).flatten(): ax_handles, ax_labels = ax.get_legend_handles_labels() handles += ax_handles labels += ax_labels @@ -409,27 +412,44 @@ def get_shifted_traces( return shift_time_axis_by_irf_location(traces, irf_location) -def ensure_axes_array(axes: Axis | Axes) -> Axes: +def ensure_axes_array( + axes: Axes | np.ndarray[Any, Axes] | UnsetType = Unset, axis: UnsetType = Unset +) -> np.ndarray[Any, Axes]: """Ensure that axes have flatten method even if it is a single axis. Parameters ---------- - axes : Axis | Axes - Axis or Axes to convert for API consistency. + axes : Axes | np.ndarray[Any, Axes] | UnsetType + Axes or array of Axes to convert for API consistency. + axis : UnsetType + Deprecated use ``axes`` instead. Defaults to Unset. Returns ------- - Axes + np.ndarray[Any, Axes] Numpy ndarray of axes. + + Raises + ------ + ValueError + If ``axes`` was not provided, ``ax`` should be a required argument but to facilitate the + deprecation ``axis`` -> ``axes`` it has a default of ``Unset``. """ - # We can't use `Axis` in isinstance so we check for the np.ndarray attribute of `Axes` - if hasattr(axes, "flatten") is False: - axes = np.array([axes]) - return axes + if isinstance(axes, UnsetType) and not isinstance(axis, UnsetType): + warn_deprecated( + deprecated_qual_name_usage="axis", + new_qual_name_usage="axes", + to_be_removed_in_version="0.9.0", + ) + axes = axis + if isinstance(axes, UnsetType): + msg = "Required argument ``axes`` wasn't set." + raise ValueError(msg) + return np.array([axes]) if isinstance(axes, np.ndarray) is False else axes -def add_cycler_if_not_none(axis: Axis | Axes, cycler: Cycler | None) -> None: - """Add cycler to and axis if it is not None. +def add_cycler_if_not_none(axes: Axes | np.ndarray[Any, Axes], cycler: Cycler | None) -> None: + """Add cycler to ``Axes`` if it is not None. This is a convenience function that allow to opt out of using a cycler, which is needed to run a plotting function in a loop @@ -438,14 +458,13 @@ def add_cycler_if_not_none(axis: Axis | Axes, cycler: Cycler | None) -> None: Parameters ---------- - axis : Axis | Axes - Axis to plot on. + axes : Axes | np.ndarray[Any, Axes] + Axes to add ``cycler`` to. cycler : Cycler | None Plot style cycler to use. """ if cycler is not None: - axis = ensure_axes_array(axis) - for ax in axis.flatten(): + for ax in ensure_axes_array(axes).flatten(): ax.set_prop_cycle(cycler) @@ -736,7 +755,7 @@ def get_subplot_label_format_function( def add_subplot_labels( - axes: Axis | Axes, + axes: Axes | np.ndarray[Any, Axes], *, label_position: tuple[float, float] = (-0.05, 1.05), label_coords: SubPlotLabelCoord = "axes fraction", @@ -750,7 +769,7 @@ def add_subplot_labels( Parameters ---------- - axes : Axis | Axes + axes : Axes | np.ndarray[Any, Axes] Axes (subplots) on which the labels should be added. label_position : tuple[float, float] Position of the label in ``label_coords`` coordinates. diff --git a/pyproject.toml b/pyproject.toml index 407fbf01..64e89a01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Chemistry", "Topic :: Scientific/Engineering :: Physics",