Skip to content

Commit

Permalink
⌨️🩹 Fix wrong typing using Axis instead of Axes (#314)
Browse files Browse the repository at this point in the history
* 🧪⌨️ Fix typing of Axes as Axis and np.ndarray[..., Axes] as Axes

* 🧰🩹 Change docformatter rev to a version compatible with pre-commit 4

🧰⬆️ Update pre-commit config
🧰✨ Add taplo to pre-commit toll for toml formatting and linting

* 🚇🩹 Don't run taplo lint in pre-commit CI

Pre-commit-CI doesn't support reachin out to get schema and causes a DNS error
See CI run:
https://results.pre-commit.ci/run/github/299106891/1734109072.qfn_wrsQS6mmHWJ4LP5PzQ
  • Loading branch information
s-weigand authored Dec 14, 2024
1 parent 29b37dc commit abaff2e
Show file tree
Hide file tree
Showing 13 changed files with 196 additions and 150 deletions.
25 changes: 16 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
ci:
skip: [flake8]
skip: [flake8, taplo-lint]

repos:
# Formatters
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-ast
- id: check-builtin-literals
Expand All @@ -16,14 +16,14 @@ repos:
args: [--remove]

- repo: https://github.com/PyCQA/docformatter
rev: v1.7.5
rev: 06907d0
hooks:
- id: docformatter
additional_dependencies: [tomli]
args: [--in-place, --config, ./pyproject.toml]

- repo: https://github.com/tox-dev/pyproject-fmt
rev: "2.2.4"
rev: "v2.5.0"
hooks:
- id: pyproject-fmt

Expand All @@ -36,17 +36,18 @@ repos:
rev: v4.0.0-alpha.8 # Use the sha or tag you want to point at
hooks:
- id: prettier
additional_dependencies: ["[email protected]"]

# Notebook tools
- repo: https://github.com/kynan/nbstripout
rev: 0.7.1
rev: 0.8.1
hooks:
- id: nbstripout
args: [--drop-empty-cells]

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.7
rev: v0.8.3
hooks:
- id: ruff
name: "ruff sort imports notebooks"
Expand All @@ -72,7 +73,7 @@ repos:
# Linters

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
rev: v1.13.0
hooks:
- id: mypy
exclude: ^docs
Expand All @@ -88,7 +89,7 @@ repos:

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.7
rev: v0.8.3
hooks:
- id: ruff
name: "ruff sort imports"
Expand Down Expand Up @@ -127,6 +128,12 @@ repos:
args: ["--ignore-words-list=doas"]

- repo: https://github.com/rhysd/actionlint
rev: "v1.7.2"
rev: "v1.7.4"
hooks:
- id: actionlint

- repo: https://github.com/ComPWA/taplo-pre-commit
rev: v0.9.3
hooks:
- id: taplo-format
- id: taplo-lint
2 changes: 1 addition & 1 deletion .ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ select = [
"RSE", # flake8-raise
"RET", # flake8-return
"SIM", # flake8-simplify
"TCH", # flake8-type-checking
"TC", # flake8-type-checking
"ARG", # flake8-unused-arguments
"PTH", # flake8-use-pathlib
"ERA", # eradicate
Expand Down
31 changes: 15 additions & 16 deletions pyglotaran_extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,38 +36,37 @@
from pyglotaran_extras.plotting.utils import add_subplot_labels

__all__ = [
"CONFIG",
"PerFunctionPlotConfig",
"add_subplot_labels",
"create_config_schema",
"load_data",
"setup_case_study",
"plot_coherent_artifact",
"plot_concentrations",
"plot_config_context",
"plot_das",
"plot_data_overview",
"plot_doas",
"plot_fitted_traces",
"plot_guidance",
"plot_irf_dispersion_center",
"plot_overview",
"plot_simple_overview",
"plot_residual",
"plot_das",
"plot_norm_das",
"plot_norm_sas",
"plot_sas",
"plot_spectra",
"plot_lsv_data",
"plot_lsv_residual",
"plot_norm_das",
"plot_norm_sas",
"plot_overview",
"plot_residual",
"plot_rsv_data",
"plot_rsv_residual",
"plot_sas",
"plot_simple_overview",
"plot_spectra",
"plot_sv_data",
"plot_sv_residual",
"plot_svd",
"plot_fitted_traces",
"select_plot_wavelengths",
"add_subplot_labels",
# Config
"PerFunctionPlotConfig",
"plot_config_context",
"setup_case_study",
"use_plot_config",
"create_config_schema",
"CONFIG",
]

__version__ = "0.7.4.dev0"
Expand Down
2 changes: 1 addition & 1 deletion pyglotaran_extras/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from pyglotaran_extras.io.load_data import load_data
from pyglotaran_extras.io.setup_case_study import setup_case_study

__all__ = ["setup_case_study", "load_data"]
__all__ = ["load_data", "setup_case_study"]
8 changes: 4 additions & 4 deletions pyglotaran_extras/plotting/plot_concentrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
if TYPE_CHECKING:
import xarray as xr
from cycler import Cycler
from matplotlib.axis import Axis
from matplotlib.axes import Axes


@use_plot_config(exclude_from_config=("cycler",))
def plot_concentrations(
res: xr.Dataset,
ax: Axis,
ax: Axes,
center_λ: float | None,
linlog: bool = False,
linthresh: float = 1,
Expand All @@ -34,8 +34,8 @@ def plot_concentrations(
----------
res : xr.Dataset
Result dataset from a pyglotaran optimization.
ax : Axis
Axis to plot the traces on
ax : Axes
Axes to plot the traces on
center_λ : float | None
Center wavelength (λ in nm)
linlog : bool
Expand Down
24 changes: 12 additions & 12 deletions pyglotaran_extras/plotting/plot_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import cast

import matplotlib.pyplot as plt
import numpy as np
from glotaran.io.prepare_dataset import add_svd_to_dataset
from matplotlib.axis import Axis
from matplotlib.axes import Axes

from pyglotaran_extras.config.plot_config import use_plot_config
from pyglotaran_extras.io.load_data import load_data
Expand All @@ -27,8 +27,8 @@
import xarray as xr
from cycler import Cycler
from glotaran.project.result import Result
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.pyplot import Axes

from pyglotaran_extras.types import DatasetConvertible

Expand All @@ -48,7 +48,7 @@ def plot_data_overview(
vmax: float | None = None,
svd_cycler: Cycler | None = PlotStyle().cycler,
use_svd_number: bool = False,
) -> tuple[Figure, Axes] | tuple[Figure, Axis]:
) -> tuple[Figure, np.ndarray[Axes]] | tuple[Figure, Axes]:
"""Plot data as filled contour plot and SVD components.
Parameters
Expand Down Expand Up @@ -86,7 +86,7 @@ def plot_data_overview(
Returns
-------
tuple[Figure, Axes] | tuple[Figure, Axis]
tuple[Figure, np.ndarray[Axes]] | tuple[Figure, Axes]
Figure and axes which can then be refined by the user.
"""
dataset = load_data(dataset, _stacklevel=3)
Expand All @@ -103,11 +103,11 @@ def plot_data_overview(
)

fig = plt.figure(figsize=figsize)
data_ax = cast(Axis, plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3, fig=fig))
data_ax = plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3, fig=fig)
fig.subplots_adjust(hspace=0.5, wspace=0.25)
lsv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 0), fig=fig))
sv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 1), fig=fig))
rsv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 2), fig=fig))
lsv_ax = plt.subplot2grid((4, 3), (3, 0), fig=fig)
sv_ax = plt.subplot2grid((4, 3), (3, 1), fig=fig)
rsv_ax = plt.subplot2grid((4, 3), (3, 2), fig=fig)

