Skip to content

Commit

Permalink
Merge pull request #1157 from VincentRouvreau/persistence_graphical_t…
Browse files Browse the repository at this point in the history
…ools_for_sklearn_itf

Persistence graphical tools for sklearn itf
  • Loading branch information
VincentRouvreau authored Jan 24, 2025
2 parents 350a3db + 77870dd commit f9f0cf3
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 83 deletions.
3 changes: 3 additions & 0 deletions .github/next_release.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ Below is a list of changes:
* Delaunay Čech complex (using minimal enclosing ball)
* Alpha complex (moved in this new section)

- [Persistence graphical tools](https://gudhi.inria.fr/python/latest/persistence_graphical_tools_user.html)
- Can now handle scikit-learn like interfaces outputs as inputs

- [Module](link)
- **...**

Expand Down
4 changes: 2 additions & 2 deletions src/python/doc/persistence_graphical_tools_user.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ This function can display the persistence result as a diagram:
ax.set_aspect("equal") # forces to be square shaped
plt.show()

Note that (as barcode and density) it can also take a simple `np.array`
Note that (as barcode and density) it can also take a simple `np.array`
of shape (N x 2) encoding a persistence diagram (in a given dimension).

.. plot::
:include-source:

import matplotlib.pyplot as plt
import gudhi
import numpy as np
Expand Down
167 changes: 98 additions & 69 deletions src/python/gudhi/persistence_graphical_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# - 2020/02 Theo Lacombe: Added more options for improved rendering and more flexibility.
# - 2022/11 Vincent Rouvreau: "Automatic" legend display detected by _array_handler that returns if the persistence
# was a nx2 array.
# - 2024/11 Vincent Rouvreau: Support for sklearn like persistence feedback: New _format_handler function that
# enhances former _array_handler function.
# - YYYY/MM Author: Description of the modification

from os import path
Expand Down Expand Up @@ -56,27 +58,47 @@ def _min_birth_max_death(persistence, band=0.0):
return (min_birth, max_death)


def _array_handler(a):
def _format_handler(a):
"""
:param a: if array, assumes it is a (n x 2) np.array and returns a
persistence-compatible list (padding with 0), so that the
plot can be performed seamlessly.
:param a: * If array, assumes it is a (n x 2) np.array
* If iterable of array, assumes it is an iterable on (n x 2) np.array
Returns a persistence-compatible list so that the plot can be performed seamlessly. It is padding with
the index on the array, 0 in case of an array, in order to simulate the dimension required by the plot.
:returns: * List[dimension, [birth, death]] Persistence, compatible with plot functions, list.
* boolean Modification status (True if output is different from input)
* int Modification status: 0 if not modified, 1 if input was a (n x 2) np.array, 2 if input was an
iterable on (n x 2) np.array
"""
if isinstance(a[0][1], (np.floating, float)):
return [[0, x] for x in a], True
else:
return a, False
# TODO: _format_handler should return a list of numpy arrays as it is close from what matplotlib expects
# Array
try:
first_death_value = a[0][1]
if isinstance(first_death_value, (np.floating, float, np.integer, int)):
return [[0, x] for x in a], 1
except IndexError:
pass
# Iterable of array
try:
pers = []
fake_dim = 0
for elt in a:
first_death_value = elt[0][1]
if not isinstance(first_death_value, (np.floating, float, np.integer, int)):
raise TypeError("Should be a list of (birth,death)")
pers.extend([fake_dim, x] for x in elt)
fake_dim = fake_dim + 1
return pers, 2
except TypeError:
pass
# Nothing to be done otherwise
return a, 0


def _limit_to_max_intervals(persistence, max_intervals, key):
"""This function returns truncated persistence if length is bigger than max_intervals.
:param persistence: Persistence intervals values list. Can be grouped by dimension or not.
:type persistence: an array of (dimension, (birth, death)) or an array of (birth, death).
:param max_intervals: maximal number of intervals to display.
Selected intervals are those with the longest life time. Set it
to 0 to see all. Default value is 1000.
:param max_intervals: maximal number of intervals to display. Selected intervals are those with the longest life
time. Set it to 0 to see all. Default value is 1000.
:type max_intervals: int.
:param key: key function for sort algorithm.
:type key: function or lambda.
Expand Down Expand Up @@ -118,6 +140,7 @@ def _matplotlib_can_use_tex() -> bool:
return False
return True

# TODO: a new homology_dimensions like argument for plot_persistence_barcode and plot_persistence_diagram

def plot_persistence_barcode(
persistence=[],
Expand All @@ -130,35 +153,33 @@ def plot_persistence_barcode(
axes=None,
fontsize=16,
):
"""This function plots the persistence bar code from persistence values list
, a np.array of shape (N x 2) (representing a diagram
in a single homology dimension),
or from a `persistence diagram <fileformats.html#persistence-diagram>`_ file.
"""This function plots the persistence bar code from persistence values list, a np.array of shape (N x 2)
(representing a diagram in a single homology dimension), a list of np.array of shape (N x 2)
(representing a diagram in a range of homology dimensions), or from a `persistence diagram
<fileformats.html#persistence-diagram>`_ file.
:param persistence: Persistence intervals values list. Can be grouped by dimension or not.
:type persistence: an array of (dimension, (birth, death)) or an array of (birth, death)
:type persistence: an array of (dimension, (birth, death)), an array of (birth, death) or an array of array of
(birth, death)
:param persistence_file: A `persistence diagram <fileformats.html#persistence-diagram>`_ file style name
(reset persistence if both are set).
:type persistence_file: string
:param alpha: barcode transparency value (0.0 transparent through 1.0
opaque - default is 0.6).
:param alpha: barcode transparency value (0.0 transparent through 1.0 opaque - default is 0.6).
:type alpha: float
:param max_intervals: maximal number of intervals to display.
Selected intervals are those with the longest life time. Set it
to 0 to see all. Default value is 20000.
:param max_intervals: maximal number of intervals to display. Selected intervals are those with the longest life
time. Set it to 0 to see all. Default value is 20000.
:type max_intervals: int
:param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x
inf_delta)` above :code:`max_death` value. A reasonable value is
between 0.05 and 0.5 - default is 0.1.
:param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x inf_delta)` above :code:`max_death`
value. A reasonable value is between 0.05 and 0.5 - default is 0.1.
:type inf_delta: float
:param legend: Display the dimension color legend. Default is None, meaning the legend is displayed if dimension
is specified in the persistence argument, and not displayed if dimension is not specified.
:param legend: Display the color legend. Default is None, meaning the legend is displayed if dimension is specified
in the persistence argument or if persistence is a range over an array of (birth, death), and not displayed
otherwise.
:type legend: boolean or None
:param colormap: A matplotlib-like qualitative colormaps. Default is None
which means :code:`matplotlib.cm.Set1.colors`.
:param colormap: A matplotlib-like qualitative colormaps. Default is None which means
:code:`matplotlib.cm.Set1.colors`.
:type colormap: tuple of colors (3-tuple of float between 0. and 1.)
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
a new set of axes.
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on a new set of axes.
:type axes: `matplotlib.axes.Axes`
:param fontsize: Fontsize to use in axis.
:type fontsize: int
Expand All @@ -174,8 +195,8 @@ def plot_persistence_barcode(
plt.rc("text", usetex=False)
plt.rc("font", family="DejaVu Sans")

# By default, let's say the persistence is not an array of shape (N x 2) - Can be from a persistence file
nx2_array = False
# By default, let's say the persistence is List[dimension, [birth, death]] - Can be from a persistence file
input_type = 0
if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
Expand All @@ -188,7 +209,7 @@ def plot_persistence_barcode(
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)

try:
persistence, nx2_array = _array_handler(persistence)
persistence, input_type = _format_handler(persistence)
persistence = _limit_to_max_intervals(
persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0]
)
Expand All @@ -199,8 +220,7 @@ def plot_persistence_barcode(
pass

delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for bar code to be more
# readable
# Replace infinity values with max_death + delta for bar code to be more readable
infinity = max_death + delta
axis_start = min_birth - delta

Expand All @@ -215,14 +235,19 @@ def plot_persistence_barcode(

axes.barh(range(len(x)), y, left=x, alpha=alpha, color=c, linewidth=0)

if legend is None and not nx2_array:
# By default, if persistence is an array of (dimension, (birth, death)), display the legend
if legend is None and input_type != 1:
# By default, if persistence is an array of (dimension, (birth, death)), or an
# iterator[iterator[birth, death]], display the legend
legend = True

if legend:
title = "Dimension"
if input_type == 2:
title = "Range"
dimensions = {item[0] for item in persistence}
axes.legend(
handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions],
title=title,
loc="best",
)

Expand All @@ -249,36 +274,34 @@ def plot_persistence_diagram(
fontsize=16,
greyblock=True,
):
r"""This function plots the persistence diagram from persistence values
list, a np.array of shape (N x 2) representing a diagram in a single
homology dimension, or from a `persistence diagram <fileformats.html#persistence-diagram>`_ file`.
"""This function plots the persistence diagram from persistence values list, a np.array of shape (N x 2)
(representing a diagram in a single homology dimension), a list of np.array of shape (N x 2)
(representing a diagram in a range of homology dimensions), or from a `persistence diagram
<fileformats.html#persistence-diagram>`_ file.
:param persistence: Persistence intervals values list. Can be grouped by dimension or not.
:type persistence: an array of (dimension, (birth, death)) or an array of (birth, death)
:type persistence: an array of (dimension, (birth, death)), an array of (birth, death) or an array of array of
(birth, death)
:param persistence_file: A `persistence diagram <fileformats.html#persistence-diagram>`_ file style name
(reset persistence if both are set).
:type persistence_file: string
:param alpha: plot transparency value (0.0 transparent through 1.0
opaque - default is 0.6).
:param alpha: plot transparency value (0.0 transparent through 1.0 opaque - default is 0.6).
:type alpha: float
:param band: band (not displayed if :math:`\leq` 0. - default is 0.)
:type band: float
:param max_intervals: maximal number of intervals to display.
Selected intervals are those with the longest life time. Set it
to 0 to see all. Default value is 1000000.
:param max_intervals: maximal number of intervals to display. Selected intervals are those with the longest life
time. Set it to 0 to see all. Default value is 1000000.
:type max_intervals: int
:param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x
inf_delta)` above :code:`max_death` value. A reasonable value is
between 0.05 and 0.5 - default is 0.1.
:param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x inf_delta)` above :code:`max_death`
value. A reasonable value is between 0.05 and 0.5 - default is 0.1.
:type inf_delta: float
:param legend: Display the dimension color legend. Default is None, meaning the legend is displayed if dimension
is specified in the persistence argument, and not displayed if dimension is not specified.
:type legend: boolean or None
:param colormap: A matplotlib-like qualitative colormaps. Default is None
which means :code:`matplotlib.cm.Set1.colors`.
:param colormap: A matplotlib-like qualitative colormaps. Default is None which means
:code:`matplotlib.cm.Set1.colors`.
:type colormap: tuple of colors (3-tuple of float between 0. and 1.)
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
a new set of axes.
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on a new set of axes.
:type axes: `matplotlib.axes.Axes`
:param fontsize: Fontsize to use in axis.
:type fontsize: int
Expand All @@ -296,8 +319,8 @@ def plot_persistence_diagram(
plt.rc("text", usetex=False)
plt.rc("font", family="DejaVu Sans")

# By default, let's say the persistence is not an array of shape (N x 2) - Can be from a persistence file
nx2_array = False
# By default, let's say the persistence is List[dimension, [birth, death]] - Can be from a persistence file
input_type = 0
if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
Expand All @@ -310,7 +333,7 @@ def plot_persistence_diagram(
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)

try:
persistence, nx2_array = _array_handler(persistence)
persistence, input_type = _format_handler(persistence)
persistence = _limit_to_max_intervals(
persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0]
)
Expand Down Expand Up @@ -363,14 +386,19 @@ def plot_persistence_diagram(
axes.set_yticks(yt)
axes.set_yticklabels(ytl)

if legend is None and not nx2_array:
# By default, if persistence is an array of (dimension, (birth, death)), display the legend
if legend is None and input_type != 1:
# By default, if persistence is an array of (dimension, (birth, death)), or an
# iterator[iterator[birth, death]], display the legend
legend = True

if legend:
title = "Dimension"
if input_type == 2:
title = "Range"
dimensions = list({item[0] for item in persistence})
axes.legend(
handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions],
title=title,
loc="lower right",
)

