Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add plot_pfid plot function #283

Merged
merged 6 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ jobs:
- name: Install pyglotaran-extras
run: |
pip install wheel
pip install -r requirements_pinned.txt
pip install .

- name: ${{ matrix.example_name }}
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
tests/resources/PE2017

# virtual envs
.venv
.conda

# Do not ship old c_matrix*.c files, due to
# compatibility problems with newer numpy versions
*.c
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ repos:
hooks:
- id: mypy
exclude: ^docs
additional_dependencies: [types-all]
additional_dependencies: [types-tabulate]

- repo: https://github.com/econchick/interrogate
rev: 1.7.0
Expand Down
195 changes: 195 additions & 0 deletions pyglotaran_extras/plotting/plot_pfid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
"""Module containing PFID (Perturbed Free Induction Decay) plotting functionality."""

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Literal

import matplotlib.pyplot as plt
import numpy as np

from pyglotaran_extras.io.load_data import load_data
from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import abs_max
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none
from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi
from pyglotaran_extras.plotting.utils import condense_numbered_strings
from pyglotaran_extras.plotting.utils import extract_irf_location
from pyglotaran_extras.plotting.utils import shift_time_axis_by_irf_location

if TYPE_CHECKING:
from cycler import Cycler
from glotaran.project.result import Result
from matplotlib.figure import Figure
from matplotlib.pyplot import Axes

from pyglotaran_extras.types import DatasetConvertible


def plot_pfid( # noqa: C901
dataset: DatasetConvertible | Result,
*,
names: list[str] | None = None,
oscillation_type: Literal["cos", "sin"] = "cos",
show_clps: bool = False,
time_range: tuple[float, float] | None = None,
spectral: float | None = None,
main_irf_nr: int | None = 0,
normalize: bool = False,
figsize: tuple[float, float] | None = None,
show_zero_line: bool = True,
cycler: Cycler | None = PlotStyle().cycler,
title: str | None = "Perturbed Free Induction Decays",
legend_format_string: str = r"{label}: $\nu$={frequency:.0f}, $\gamma$={rate:.1f}",
) -> tuple[Figure, Axes]:
r"""Plot PFID (Perturbed Free Induction Decay) related data of the optimization result.

Parameters
----------
dataset : DatasetConvertible | Result
Result dataset from a pyglotaran optimization.
names : list[str] | None
List of oscillation names which should be plotted.
Defaults to None which means that all oscillations will be plotted.
oscillation_type : Literal["cos", "sin"]
Type of the oscillation to show in the oscillation plot. Defaults to "cos"
show_clps : bool, optional
If True, plots additional conditional linear parameters (CLPs), in which case the
oscillation_type argument is ignored (both cos and sin are plotted). Defaults to False.
time_range : tuple[float, float] | None
Start and end time for the Oscillation plot, if ``main_irf_nr`` is not None the value are
relative to the IRF location. Defaults to None which means that the full time range is
used.
spectral : float | None
Value of the spectral axis that should be used to select the data for the Oscillation
plot this value does not need to be an exact existing value and only has effect if the
IRF has dispersion. If None the Oscillation plot at lowest spectral value will be shown.
main_irf_nr : int | None
Index of the main ``irf`` component when using an ``irf`` parametrized with multiple peaks
and is used to shift the time axis. If it is none ``None`` the shifting will be
deactivated. Defaults to 0.
normalize : bool
Whether or not to normalize the PFID spectra plot. If the PFID spectra is normalized,
the Oscillation is scaled with the reciprocal of the normalization to compensate for this.
Defaults to False.
figsize : tuple[float, float] | None
Size of the figure (N, M) in inches. Defaults to None which then uses
(20,5) if show_clps=False, (20, 10) if show_clps=True.
show_zero_line : bool
Whether or not to add a horizontal line at zero. Defaults to True
cycler : Cycler | None
Plot style cycler to use. Defaults to PlotStyle().cycler
title : str | None
Title of the figure. Defaults to "Perturbed Free Induction Decays"
legend_format_string : str
Format string for each entry in the legend of the oscillation plot. Possible values which
can be replaced are ``label`` (label of the oscillation in the model definition),
``frequency`` (ν) and ``rate`` (γ). Use ``""`` to remove the legend. Defaults to
``r"{label}: $\nu$={frequency:.0f}, $\gamma$={rate:.1f}"``

Returns
-------
tuple[Figure, Axes]
Figure object which contains the plots and the Axes.

See Also
--------
calculate_ticks_in_units_of_pi
"""
dataset = load_data(dataset, _stacklevel=3)
if figsize is None:
figsize = (20, 10) if show_clps else (20, 5)
fig, axes = plt.subplots(2 if show_clps else 1, 3, figsize=figsize)