if len(data.time) > 1:
data.plot(x="time", ax=data_ax, center=False, cmap=cmap, vmin=vmin, vmax=vmax)
Expand Down Expand Up @@ -146,7 +146,7 @@ def plot_data_overview(
if linlog:
data_ax.set_xscale("symlog", linthresh=linthresh)
data_ax.xaxis.set_minor_locator(MinorSymLogLocator(linthresh))
return fig, (data_ax, lsv_ax, sv_ax, rsv_ax)
return fig, np.array((data_ax, lsv_ax, sv_ax, rsv_ax))


def _plot_single_trace(
Expand All @@ -157,7 +157,7 @@ def _plot_single_trace(
linlog: bool = False,
linthresh: float = 1,
figsize: tuple[float, float] = (15, 10),
) -> tuple[Figure, Axis]:
) -> tuple[Figure, Axes]:
"""Plot single trace data in case ``plot_data_overview`` gets passed ingle trace data.
Parameters
Expand All @@ -178,7 +178,7 @@ def _plot_single_trace(
Returns
-------
tuple[Figure, Axis]
tuple[Figure, Axes]
Figure and axis which can then be refined by the user.
"""
fig, ax = plt.subplots(1, 1, figsize=figsize)
Expand Down
34 changes: 16 additions & 18 deletions pyglotaran_extras/plotting/plot_irf_dispersion_center.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

import matplotlib.pyplot as plt
import xarray as xr
from matplotlib.axis import Axis
from matplotlib.figure import Figure

from pyglotaran_extras.config.plot_config import use_plot_config
from pyglotaran_extras.io.utils import result_dataset_mapping
Expand All @@ -20,26 +18,28 @@
from typing import Literal

from cycler import Cycler
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from pyglotaran_extras.types import ResultLike


@use_plot_config(exclude_from_config=("cycler", "ax"))
def plot_irf_dispersion_center(
result: ResultLike,
ax: Axis | None = None,
ax: Axes | None = None,
figsize: tuple[float, float] = (12, 8),
cycler: Cycler | None = PlotStyle().cycler,
irf_location: float | None = None,
) -> tuple[Figure, Axis] | None:
) -> tuple[Figure, Axes] | None:
"""Plot the IRF dispersion center over the spectral dimension for one or multiple datasets.
Parameters
----------
result : ResultLike
Data structure which can be converted to a mapping.
ax : Axis | None
Axis to plot on. Defaults to None which means that a new figure and axis will be created.
ax : Axes | None
Axes to plot on. Defaults to None which means that a new figure and axis will be created.
figsize : tuple[float, float]
Size of the figure (N, M) in inches. Defaults to (12, 8).
cycler : Cycler | None
Expand All @@ -50,51 +50,49 @@ def plot_irf_dispersion_center(
Returns
-------
tuple[Figure, Axis] | None
Figure object which contains the plots and the Axis,
tuple[Figure, Axes] | None
Figure object which contains the plots and the Axes,
if ``ax`` is not None nothing will be returned.
"""
result_map = result_dataset_mapping(result)
if ax is None:
fig, axis = cast(tuple[Figure, Axis], plt.subplots(1, figsize=figsize))
else:
axis = ax
fig, ax = plt.subplots(1, figsize=figsize)
for dataset_name, dataset in result_map.items():
_plot_irf_dispersion_center(
dataset,
axis,
ax,
spectral_axis="x",
cycler=cycler,
label=dataset_name,
irf_location=irf_location,
)
axis.legend()
ax.legend()