Expand All @@ -395,17 +423,20 @@ def plot_persistence_density(
fontsize=16,
greyblock=False,
):
"""This function plots the persistence density from persistence values list, np.array of shape (N x 2) representing
a diagram in a single homology dimension, or from a `persistence diagram <fileformats.html#persistence-diagram>`_
file. Be aware that this function does not distinguish the dimension, it is up to you to select the required one.
"""This function plots the persistence density from persistence values list, a np.array of shape (N x 2)
(representing a diagram in a single homology dimension), a list of np.array of shape (N x 2)
(representing a diagram in a range of homology dimensions), or from a `persistence diagram
<fileformats.html#persistence-diagram>`_ file.
Be aware that this function does not distinguish the dimension, it is up to you to select the required one.
This function also does not handle degenerate data set (scipy correlation matrix inversion can fail).
:Requires: `SciPy <installation.html#scipy>`_
:param persistence: Persistence intervals values list. Can be grouped by dimension or not.
:type persistence: an array of (dimension, (birth, death)) or an array of (birth, death)
:param persistence_file: A `persistence diagram <fileformats.html#persistence-diagram>`_
file style name (reset persistence if both are set).
:type persistence: an array of (dimension, (birth, death)), an array of (birth, death) or an array of array of
(birth, death)
:param persistence_file: A `persistence diagram <fileformats.html#persistence-diagram>`_ file style name (reset
persistence if both are set).
:type persistence_file: string
:param nbins: Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents (default is 300)
:type nbins: int
Expand Down Expand Up @@ -463,7 +494,7 @@ def plot_persistence_density(

try:
# if not read from file but given by an argument
persistence, _ = _array_handler(persistence)
persistence, _ = _format_handler(persistence)
persistence_dim = np.array(
[
(dim_interval[1][0], dim_interval[1][1])
Expand All @@ -473,9 +504,7 @@ def plot_persistence_density(
)
persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
persistence_dim = np.array(
_limit_to_max_intervals(
persistence_dim, max_intervals, key=lambda life_time: life_time[1] - life_time[0]
)
_limit_to_max_intervals(persistence_dim, max_intervals, key=lambda life_time: life_time[1] - life_time[0])
)

# Set as numpy array birth and death (remove undefined values - inf and NaN)
Expand Down
Loading

0 comments on commit f9f0cf3

Please sign in to comment.