Skip to content

Commit

Permalink
Merge pull request #52 from voetberg/display_base
Browse files Browse the repository at this point in the history
Display examples
  • Loading branch information
bnord authored Apr 8, 2024
2 parents 2ed43e6 + b6d3c3a commit 1087f26
Show file tree
Hide file tree
Showing 13 changed files with 271 additions and 44 deletions.
23 changes: 14 additions & 9 deletions src/metrics/coverage_fraction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from typing import Any
from torch import tensor
from tqdm import tqdm
from typing import Any

from metrics.metric import Metric
from utils.config import get_item
Expand Down Expand Up @@ -53,15 +54,19 @@ def calculate(self):
# find the percentile for the posterior for this observation
# this is n_params dimensional
# the units are in parameter space
confidence_lower = np.percentile(
samples.cpu(),
percentile_lower,
axis=0
confidence_lower = tensor(
np.percentile(
samples.cpu(),
percentile_lower,
axis=0
)
)
confidence_upper = np.percentile(
samples.cpu(),
percentile_upper,
axis=0
confidence_upper = tensor(
np.percentile(
samples.cpu(),
percentile_upper,
axis=0
)
)

# this is asking if the true parameter value
Expand Down
2 changes: 1 addition & 1 deletion src/models/sbi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def sample_posterior(self, n_samples:int, y_true): # TODO typing
(n_samples,),
x=y_true,
show_progress_bars=False
).cpu()
).cpu() # TODO Unbind from cpu

def predict_posterior(self, data):
posterior_samples = self.sample_posterior(data.y_true)
Expand Down
8 changes: 7 additions & 1 deletion src/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from plots.cdf_ranks import CDFRanks
from plots.coverage_fraction import CoverageFraction
from plots.ranks import Ranks

Plots = {

CDFRanks.__name__: CDFRanks,
CoverageFraction.__name__: CoverageFraction,
Ranks.__name__: Ranks
}
35 changes: 35 additions & 0 deletions src/plots/cdf_ranks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from sbi.analysis import sbc_rank_plot, run_sbc
from torch import tensor

from plots.plot import Display
from utils.config import get_item

class CDFRanks(Display):
def __init__(self, model, data, save: bool, show: bool, out_dir: str | None = None):
super().__init__(model, data, save, show, out_dir)

def _plot_name(self):
return "cdf_ranks.png"

def _data_setup(self):
thetas = tensor(self.data.theta_true())
y_true = tensor(self.data.x_true())
self.num_samples = get_item("metrics_common", "samples_per_inference", raise_exception=False)

ranks, _ = run_sbc(
thetas, y_true, self.model.posterior, num_posterior_samples=self.num_samples
)
self.ranks = ranks

def _plot_settings(self):
self.colors = get_item("plots_common", "parameter_colors", raise_exception=False)
self.labels = get_item("plots_common", "parameter_labels", raise_exception=False)

def _plot(self):
sbc_rank_plot(
self.ranks,
self.num_samples,
plot_type="cdf",
parameter_labels=self.labels,
colors=self.colors,
)
78 changes: 78 additions & 0 deletions src/plots/coverage_fraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colormaps as cm

from metrics.coverage_fraction import CoverageFraction as coverage_fraction_metric
from plots.plot import Display
from utils.config import get_item

class CoverageFraction(Display):
def __init__(self, model, data, save: bool, show: bool, out_dir: str | None = None):
super().__init__(model, data, save, show, out_dir)

def _plot_name(self):
return "coverage_fraction.png"

def _data_setup(self):
_, coverage = coverage_fraction_metric(self.model, self.data, out_dir=None).calculate()
self.coverage_fractions = coverage

def _plot_settings(self):
self.labels = get_item("plots_common", "parameter_labels", raise_exception=False)
self.n_parameters = len(self.labels)
self.figure_size = tuple(get_item("plots_common", "figure_size", raise_exception=False))
self.line_cycle = tuple(get_item("plots_common", "line_style_cycle", raise_exception=False))

def _plot(
self,
figure_alpha=1.0,
line_width=3,
legend_loc="lower right",
reference_line_label='Reference Line',
reference_line_style="k--",
x_label="Confidence Interval of the Posterior Volume",
y_label="Fraction of Lenses within Posterior Volume",
title="NPE"
):