if ax is None:
fig.suptitle("Instrument Response Functions", fontsize=16)
return fig, axis
return fig, ax
return None


def _plot_irf_dispersion_center(
res: xr.Dataset,
ax: Axis,
ax: Axes,
*,
spectral_axis: Literal["x", "y"] = "x",
cycler: Cycler | None = PlotStyle().cycler,
label: str = "IRF",
irf_location: float | None = None,
) -> None:
"""Plot the IRF dispersion center on an Axis ``ax``.
"""Plot the IRF dispersion center on an Axes ``ax``.
This is an internal function to be used by higher level functions.
Parameters
----------
res : xr.Dataset
Dataset containing the IRF data.
ax : Axis
Axis to plot on.
ax : Axes
Axes to plot on.
spectral_axis : Literal["x", "y"]
Direct of the spectral axis in the plot. Defaults to "x"
cycler : Cycler | None
Expand Down
8 changes: 4 additions & 4 deletions pyglotaran_extras/plotting/plot_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
if TYPE_CHECKING:
import xarray as xr
from cycler import Cycler
from matplotlib.axis import Axis
from matplotlib.axes import Axes


@use_plot_config(exclude_from_config=("cycler",))
def plot_residual(
res: xr.Dataset,
ax: Axis,
ax: Axes,
linlog: bool = False,
linthresh: float = 1,
show_data: bool | None = False,
Expand All @@ -36,8 +36,8 @@ def plot_residual(
----------
res : xr.Dataset
Result dataset
ax : Axis
Axis to plot on.
ax : Axes
Axes to plot on.
linlog : bool
Whether to use 'symlog' scale or not. Defaults to False.
linthresh : float
Expand Down
Loading

0 comments on commit abaff2e

Please sign in to comment.