Skip to content

Commit

Permalink
Done for metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
voetberg committed Jun 21, 2024
1 parent 0e8b23d commit f783b36
Show file tree
Hide file tree
Showing 13 changed files with 103 additions and 42 deletions.
13 changes: 7 additions & 6 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ Metrics
.. autoclass:: metrics.metric.Metric
:members:

.. autoclass:: plots.CDFRanks
:private-members: _plot
.. autoclass:: metrics.AllSBC
:members: calculate

.. autoclass:: plots.Ranks
:private-members: _plot
.. autoclass:: metrics.LC2ST

.. autoclass:: plots.CoverageFraction
:private-members: _plot
.. autoclass:: metrics.local_two_sample.LocalTwoSampleTest
:members: calculate

.. autoclass:: metrics.CoverageFraction
:members: calculate

.. bibliography::
12 changes: 6 additions & 6 deletions docs/source/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@ Plots
.. autoclass:: plots.plot.Display

.. autoclass:: plots.CDFRanks
:private-members: _plot
:members: plot

.. autoclass:: plots.Ranks
:private-members: _plot
:members: plot

.. autoclass:: plots.CoverageFraction
:private-members: _plot
:members: plot

.. autoclass:: plots.TARP
:private-members: _plot
:members: plot

.. autoclass:: plots.LC2ST
.. autoclass:: plots.local_two_sample.LocalTwoSampleTest
:private-members: _plot
:members: plot

.. autoclass:: plots.PPC
:private-members: _plot
:members: plot

.. bibliography::
4 changes: 2 additions & 2 deletions src/deepdiagnostics/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from deepdiagnostics.metrics.all_sbc import AllSBC
from deepdiagnostics.metrics.coverage_fraction import CoverageFraction
from deepdiagnostics.metrics.local_two_sample import LocalTwoSampleTest
from deepdiagnostics.metrics.local_two_sample import LocalTwoSampleTest as LC2ST

Metrics = {
CoverageFraction.__name__: CoverageFraction,
AllSBC.__name__: AllSBC,
"LC2ST": LocalTwoSampleTest
"LC2ST": LC2ST
}
20 changes: 19 additions & 1 deletion src/deepdiagnostics/metrics/all_sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@


