-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
327 additions
and
5,809 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .api import get_data_barcode, get_nn_barcodes, plot_barcode | ||
from .api import evaluate_barcode, get_data_barcode, get_nn_barcodes, plot_barcode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
import heapq | ||
|
||
import numpy as np | ||
|
||
|
||
def _get_available_metrics(): | ||
return { | ||
# absolute length based metrics | ||
"max_length": _compute_longest_interval_metric, | ||
"mean_length": _compute_length_mean_metric, | ||
"median_length": _compute_length_median_metric, | ||
"stdev_length": _compute_length_stdev_metric, | ||
"sum_length": _compute_length_sum_metric, | ||
# relative length based metrics | ||
"ratio_2_1": _compute_two_to_one_ratio_metric, | ||
"ratio_3_1": _compute_three_to_one_ratio_metric, | ||
# entopy based metrics | ||
"h": _compute_entropy_metric, | ||
"normh": _compute_normed_entropy_metric, | ||
# signal to noise ration | ||
"snr": _compute_snr_metric, | ||
# birth-death based metrics | ||
"mean_birth": _compute_births_mean_metric, | ||
"stdev_birth": _compute_births_stdev_metric, | ||
"mean_death": _compute_deaths_mean_metric, | ||
"stdev_death": _compute_deaths_stdev_metric, | ||
} | ||
|
||
|
||
def compute_metric(barcode, metric_name=None): | ||
metrics = _get_available_metrics() | ||
if metric_name is None: | ||
return {name: fn(barcode) for (name, fn) in metrics.items()} | ||
else: | ||
return metrics[metric_name](barcode) | ||
|
||
|
||
def _get_lengths(barcode): | ||
diag = barcode["H0"] | ||
return [d[1] - d[0] for d in diag] | ||
|
||
|
||
def _compute_longest_interval_metric(barcode): | ||
lengths = _get_lengths(barcode) | ||
return np.max(lengths).item() | ||
|
||
|
||
def _compute_length_mean_metric(barcode): | ||
lengths = _get_lengths(barcode) | ||
return np.mean(lengths).item() | ||
|
||
|
||
def _compute_length_median_metric(barcode): | ||
lengths = _get_lengths(barcode) | ||
return np.median(lengths).item() | ||
|
||
|
||
def _compute_length_stdev_metric(barcode): | ||
lengths = _get_lengths(barcode) | ||
return np.std(lengths).item() | ||
|
||
|
||
def _compute_length_sum_metric(barcode): | ||
lengths = _get_lengths(barcode) | ||
return np.sum(lengths).item() | ||
|
||
|
||
# Proportion between the longest intervals: 2/1 ratio, 3/1 ratio | ||
def _compute_two_to_one_ratio_metric(barcode): | ||
lengths = _get_lengths(barcode) | ||
value = heapq.nlargest(2, lengths)[1] / lengths[0] | ||
return value.item() | ||
|
||
|
||
def _compute_three_to_one_ratio_metric(barcode): | ||
lengths = _get_lengths(barcode) | ||
value = heapq.nlargest(3, lengths)[2] / lengths[0] | ||
return value.item() | ||
|
||
|
||
# Compute the persistent entropy and normed persistent entropy | ||
def _get_entropy(values, normalize: bool): | ||
values_sum = np.sum(values) | ||
entropy = (-1) * np.sum(np.divide(values, values_sum) * np.log(np.divide(values, values_sum))) | ||
if normalize: | ||
entropy = entropy / np.log(values_sum) | ||
return entropy | ||
|
||
|
||
def _compute_entropy_metric(barcode): | ||
return _get_entropy(_get_lengths(barcode), normalize=False).item() | ||
|
||
|
||
def _compute_normed_entropy_metric(barcode): | ||
return _get_entropy(_get_lengths(barcode), normalize=True).item() | ||
|
||
|
||
# Compute births | ||
def _get_births(barcode): | ||
diag = barcode["H0"] | ||
return np.array([x[0] for x in diag]) | ||
|
||
|
||
# Comput deaths | ||
def _get_deaths(barcode): | ||
diag = barcode["H0"] | ||
return np.array([x[1] for x in diag]) | ||
|
||
|
||
# def _get_birth(barcode, dim): | ||
# diag = barcode['H0'] | ||
# temp = np.array([x[0] for x in diag if x[2] == dim]) | ||
# return temp[0] | ||
|
||
|
||
# def _get_death(barcode, dim): | ||
# diag = barcode['H0'] | ||
# temp = np.array([x[1] for x in diag if x[2] == dim]) | ||
# return temp[-1] | ||
|
||
|
||
# Compute SNR | ||
def _compute_snr_metric(barcode): | ||
births = _get_births(barcode) | ||
deaths = _get_deaths(barcode) | ||
signal = np.mean(deaths - births) | ||
noise = np.std(births) | ||
snr = signal / noise | ||
return snr.item() | ||
|
||
|
||
# Compute the birth-death pair indices: Birth mean, birth stdev, death mean, death stdev | ||
def _compute_births_mean_metric(barcode): | ||
return np.mean(_get_births(barcode)).item() | ||
|
||
|
||
def _compute_births_stdev_metric(barcode): | ||
return np.std(_get_births(barcode)).item() | ||
|
||
|
||
def _compute_deaths_mean_metric(barcode): | ||
return np.mean(_get_deaths(barcode)).item() | ||
|
||
|
||
def _compute_deaths_stdev_metric(barcode): | ||
return np.std(_get_deaths(barcode)).item() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,6 @@ | ||
from .api import get_random_input, reduce_dim, visualize_layer_manifolds | ||
from .api import ( | ||
get_random_input, | ||
reduce_dim, | ||
visualize_layer_manifolds, | ||
visualize_recurrent_layer_manifolds, | ||
) |
Oops, something went wrong.