Skip to content

Commit

Permalink
Resolve merge confict
Browse files Browse the repository at this point in the history
  • Loading branch information
voetberg committed Jun 24, 2024
2 parents e559faa + 5593326 commit 41cc428
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from plots.tarp import TARP
from plots.local_two_sample import LocalTwoSampleTest
from plots.predictive_posterior_check import PPC
from plots.predictive_prior_check import PriorPC

Plots = {
CDFRanks.__name__: CDFRanks,
CoverageFraction.__name__: CoverageFraction,
Ranks.__name__: Ranks,
TARP.__name__: TARP,
"LC2ST": LocalTwoSampleTest,
PPC.__name__: PPC
PPC.__name__: PPC,
PriorPC.__name__: PriorPC
}
133 changes: 133 additions & 0 deletions src/plots/predictive_prior_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import Optional, Sequence
import matplotlib.pyplot as plt
import numpy as np

from plots.plot import Display

class PriorPC(Display):
def __init__(
self,
model,
data,
save:bool,
show:bool,
out_dir:Optional[str]=None,
percentiles: Optional[Sequence] = None,
use_progress_bar: Optional[bool] = None,
samples_per_inference: Optional[int] = None,
number_simulations: Optional[int] = None,
parameter_names: Optional[Sequence] = None,
parameter_colors: Optional[Sequence]= None,
colorway: Optional[str]=None
):
super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway)

def _plot_name(self):
return "predictive_prior_check.png"

def get_prior_samples(self, n_columns, n_rows):

context_shape = self.data.true_context().shape

self.prior_predictive_samples = np.zeros((n_rows, n_columns, context_shape[-1]))
self.prior_sample = np.zeros((n_rows, n_columns, self.data.n_dims))
self.context = np.zeros((n_rows, n_columns, context_shape[-1]))
random_context_indices = self.data.rng.integers(0, context_shape[0], (n_rows, n_columns))

for row_index in range(n_rows):
for column_index in range(n_columns):

sample = random_context_indices[row_index, column_index]
context_sample = self.data.true_context()[sample, :]

prior_sample = self.data.sample_prior(1)[0]
# get the posterior samples for that context
self.prior_predictive_samples[row_index, column_index] = self.data.simulator.simulate(
theta=prior_sample, context_samples = context_sample
)
self.prior_sample[row_index, column_index] = prior_sample
self.context[row_index, column_index] = context_sample

def fill_text(self, row_index, column_index, row_parameter_index, column_parameter_index, label_samples, round_parameters):
if label_samples in ['both', 'rows', 'columns']:
row_name = self.parameter_names[row_parameter_index]
row_value = self.prior_sample[row_index, column_index, row_parameter_index]

col_name = self.parameter_names[column_parameter_index]
col_value = self.prior_sample[row_index, column_index, column_parameter_index]
if round_parameters:
row_value = round(row_value, 4)
col_value = round(col_value, 4)

if label_samples == "both":
return f"{row_name}={row_value}, {col_name}={col_value}"
elif label_samples == "rows":
return f"{row_name}={row_value}"
else:
return f"{col_name}={col_value}"

else:
raise ValueError(f"Cannot use {label_samples} to assign labels. Choose from 'both', 'rows', 'columns'.")


def _plot(
self,
n_rows: Optional[int] = 3,
n_columns: Optional[int] = 3,
row_parameter_index: Optional[int] = 0,
column_parameter_index: Optional[int] = 1,
round_parameters: Optional[bool] = True,
sort_rows: bool = True,
sort_columns: bool = True,
label_samples: Optional[str] = 'both',
title:Optional[str]="Simulated output from prior",
y_label:Optional[str]=None,
x_label:str=None):


self.get_prior_samples(n_rows, n_columns)
figure, subplots = plt.subplots(
n_columns,
n_rows,
figsize=(int(self.figure_size[0]*n_rows*.6), int(self.figure_size[1]*n_columns*.6)),
sharex=False,
sharey=True
)

if x_label is None:
x_label = f"$theta_{row_parameter_index}$ = {self.parameter_names[row_parameter_index]}"

if y_label is None:
y_label = f"$theta_{column_parameter_index}$ = {self.parameter_names[column_parameter_index]}"

column_order = np.argsort(
self.prior_sample[:, :, column_parameter_index], axis=-1
)
row_order = np.argsort(
self.prior_sample[:, :, row_parameter_index], axis=-1
)

for plot_row_index in range(n_rows):
for plot_column_index in range(n_columns):

row_index = plot_row_index if not sort_rows else row_order[plot_row_index, plot_column_index]
column_index = plot_column_index if not sort_rows else column_order[plot_row_index, plot_column_index]

text = self.fill_text(
row_index,
column_index,
row_parameter_index,
column_parameter_index,
label_samples=label_samples,
round_parameters=round_parameters
)

subplots[plot_row_index, plot_column_index].title.set_text(text)
subplots[plot_row_index, plot_column_index].plot(
self.context[column_index, row_index],
self.prior_predictive_samples[column_index, row_index]
)

figure.supylabel(y_label)
figure.supxlabel(x_label)
figure.suptitle(title)
4 changes: 3 additions & 1 deletion src/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
"coverage_sigma": 3 # How many sigma to show coverage over
},
"LC2ST": {},
"PPC":{}
"PPC": {},
"PriorPC":{}

},
"metrics_common": {
"use_progress_bar": False,
Expand Down
9 changes: 8 additions & 1 deletion tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
CoverageFraction,
TARP,
LocalTwoSampleTest,
PPC
PPC,
PriorPC
)


Expand Down Expand Up @@ -85,4 +86,10 @@ def test_ppc(plot_config, mock_model, mock_data, mock_2d_data, result_output):
out_dir=f"{result_output.strip('/')}/mock_2d/")
assert type(plot.data.simulator).__name__ == "Mock2DSimulator"
plot(**get_item("plots", "PPC", raise_exception=False))


def test_prior_pc(plot_config, mock_model, mock_data):
Config(plot_config)
plot = PriorPC(mock_model, mock_data, save=True, show=False)
plot(**get_item("plots", "PriorPC", raise_exception=False))
assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")

0 comments on commit 41cc428

Please sign in to comment.