diff --git a/pyproject.toml b/pyproject.toml index eea7716..095de07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ numpy = "^1.26.4" matplotlib = "^3.8.3" tarp = "^0.1.1" deprecation = "^2.1.0" +scipy = "1.12.0" [tool.poetry.group.dev.dependencies] diff --git a/src/client/client.py b/src/client/client.py index 063b1d3..0ffe68f 100644 --- a/src/client/client.py +++ b/src/client/client.py @@ -97,9 +97,15 @@ def main(): plots = config.get_section("plots", raise_exception=False) for metrics_name, metrics_args in metrics.items(): - Metrics[metrics_name](model, data, **metrics_args)() + try: + Metrics[metrics_name](model, data, **metrics_args)() + except (NotImplementedError, RuntimeError) as error: + print(f"WARNING - skipping metric {metrics_name} due to error: {error}") for plot_name, plot_args in plots.items(): - Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)( - **plot_args - ) + try: + Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)( + **plot_args + ) + except (NotImplementedError, RuntimeError) as error: + print(f"WARNING - skipping plot {plot_name} due to error: {error}") diff --git a/src/data/data.py b/src/data/data.py index 129877d..a022d6f 100644 --- a/src/data/data.py +++ b/src/data/data.py @@ -125,4 +125,4 @@ def load_prior(self, prior, prior_kwargs): return lambda size: choices[prior](**prior_kwargs, size=size) except KeyError as e: - raise RuntimeError(f"Data missing a prior specification - {e}") + raise RuntimeError(f"Data missing a prior specification - {e}") \ No newline at end of file diff --git a/src/data/h5_data.py b/src/data/h5_data.py index 80ddac0..c10b4a5 100644 --- a/src/data/h5_data.py +++ b/src/data/h5_data.py @@ -10,7 +10,8 @@ class H5Data(Data): def __init__(self, path: str, simulator: Callable): super().__init__(path, simulator) - + self.theta_true = self.get_theta_true() + def _load(self, path): assert path.split(".")[-1] == "h5", "File extension must be h5" loaded_data = {} diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py index 6d58c90..450669f 100644 --- a/src/metrics/__init__.py +++ b/src/metrics/__init__.py @@ -1,4 +1,7 @@ from metrics.all_sbc import AllSBC from metrics.coverage_fraction import CoverageFraction +from metrics.local_two_sample import LocalTwoSampleTest -Metrics = {CoverageFraction.__name__: CoverageFraction, AllSBC.__name__: AllSBC} + +_all = [CoverageFraction, AllSBC, LocalTwoSampleTest] +Metrics = {m.__name__: m for m in _all} diff --git a/src/metrics/local_two_sample.py b/src/metrics/local_two_sample.py new file mode 100644 index 0000000..e078670 --- /dev/null +++ b/src/metrics/local_two_sample.py @@ -0,0 +1,175 @@ +from typing import Any, Optional, Union +import numpy as np + +from sklearn.model_selection import KFold +from sklearn.neural_network import MLPClassifier +from sklearn.utils import shuffle + +from metrics.metric import Metric +from utils.config import get_item + +class LocalTwoSampleTest(Metric): + def __init__(self, model: Any, data: Any, out_dir: str | None = None, num_simulations: Optional[int] = None) -> None: + super().__init__(model, data, out_dir) + self.num_simulations = num_simulations if num_simulations is not None else get_item( + "metrics_common", "number_simulations", raise_exception=False + ) + def _collect_data_params(self): + + # P is the prior and x_P is generated via the simulator from the parameters P. + self.p = self.data.sample_prior(self.num_simulations) + self.q = np.zeros_like(self.p) + + self.outcome_given_p = np.zeros((self.num_simulations, self.data.simulator.generate_context().shape[-1])) + self.outcome_given_q = np.zeros_like(self.outcome_given_p) + self.evaluation_context = np.zeros_like(self.outcome_given_p) + + for index, p in enumerate(self.p): + context = self.data.simulator.generate_context() + self.outcome_given_p[index] = self.data.simulator.simulate(p, context) + # Q is the approximate posterior amortized in x + q = self.model.sample_posterior(1, context).ravel() + self.q[index] = q + self.outcome_given_q[index] = self.data.simulator.simulate(q, context) + + self.evaluation_context = np.array([self.data.simulator.generate_context() for _ in range(self.num_simulations)]) + + def train_linear_classifier(self, p, q, x_p, x_q, classifier:str, classifier_kwargs:dict={}): + classifier_map = { + "MLP":MLPClassifier + } + try: + classifier = classifier_map[classifier](**classifier_kwargs) + except KeyError: + raise NotImplementedError( + f"{classifier} not implemented, choose from {list(classifier_map.keys())}.") + + joint_P_x = np.concatenate([p, x_p], axis=1) + joint_Q_x = np.concatenate([q, x_q], axis=1) + + features = np.concatenate([joint_P_x, joint_Q_x], axis=0) + labels = np.concatenate( + [np.array([0] * len(joint_P_x)), np.array([1] * len(joint_Q_x))] + ).ravel() + + # shuffle features and labels + features, labels = shuffle(features, labels) + + # train the classifier + classifier.fit(X=features, y=labels) + return classifier + + def _eval_model(self, P, evaluation_sample, classifier): + evaluation = np.concatenate([P, evaluation_sample], axis=1) + probability = classifier.predict_proba(evaluation)[:, 0] + return probability + + def _scores(self, p, q, x_p, x_q, classifier, cross_evaluate: bool=True, classifier_kwargs=None): + model_probabilities = [] + for model, model_args in zip(classifier, classifier_kwargs): + if cross_evaluate: + model_probabilities.append(self._cross_eval_score(p, q, x_p, x_q, model, model_args)) + else: + trained_model = self.train_linear_classifier(p, q, x_p, x_q, model, model_args) + model_probabilities.append(self._eval_model(P=p, classifier=trained_model)) + + return np.mean(model_probabilities, axis=0) + + def _cross_eval_score(self, p, q, x_p, x_q, classifier, classifier_kwargs, n_cross_folds=5): + kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42) # Getting the shape + cv_splits = kf.split(p) + # train classifiers over cv-folds + probabilities = [] + self.evaluation_data = np.zeros((n_cross_folds, len(next(cv_splits)[1]), self.evaluation_context.shape[-1])) + + kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42) + cv_splits = kf.split(p) + for cross_trial, (train_index, val_index) in enumerate(cv_splits): + # get train split + p_train, x_p_train = p[train_index,:], x_p[train_index,:] + q_train, x_q_train = q[train_index,:], x_q[train_index,:] + trained_nth_classifier = self.train_linear_classifier(p_train, q_train, x_p_train, x_q_train, classifier, classifier_kwargs) + p_evaluate = p[val_index] + for index, p_validation in enumerate(p_evaluate): + self.evaluation_data[cross_trial][index] = self.data.simulator.simulate( + p_validation, self.evaluation_context[val_index][index] + ) + probabilities.append(self._eval_model(p_evaluate, self.evaluation_data[cross_trial], trained_nth_classifier)) + return probabilities + + def permute_data(self, P, Q): + """Permute the concatenated data [P,Q] to create null-hyp samples. + + Args: + P (torch.Tensor): data of shape (n_samples, dim) + Q (torch.Tensor): data of shape (n_samples, dim) + """ + n_samples = P.shape[0] + X = np.concatenate([P, Q], axis=0) + X_perm = X[self.data.rng.permutation(np.arange(n_samples * 2))] + return X_perm[:n_samples], X_perm[n_samples:] + + def calculate( + self, + linear_classifier:Union[str, list[str]]='MLP', + cross_evaluate:bool=True, + n_null_hypothesis_trials=100, + classifier_kwargs:Union[dict, list[dict]]=None + ): + + if isinstance(linear_classifier, str): + linear_classifier = [linear_classifier] + + if classifier_kwargs is None: + classifier_kwargs = {} + if isinstance(classifier_kwargs, dict): + classifier_kwargs = [classifier_kwargs] + + probabilities = self._scores( + self.p, + self.q, + self.outcome_given_p, + self.outcome_given_q, + classifier=linear_classifier, + cross_evaluate=cross_evaluate, + classifier_kwargs=classifier_kwargs + ) + null_hypothesis_probabilities = [] + for _ in range(n_null_hypothesis_trials): + joint_P_x = np.concatenate([self.p, self.outcome_given_p], axis=1) + joint_Q_x = np.concatenate([self.q, self.outcome_given_q], axis=1) + joint_P_x_perm, joint_Q_x_perm = self.permute_data( + joint_P_x, joint_Q_x, + ) + p_null = joint_P_x_perm[:, : self.p.shape[-1]] + p_given_x_null = joint_P_x_perm[:, self.p.shape[-1] :] + q_null = joint_Q_x_perm[:, : self.q.shape[-1]] + q_given_x_null = joint_Q_x_perm[:, self.q.shape[-1] :] + + null_result = self._scores( + p_null, + q_null, + p_given_x_null, + q_given_x_null, + classifier=linear_classifier, + cross_evaluate=cross_evaluate, + classifier_kwargs=classifier_kwargs + ) + + null_hypothesis_probabilities.append(null_result) + + null = np.array(null_hypothesis_probabilities) + self.output = { + "lc2st_probabilities": probabilities, + "lc2st_null_hypothesis_probabilities": null + } + return probabilities, null + + def __call__(self, **kwds: Any) -> Any: + try: + self._collect_data_params() + except NotImplementedError: + pass + + self.calculate(**kwds) + self._finish() \ No newline at end of file diff --git a/src/models/sbi_model.py b/src/models/sbi_model.py index 9085402..a244da1 100644 --- a/src/models/sbi_model.py +++ b/src/models/sbi_model.py @@ -24,6 +24,8 @@ def sample_posterior(self, n_samples: int, y_true): # TODO typing def predict_posterior(self, data): posterior_samples = self.sample_posterior(data.y_true) posterior_predictive_samples = data.simulator( - data.theta_true(), posterior_samples + data.get_theta_true(), posterior_samples ) return posterior_predictive_samples + + \ No newline at end of file diff --git a/src/plots/__init__.py b/src/plots/__init__.py index f576bd7..b186bc2 100644 --- a/src/plots/__init__.py +++ b/src/plots/__init__.py @@ -1,11 +1,8 @@ from plots.cdf_ranks import CDFRanks from plots.coverage_fraction import CoverageFraction from plots.ranks import Ranks +from plots.local_two_sample import LocalTwoSampleTest from plots.tarp import TARP -Plots = { - CDFRanks.__name__: CDFRanks, - CoverageFraction.__name__: CoverageFraction, - Ranks.__name__: Ranks, - TARP.__name__: TARP, -} +_all = [CoverageFraction, CDFRanks, Ranks, LocalTwoSampleTest, TARP] +Plots = {m.__name__: m for m in _all} \ No newline at end of file diff --git a/src/plots/local_two_sample.py b/src/plots/local_two_sample.py new file mode 100644 index 0000000..0923763 --- /dev/null +++ b/src/plots/local_two_sample.py @@ -0,0 +1,212 @@ +from typing import Optional, Sequence, Union +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.colors import Normalize +from matplotlib.patches import Rectangle + +from plots.plot import Display +from metrics.local_two_sample import LocalTwoSampleTest as l2st +from utils.config import get_item +from utils.plotting_utils import get_hex_colors + +class LocalTwoSampleTest(Display): + + # https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L133 + + def __init__(self, + model, + data, + save:bool, + show:bool, + out_dir:Optional[str]=None, + percentiles: Optional[Sequence] = None, + parameter_names: Optional[Sequence] = None, + parameter_colors: Optional[Sequence]= None, + figure_size: Optional[Sequence] = None, + num_simulations: Optional[int] = None, + colorway: Optional[str]=None): + super().__init__(model, data, save, show, out_dir) + self.percentiles = percentiles if percentiles is not None else get_item("metrics_common", item='percentiles', raise_exception=False) + + self.param_names = parameter_names if parameter_names is not None else get_item("plots_common", item="parameter_labels", raise_exception=False) + self.param_colors = parameter_colors if parameter_colors is not None else get_item("plots_common", item="parameter_colors", raise_exception=False) + self.figure_size = figure_size if figure_size is not None else get_item("plots_common", item="figure_size", raise_exception=False) + + colorway = colorway if colorway is not None else get_item( + "plots_common", "default_colorway", raise_exception=False + ) + self.region_colors = get_hex_colors(n_colors=len(self.percentiles), colorway=colorway) + + num_simulations = num_simulations if num_simulations is not None else get_item( + "metrics_common", "number_simulations", raise_exception=False + ) + self.l2st = l2st(model, data, out_dir, num_simulations) + + def _plot_name(self): + return "local_C2ST.png" + + def _make_pairplot_values(self, random_samples): + pp_vals = np.array([np.mean(random_samples <= alpha) for alpha in self.cdf_alphas]) + return pp_vals + + def lc2st_pairplot(self, subplot, confidence_region_alpha=0.2): + + null_cdf = self._make_pairplot_values([0.5] * len(self.probability)) + subplot.plot( + self.cdf_alphas, null_cdf, "--", color="black", label="Theoretical Null CDF" + ) + + null_hypothesis_pairplot = np.zeros((len(self.cdf_alphas), *null_cdf.shape)) + + for t in range(len(self.null_hypothesis_probability)): + null_hypothesis_pairplot[t] = self._make_pairplot_values(self.null_hypothesis_probability[t]) + + + for percentile, color in zip(self.percentiles, self.region_colors): + low_null = np.quantile(null_hypothesis_pairplot, percentile/100, axis=1) + up_null = np.quantile(null_hypothesis_pairplot, (100-percentile)/100, axis=1) + + subplot.fill_between( + self.cdf_alphas, + low_null, + up_null, + color=color, + alpha=confidence_region_alpha, + label=f"{percentile}% Conf. region", + ) + + for prob, label, color in zip(self.probability, self.param_names, self.param_colors): + pairplot_values = self._make_pairplot_values(prob) + subplot.plot(self.cdf_alphas, pairplot_values, label=label, color=color) + + def probability_intensity(self, subplot, plot_dims, features, n_bins=20): + evaluation_data = self.l2st.evaluation_data + + if len(evaluation_data.shape) >=3: # Used the kfold option + evaluation_data = evaluation_data.reshape(( + evaluation_data.shape[0]*evaluation_data.shape[1], + evaluation_data.shape[-1])) + self.probability = self.probability.ravel() + + if plot_dims==1: + + _, bins, patches = subplot.hist( + evaluation_data[:,features], n_bins, weights=self.probability, density=True, color=self.param_colors[features]) + + eval_bins = np.select( + [evaluation_data[:,features] <= i for i in bins[1:]], list(range(n_bins)) + ) + + # get mean predicted proba for each bin + weights = np.array([self.probability[eval_bins==i].mean() for i in np.unique(eval_bins)]) #df_probas.groupby(["bins"]).mean().probas + colors = plt.get_cmap(self.colorway) + + for w, p in zip(weights, patches): + p.set_facecolor(colors(w)) # color is mean predicted proba + + else: + + _, x_edges, y_edges, patches = subplot.hist2d( + evaluation_data[:,features[0]], + evaluation_data[:,features[1]], + n_bins, + density=True, color=self.param_colors[features[0]]) + + eval_bins_dim_1 = np.select( + [evaluation_data[:,features[0]] <= i for i in x_edges[1:]], list(range(n_bins)) + ) + eval_bins_dim_2 = np.select( + [evaluation_data[:,features[1]] <= i for i in y_edges[1:]], list(range(n_bins)) + ) + + colors = plt.get_cmap(self.colorway) + + weights = np.empty((n_bins, n_bins)) + for i in range(n_bins): + for j in range(n_bins): + try: + weights[i, j] = self.probability[np.logical_and(eval_bins_dim_1==i, eval_bins_dim_2==j)].mean() + except KeyError: + pass + + for i in range(len(x_edges) - 1): + for j in range(len(y_edges) - 1): + weight = weights[i,j] + facecolor = colors(weight) + # if no sample in bin, set color to white + if weight == np.nan: + facecolor = "white" + rect = Rectangle( + (x_edges[i], y_edges[j]), + x_edges[i + 1] - x_edges[i], + y_edges[j + 1] - y_edges[j], + facecolor=facecolor, + edgecolor="none", + ) + subplot.add_patch(rect) + + + def _plot(self, + use_intensity_plot:bool=True, + n_alpha_samples:int=100, + confidence_region_alpha:float=0.2, + n_intensity_bins:int=20, + intensity_dimension:int=2, + intensity_feature_index:Union[int, Sequence[int]]=[0,1], + linear_classifier:Union[str, list[str]]='MLP', + cross_evaluate:bool=True, + n_null_hypothesis_trials=100, + classifier_kwargs:Union[dict, list[dict]]=None, + y_label="Empirical CDF", + x_label="", + title="Local Classifier 2-Sample Test" + ): + + if use_intensity_plot: + if intensity_dimension not in (1, 2): + raise NotImplementedError("LC2ST Intensity Plot only implemented in 1D and 2D") + + if intensity_dimension == 1: + try: + int(intensity_feature_index) + except TypeError: + raise ValueError(f"Cannot use {intensity_feature_index} to plot, please supply an integer value index.") + + else: + try: + assert len(intensity_feature_index) == intensity_dimension + int(intensity_feature_index[0]) + int(intensity_feature_index[1]) + except (AssertionError, TypeError): + raise ValueError(f"Cannot use {intensity_feature_index} to plot, please supply a list of 2 integer value indices.") + + self.l2st(**{ + "linear_classifier":linear_classifier, + "cross_evaluate": cross_evaluate, + "n_null_hypothesis_trials": n_null_hypothesis_trials, + "classifier_kwargs": classifier_kwargs}) + + self.probability, self.null_hypothesis_probability = self.l2st.output["lc2st_probabilities"], self.l2st.output["lc2st_null_hypothesis_probabilities"] + + # Plots to make - + # pp_plot_lc2st: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L49 + # eval_space_with_proba_intensity: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L133 + + n_plots = 1 if not use_intensity_plot else 2 + figure_size = self.figure_size if n_plots==1 else (int(self.figure_size[0]*1.8),self.figure_size[1]) + fig, subplots = plt.subplots(1, n_plots, figsize=figure_size) + self.cdf_alphas = np.linspace(0, 1, n_alpha_samples) + + self.lc2st_pairplot(subplots[0] if n_plots == 2 else subplots, confidence_region_alpha=confidence_region_alpha) + if use_intensity_plot: + self.probability_intensity( + subplots[1], + intensity_dimension, + n_bins=n_intensity_bins, + features=intensity_feature_index + ) + + fig.legend() + fig.supylabel(y_label) + fig.supxlabel(x_label) + fig.suptitle(title) \ No newline at end of file diff --git a/src/plots/plot.py b/src/plots/plot.py index 0448800..e3ac508 100644 --- a/src/plots/plot.py +++ b/src/plots/plot.py @@ -27,7 +27,6 @@ def __init__( self.model = model self._common_settings() - self._plot_settings() self.plot_name = self._plot_name() def _plot_name(self): @@ -77,6 +76,14 @@ def _finish(self): plt.cla() def __call__(self, **plot_args) -> None: - self._data_setup() + try: + self._data_setup() + except NotImplementedError: + pass + try: + self._plot_settings() + except NotImplementedError: + pass + self._plot(**plot_args) self._finish() diff --git a/src/utils/defaults.py b/src/utils/defaults.py index 3e5a1ed..3073bdd 100644 --- a/src/utils/defaults.py +++ b/src/utils/defaults.py @@ -10,7 +10,10 @@ "data_engine": "H5Data", "prior":"normal", "prior_kwargs": None, - "simulator_kwargs": None, + "simulator_kwargs": None, + "prior": "normal", + "prior_kwargs":{} + }, "plots_common": { "axis_spines": False, @@ -26,6 +29,7 @@ "CDFRanks": {}, "Ranks": {"num_bins": None}, "CoverageFraction": {}, + "LocalTwoSampleTest":{}, "TARP": { "coverage_sigma": 3 # How many sigma to show coverage over }, @@ -39,5 +43,9 @@ "metrics": { "AllSBC": {}, "CoverageFraction": {}, + "LocalTwoSampleTest":{ + "linear_classifier":"MLP", + "classifier_kwargs":{"alpha":0, "max_iter":2500} + } }, } diff --git a/src/utils/plotting_utils.py b/src/utils/plotting_utils.py new file mode 100644 index 0000000..dc138d6 --- /dev/null +++ b/src/utils/plotting_utils.py @@ -0,0 +1,11 @@ +import numpy as np +import matplotlib as mpl + +def get_hex_colors(n_colors:int, colorway:str): + cmap = mpl.pyplot.get_cmap(colorway) + hex_colors = [] + arr=np.linspace(0, 1, n_colors) + for hit in arr: + hex_colors.append(mpl.colors.rgb2hex(cmap(hit))) + + return hex_colors \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 094fbb6..26b8af2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,8 +9,8 @@ class MockSimulator(Simulator): - def generate_context(self, n_samples: int) -> np.ndarray: - return np.linspace(0, 100, n_samples) + def generate_context(self, n_samples=None) -> np.ndarray: + return np.linspace(0, 100, 101) def simulate(self, theta: np.ndarray, context_samples: np.ndarray) -> np.ndarray: thetas = np.atleast_2d(theta) diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py deleted file mode 100644 index 3ed6116..0000000 --- a/tests/test_evaluate.py +++ /dev/null @@ -1,185 +0,0 @@ -import sys -import pytest -import torch -import numpy as np -import sbi -import os - -# flake8: noqa -#sys.path.append("..") -print(sys.path) -from scripts.evaluate import Diagnose_static, Diagnose_generative -from scripts.io import ModelLoader -#from src.scripts import evaluate - - -""" -""" - - -""" -Test the evaluate module -""" - - -@pytest.fixture -def diagnose_static_instance(): - return Diagnose_static() - -@pytest.fixture -def diagnose_generative_instance(): - return Diagnose_generative() - - -@pytest.fixture -def posterior_generative_sbi_model(): - # create a temporary directory for the saved model - #dir = "savedmodels/sbi/" - #os.makedirs(dir) - - # now save the model - low_bounds = torch.tensor([0, -10]) - high_bounds = torch.tensor([10, 10]) - - prior = sbi.utils.BoxUniform(low = low_bounds, high = high_bounds) - - posterior = sbi.inference.base.infer(simulator, prior, "SNPE", num_simulations=10000) - - # Provide the posterior to the tests - yield prior, posterior - - # Teardown: Remove the temporary directory and its contents - #shutil.rmtree(dataset_dir) - -@pytest.fixture -def setup_plot_dir(): - # create a temporary directory for the saved model - dir = "tests/plots/" - os.makedirs(dir) - yield dir - -def simulator(thetas): # , percent_errors): - # convert to numpy array (if tensor): - thetas = np.atleast_2d(thetas) - # Check if the input has the correct shape - if thetas.shape[1] != 2: - raise ValueError( - "Input tensor must have shape (n, 2) \ - where n is the number of parameter sets." - ) - - # Unpack the parameters - if thetas.shape[0] == 1: - # If there's only one set of parameters, extract them directly - m, b = thetas[0, 0], thetas[0, 1] - else: - # If there are multiple sets of parameters, extract them for each row - m, b = thetas[:, 0], thetas[:, 1] - x = np.linspace(0, 100, 101) - rs = np.random.RandomState() # 2147483648)# - # I'm thinking sigma could actually be a function of x - # if we want to get fancy down the road - # Generate random noise (epsilon) based - # on a normal distribution with mean 0 and standard deviation sigma - sigma = 5 - ε = rs.normal(loc=0, scale=sigma, size=(len(x), thetas.shape[0])) - - # Initialize an empty array to store the results for each set of parameters - y = np.zeros((len(x), thetas.shape[0])) - for i in range(thetas.shape[0]): - m, b = thetas[i, 0], thetas[i, 1] - y[:, i] = m * x + b + ε[:, i] - return torch.Tensor(y.T) - - -def test_generate_sbc_samples(diagnose_generative_instance, - posterior_generative_sbi_model): - # Mock data - #low_bounds = torch.tensor([0, -10]) - #high_bounds = torch.tensor([10, 10]) - - #prior = sbi.utils.BoxUniform(low=low_bounds, high=high_bounds) - prior, posterior = posterior_generative_sbi_model - #inference_instance # provide a mock posterior object - simulator_test = simulator # provide a mock simulator function - num_sbc_runs = 1000 - num_posterior_samples = 1000 - - # Generate SBC samples - thetas, ys, ranks, dap_samples = diagnose_generative_instance.generate_sbc_samples( - prior, posterior, simulator_test, num_sbc_runs, num_posterior_samples - ) - - # Add assertions based on the expected behavior of the method - - -def test_run_all_sbc(diagnose_generative_instance, - posterior_generative_sbi_model, - setup_plot_dir): - labels_list = ["$m$", "$b$"] - colorlist = ["#9C92A3", "#0F5257"] - - prior, posterior = posterior_generative_sbi_model - simulator_test = simulator # provide a mock simulator function - - save_path = setup_plot_dir - - diagnose_generative_instance.run_all_sbc( - prior, - posterior, - simulator_test, - labels_list, - colorlist, - num_sbc_runs=1_000, - num_posterior_samples=1_000, - samples_per_inference=1_000, - plot=False, - save=True, - path=save_path, - ) - # Check if PDF files were saved - assert os.path.exists(save_path), f"No 'plots' folder found at {save_path}" - - # List all files in the directory - files_in_directory = os.listdir(save_path) - - # Check if at least one PDF file is present - pdf_files = [file for file in files_in_directory if file.endswith(".pdf")] - assert pdf_files, "No PDF files found in the 'plots' folder" - - # We expect the pdfs to exist in the directory - expected_pdf_files = ["sbc_ranks.pdf", "sbc_ranks_cdf.pdf", "coverage.pdf"] - for expected_file in expected_pdf_files: - assert ( - expected_file in pdf_files - ), f"Expected PDF file '{expected_file}' not found" - - -""" -def test_sbc_statistics(diagnose_instance): - # Mock data - ranks = # provide mock ranks - thetas = # provide mock thetas - dap_samples = # provide mock dap_samples - num_posterior_samples = 1000 - - # Calculate SBC statistics - check_stats = diagnose_instance.sbc_statistics( - ranks, thetas, dap_samples, num_posterior_samples - ) - - # Add assertions based on the expected behavior of the method - -def test_plot_1d_ranks(diagnose_instance): - # Mock data - ranks = # provide mock ranks - num_posterior_samples = 1000 - labels_list = # provide mock labels_list - colorlist = # provide mock colorlist - - # Plot 1D ranks - diagnose_instance.plot_1d_ranks( - ranks, num_posterior_samples, labels_list, - colorlist, plot=False, save=False - ) -""" diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 1cec089..127ae28 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -6,7 +6,8 @@ from metrics import ( Metrics, CoverageFraction, - AllSBC + AllSBC, + LocalTwoSampleTest ) @pytest.fixture @@ -31,7 +32,6 @@ def test_all_defaults(metric_config, mock_model, mock_data): Ensures each metric has a default set of parameters and is included in the defaults list Ensures each test can initialize, regardless of the veracity of the output """ - Config(metric_config) for metric_name, metric_obj in Metrics.items(): assert metric_name in Defaults['metrics'] @@ -39,7 +39,6 @@ def test_all_defaults(metric_config, mock_model, mock_data): def test_coverage_fraction(metric_config, mock_model, mock_data): - Config(metric_config) coverage_fraction = CoverageFraction(mock_model, mock_data) _, coverage = coverage_fraction.calculate() assert coverage_fraction.output.all() is not None @@ -48,7 +47,11 @@ def test_coverage_fraction(metric_config, mock_model, mock_data): assert coverage.shape def test_all_sbc(metric_config, mock_model, mock_data): - Config(metric_config) all_sbc = AllSBC(mock_model, mock_data) all_sbc() - # TODO What is this supposed to be \ No newline at end of file + # TODO What is this supposed to be + +def test_lc2st(metric_config, mock_model, mock_data): + lc2st = LocalTwoSampleTest(mock_model, mock_data) + lc2st() + assert lc2st.output is not None \ No newline at end of file diff --git a/tests/test_plots.py b/tests/test_plots.py index 253343b..4006ac9 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -8,7 +8,8 @@ CDFRanks, Ranks, CoverageFraction, - TARP + TARP, + LocalTwoSampleTest ) @pytest.fixture @@ -55,4 +56,10 @@ def test_plot_coverage(plot_config, mock_model, mock_data): def test_plot_tarp(plot_config, mock_model, mock_data): plot = TARP(mock_model, mock_data, save=True, show=False) - plot(**get_item("plots", "TARP", raise_exception=False)) \ No newline at end of file + plot(**get_item("plots", "TARP", raise_exception=False)) + assert os.path.exists(f"{plot.out_path}/{plot.plot_name}") + +def test_plot_lc2st(plot_config, mock_model, mock_data): + plot = LocalTwoSampleTest(mock_model, mock_data, save=True, show=False) + plot(**get_item("plots", "LocalTwoSampleTest", raise_exception=False)) + assert os.path.exists(f"{plot.out_path}/{plot.plot_name}") \ No newline at end of file