-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
271 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,9 @@ | ||
from plots.cdf_ranks import CDFRanks | ||
from plots.coverage_fraction import CoverageFraction | ||
from plots.ranks import Ranks | ||
|
||
Plots = { | ||
|
||
CDFRanks.__name__: CDFRanks, | ||
CoverageFraction.__name__: CoverageFraction, | ||
Ranks.__name__: Ranks | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from sbi.analysis import sbc_rank_plot, run_sbc | ||
from torch import tensor | ||
|
||
from plots.plot import Display | ||
from utils.config import get_item | ||
|
||
class CDFRanks(Display): | ||
def __init__(self, model, data, save: bool, show: bool, out_dir: str | None = None): | ||
super().__init__(model, data, save, show, out_dir) | ||
|
||
def _plot_name(self): | ||
return "cdf_ranks.png" | ||
|
||
def _data_setup(self): | ||
thetas = tensor(self.data.theta_true()) | ||
y_true = tensor(self.data.x_true()) | ||
self.num_samples = get_item("metrics_common", "samples_per_inference", raise_exception=False) | ||
|
||
ranks, _ = run_sbc( | ||
thetas, y_true, self.model.posterior, num_posterior_samples=self.num_samples | ||
) | ||
self.ranks = ranks | ||
|
||
def _plot_settings(self): | ||
self.colors = get_item("plots_common", "parameter_colors", raise_exception=False) | ||
self.labels = get_item("plots_common", "parameter_labels", raise_exception=False) | ||
|
||
def _plot(self): | ||
sbc_rank_plot( | ||
self.ranks, | ||
self.num_samples, | ||
plot_type="cdf", | ||
parameter_labels=self.labels, | ||
colors=self.colors, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from matplotlib import colormaps as cm | ||
|
||
from metrics.coverage_fraction import CoverageFraction as coverage_fraction_metric | ||
from plots.plot import Display | ||
from utils.config import get_item | ||
|
||
class CoverageFraction(Display): | ||
def __init__(self, model, data, save: bool, show: bool, out_dir: str | None = None): | ||
super().__init__(model, data, save, show, out_dir) | ||
|
||
def _plot_name(self): | ||
return "coverage_fraction.png" | ||
|
||
def _data_setup(self): | ||
_, coverage = coverage_fraction_metric(self.model, self.data, out_dir=None).calculate() | ||
self.coverage_fractions = coverage | ||
|
||
def _plot_settings(self): | ||
self.labels = get_item("plots_common", "parameter_labels", raise_exception=False) | ||
self.n_parameters = len(self.labels) | ||
self.figure_size = tuple(get_item("plots_common", "figure_size", raise_exception=False)) | ||
self.line_cycle = tuple(get_item("plots_common", "line_style_cycle", raise_exception=False)) | ||
|
||
def _plot( | ||
self, | ||
figure_alpha=1.0, | ||
line_width=3, | ||
legend_loc="lower right", | ||
reference_line_label='Reference Line', | ||
reference_line_style="k--", | ||
x_label="Confidence Interval of the Posterior Volume", | ||
y_label="Fraction of Lenses within Posterior Volume", | ||
title="NPE" | ||
): | ||
|
||
n_steps = self.coverage_fractions.shape[0] | ||
percentile_array = np.linspace(0, 1, n_steps) | ||
color_cycler = iter(plt.cycler("color", cm.get_cmap(self.colorway).colors)) | ||
line_style_cycler = iter(plt.cycler("line_style", self.line_cycle)) | ||
|
||
# Plotting | ||
fig, ax = plt.subplots(1, 1, figsize=self.figure_size) | ||
|
||
# Iterate over the number of parameters in the model | ||
for i in range(self.n_parameters): | ||
|
||
color = next(color_cycler)["color"] | ||
line_style = next(line_style_cycler)["line_style"] | ||
|
||
ax.plot( | ||
percentile_array, | ||
self.coverage_fractions[:, i], | ||
alpha=figure_alpha, | ||
lw=line_width, | ||
linestyle=line_style, | ||
color=color, | ||
label=self.labels[i], | ||
) | ||
|
||
ax.plot( | ||
[0, 0.5, 1], [0, 0.5, 1], reference_line_style, lw=line_width, zorder=1000, | ||
label=reference_line_label | ||
) | ||
|
||
ax.set_xlim([-0.05, 1.05]) | ||
ax.set_ylim([-0.05, 1.05]) | ||
|
||
ax.text(0.03, 0.93, "Under-confident", horizontalalignment="left") | ||
ax.text(0.3, 0.05, "Overconfident", horizontalalignment="left") | ||
|
||
ax.legend(loc=legend_loc) | ||
|
||
ax.set_xlabel(x_label) | ||
ax.set_ylabel(y_label) | ||
ax.set_title(title) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from sbi.analysis import sbc_rank_plot, run_sbc | ||
from torch import tensor | ||
|
||
from plots.plot import Display | ||
from utils.config import get_item | ||
|
||
class Ranks(Display): | ||
def __init__(self, model, data, save: bool, show: bool, out_dir: str | None = None): | ||
super().__init__(model, data, save, show, out_dir) | ||
|
||
def _plot_name(self): | ||
return "ranks.png" | ||
|
||
def _data_setup(self): | ||
thetas = tensor(self.data.theta_true()) | ||
y_true = tensor(self.data.x_true()) | ||
self.num_samples = get_item("metrics_common", "samples_per_inference", raise_exception=False) | ||
|
||
ranks, _ = run_sbc( | ||
thetas, y_true, self.model.posterior, num_posterior_samples=self.num_samples | ||
) | ||
self.ranks = ranks | ||
|
||
def _plot_settings(self): | ||
self.colors = get_item("plots_common", "parameter_colors", raise_exception=False) | ||
self.labels = get_item("plots_common", "parameter_labels", raise_exception=False) | ||
|
||
def _plot(self, num_bins=None): | ||
sbc_rank_plot( | ||
ranks=self.ranks, | ||
num_posterior_samples=self.num_samples, | ||
plot_type="hist", | ||
num_bins=num_bins, | ||
parameter_labels=self.labels, | ||
colors=self.colors | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Oops, something went wrong.