n_steps = self.coverage_fractions.shape[0]
percentile_array = np.linspace(0, 1, n_steps)
color_cycler = iter(plt.cycler("color", cm.get_cmap(self.colorway).colors))
line_style_cycler = iter(plt.cycler("line_style", self.line_cycle))

# Plotting
fig, ax = plt.subplots(1, 1, figsize=self.figure_size)

# Iterate over the number of parameters in the model
for i in range(self.n_parameters):

color = next(color_cycler)["color"]
line_style = next(line_style_cycler)["line_style"]

ax.plot(
percentile_array,
self.coverage_fractions[:, i],
alpha=figure_alpha,
lw=line_width,
linestyle=line_style,
color=color,
label=self.labels[i],
)

ax.plot(
[0, 0.5, 1], [0, 0.5, 1], reference_line_style, lw=line_width, zorder=1000,
label=reference_line_label
)

ax.set_xlim([-0.05, 1.05])
ax.set_ylim([-0.05, 1.05])

ax.text(0.03, 0.93, "Under-confident", horizontalalignment="left")
ax.text(0.3, 0.05, "Overconfident", horizontalalignment="left")

ax.legend(loc=legend_loc)

ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_title(title)

48 changes: 28 additions & 20 deletions src/plots/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,34 @@
import matplotlib.pyplot as plt
from matplotlib import rcParams

from utils.config import get_section
from utils.config import get_item

class Display:
def __init__(self, model, data, save:bool, show:bool, out_path:Optional[str]):
def __init__(self, model, data, save:bool, show:bool, out_dir:Optional[str]=None):
self.save = save
self.show = show
self.out_path = out_path.rstrip("/")
if self.save:
assert self.out_path is not None, "out_path required to save files."
self.data = data

if not os.path.exists(os.path.dirname(out_path)):
os.makedirs(os.path.dirname(out_path))

self.out_path = None
if (out_dir is None) and self.save:
self.out_path = get_item("common", "out_dir", raise_exception=False)

elif self.save and (out_dir is not None):
self.out_path = out_dir

if self.out_path is not None:
if not os.path.exists(os.path.dirname(self.out_path)):
os.makedirs(os.path.dirname(self.out_path))

self.model = model
self._data_setup(data)
self._common_settings()
self._plot_settings()
self.plot_name = self._plot_name()

def _plot_name(self):
raise NotImplementedError

def _data_setup(self, data):
def _data_setup(self):
# Set all the vars used for the plot
raise NotImplementedError

Expand All @@ -38,25 +43,28 @@ def _plot(self, **kwrgs):
raise NotImplementedError

def _common_settings(self):
plot_common = get_section("plot_common", raise_exception=False)
rcParams["axes.spines.right"] = bool(plot_common['axis_spines'])
rcParams["axes.spines.top"] = bool(plot_common['axis_spines'])

rcParams["axes.spines.right"] = bool(get_item('plots_common', 'axis_spines', raise_exception=False))
rcParams["axes.spines.top"] = bool(get_item('plots_common','axis_spines', raise_exception=False))

# Style
self.colorway = plot_common["colorway"]
tight_layout = bool(plot_common['tight_layout'])
self.colorway = get_item('plots_common', "default_colorway", raise_exception=False)
tight_layout = bool(get_item('plots_common','tight_layout', raise_exception=False))
if tight_layout:
plt.tight_layout()
plot_style = plot_common['plot_style']
plot_style = get_item('plots_common','plot_style', raise_exception=False)
plt.style.use(plot_style)

def _finish(self):
assert os.path.splitext(self.plot_name)[-1] != '', f"plot name, {self.plot_name}, is malformed. Please supply a name with an extension."
if self.save:
plt.savefig(f"{self.out_path}/{self.plot_name}")
if self.plot:
plt.savefig(f"{self.out_path.rstrip('/')}/{self.plot_name}")
if self.show:
plt.show()

plt.cla()

def __call__(self, **kwargs) -> None:
self._plot(**kwargs)
def __call__(self, **plot_args) -> None:
self._data_setup()
self._plot(**plot_args)
self._finish()
37 changes: 37 additions & 0 deletions src/plots/ranks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from sbi.analysis import sbc_rank_plot, run_sbc
from torch import tensor