class AllSBC(Metric):
"""
Calculate SBC diagnostics metrics and add them to the output.
Adapted from :cite:p:`centero2020sbi`.
More information about specific metrics can be found `here <https://sbi-dev.github.io/sbi/tutorial/13_diagnostics_simulation_based_calibration/#a-shifted-posterior-mean>`_.
.. code-block:: python
from deepdiagnostics.metrics import AllSBC
metrics = AllSBC(model, data, save=False)()
metrics = metrics.output
"""
def __init__(
self,
model: Any,
Expand All @@ -29,7 +41,13 @@ def _collect_data_params(self):
self.thetas = tensor(self.data.get_theta_true())
self.context = tensor(self.data.true_context())

def calculate(self):
def calculate(self) -> dict[str, Sequence]:
"""
Calculate all SBC diagnostic metrics
Returns:
dict[str, Sequence]: Dictionary with all calculations, labeled by their name.
"""
ranks, dap_samples = run_sbc(
self.thetas,
self.context,
Expand Down
18 changes: 16 additions & 2 deletions src/deepdiagnostics/metrics/coverage_fraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
from deepdiagnostics.metrics.metric import Metric

class CoverageFraction(Metric):
""" """
"""
Calculate the coverage of a set number of inferences over different confidence regions.
.. code-block:: python
from deepdiagnostics.metrics import CoverageFraction
samples, coverage = CoverageFraction(model, data, save=False).calculate()
"""

def __init__(
self,
Expand Down Expand Up @@ -36,7 +44,13 @@ def _run_model_inference(self, samples_per_inference, y_inference):
samples = self.model.sample_posterior(samples_per_inference, y_inference)
return samples

def calculate(self):
def calculate(self) -> tuple[Sequence, Sequence]:
"""
Calculate the coverage fraction of the given model and data
Returns:
tuple[Sequence, Sequence]: A tuple of the samples tested (M samples, Samples per inference, N parameters) and the coverage over those samples.
"""
all_samples = np.empty(
(len(self.context), self.samples_per_inference, np.shape(self.thetas)[1])
)
Expand Down
38 changes: 33 additions & 5 deletions src/deepdiagnostics/metrics/local_two_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,21 @@
from deepdiagnostics.metrics.metric import Metric

class LocalTwoSampleTest(Metric):
"""
Adapted from :cite:p:`linhart2023lc2st`.
Train a classifier to verify the quality of the posterior via classifier accuracy.
Produces an array of inference accuracies for the trained classier, representing the cases of either denying the null hypothesis
(that the posterior output of the simulation is not significantly different from a given random sample.)
Code referenced from:
`github.com/JuliaLinhart/lc2st/lc2st.py::train_lc2st <https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/lc2st.py#L25>`_.
.. code-block:: python
from deepdiagnostics.metrics import LC2ST
true_probabilities, null_hypothesis_probabilities = LC2ST(model, data, save=False).calculate()
"""
def __init__(
self,
model: Any,
Expand All @@ -32,7 +47,6 @@ def __init__(
number_simulations
)


def _collect_data_params(self):
# P is the prior and x_P is generated via the simulator from the parameters P.
self.p = self.data.sample_prior(self.number_simulations)
Expand Down Expand Up @@ -74,7 +88,6 @@ def _collect_data_params(self):
self.outcome_given_q[index] = q_outcome.ravel() # Q is the approximate posterior amortized in x



def train_linear_classifier(
self, p, q, x_p, x_q, classifier: str, classifier_kwargs: dict = {}
):
Expand Down Expand Up @@ -188,7 +201,7 @@ def _cross_eval_score(
)
return probabilities

def permute_data(self, P, Q):
def _permute_data(self, P, Q):
"""Permute the concatenated data [P,Q] to create null-hyp samples.
Args:
Expand All @@ -206,7 +219,22 @@ def calculate(
cross_evaluate: bool = True,
n_null_hypothesis_trials=100,
classifier_kwargs: Union[dict, list[dict]] = None,
):
) -> tuple[np.ndarray, np.ndarray]:
"""
Perform the calculation for the LC2ST.
Adds the results to the lc2st.output (dict) under the parameters
"lc2st_probabilities", "lc2st_null_hypothesis_probabilities" as lists.
Args:
linear_classifier (Union[str, list[str]], optional): linear classifier to use for the test. Only MLP is implemented. Defaults to "MLP".
cross_evaluate (bool, optional): Use a k-fold'd dataset for evaluation. Defaults to True.
n_null_hypothesis_trials (int, optional): Number of times to draw and test the null hypothesis. Defaults to 100.
classifier_kwargs (Union[dict, list[dict]], optional): Additional kwargs for the linear classifier. Defaults to None.
Returns:
tuple[np.ndarray, np.ndarray]: arrays storing the true and null hypothesis probabilities given the linear classifier.
"""
if isinstance(linear_classifier, str):
linear_classifier = [linear_classifier]

Expand All @@ -228,7 +256,7 @@ def calculate(
for _ in range(n_null_hypothesis_trials):
joint_P_x = np.concatenate([self.p, self.outcome_given_p], axis=1)
joint_Q_x = np.concatenate([self.q, self.outcome_given_q], axis=1)
joint_P_x_perm, joint_Q_x_perm = self.permute_data(
joint_P_x_perm, joint_Q_x_perm = self._permute_data(
joint_P_x,
joint_Q_x,
)
Expand Down
6 changes: 3 additions & 3 deletions src/deepdiagnostics/plots/cdf_ranks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(

super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway)

def _plot_name(self):
def plot_name(self):
return "cdf_ranks.png"

def _data_setup(self):
Expand All @@ -49,10 +49,10 @@ def _data_setup(self):
)
self.ranks = ranks

def _plot_settings(self):
def plot_settings(self):
pass

def _plot(self):
def plot(self):
"""
"""
sbc_rank_plot(
Expand Down
4 changes: 2 additions & 2 deletions src/deepdiagnostics/plots/coverage_fraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
self.n_parameters = len(self.parameter_names)
self.line_cycle = tuple(get_item("plots_common", "line_style_cycle", raise_exception=False))

def _plot_name(self):
def plot_name(self):
return "coverage_fraction.png"

def _data_setup(self):
Expand All @@ -52,7 +52,7 @@ def _data_setup(self):
).calculate()
self.coverage_fractions = coverage

def _plot(
def plot(
self,
figure_alpha=1.0,
line_width=3,
Expand Down
6 changes: 3 additions & 3 deletions src/deepdiagnostics/plots/local_two_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self.region_colors = get_hex_colors(n_colors=len(self.percentiles), colorway=self.colorway)
self.l2st = l2st(model, data, out_dir, True, self.use_progress_bar, self.samples_per_inference, self.percentiles, self.number_simulations)

def _plot_name(self):
def plot_name(self):
return "local_C2ST.png"

def _make_pairplot_values(self, random_samples):
Expand Down Expand Up @@ -182,7 +182,7 @@ def probability_intensity(self, subplot, features, n_bins=20):
)
subplot.add_patch(rect)

def _plot(
def plot(
self,
use_intensity_plot: bool = True,
n_alpha_samples: int = 100,
Expand Down Expand Up @@ -291,4 +291,4 @@ def _plot(
self._finish()

def __call__(self, **plot_args) -> None:
self._plot(**plot_args)
self.plot(**plot_args)
6 changes: 3 additions & 3 deletions src/deepdiagnostics/plots/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ def __init__(
self._common_settings()
self.plot_name = self._plot_name()

def _plot_name(self):
def plot_name(self):
raise NotImplementedError

def _data_setup(self):
# Set all the vars used for the plot
raise NotImplementedError

def _plot(self, **kwrgs):
def plot(self, **kwrgs):
# Make the plot object with plt.
raise NotImplementedError

Expand Down Expand Up @@ -113,5 +113,5 @@ def __call__(self, **plot_args) -> None:
except NotImplementedError:
pass

self._plot(**plot_args)
self.plot(**plot_args)
self._finish()
8 changes: 4 additions & 4 deletions src/deepdiagnostics/plots/predictive_posterior_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
):
super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway)

def _plot_name(self):
def plot_name(self):
return "predictive_posterior_check.png"

def get_posterior_2d(self, n_simulator_draws):
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_posterior_1d(self, n_simulator_draws):
theta=self.data.get_theta_true()[sample, :], context_samples=context_sample
)

def _plot_1d(self,
def plot_1d(self,
subplots: np.ndarray,
subplot_index: int,
n_coverage_sigma: Optional[int] = 3,
Expand Down Expand Up @@ -134,7 +134,7 @@ def _plot_1d(self,
label='Theta True'
)

def _plot_2d(self, subplots, subplot_index, include_axis_ticks):
def plot_2d(self, subplots, subplot_index, include_axis_ticks):
subplots[1, subplot_index].imshow(self.posterior_predictive_samples[subplot_index])
subplots[0, subplot_index].imshow(self.posterior_true_samples[subplot_index])

Expand All @@ -145,7 +145,7 @@ def _plot_2d(self, subplots, subplot_index, include_axis_ticks):
subplots[0, subplot_index].set_xticks([])
subplots[0, subplot_index].set_yticks([])

def _plot(
def plot(
self,
n_coverage_sigma: Optional[int] = 3,
true_sigma: Optional[float] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/deepdiagnostics/plots/ranks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
):
super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway)

def _plot_name(self):
def plot_name(self):
return "ranks.png"

def _data_setup(self):
Expand All @@ -48,7 +48,7 @@ def _data_setup(self):
)
self.ranks = ranks

def _plot(self, num_bins:int=20):
def plot(self, num_bins:int=20):
"""
Args:
num_bins (int): Number of histogram bins. Defaults to 20.
Expand Down
6 changes: 3 additions & 3 deletions src/deepdiagnostics/plots/tarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
self.line_style = get_item(
"plots_common", "line_style_cycle", raise_exception=False
)
def _plot_name(self):
def plot_name(self):
return "tarp.png"

def _data_setup(self):
Expand All @@ -67,7 +67,7 @@ def _data_setup(self):

self.posterior_samples = np.swapaxes(self.posterior_samples, 0, 1)

def _plot_settings(self):
def plot_settings(self):
self.line_style = get_item(
"plots_common", "line_style_cycle", raise_exception=False
)
Expand All @@ -82,7 +82,7 @@ def _get_hex_sigma_colors(self, n_colors):

return hex_colors

def _plot(
def plot(
self,
coverage_sigma: int = 3,
reference_point: Union[str, np.ndarray] = "random",
Expand Down

0 comments on commit f783b36

Please sign in to comment.