add_cycler_if_not_none(axes, cycler)

time_sel_kwargs = {"time": slice(time_range[0], time_range[1])} if time_range else {}
names = dataset.pfid.to_numpy() if names is None else names
osc_sel_kwargs = {"pfid": names}
if spectral is None and "spectral" in dataset.coords:
spectral = dataset.coords["spectral"].min().item()
irf_location = extract_irf_location(dataset, spectral, main_irf_nr)

pfid_phase = dataset["pfid_phase"].sel(**osc_sel_kwargs)
oscillations_spectra = dataset["pfid_associated_spectra"].sel(**osc_sel_kwargs)
oscillation_types = ["cos", "sin"] if show_clps else [oscillation_type]

for axes_index, osc_type in enumerate(oscillation_types):
oscillations = dataset[f"pfid_{osc_type}"]

if "spectral" in oscillations.coords:
oscillations = oscillations.sel(spectral=spectral, method="nearest")

oscillations = shift_time_axis_by_irf_location(
oscillations.sel(**osc_sel_kwargs), irf_location, _internal_call=True
)
osc_max = abs_max(oscillations, result_dims="pfid")
spectra_max = abs_max(oscillations_spectra, result_dims="pfid")
scales = np.sqrt(osc_max * spectra_max)
norm_factor = scales.max() if normalize else 1

oscillations_to_plot = (oscillations / osc_max * scales * norm_factor).sel(
**time_sel_kwargs
)

for oscillation_label in oscillations_to_plot.pfid.to_numpy():
oscillation = oscillations_to_plot.sel(pfid=[oscillation_label])
frequency = oscillation.pfid_frequency.item()
rate = oscillation.pfid_rate.item()
oscillation.plot.line(
x="time",
ax=axes[axes_index, 0] if show_clps else axes[0],
label=legend_format_string.format(
label=oscillation_label, frequency=frequency, rate=rate
),
)

spectra_ax = axes[0, 2] if show_clps else axes[1]
phases_ax = axes[1, 2] if show_clps else axes[2]

(oscillations_spectra / spectra_max * scales / norm_factor).plot.line(
x="spectral", ax=spectra_ax
)
pfid_phase.plot.line(x="spectral", ax=phases_ax)

spectra_ax.set_title("Spectra")
spectra_ax.set_ylabel("Normalized PFID" if normalize else "PFID")

phases_ax.set_title("Phases")
phases_ax.set_yticks(*calculate_ticks_in_units_of_pi(pfid_phase), rotation="horizontal")
phases_ax.set_ylabel("Phase (π)")

if show_clps:
for axes_index, osc_type in enumerate(oscillation_types):
for pfidname in names:
clp_label = f"{pfidname}_{osc_type}"
dataset["clp"].sel(clp_label=clp_label).plot(ax=axes[axes_index, 1])
axes[axes_index, 1].set_title(f"clps={condense_numbered_strings(names)}")

axes[0, 0].set_title(f"Cos Oscillations {spectral}")
axes[1, 0].set_title(f"Sin Oscillations {spectral}")
axes[0, 0].set_xlabel("Time (ps)")
axes[1, 0].set_xlabel("Time (ps)")
axes[0, 1].set_xlabel("Wavenumber (1/cm)")
axes[1, 1].set_xlabel("Wavenumber (1/cm)")
axes[0, 2].set_xlabel("Wavenumber (1/cm)")
axes[1, 2].set_xlabel("Wavenumber (1/cm)")
else:
axes[0].set_title(f"{oscillation_type.capitalize()} Oscillations {spectral}")

for ax in axes.flatten():
legend = ax.get_legend()
if legend is not None:
legend.remove()

if legend_format_string:
axes[0].legend() if not show_clps else axes[0, 0].legend()

if show_zero_line:
[ax.axhline(0, color="k", linewidth=1) for ax in axes.flatten()]

if title:
fig.suptitle(title, fontsize=16)