from plots.plot import Display
from utils.config import get_item

class Ranks(Display):
def __init__(self, model, data, save: bool, show: bool, out_dir: str | None = None):
super().__init__(model, data, save, show, out_dir)

def _plot_name(self):
return "ranks.png"

def _data_setup(self):
thetas = tensor(self.data.theta_true())
y_true = tensor(self.data.x_true())
self.num_samples = get_item("metrics_common", "samples_per_inference", raise_exception=False)

ranks, _ = run_sbc(
thetas, y_true, self.model.posterior, num_posterior_samples=self.num_samples
)
self.ranks = ranks

def _plot_settings(self):
self.colors = get_item("plots_common", "parameter_colors", raise_exception=False)
self.labels = get_item("plots_common", "parameter_labels", raise_exception=False)

def _plot(self, num_bins=None):
sbc_rank_plot(
ranks=self.ranks,
num_posterior_samples=self.num_samples,
plot_type="hist",
num_bins=num_bins,
parameter_labels=self.labels,
colors=self.colors
)

13 changes: 9 additions & 4 deletions src/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,22 @@
"plots_common": {
"axis_spines": False,
"tight_layout": True,
"colorway": "virdids",
"plot_style": "fast"
"default_colorway": "viridis",
"plot_style": "fast",
"parameter_labels" : ['$m$','$b$'],
"parameter_colors": ['#9C92A3','#0F5257'],
"line_style_cycle": ["-", "-."],
"figure_size": [6, 6]
},
"plots":{
"type_of_plot":{"specific_kwargs"}
"CDFRanks":{},
"Ranks":{"num_bins":None},
"CoverageFraction":{}
},
"metrics_common": {
"use_progress_bar": False,
"samples_per_inference":1000,
"percentiles":[75, 85, 95]

},
"metrics":{
"AllSBC":{},
Expand Down
15 changes: 6 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,17 @@ def __init__(self):

def __call__(self, thetas, samples):
thetas = np.atleast_2d(thetas)
# Check if the input has the correct shape
if thetas.shape[1] != 2:
raise ValueError("Input tensor must have shape (n, 2) where n is the number of parameter sets.")

# Unpack the parameters
if thetas.shape[0] == 1:
# If there's only one set of parameters, extract them directly
m, b = thetas[0, 0], thetas[0, 1]
else:
# If there are multiple sets of parameters, extract them for each row
m, b = thetas[:, 0], thetas[:, 1]
x = np.linspace(0, 100, samples)
rs = np.random.RandomState()#2147483648)#
# I'm thinking sigma could actually be a function of x
# if we want to get fancy down the road
# Generate random noise (epsilon) based on a normal distribution with mean 0 and standard deviation sigma
rs = np.random.RandomState()
sigma = 1
epsilon = rs.normal(loc=0, scale=sigma, size=(len(x), thetas.shape[0]))

Expand Down Expand Up @@ -77,7 +72,7 @@ def factory(
plots=None,
metrics=None
):
config = { "common": {}, "model": {}, "data":{}, "plot_common": {}, "plots":{}, "metric_common": {},"metrics":{}}
config = { "common": {}, "model": {}, "data":{}, "plots_common": {}, "plots":{}, "metrics_common": {},"metrics":{}}

# Single settings
if out_dir is not None:
Expand All @@ -95,9 +90,11 @@ def factory(

# Dict settings
if plot_settings is not None:
config['plots_common'] = plot_settings
for key, item in plot_settings.items():
config['plots_common'][key] = item
if metrics_settings is not None:
config['metrics_common'] = metrics_settings
for key, item in metrics_settings.items():
config['metrics_common'][key] = item

if metrics is not None:
if isinstance(metrics, dict):
Expand Down
Binary file added tests/plots/coverage.pdf
Binary file not shown.
Binary file added tests/plots/sbc_ranks.pdf
Binary file not shown.
Binary file added tests/plots/sbc_ranks_cdf.pdf
Binary file not shown.
Loading

0 comments on commit 1087f26

Please sign in to comment.