From f783b36a8bb5876a92a7dabd6bc7fe2259dae53f Mon Sep 17 00:00:00 2001 From: voetberg Date: Fri, 21 Jun 2024 13:53:04 -0500 Subject: [PATCH] Done for metrics --- docs/source/metrics.rst | 13 ++++--- docs/source/plots.rst | 12 +++--- src/deepdiagnostics/metrics/__init__.py | 4 +- src/deepdiagnostics/metrics/all_sbc.py | 20 +++++++++- .../metrics/coverage_fraction.py | 18 ++++++++- .../metrics/local_two_sample.py | 38 ++++++++++++++++--- src/deepdiagnostics/plots/cdf_ranks.py | 6 +-- .../plots/coverage_fraction.py | 4 +- src/deepdiagnostics/plots/local_two_sample.py | 6 +-- src/deepdiagnostics/plots/plot.py | 6 +-- .../plots/predictive_posterior_check.py | 8 ++-- src/deepdiagnostics/plots/ranks.py | 4 +- src/deepdiagnostics/plots/tarp.py | 6 +-- 13 files changed, 103 insertions(+), 42 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 9bcd3e0..12f4241 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -6,14 +6,15 @@ Metrics .. autoclass:: metrics.metric.Metric :members: -.. autoclass:: plots.CDFRanks - :private-members: _plot +.. autoclass:: metrics.AllSBC + :members: calculate -.. autoclass:: plots.Ranks - :private-members: _plot +.. autoclass:: metrics.LC2ST -.. autoclass:: plots.CoverageFraction - :private-members: _plot +.. autoclass:: metrics.local_two_sample.LocalTwoSampleTest + :members: calculate +.. autoclass:: metrics.CoverageFraction + :members: calculate .. bibliography:: \ No newline at end of file diff --git a/docs/source/plots.rst b/docs/source/plots.rst index 17cad93..009f344 100644 --- a/docs/source/plots.rst +++ b/docs/source/plots.rst @@ -6,22 +6,22 @@ Plots .. autoclass:: plots.plot.Display .. autoclass:: plots.CDFRanks - :private-members: _plot + :members: plot .. autoclass:: plots.Ranks - :private-members: _plot + :members: plot .. autoclass:: plots.CoverageFraction - :private-members: _plot + :members: plot .. autoclass:: plots.TARP - :private-members: _plot + :members: plot .. autoclass:: plots.LC2ST .. autoclass:: plots.local_two_sample.LocalTwoSampleTest - :private-members: _plot + :members: plot .. autoclass:: plots.PPC - :private-members: _plot + :members: plot .. bibliography:: \ No newline at end of file diff --git a/src/deepdiagnostics/metrics/__init__.py b/src/deepdiagnostics/metrics/__init__.py index 9211289..5505faa 100644 --- a/src/deepdiagnostics/metrics/__init__.py +++ b/src/deepdiagnostics/metrics/__init__.py @@ -1,9 +1,9 @@ from deepdiagnostics.metrics.all_sbc import AllSBC from deepdiagnostics.metrics.coverage_fraction import CoverageFraction -from deepdiagnostics.metrics.local_two_sample import LocalTwoSampleTest +from deepdiagnostics.metrics.local_two_sample import LocalTwoSampleTest as LC2ST Metrics = { CoverageFraction.__name__: CoverageFraction, AllSBC.__name__: AllSBC, - "LC2ST": LocalTwoSampleTest + "LC2ST": LC2ST } diff --git a/src/deepdiagnostics/metrics/all_sbc.py b/src/deepdiagnostics/metrics/all_sbc.py index b56fc77..28a6dba 100644 --- a/src/deepdiagnostics/metrics/all_sbc.py +++ b/src/deepdiagnostics/metrics/all_sbc.py @@ -6,6 +6,18 @@ class AllSBC(Metric): + """ + Calculate SBC diagnostics metrics and add them to the output. + Adapted from :cite:p:`centero2020sbi`. + More information about specific metrics can be found `here `_. + + .. code-block:: python + + from deepdiagnostics.metrics import AllSBC + + metrics = AllSBC(model, data, save=False)() + metrics = metrics.output + """ def __init__( self, model: Any, @@ -29,7 +41,13 @@ def _collect_data_params(self): self.thetas = tensor(self.data.get_theta_true()) self.context = tensor(self.data.true_context()) - def calculate(self): + def calculate(self) -> dict[str, Sequence]: + """ + Calculate all SBC diagnostic metrics + + Returns: + dict[str, Sequence]: Dictionary with all calculations, labeled by their name. + """ ranks, dap_samples = run_sbc( self.thetas, self.context, diff --git a/src/deepdiagnostics/metrics/coverage_fraction.py b/src/deepdiagnostics/metrics/coverage_fraction.py index 4399809..04e4933 100644 --- a/src/deepdiagnostics/metrics/coverage_fraction.py +++ b/src/deepdiagnostics/metrics/coverage_fraction.py @@ -6,7 +6,15 @@ from deepdiagnostics.metrics.metric import Metric class CoverageFraction(Metric): - """ """ + """ + Calculate the coverage of a set number of inferences over different confidence regions. + + .. code-block:: python + + from deepdiagnostics.metrics import CoverageFraction + + samples, coverage = CoverageFraction(model, data, save=False).calculate() + """ def __init__( self, @@ -36,7 +44,13 @@ def _run_model_inference(self, samples_per_inference, y_inference): samples = self.model.sample_posterior(samples_per_inference, y_inference) return samples - def calculate(self): + def calculate(self) -> tuple[Sequence, Sequence]: + """ + Calculate the coverage fraction of the given model and data + + Returns: + tuple[Sequence, Sequence]: A tuple of the samples tested (M samples, Samples per inference, N parameters) and the coverage over those samples. + """ all_samples = np.empty( (len(self.context), self.samples_per_inference, np.shape(self.thetas)[1]) ) diff --git a/src/deepdiagnostics/metrics/local_two_sample.py b/src/deepdiagnostics/metrics/local_two_sample.py index dcb77f4..978b977 100644 --- a/src/deepdiagnostics/metrics/local_two_sample.py +++ b/src/deepdiagnostics/metrics/local_two_sample.py @@ -9,6 +9,21 @@ from deepdiagnostics.metrics.metric import Metric class LocalTwoSampleTest(Metric): + """ + Adapted from :cite:p:`linhart2023lc2st`. + Train a classifier to verify the quality of the posterior via classifier accuracy. + Produces an array of inference accuracies for the trained classier, representing the cases of either denying the null hypothesis + (that the posterior output of the simulation is not significantly different from a given random sample.) + + Code referenced from: + `github.com/JuliaLinhart/lc2st/lc2st.py::train_lc2st `_. + + .. code-block:: python + + from deepdiagnostics.metrics import LC2ST + + true_probabilities, null_hypothesis_probabilities = LC2ST(model, data, save=False).calculate() + """ def __init__( self, model: Any, @@ -32,7 +47,6 @@ def __init__( number_simulations ) - def _collect_data_params(self): # P is the prior and x_P is generated via the simulator from the parameters P. self.p = self.data.sample_prior(self.number_simulations) @@ -74,7 +88,6 @@ def _collect_data_params(self): self.outcome_given_q[index] = q_outcome.ravel() # Q is the approximate posterior amortized in x - def train_linear_classifier( self, p, q, x_p, x_q, classifier: str, classifier_kwargs: dict = {} ): @@ -188,7 +201,7 @@ def _cross_eval_score( ) return probabilities - def permute_data(self, P, Q): + def _permute_data(self, P, Q): """Permute the concatenated data [P,Q] to create null-hyp samples. Args: @@ -206,7 +219,22 @@ def calculate( cross_evaluate: bool = True, n_null_hypothesis_trials=100, classifier_kwargs: Union[dict, list[dict]] = None, - ): + ) -> tuple[np.ndarray, np.ndarray]: + """ + Perform the calculation for the LC2ST. + Adds the results to the lc2st.output (dict) under the parameters + "lc2st_probabilities", "lc2st_null_hypothesis_probabilities" as lists. + + Args: + linear_classifier (Union[str, list[str]], optional): linear classifier to use for the test. Only MLP is implemented. Defaults to "MLP". + cross_evaluate (bool, optional): Use a k-fold'd dataset for evaluation. Defaults to True. + n_null_hypothesis_trials (int, optional): Number of times to draw and test the null hypothesis. Defaults to 100. + classifier_kwargs (Union[dict, list[dict]], optional): Additional kwargs for the linear classifier. Defaults to None. + + Returns: + tuple[np.ndarray, np.ndarray]: arrays storing the true and null hypothesis probabilities given the linear classifier. + + """ if isinstance(linear_classifier, str): linear_classifier = [linear_classifier] @@ -228,7 +256,7 @@ def calculate( for _ in range(n_null_hypothesis_trials): joint_P_x = np.concatenate([self.p, self.outcome_given_p], axis=1) joint_Q_x = np.concatenate([self.q, self.outcome_given_q], axis=1) - joint_P_x_perm, joint_Q_x_perm = self.permute_data( + joint_P_x_perm, joint_Q_x_perm = self._permute_data( joint_P_x, joint_Q_x, ) diff --git a/src/deepdiagnostics/plots/cdf_ranks.py b/src/deepdiagnostics/plots/cdf_ranks.py index 236d592..ac4b2e7 100644 --- a/src/deepdiagnostics/plots/cdf_ranks.py +++ b/src/deepdiagnostics/plots/cdf_ranks.py @@ -37,7 +37,7 @@ def __init__( super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) - def _plot_name(self): + def plot_name(self): return "cdf_ranks.png" def _data_setup(self): @@ -49,10 +49,10 @@ def _data_setup(self): ) self.ranks = ranks - def _plot_settings(self): + def plot_settings(self): pass - def _plot(self): + def plot(self): """ """ sbc_rank_plot( diff --git a/src/deepdiagnostics/plots/coverage_fraction.py b/src/deepdiagnostics/plots/coverage_fraction.py index 8b33c01..42beead 100644 --- a/src/deepdiagnostics/plots/coverage_fraction.py +++ b/src/deepdiagnostics/plots/coverage_fraction.py @@ -43,7 +43,7 @@ def __init__( self.n_parameters = len(self.parameter_names) self.line_cycle = tuple(get_item("plots_common", "line_style_cycle", raise_exception=False)) - def _plot_name(self): + def plot_name(self): return "coverage_fraction.png" def _data_setup(self): @@ -52,7 +52,7 @@ def _data_setup(self): ).calculate() self.coverage_fractions = coverage - def _plot( + def plot( self, figure_alpha=1.0, line_width=3, diff --git a/src/deepdiagnostics/plots/local_two_sample.py b/src/deepdiagnostics/plots/local_two_sample.py index 1f99bbe..4814979 100644 --- a/src/deepdiagnostics/plots/local_two_sample.py +++ b/src/deepdiagnostics/plots/local_two_sample.py @@ -57,7 +57,7 @@ def __init__( self.region_colors = get_hex_colors(n_colors=len(self.percentiles), colorway=self.colorway) self.l2st = l2st(model, data, out_dir, True, self.use_progress_bar, self.samples_per_inference, self.percentiles, self.number_simulations) - def _plot_name(self): + def plot_name(self): return "local_C2ST.png" def _make_pairplot_values(self, random_samples): @@ -182,7 +182,7 @@ def probability_intensity(self, subplot, features, n_bins=20): ) subplot.add_patch(rect) - def _plot( + def plot( self, use_intensity_plot: bool = True, n_alpha_samples: int = 100, @@ -291,4 +291,4 @@ def _plot( self._finish() def __call__(self, **plot_args) -> None: - self._plot(**plot_args) + self.plot(**plot_args) diff --git a/src/deepdiagnostics/plots/plot.py b/src/deepdiagnostics/plots/plot.py index 62024e6..7813fe5 100644 --- a/src/deepdiagnostics/plots/plot.py +++ b/src/deepdiagnostics/plots/plot.py @@ -66,14 +66,14 @@ def __init__( self._common_settings() self.plot_name = self._plot_name() - def _plot_name(self): + def plot_name(self): raise NotImplementedError def _data_setup(self): # Set all the vars used for the plot raise NotImplementedError - def _plot(self, **kwrgs): + def plot(self, **kwrgs): # Make the plot object with plt. raise NotImplementedError @@ -113,5 +113,5 @@ def __call__(self, **plot_args) -> None: except NotImplementedError: pass - self._plot(**plot_args) + self.plot(**plot_args) self._finish() diff --git a/src/deepdiagnostics/plots/predictive_posterior_check.py b/src/deepdiagnostics/plots/predictive_posterior_check.py index 6f836c0..b56a798 100644 --- a/src/deepdiagnostics/plots/predictive_posterior_check.py +++ b/src/deepdiagnostics/plots/predictive_posterior_check.py @@ -35,7 +35,7 @@ def __init__( ): super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) - def _plot_name(self): + def plot_name(self): return "predictive_posterior_check.png" def get_posterior_2d(self, n_simulator_draws): @@ -91,7 +91,7 @@ def get_posterior_1d(self, n_simulator_draws): theta=self.data.get_theta_true()[sample, :], context_samples=context_sample ) - def _plot_1d(self, + def plot_1d(self, subplots: np.ndarray, subplot_index: int, n_coverage_sigma: Optional[int] = 3, @@ -134,7 +134,7 @@ def _plot_1d(self, label='Theta True' ) - def _plot_2d(self, subplots, subplot_index, include_axis_ticks): + def plot_2d(self, subplots, subplot_index, include_axis_ticks): subplots[1, subplot_index].imshow(self.posterior_predictive_samples[subplot_index]) subplots[0, subplot_index].imshow(self.posterior_true_samples[subplot_index]) @@ -145,7 +145,7 @@ def _plot_2d(self, subplots, subplot_index, include_axis_ticks): subplots[0, subplot_index].set_xticks([]) subplots[0, subplot_index].set_yticks([]) - def _plot( + def plot( self, n_coverage_sigma: Optional[int] = 3, true_sigma: Optional[float] = None, diff --git a/src/deepdiagnostics/plots/ranks.py b/src/deepdiagnostics/plots/ranks.py index d689c15..746cf71 100644 --- a/src/deepdiagnostics/plots/ranks.py +++ b/src/deepdiagnostics/plots/ranks.py @@ -37,7 +37,7 @@ def __init__( ): super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) - def _plot_name(self): + def plot_name(self): return "ranks.png" def _data_setup(self): @@ -48,7 +48,7 @@ def _data_setup(self): ) self.ranks = ranks - def _plot(self, num_bins:int=20): + def plot(self, num_bins:int=20): """ Args: num_bins (int): Number of histogram bins. Defaults to 20. diff --git a/src/deepdiagnostics/plots/tarp.py b/src/deepdiagnostics/plots/tarp.py index e2e3740..ba39fd6 100644 --- a/src/deepdiagnostics/plots/tarp.py +++ b/src/deepdiagnostics/plots/tarp.py @@ -45,7 +45,7 @@ def __init__( self.line_style = get_item( "plots_common", "line_style_cycle", raise_exception=False ) - def _plot_name(self): + def plot_name(self): return "tarp.png" def _data_setup(self): @@ -67,7 +67,7 @@ def _data_setup(self): self.posterior_samples = np.swapaxes(self.posterior_samples, 0, 1) - def _plot_settings(self): + def plot_settings(self): self.line_style = get_item( "plots_common", "line_style_cycle", raise_exception=False ) @@ -82,7 +82,7 @@ def _get_hex_sigma_colors(self, n_colors): return hex_colors - def _plot( + def plot( self, coverage_sigma: int = 3, reference_point: Union[str, np.ndarray] = "random",