fig.tight_layout()
return fig, axes
12 changes: 7 additions & 5 deletions pyglotaran_extras/plotting/plot_spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ def plot_sas(
keys = [
v for v in res.data_vars if v.startswith(("species_associated_spectra", "species_spectra"))
]
for key in keys:
for key in reversed(keys):
sas = res[key]
sas.plot.line(x="spectral", ax=ax)
for zorder, species in zip(range(100)[::-1], sas.coords["species"], strict=False):
sas.sel(species=species).plot.line(x="spectral", ax=ax, zorder=zorder)
ax.set_title(title)
ax.get_legend().remove()
if show_zero_line is True:
ax.axhline(0, color="k", linewidth=1)

Expand Down Expand Up @@ -114,9 +114,11 @@ def plot_norm_sas(
]
for key in keys:
sas = res[key]
(sas / np.abs(sas).max(dim="spectral")).plot.line(x="spectral", ax=ax)
for zorder, species in zip(range(100)[::-1], sas.coords["species"], strict=False):
(sas / np.abs(sas).max(dim="spectral")).sel(species=species).plot.line(
x="spectral", ax=ax, zorder=zorder
)
ax.set_title(title)
ax.get_legend().remove()
if show_zero_line is True:
ax.axhline(0, color="k", linewidth=1)

Expand Down
40 changes: 40 additions & 0 deletions pyglotaran_extras/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import itertools
from collections.abc import Iterable
from math import ceil
from math import log
Expand All @@ -22,6 +23,7 @@
from collections.abc import Mapping
from typing import Literal

from collecations.abs import Sequence
from cycler import Cycler
from matplotlib.axis import Axis
from matplotlib.figure import Figure
Expand Down Expand Up @@ -775,3 +777,41 @@ def add_subplot_labels(
xycoords=label_coords,
fontsize=fontsize,
)


def condense_numbered_strings(inputs: Sequence[str] | None) -> str:
"""Condense a list of numbered strings into a compact string representation.

This function takes a sequence of strings with a common prefix followed by numbers,
and condenses consecutive numbers into ranges. For example, the input
["osc1", "osc2", "osc3", "osc5"] will be condensed to "osc1-3,5".

Parameters
----------
inputs : Sequence[str] | None
A sequence of strings to condense. If None or empty, an empty string is returned.

Returns
-------
str
A condensed string representation of the input sequence.
"""
if inputs is None:
return ""

# Extract prefix and numbers
prefix = "".join(c for c in inputs[0] if not c.isdigit())
nums = sorted(int("".join(c for c in s if c.isdigit())) for s in inputs)

# Find and format ranges
ranges = []
for _, group in itertools.groupby(enumerate(nums), key=lambda x: x[0] - x[1]):
group_list = list(group)
start = group_list[0][1]
end = group_list[-1][1]
if start == end:
ranges.append(str(start))
else:
ranges.append(f"{start}-{end}")

return f"{prefix}{','.join(ranges)}"
13 changes: 13 additions & 0 deletions tests/plotting/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none
from pyglotaran_extras.plotting.utils import add_subplot_labels
from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi
from pyglotaran_extras.plotting.utils import condense_numbered_strings
from pyglotaran_extras.plotting.utils import ensure_axes_array
from pyglotaran_extras.plotting.utils import format_sub_plot_number_upper_case_letter
from pyglotaran_extras.plotting.utils import get_next_cycler_color
Expand Down Expand Up @@ -206,3 +207,15 @@ def test_add_subplot_labels_label_format_template(label_format_template: str, ex
add_subplot_labels(ax, label_format_template=label_format_template)

assert ax.texts[0].get_text() == expected


def test_condense_numbered_strings():
"""Test condensed labels."""
assert condense_numbered_strings(None) == ""
assert condense_numbered_strings(["osc1", "osc2", "osc3", "osc4"]) == "osc1-4"
assert condense_numbered_strings(["osc3", "osc4"]) == "osc3-4"
assert condense_numbered_strings(["osc1"]) == "osc1"
assert condense_numbered_strings(["osc1", "osc3", "osc4", "osc5", "osc7"]) == "osc1,3-5,7"
assert condense_numbered_strings(["abc1", "abc3", "abc4", "abc5", "abc7"]) == "abc1,3-5,7"
assert condense_numbered_strings(["aa1", "aa3", "aa4", "aa5", "aa7"]) == "aa1,3-5,7"
# TODO: handle case where substring is not all the same (e.g. abc1, def2)
Loading