diff --git a/src/metrics/coverage_fraction.py b/src/metrics/coverage_fraction.py index 85dba4a..9b3ff6c 100644 --- a/src/metrics/coverage_fraction.py +++ b/src/metrics/coverage_fraction.py @@ -1,6 +1,7 @@ import numpy as np -from typing import Any +from torch import tensor from tqdm import tqdm +from typing import Any from metrics.metric import Metric from utils.config import get_item @@ -53,15 +54,19 @@ def calculate(self): # find the percentile for the posterior for this observation # this is n_params dimensional # the units are in parameter space - confidence_lower = np.percentile( - samples.cpu(), - percentile_lower, - axis=0 + confidence_lower = tensor( + np.percentile( + samples.cpu(), + percentile_lower, + axis=0 + ) ) - confidence_upper = np.percentile( - samples.cpu(), - percentile_upper, - axis=0 + confidence_upper = tensor( + np.percentile( + samples.cpu(), + percentile_upper, + axis=0 + ) ) # this is asking if the true parameter value diff --git a/src/models/sbi_model.py b/src/models/sbi_model.py index 7e9b382..bcc3a0c 100644 --- a/src/models/sbi_model.py +++ b/src/models/sbi_model.py @@ -20,7 +20,7 @@ def sample_posterior(self, n_samples:int, y_true): # TODO typing (n_samples,), x=y_true, show_progress_bars=False - ).cpu() + ).cpu() # TODO Unbind from cpu def predict_posterior(self, data): posterior_samples = self.sample_posterior(data.y_true) diff --git a/src/plots/__init__.py b/src/plots/__init__.py index b8536cc..7ad7227 100644 --- a/src/plots/__init__.py +++ b/src/plots/__init__.py @@ -1,3 +1,9 @@ +from plots.cdf_ranks import CDFRanks +from plots.coverage_fraction import CoverageFraction +from plots.ranks import Ranks + Plots = { - + CDFRanks.__name__: CDFRanks, + CoverageFraction.__name__: CoverageFraction, + Ranks.__name__: Ranks } \ No newline at end of file diff --git a/src/plots/cdf_ranks.py b/src/plots/cdf_ranks.py new file mode 100644 index 0000000..03d5f66 --- /dev/null +++ b/src/plots/cdf_ranks.py @@ -0,0 +1,35 @@ +from sbi.analysis import sbc_rank_plot, run_sbc +from torch import tensor + +from plots.plot import Display +from utils.config import get_item + +class CDFRanks(Display): + def __init__(self, model, data, save: bool, show: bool, out_dir: str | None = None): + super().__init__(model, data, save, show, out_dir) + + def _plot_name(self): + return "cdf_ranks.png" + + def _data_setup(self): + thetas = tensor(self.data.theta_true()) + y_true = tensor(self.data.x_true()) + self.num_samples = get_item("metrics_common", "samples_per_inference", raise_exception=False) + + ranks, _ = run_sbc( + thetas, y_true, self.model.posterior, num_posterior_samples=self.num_samples + ) + self.ranks = ranks + + def _plot_settings(self): + self.colors = get_item("plots_common", "parameter_colors", raise_exception=False) + self.labels = get_item("plots_common", "parameter_labels", raise_exception=False) + + def _plot(self): + sbc_rank_plot( + self.ranks, + self.num_samples, + plot_type="cdf", + parameter_labels=self.labels, + colors=self.colors, + ) \ No newline at end of file diff --git a/src/plots/coverage_fraction.py b/src/plots/coverage_fraction.py new file mode 100644 index 0000000..c70d440 --- /dev/null +++ b/src/plots/coverage_fraction.py @@ -0,0 +1,78 @@ +import numpy as np +import matplotlib.pyplot as plt +from matplotlib import colormaps as cm + +from metrics.coverage_fraction import CoverageFraction as coverage_fraction_metric +from plots.plot import Display +from utils.config import get_item + +class CoverageFraction(Display): + def __init__(self, model, data, save: bool, show: bool, out_dir: str | None = None): + super().__init__(model, data, save, show, out_dir) + + def _plot_name(self): + return "coverage_fraction.png" + + def _data_setup(self): + _, coverage = coverage_fraction_metric(self.model, self.data, out_dir=None).calculate() + self.coverage_fractions = coverage + + def _plot_settings(self): + self.labels = get_item("plots_common", "parameter_labels", raise_exception=False) + self.n_parameters = len(self.labels) + self.figure_size = tuple(get_item("plots_common", "figure_size", raise_exception=False)) + self.line_cycle = tuple(get_item("plots_common", "line_style_cycle", raise_exception=False)) + + def _plot( + self, + figure_alpha=1.0, + line_width=3, + legend_loc="lower right", + reference_line_label='Reference Line', + reference_line_style="k--", + x_label="Confidence Interval of the Posterior Volume", + y_label="Fraction of Lenses within Posterior Volume", + title="NPE" + ): + + n_steps = self.coverage_fractions.shape[0] + percentile_array = np.linspace(0, 1, n_steps) + color_cycler = iter(plt.cycler("color", cm.get_cmap(self.colorway).colors)) + line_style_cycler = iter(plt.cycler("line_style", self.line_cycle)) + + # Plotting + fig, ax = plt.subplots(1, 1, figsize=self.figure_size) + + # Iterate over the number of parameters in the model + for i in range(self.n_parameters): + + color = next(color_cycler)["color"] + line_style = next(line_style_cycler)["line_style"] + + ax.plot( + percentile_array, + self.coverage_fractions[:, i], + alpha=figure_alpha, + lw=line_width, + linestyle=line_style, + color=color, + label=self.labels[i], + ) + + ax.plot( + [0, 0.5, 1], [0, 0.5, 1], reference_line_style, lw=line_width, zorder=1000, + label=reference_line_label + ) + + ax.set_xlim([-0.05, 1.05]) + ax.set_ylim([-0.05, 1.05]) + + ax.text(0.03, 0.93, "Under-confident", horizontalalignment="left") + ax.text(0.3, 0.05, "Overconfident", horizontalalignment="left") + + ax.legend(loc=legend_loc) + + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + ax.set_title(title) + \ No newline at end of file diff --git a/src/plots/plot.py b/src/plots/plot.py index 3501cdf..2bbed32 100644 --- a/src/plots/plot.py +++ b/src/plots/plot.py @@ -3,21 +3,26 @@ import matplotlib.pyplot as plt from matplotlib import rcParams -from utils.config import get_section +from utils.config import get_item class Display: - def __init__(self, model, data, save:bool, show:bool, out_path:Optional[str]): + def __init__(self, model, data, save:bool, show:bool, out_dir:Optional[str]=None): self.save = save self.show = show - self.out_path = out_path.rstrip("/") - if self.save: - assert self.out_path is not None, "out_path required to save files." + self.data = data - if not os.path.exists(os.path.dirname(out_path)): - os.makedirs(os.path.dirname(out_path)) - + self.out_path = None + if (out_dir is None) and self.save: + self.out_path = get_item("common", "out_dir", raise_exception=False) + + elif self.save and (out_dir is not None): + self.out_path = out_dir + + if self.out_path is not None: + if not os.path.exists(os.path.dirname(self.out_path)): + os.makedirs(os.path.dirname(self.out_path)) + self.model = model - self._data_setup(data) self._common_settings() self._plot_settings() self.plot_name = self._plot_name() @@ -25,7 +30,7 @@ def __init__(self, model, data, save:bool, show:bool, out_path:Optional[str]): def _plot_name(self): raise NotImplementedError - def _data_setup(self, data): + def _data_setup(self): # Set all the vars used for the plot raise NotImplementedError @@ -38,25 +43,28 @@ def _plot(self, **kwrgs): raise NotImplementedError def _common_settings(self): - plot_common = get_section("plot_common", raise_exception=False) - rcParams["axes.spines.right"] = bool(plot_common['axis_spines']) - rcParams["axes.spines.top"] = bool(plot_common['axis_spines']) + + rcParams["axes.spines.right"] = bool(get_item('plots_common', 'axis_spines', raise_exception=False)) + rcParams["axes.spines.top"] = bool(get_item('plots_common','axis_spines', raise_exception=False)) # Style - self.colorway = plot_common["colorway"] - tight_layout = bool(plot_common['tight_layout']) + self.colorway = get_item('plots_common', "default_colorway", raise_exception=False) + tight_layout = bool(get_item('plots_common','tight_layout', raise_exception=False)) if tight_layout: plt.tight_layout() - plot_style = plot_common['plot_style'] + plot_style = get_item('plots_common','plot_style', raise_exception=False) plt.style.use(plot_style) def _finish(self): assert os.path.splitext(self.plot_name)[-1] != '', f"plot name, {self.plot_name}, is malformed. Please supply a name with an extension." if self.save: - plt.savefig(f"{self.out_path}/{self.plot_name}") - if self.plot: + plt.savefig(f"{self.out_path.rstrip('/')}/{self.plot_name}") + if self.show: plt.show() + + plt.cla() - def __call__(self, **kwargs) -> None: - self._plot(**kwargs) + def __call__(self, **plot_args) -> None: + self._data_setup() + self._plot(**plot_args) self._finish() \ No newline at end of file diff --git a/src/plots/ranks.py b/src/plots/ranks.py new file mode 100644 index 0000000..52581d7 --- /dev/null +++ b/src/plots/ranks.py @@ -0,0 +1,37 @@ +from sbi.analysis import sbc_rank_plot, run_sbc +from torch import tensor + +from plots.plot import Display +from utils.config import get_item + +class Ranks(Display): + def __init__(self, model, data, save: bool, show: bool, out_dir: str | None = None): + super().__init__(model, data, save, show, out_dir) + + def _plot_name(self): + return "ranks.png" + + def _data_setup(self): + thetas = tensor(self.data.theta_true()) + y_true = tensor(self.data.x_true()) + self.num_samples = get_item("metrics_common", "samples_per_inference", raise_exception=False) + + ranks, _ = run_sbc( + thetas, y_true, self.model.posterior, num_posterior_samples=self.num_samples + ) + self.ranks = ranks + + def _plot_settings(self): + self.colors = get_item("plots_common", "parameter_colors", raise_exception=False) + self.labels = get_item("plots_common", "parameter_labels", raise_exception=False) + + def _plot(self, num_bins=None): + sbc_rank_plot( + ranks=self.ranks, + num_posterior_samples=self.num_samples, + plot_type="hist", + num_bins=num_bins, + parameter_labels=self.labels, + colors=self.colors + ) + \ No newline at end of file diff --git a/src/utils/defaults.py b/src/utils/defaults.py index b2d75aa..0f511ac 100644 --- a/src/utils/defaults.py +++ b/src/utils/defaults.py @@ -13,17 +13,22 @@ "plots_common": { "axis_spines": False, "tight_layout": True, - "colorway": "virdids", - "plot_style": "fast" + "default_colorway": "viridis", + "plot_style": "fast", + "parameter_labels" : ['$m$','$b$'], + "parameter_colors": ['#9C92A3','#0F5257'], + "line_style_cycle": ["-", "-."], + "figure_size": [6, 6] }, "plots":{ - "type_of_plot":{"specific_kwargs"} + "CDFRanks":{}, + "Ranks":{"num_bins":None}, + "CoverageFraction":{} }, "metrics_common": { "use_progress_bar": False, "samples_per_inference":1000, "percentiles":[75, 85, 95] - }, "metrics":{ "AllSBC":{}, diff --git a/tests/conftest.py b/tests/conftest.py index 8d9f116..3be2a31 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,11 +14,9 @@ def __init__(self): def __call__(self, thetas, samples): thetas = np.atleast_2d(thetas) - # Check if the input has the correct shape if thetas.shape[1] != 2: raise ValueError("Input tensor must have shape (n, 2) where n is the number of parameter sets.") - # Unpack the parameters if thetas.shape[0] == 1: # If there's only one set of parameters, extract them directly m, b = thetas[0, 0], thetas[0, 1] @@ -26,10 +24,7 @@ def __call__(self, thetas, samples): # If there are multiple sets of parameters, extract them for each row m, b = thetas[:, 0], thetas[:, 1] x = np.linspace(0, 100, samples) - rs = np.random.RandomState()#2147483648)# - # I'm thinking sigma could actually be a function of x - # if we want to get fancy down the road - # Generate random noise (epsilon) based on a normal distribution with mean 0 and standard deviation sigma + rs = np.random.RandomState() sigma = 1 epsilon = rs.normal(loc=0, scale=sigma, size=(len(x), thetas.shape[0])) @@ -77,7 +72,7 @@ def factory( plots=None, metrics=None ): - config = { "common": {}, "model": {}, "data":{}, "plot_common": {}, "plots":{}, "metric_common": {},"metrics":{}} + config = { "common": {}, "model": {}, "data":{}, "plots_common": {}, "plots":{}, "metrics_common": {},"metrics":{}} # Single settings if out_dir is not None: @@ -95,9 +90,11 @@ def factory( # Dict settings if plot_settings is not None: - config['plots_common'] = plot_settings + for key, item in plot_settings.items(): + config['plots_common'][key] = item if metrics_settings is not None: - config['metrics_common'] = metrics_settings + for key, item in metrics_settings.items(): + config['metrics_common'][key] = item if metrics is not None: if isinstance(metrics, dict): diff --git a/tests/plots/coverage.pdf b/tests/plots/coverage.pdf new file mode 100644 index 0000000..bb08e02 Binary files /dev/null and b/tests/plots/coverage.pdf differ diff --git a/tests/plots/sbc_ranks.pdf b/tests/plots/sbc_ranks.pdf new file mode 100644 index 0000000..f139a2f Binary files /dev/null and b/tests/plots/sbc_ranks.pdf differ diff --git a/tests/plots/sbc_ranks_cdf.pdf b/tests/plots/sbc_ranks_cdf.pdf new file mode 100644 index 0000000..d436496 Binary files /dev/null and b/tests/plots/sbc_ranks_cdf.pdf differ diff --git a/tests/test_plots.py b/tests/test_plots.py index e69de29..e19bdc5 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -0,0 +1,56 @@ +import os +import pytest + +from utils.defaults import Defaults +from utils.config import Config, get_item +from plots import ( + Plots, + CDFRanks, + Ranks, + CoverageFraction +) + +@pytest.fixture +def plot_config(config_factory): + out_dir = "./temp_results/" + metrics_settings={"use_progress_bar":False, "samples_per_inference":10, "percentiles":[95]} + config = config_factory(out_dir=out_dir, metrics_settings=metrics_settings) + return config + +def test_all_plot_catalogued(): + '''Each metrics gets its own file, and each metric is included in the Metrics dictionary + so the client can use it. + This test verifies all metrics are cataloged''' + + all_files = os.listdir("src/plots/") + files_ignore = ['plot.py', '__init__.py', '__pycache__'] # All files not containing a metric + num_files = len([file for file in all_files if file not in files_ignore]) + assert len(Plots) == num_files + +def test_all_defaults(plot_config, mock_model, mock_data): + """ + Ensures each metric has a default set of parameters and is included in the defaults list + Ensures each test can initialize, regardless of the veracity of the output + """ + Config(plot_config) + for plot_name, plot_obj in Plots.items(): + assert plot_name in Defaults['plots'] + plot_obj(mock_model, mock_data, save=True, show=False) + +def test_plot_cdf(plot_config, mock_model, mock_data): + Config(plot_config) + plot = CDFRanks(mock_model, mock_data, save=True, show=False) + plot(**get_item("plots", "CDFRanks", raise_exception=False)) + assert os.path.exists(f"{plot.out_path}/{plot.plot_name}") + +def test_plot_ranks(plot_config, mock_model, mock_data): + Config(plot_config) + plot = Ranks(mock_model, mock_data, save=True, show=False) + plot(**get_item("plots", "Ranks", raise_exception=False)) + assert os.path.exists(f"{plot.out_path}/{plot.plot_name}") + +def test_plot_coverage(plot_config, mock_model, mock_data): + Config(plot_config) + plot = CoverageFraction(mock_model, mock_data, save=True, show=False) + plot(**get_item("plots", "CoverageFraction", raise_exception=False)) + assert os.path.exists(f"{plot.out_path}/{plot.plot_name}") \ No newline at end of file