From 4908b840db501cbd575e9b44fabb26945385add9 Mon Sep 17 00:00:00 2001 From: voetberg Date: Fri, 10 May 2024 15:27:33 -0500 Subject: [PATCH 1/5] lc2st draft --- pyproject.toml | 1 + src/data/data.py | 50 ++++++--- src/data/h5_data.py | 8 +- src/metrics/__init__.py | 9 +- src/metrics/local_two_sample.py | 167 ++++++++++++++++++++++++++++ src/models/sbi_model.py | 3 +- src/plots/__init__.py | 9 +- src/plots/local_two_sample.py | 151 ++++++++++++++++++++++++++ src/utils/defaults.py | 10 +- src/utils/plotting_utils.py | 11 ++ tests/test_evaluate.py | 185 -------------------------------- tests/test_plots.py | 11 +- 12 files changed, 392 insertions(+), 223 deletions(-) create mode 100644 src/metrics/local_two_sample.py create mode 100644 src/plots/local_two_sample.py create mode 100644 src/utils/plotting_utils.py delete mode 100644 tests/test_evaluate.py diff --git a/pyproject.toml b/pyproject.toml index eea7716..b636077 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ getdist = "^1.4.7" h5py = "^3.10.0" numpy = "^1.26.4" matplotlib = "^3.8.3" +scikit-learn = "^1.4.2" tarp = "^0.1.1" deprecation = "^2.1.0" diff --git a/src/data/data.py b/src/data/data.py index a839811..6e06479 100644 --- a/src/data/data.py +++ b/src/data/data.py @@ -1,30 +1,35 @@ import importlib.util import sys import os +import numpy as np from utils.config import get_item from utils.defaults import Defaults class Data: - def __init__(self, path:str, simulator_name: str): + def __init__(self, path:str, simulator_name: str, prior: str = "normal", prior_kwargs:dict=None): self.data = self._load(path) self.simulator = self._load_simulator(simulator_name) + self.prior_dist = self.load_prior(prior, prior_kwargs) + + self.n_dims = self.theta_true().shape[1] def _load_simulator(self, name): - try: - simulator_path = os.environ[f"{Defaults['common']['sim_location']}:{name}"] - except KeyError as e: - raise RuntimeError(f"Simulator cannot be found using env var {e}. Hint: have you registered your simulation with utils.register_simulator?") + if name is not None: + try: + simulator_path = os.environ[f"{Defaults['common']['sim_location']}:{name}"] + except KeyError as e: + raise RuntimeError(f"Simulator cannot be found using env var {e}. Hint: have you registered your simulation with utils.register_simulator?") - new_class = os.path.dirname(simulator_path) - sys.path.insert(1, new_class) + new_class = os.path.dirname(simulator_path) + sys.path.insert(1, new_class) - # TODO robust error checks - module_name = os.path.basename(simulator_path.rstrip('.py')) - m = importlib.import_module(module_name) - - simulator = getattr(m, name) - return simulator() + # TODO robust error checks + module_name = os.path.basename(simulator_path.rstrip('.py')) + m = importlib.import_module(module_name) + + simulator = getattr(m, name) + return simulator() def _load(self, path:str): raise NotImplementedError @@ -36,9 +41,8 @@ def x_true(self): def y_true(self): return self.simulator(self.theta_true(), self.x_true()) - def prior(self): - # From Data - raise NotImplementedError + def prior(self, n_samples:int): + return self.prior_dist(size=(n_samples, self.n_dims)) def theta_true(self): return get_item("data", "theta_true") @@ -47,4 +51,16 @@ def sigma_true(self): return get_item("data", "sigma_true") def save(self, data, path:str): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError + + def load_prior(self, prior, prior_kwargs): + rng = np.random.default_rng(seed=42) + choices = { + "normal": rng.normal + } + + if prior not in choices.keys(): + raise NotImplementedError(f"{prior} is not an option for a prior, choose from {list(choices.keys())}") + if prior_kwargs is None: + prior_kwargs = {} + return lambda size: choices[prior](**prior_kwargs, size=size) diff --git a/src/data/h5_data.py b/src/data/h5_data.py index 888e3ea..675797c 100644 --- a/src/data/h5_data.py +++ b/src/data/h5_data.py @@ -7,8 +7,8 @@ from data.data import Data class H5Data(Data): - def __init__(self, path:str, simulator:Callable): - super().__init__(path, simulator) + def __init__(self, path:str, simulator: str, prior: str = "normal", prior_kwargs:dict=None): + super().__init__(path, simulator, prior, prior_kwargs) def _load(self, path): assert path.split(".")[-1] == "h5", "File extension must be h5" @@ -36,10 +36,6 @@ def x_true(self): def y_true(self): return self.simulator(self.theta_true(), self.x_true()) - def prior(self): - # From Data - raise NotImplementedError - def theta_true(self): return self.data['thetas'] diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py index faaa00e..630f2fe 100644 --- a/src/metrics/__init__.py +++ b/src/metrics/__init__.py @@ -1,8 +1,7 @@ from metrics.all_sbc import AllSBC from metrics.coverage_fraction import CoverageFraction +from metrics.local_two_sample import LocalTwoSampleTest -Metrics = { - CoverageFraction.__name__: CoverageFraction, - AllSBC.__name__: AllSBC - -} \ No newline at end of file + +_all = [CoverageFraction, AllSBC, LocalTwoSampleTest] +Metrics = {m.__name__: m for m in _all} \ No newline at end of file diff --git a/src/metrics/local_two_sample.py b/src/metrics/local_two_sample.py new file mode 100644 index 0000000..f769be3 --- /dev/null +++ b/src/metrics/local_two_sample.py @@ -0,0 +1,167 @@ +from typing import Any, Union +import torch +import numpy as np + +from sklearn.model_selection import KFold +from sklearn.neural_network import MLPClassifier +from sklearn.utils import shuffle + +from metrics.metric import Metric +from utils.config import get_item + +class LocalTwoSampleTest(Metric): + def __init__(self, model: Any, data: Any, out_dir: str | None = None) -> None: + super().__init__(model, data, out_dir) + + def _collect_data_params(self): + + samples_per_inference = get_item( + "metrics_common", "samples_per_inference", raise_exception=False + ) + x_full = torch.tensor(self.data.x_true()) + x_sample, x_eval = x_full[:int(len(x_full)/2)], x_full[int(len(x_full)/2):] + + # P is the prior and x_P is generated via the simulator from the parameters P. + self.p = self.data.prior(samples_per_inference) + + # Q is the approximate posterior amortized in x. x_Q is a shuffled version of x_P, used to generate independent samples from Q | x. + self.q = self.model.sample_posterior(samples_per_inference, x_sample) + + self.x_given_p = self.x_given_q = x_sample + self.x_evaluation = x_eval + + def train_linear_classifier(self, p, q, x_p, x_q, classifier:str, classifier_kwargs:dict={}): + classifier_map = { + "MLP":MLPClassifier + } + try: + classifier = classifier_map[classifier](**classifier_kwargs) + except KeyError: + raise NotImplementedError( + f"{classifier} not implemented, choose from {list(classifier_map.keys())}.") + + joint_P_x = np.concatenate([p, x_p], axis=1) + joint_Q_x = np.concatenate([q, x_q], axis=1) + + features = np.concatenate([joint_P_x, joint_Q_x], axis=0) + labels = np.concatenate( + [np.array([0] * len(joint_P_x)), np.array([1] * len(joint_Q_x))] + ).ravel() + + # shuffle features and labels + features, labels = shuffle(features, labels) + + # train the classifier + classifier.fit(X=features, y=labels) + return classifier + + def _eval_model(self, P, classifier): + + x_evaluate = np.concatenate([P, self.x_evaluation.repeat(len(P), 1)], axis=1) + probability = classifier.predict_proba(x_evaluate)[:, 0] + return probability + + def _scores(self, p, q, x_p, x_q, classifier, cross_evaluate: bool=True, classifier_kwargs=None): + model_probabilities = [] + for model, model_args in zip(classifier, classifier_kwargs): + if cross_evaluate: + model_probabilities.append(self._cross_eval_scores(p, q, x_p, x_q, model)) + else: + trained_model = self.train_linear_classifier(p, q, x_p, x_q, model, model_args) + model_probabilities.append(self._eval_model(P=p, classifier=trained_model)) + + return np.mean(model_probabilities, axis=0) + + + def _cross_eval_score(self, p, q, x_p, x_q, classifier, classifier_kwargs, n_cross_folds=5): + kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42) # TODO get seed from config + cv_splits = kf.split(p) + + # train classifiers over cv-folds + probabilities = [] + for train_index, val_index in cv_splits: + # get train split + p_train, x_p_train = p[train_index], x_p[train_index] + q_train, x_q_train = q[train_index], x_q[train_index] + + trained_nth_classifier = self.train_linear_classifier(p_train, q_train, x_p_train, x_q_train, classifier, classifier_kwargs) + + p_evaluate = p[val_index] + probabilities.append(self._eval_model(p_evaluate, trained_nth_classifier)) + return probabilities + + @staticmethod + def permute_data(P, Q, seed=42): + """Permute the concatenated data [P,Q] to create null-hyp samples. + + Args: + P (torch.Tensor): data of shape (n_samples, dim) + Q (torch.Tensor): data of shape (n_samples, dim) + seed (int, optional): random seed. Defaults to 42. + """ + # set seed + torch.manual_seed(seed) # TODO Get seed + # check inputs + assert P.shape[0] == Q.shape[0] + + n_samples = P.shape[0] + X = torch.cat([P, Q], dim=0) + X_perm = X[torch.randperm(n_samples * 2)] + return X_perm[:n_samples], X_perm[n_samples:] + + def calculate(self, + linear_classifier:Union[str, list[str]], + cross_evaluate:bool=True, + n_null_hypothesis_trials=100, + classifier_kwargs:Union[dict, list[dict]]=None): + + if isinstance(linear_classifier, str): + linear_classifier = [linear_classifier] + + if classifier_kwargs is None: + classifier_kwargs = {} + if isinstance(classifier_kwargs, dict): + classifier_kwargs = [classifier_kwargs] + + probabilities = self._scores( + self.p, + self.q, + self.x_given_p, + self.x_given_q, + self.x_evaluation, + classifier=linear_classifier, + cross_evaluate=cross_evaluate, + classifier_kwargs=classifier_kwargs + ) + + null_hypothesis_probabilities = [] + for trial in range(n_null_hypothesis_trials): + joint_P_x = torch.cat([self.p, self.x_given_p], dim=1) + joint_Q_x = torch.cat([self.q, self.x_given_q], dim=1) + joint_P_x_perm, joint_Q_x_perm = LocalTwoSampleTest.permute_data( + joint_P_x, joint_Q_x, seed=self.seed + trial, + ) + p_null = joint_P_x_perm[:, : self.p.shape[-1]] + p_given_x_null = joint_P_x_perm[:, self.p.shape[-1] :] + q_null = joint_Q_x_perm[:, : self.q.shape[-1]] + q_given_x_null = joint_Q_x_perm[:, self.q.shape[-1] :] + + null_result = self._scores( + p_null, + q_null, + p_given_x_null, + q_given_x_null, + self.x_evaluation, + classifier=linear_classifier, + cross_evaluate=cross_evaluate, + classifier_kwargs=classifier_kwargs + + ) + null_hypothesis_probabilities.append(null_result) + + null = np.array(null_hypothesis_probabilities) + self.output = { + "lc2st_probabilities": probabilities, + "lc2st_null_hypothesis_probabilities": null + } + return probabilities, null \ No newline at end of file diff --git a/src/models/sbi_model.py b/src/models/sbi_model.py index bcc3a0c..1501300 100644 --- a/src/models/sbi_model.py +++ b/src/models/sbi_model.py @@ -25,4 +25,5 @@ def sample_posterior(self, n_samples:int, y_true): # TODO typing def predict_posterior(self, data): posterior_samples = self.sample_posterior(data.y_true) posterior_predictive_samples = data.simulator(data.theta_true(), posterior_samples) - return posterior_predictive_samples \ No newline at end of file + return posterior_predictive_samples + \ No newline at end of file diff --git a/src/plots/__init__.py b/src/plots/__init__.py index 1d79d45..b1497ef 100644 --- a/src/plots/__init__.py +++ b/src/plots/__init__.py @@ -1,11 +1,8 @@ from plots.cdf_ranks import CDFRanks from plots.coverage_fraction import CoverageFraction from plots.ranks import Ranks +from plots.local_two_sample import LocalTwoSampleTest from plots.tarp import TARP -Plots = { - CDFRanks.__name__: CDFRanks, - CoverageFraction.__name__: CoverageFraction, - Ranks.__name__: Ranks, - TARP.__name__: TARP -} \ No newline at end of file +_all = [CoverageFraction, CDFRanks, Ranks, LocalTwoSampleTest, TARP] +Metrics = {m.__name__: m for m in _all} \ No newline at end of file diff --git a/src/plots/local_two_sample.py b/src/plots/local_two_sample.py new file mode 100644 index 0000000..a3af6be --- /dev/null +++ b/src/plots/local_two_sample.py @@ -0,0 +1,151 @@ +import os +from typing import Optional +import matplotlib.pyplot as plt +import numpy as np +from matplotlib import rcParams + +from plots.plot import Display +from metrics.local_two_sample import LocalTwoSampleTest as l2st +from utils.config import get_item +from utils.plotting_utils import get_hex_colors + +class LocalTwoSampleTest(Display): + def __init__(self, model, data, save:bool, show:bool, out_dir:Optional[str]=None): + super().__init__(model, data, save, show, out_dir) + + def _plot_name(self): + return "local_C2ST.png" + + def _data_setup(self): + self.percentiles = get_item("metrics_common", item='percentiles', raise_exception=False) + self.region_colors = get_hex_colors(n_colors=len(self.percentiles)) + + self.probability, self.null_hypothesis_probability = l2st.calculate() + + def _plot_settings(self): + self.param_names = get_item("plots_common", item="parameter_labels", raise_exception=False) + self.param_colors = get_item("plots_common", item="parameter_colors", raise_exception=False) + self.figure_size = get_item("plots_common", item="figure_size", raise_exception=False) + + def _make_pairplot_values(self, random_samples): + pp_vals = [np.mean(random_samples <= alpha) for alpha in self.cdf_alphas] + return pp_vals + + def lc2st_pairplot(self, subplot, confidence_region_alpha=0.2): + + subplot.plot( + self.cdf_alphas, self._make_pairplot_values([0.5] * len(self.probability)), "--", color="black", + ) + + null_hypothesis_pairplot = np.zeros_like(self.null_hypothesis_probability) + for t in range(len(self.null_hypothesis_probability)): + null_hypothesis_pairplot[t] = self._make_pairplot_values(self.null_hypothesis_probability[t]) + + for percentile, color in zip(self.percentiles, self.region_colors): + low_null = null_hypothesis_pairplot.quantile(percentile/100, axis=1) + up_null = null_hypothesis_pairplot.quantile((100-percentile)/100, axis=1) + + subplot.fill_between( + self.cdf_alphas, + low_null, + up_null, + color=color, + alpha=confidence_region_alpha, + label=f"{percentile}% confidence region", + ) + + for prob, label, color in zip(self.probability, self.param_names, self.param_colors): + pairplot_values = self._make_pairplot_values(self, prob) + subplot.plot(self.cdf_alphas, pairplot_values, label=label, color=color) + + def probability_intensity(self, subplot, dim, n_bins=20, vmin=0, vmax=1): + + if dim==1: + _, bins, patches = subplot.hist(df_probas.z, n_bins, density=True, color="green") + df_probas["bins"] = np.select( + [df_probas.z <= i for i in bins[1:]], list(range(n_bins)) + ) + # get mean predicted proba for each bin + weights = df_probas.groupby(["bins"]).mean().probas + + id = list(set(range(n_bins)) - set(df_probas.bins)) + patches = np.delete(patches, id) + bins = np.delete(bins, id) + + norm = Normalize(vmin=vmin, vmax=vmax) + + for w, p in zip(weights, patches): + p.set_facecolor(cmap(w)) + + else: + _, x, y = np.histogram2d(df_probas.z_1, df_probas.z_2, bins=n_bins) + df_probas["bins_x"] = np.select( + [df_probas.z_1 <= i for i in x[1:]], list(range(n_bins)) + ) + df_probas["bins_y"] = np.select( + [df_probas.z_2 <= i for i in y[1:]], list(range(n_bins)) + ) + # get mean predicted proba for each bin + prob_mean = df_probas.groupby(["bins_x", "bins_y"]).mean().probas + + weights = np.zeros((n_bins, n_bins)) + for i in range(n_bins): + for j in range(n_bins): + try: + weights[i, j] = prob_mean.loc[i].loc[j] + except KeyError: + # if no sample in bin, set color to white + weights[i, j] = np.nan + + norm = Normalize(vmin=vmin, vmax=vmax) + for i in range(len(x) - 1): + for j in range(len(y) - 1): + facecolor = cmap(norm(weights.T[j, i])) + # if no sample in bin, set color to white + if weights.T[j, i] == np.nan: + facecolor = "white" + rect = Rectangle( + (x[i], y[j]), + x[i + 1] - x[i], + y[j + 1] - y[j], + facecolor=facecolor, # color is mean predicted proba + edgecolor="none", + ) + subplot.add_patch(rect) + + def _plot(self, + use_intensity_plot:bool=True, + n_alpha_samples:int=100, + confidence_region_alpha:float=0.2, + n_intensity_bins:int=20, + intensity_dimension:int=1, + intensity_range:tuple=(0,1), + y_label="", + x_label="", + title="" + ): + + # Plots to make - + # pp_plot_lc2st: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L49 + # eval_space_with_proba_intensity: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L133 + + n_plots = 1 if not use_intensity_plot else 2 + if intensity_dimension and intensity_dimension not in (1,2): + raise NotImplementedError("LC2ST Intensity Plot only implemented in 1D and 2D") + + fig, subplots = plt.subplot(1, n_plots, figsize=self.figure_size) + self.cdf_alphas = np.linspace(0, 1, n_alpha_samples) + self.lc2st_pairplot(subplots[0], confidence_region_alpha=confidence_region_alpha) + if use_intensity_plot: + self.probability_intensity( + subplots[1], + intensity_dimension, + n_bins=n_intensity_bins, + vmin=intensity_range[0], + vmax=intensity_range[1] + ) + + fig.legend() + fig.supylabel(y_label) + fig.supxlabel(x_label) + fig.set_title(title) diff --git a/src/utils/defaults.py b/src/utils/defaults.py index 6722a03..0de746f 100644 --- a/src/utils/defaults.py +++ b/src/utils/defaults.py @@ -9,7 +9,10 @@ "model_engine": "SBIModel" }, "data":{ - "data_engine": "H5Data" + "data_engine": "H5Data", + "prior": "normal", + "prior_kwargs":{} + }, "plots_common": { "axis_spines": False, @@ -25,6 +28,7 @@ "CDFRanks":{}, "Ranks":{"num_bins":None}, "CoverageFraction":{}, + "LocalTwoSampleTest":{}, "TARP":{ "coverage_sigma":3 # How many sigma to show coverage over } @@ -38,5 +42,9 @@ "metrics":{ "AllSBC":{}, "CoverageFraction": {}, + "LocalTwoSampleTest":{ + "linear_classifier":"MLP", + "classifier_kwargs":{"alpha":0, "max_iter":2500} + } } } \ No newline at end of file diff --git a/src/utils/plotting_utils.py b/src/utils/plotting_utils.py new file mode 100644 index 0000000..2f5ac26 --- /dev/null +++ b/src/utils/plotting_utils.py @@ -0,0 +1,11 @@ +import numpy as np +import matplotlib as mpl + +def get_hex_colors(n_colors:int, colorway:str): + cmap = mpl.cm.get_cmap(colorway) + hex_colors = [] + arr=np.linspace(0, 1, n_colors) + for hit in arr: + hex_colors.append(mpl.colors.rgb2hex(cmap(hit))) + + return hex_colors \ No newline at end of file diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py deleted file mode 100644 index 3ed6116..0000000 --- a/tests/test_evaluate.py +++ /dev/null @@ -1,185 +0,0 @@ -import sys -import pytest -import torch -import numpy as np -import sbi -import os - -# flake8: noqa -#sys.path.append("..") -print(sys.path) -from scripts.evaluate import Diagnose_static, Diagnose_generative -from scripts.io import ModelLoader -#from src.scripts import evaluate - - -""" -""" - - -""" -Test the evaluate module -""" - - -@pytest.fixture -def diagnose_static_instance(): - return Diagnose_static() - -@pytest.fixture -def diagnose_generative_instance(): - return Diagnose_generative() - - -@pytest.fixture -def posterior_generative_sbi_model(): - # create a temporary directory for the saved model - #dir = "savedmodels/sbi/" - #os.makedirs(dir) - - # now save the model - low_bounds = torch.tensor([0, -10]) - high_bounds = torch.tensor([10, 10]) - - prior = sbi.utils.BoxUniform(low = low_bounds, high = high_bounds) - - posterior = sbi.inference.base.infer(simulator, prior, "SNPE", num_simulations=10000) - - # Provide the posterior to the tests - yield prior, posterior - - # Teardown: Remove the temporary directory and its contents - #shutil.rmtree(dataset_dir) - -@pytest.fixture -def setup_plot_dir(): - # create a temporary directory for the saved model - dir = "tests/plots/" - os.makedirs(dir) - yield dir - -def simulator(thetas): # , percent_errors): - # convert to numpy array (if tensor): - thetas = np.atleast_2d(thetas) - # Check if the input has the correct shape - if thetas.shape[1] != 2: - raise ValueError( - "Input tensor must have shape (n, 2) \ - where n is the number of parameter sets." - ) - - # Unpack the parameters - if thetas.shape[0] == 1: - # If there's only one set of parameters, extract them directly - m, b = thetas[0, 0], thetas[0, 1] - else: - # If there are multiple sets of parameters, extract them for each row - m, b = thetas[:, 0], thetas[:, 1] - x = np.linspace(0, 100, 101) - rs = np.random.RandomState() # 2147483648)# - # I'm thinking sigma could actually be a function of x - # if we want to get fancy down the road - # Generate random noise (epsilon) based - # on a normal distribution with mean 0 and standard deviation sigma - sigma = 5 - ε = rs.normal(loc=0, scale=sigma, size=(len(x), thetas.shape[0])) - - # Initialize an empty array to store the results for each set of parameters - y = np.zeros((len(x), thetas.shape[0])) - for i in range(thetas.shape[0]): - m, b = thetas[i, 0], thetas[i, 1] - y[:, i] = m * x + b + ε[:, i] - return torch.Tensor(y.T) - - -def test_generate_sbc_samples(diagnose_generative_instance, - posterior_generative_sbi_model): - # Mock data - #low_bounds = torch.tensor([0, -10]) - #high_bounds = torch.tensor([10, 10]) - - #prior = sbi.utils.BoxUniform(low=low_bounds, high=high_bounds) - prior, posterior = posterior_generative_sbi_model - #inference_instance # provide a mock posterior object - simulator_test = simulator # provide a mock simulator function - num_sbc_runs = 1000 - num_posterior_samples = 1000 - - # Generate SBC samples - thetas, ys, ranks, dap_samples = diagnose_generative_instance.generate_sbc_samples( - prior, posterior, simulator_test, num_sbc_runs, num_posterior_samples - ) - - # Add assertions based on the expected behavior of the method - - -def test_run_all_sbc(diagnose_generative_instance, - posterior_generative_sbi_model, - setup_plot_dir): - labels_list = ["$m$", "$b$"] - colorlist = ["#9C92A3", "#0F5257"] - - prior, posterior = posterior_generative_sbi_model - simulator_test = simulator # provide a mock simulator function - - save_path = setup_plot_dir - - diagnose_generative_instance.run_all_sbc( - prior, - posterior, - simulator_test, - labels_list, - colorlist, - num_sbc_runs=1_000, - num_posterior_samples=1_000, - samples_per_inference=1_000, - plot=False, - save=True, - path=save_path, - ) - # Check if PDF files were saved - assert os.path.exists(save_path), f"No 'plots' folder found at {save_path}" - - # List all files in the directory - files_in_directory = os.listdir(save_path) - - # Check if at least one PDF file is present - pdf_files = [file for file in files_in_directory if file.endswith(".pdf")] - assert pdf_files, "No PDF files found in the 'plots' folder" - - # We expect the pdfs to exist in the directory - expected_pdf_files = ["sbc_ranks.pdf", "sbc_ranks_cdf.pdf", "coverage.pdf"] - for expected_file in expected_pdf_files: - assert ( - expected_file in pdf_files - ), f"Expected PDF file '{expected_file}' not found" - - -""" -def test_sbc_statistics(diagnose_instance): - # Mock data - ranks = # provide mock ranks - thetas = # provide mock thetas - dap_samples = # provide mock dap_samples - num_posterior_samples = 1000 - - # Calculate SBC statistics - check_stats = diagnose_instance.sbc_statistics( - ranks, thetas, dap_samples, num_posterior_samples - ) - - # Add assertions based on the expected behavior of the method - -def test_plot_1d_ranks(diagnose_instance): - # Mock data - ranks = # provide mock ranks - num_posterior_samples = 1000 - labels_list = # provide mock labels_list - colorlist = # provide mock colorlist - - # Plot 1D ranks - diagnose_instance.plot_1d_ranks( - ranks, num_posterior_samples, labels_list, - colorlist, plot=False, save=False - ) -""" diff --git a/tests/test_plots.py b/tests/test_plots.py index 253343b..4006ac9 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -8,7 +8,8 @@ CDFRanks, Ranks, CoverageFraction, - TARP + TARP, + LocalTwoSampleTest ) @pytest.fixture @@ -55,4 +56,10 @@ def test_plot_coverage(plot_config, mock_model, mock_data): def test_plot_tarp(plot_config, mock_model, mock_data): plot = TARP(mock_model, mock_data, save=True, show=False) - plot(**get_item("plots", "TARP", raise_exception=False)) \ No newline at end of file + plot(**get_item("plots", "TARP", raise_exception=False)) + assert os.path.exists(f"{plot.out_path}/{plot.plot_name}") + +def test_plot_lc2st(plot_config, mock_model, mock_data): + plot = LocalTwoSampleTest(mock_model, mock_data, save=True, show=False) + plot(**get_item("plots", "LocalTwoSampleTest", raise_exception=False)) + assert os.path.exists(f"{plot.out_path}/{plot.plot_name}") \ No newline at end of file From 471fbe7239472e3e215f5767b7d5079827390107 Mon Sep 17 00:00:00 2001 From: voetberg Date: Fri, 17 May 2024 10:23:23 -0500 Subject: [PATCH 2/5] Outline of lc2st metric --- src/metrics/local_two_sample.py | 41 +++++++++++++++++++++++---------- tests/test_metrics.py | 15 +++++++----- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/src/metrics/local_two_sample.py b/src/metrics/local_two_sample.py index f769be3..530e482 100644 --- a/src/metrics/local_two_sample.py +++ b/src/metrics/local_two_sample.py @@ -18,17 +18,31 @@ def _collect_data_params(self): samples_per_inference = get_item( "metrics_common", "samples_per_inference", raise_exception=False ) - x_full = torch.tensor(self.data.x_true()) - x_sample, x_eval = x_full[:int(len(x_full)/2)], x_full[int(len(x_full)/2):] + num_simulations = get_item( + "metrics_common", "number_simulations", raise_exception=False + ) # P is the prior and x_P is generated via the simulator from the parameters P. self.p = self.data.prior(samples_per_inference) + self.q = np.zeros((num_simulations, samples_per_inference, self.data.n_dims)) + x_sample = np.zeros((num_simulations, self.data.x_true().shape[-1])) + self.x_evaluation = np.zeros_like(x_sample) + + + print(x_sample.shape) + x_true = self.data.x_true() + for index, p in enumerate(self.p): + x_p = self.data.simulator(p, num_simulations) + print(x_p.shape) + # Q is the approximate posterior amortized in x + self.q[index] = self.model.sample_posterior(samples_per_inference, y_true=x_p) + x_sample[index] = x_p - # Q is the approximate posterior amortized in x. x_Q is a shuffled version of x_P, used to generate independent samples from Q | x. - self.q = self.model.sample_posterior(samples_per_inference, x_sample) + self.x_evaluation[index] = x_true[self.data.rng.integers(0, len(x_true)),:] + # x_Q is a shuffled version of x_P, used to generate independent samples from Q | x. self.x_given_p = self.x_given_q = x_sample - self.x_evaluation = x_eval + def train_linear_classifier(self, p, q, x_p, x_q, classifier:str, classifier_kwargs:dict={}): classifier_map = { @@ -65,14 +79,13 @@ def _scores(self, p, q, x_p, x_q, classifier, cross_evaluate: bool=True, classif model_probabilities = [] for model, model_args in zip(classifier, classifier_kwargs): if cross_evaluate: - model_probabilities.append(self._cross_eval_scores(p, q, x_p, x_q, model)) + model_probabilities.append(self._cross_eval_score(p, q, x_p, x_q, model, model_args)) else: trained_model = self.train_linear_classifier(p, q, x_p, x_q, model, model_args) model_probabilities.append(self._eval_model(P=p, classifier=trained_model)) return np.mean(model_probabilities, axis=0) - def _cross_eval_score(self, p, q, x_p, x_q, classifier, classifier_kwargs, n_cross_folds=5): kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42) # TODO get seed from config cv_splits = kf.split(p) @@ -110,7 +123,7 @@ def permute_data(P, Q, seed=42): return X_perm[:n_samples], X_perm[n_samples:] def calculate(self, - linear_classifier:Union[str, list[str]], + linear_classifier:Union[str, list[str]]='MLP', cross_evaluate:bool=True, n_null_hypothesis_trials=100, classifier_kwargs:Union[dict, list[dict]]=None): @@ -128,7 +141,6 @@ def calculate(self, self.q, self.x_given_p, self.x_given_q, - self.x_evaluation, classifier=linear_classifier, cross_evaluate=cross_evaluate, classifier_kwargs=classifier_kwargs @@ -151,12 +163,11 @@ def calculate(self, q_null, p_given_x_null, q_given_x_null, - self.x_evaluation, classifier=linear_classifier, cross_evaluate=cross_evaluate, classifier_kwargs=classifier_kwargs - ) + null_hypothesis_probabilities.append(null_result) null = np.array(null_hypothesis_probabilities) @@ -164,4 +175,10 @@ def calculate(self, "lc2st_probabilities": probabilities, "lc2st_null_hypothesis_probabilities": null } - return probabilities, null \ No newline at end of file + return probabilities, null + + + def __call__(self, **kwds: Any) -> Any: + self._collect_data_params() + self.calculate(**kwds) + self._finish() \ No newline at end of file diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 5d39c4e..e0fc3c9 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -6,14 +6,15 @@ from metrics import ( Metrics, CoverageFraction, - AllSBC + AllSBC, + LocalTwoSampleTest ) @pytest.fixture def metric_config(config_factory): metrics_settings={"use_progress_bar":False, "samples_per_inference":10, "percentiles":[95]} config = config_factory(metrics_settings=metrics_settings) - return config + Config(config) def test_all_metrics_catalogued(): '''Each metrics gets its own file, and each metric is included in the Metrics dictionary @@ -30,7 +31,6 @@ def test_all_defaults(metric_config, mock_model, mock_data): Ensures each metric has a default set of parameters and is included in the defaults list Ensures each test can initialize, regardless of the veracity of the output """ - Config(metric_config) for metric_name, metric_obj in Metrics.items(): assert metric_name in Defaults['metrics'] @@ -38,7 +38,6 @@ def test_all_defaults(metric_config, mock_model, mock_data): def test_coverage_fraction(metric_config, mock_model, mock_data): - Config(metric_config) coverage_fraction = CoverageFraction(mock_model, mock_data) _, coverage = coverage_fraction.calculate() assert coverage_fraction.output.all() is not None @@ -47,7 +46,11 @@ def test_coverage_fraction(metric_config, mock_model, mock_data): assert coverage.shape def test_all_sbc(metric_config, mock_model, mock_data): - Config(metric_config) all_sbc = AllSBC(mock_model, mock_data) all_sbc() - # TODO What is this supposed to be \ No newline at end of file + # TODO What is this supposed to be + +def test_lc2st(metric_config, mock_model, mock_data): + lc2st = LocalTwoSampleTest(mock_model, mock_data) + lc2st() + assert lc2st.output is not None \ No newline at end of file From a509d4eefb58653a49779510a4f133ff26f82ac8 Mon Sep 17 00:00:00 2001 From: voetberg Date: Fri, 17 May 2024 13:10:02 -0500 Subject: [PATCH 3/5] merging --- src/data/data.py | 99 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 96 insertions(+), 3 deletions(-) diff --git a/src/data/data.py b/src/data/data.py index ae03835..8e52849 100644 --- a/src/data/data.py +++ b/src/data/data.py @@ -32,8 +32,8 @@ def _load_simulator(self, name, simulator_kwargs): f"Simulator cannot be found using env var {e}. Hint: have you registered your simulation with utils.register_simulator?" ) - new_class = os.path.dirname(simulator_path) - sys.path.insert(1, new_class) + new_class = os.path.dirname(simulator_path) + sys.path.insert(1, new_class) # TODO robust error checks module_name = os.path.basename(simulator_path.rstrip(".py")) @@ -99,4 +99,97 @@ def save(self, data, path: str): raise NotImplementedError def read_prior(self): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError + + def load_prior(self, prior, prior_kwargs): + if prior is None: + prior = get_item("data", "prior", raise_exception=False) + try: + prior = self.read_prior() + except NotImplementedError: + choices = { + "normal": self.rng.normal, + "poisson": self.rng.poisson, + "uniform": self.rng.uniform, + "gamma": self.rng.gamma, + "beta": self.rng.beta, + "binominal": self.rng.binomial, + } + + if prior not in choices.keys(): + raise NotImplementedError( + f"{prior} is not an option for a prior, choose from {list(choices.keys())}" + ) + if prior_kwargs is None: + prior_kwargs = {} + return lambda size: choices[prior](**prior_kwargs, size=size) + + except KeyError as e: + raise RuntimeError(f"Data missing a prior specification - {e}") + +import importlib.util +import sys +import os +import numpy as np + +from utils.config import get_item +from utils.defaults import Defaults + +class Data: + def __init__(self, path:str, simulator_name: str, prior: str = "normal", prior_kwargs:dict=None): + self.data = self._load(path) + self.simulator = self._load_simulator(simulator_name) + self.prior_dist = self.load_prior(prior, prior_kwargs) + + self.n_dims = self.theta_true().shape[1] + + def _load_simulator(self, name): + if name is not None: + try: + simulator_path = os.environ[f"{Defaults['common']['sim_location']}:{name}"] + except KeyError as e: + raise RuntimeError(f"Simulator cannot be found using env var {e}. Hint: have you registered your simulation with utils.register_simulator?") + + new_class = os.path.dirname(simulator_path) + sys.path.insert(1, new_class) + + # TODO robust error checks + module_name = os.path.basename(simulator_path.rstrip('.py')) + m = importlib.import_module(module_name) + + simulator = getattr(m, name) + return simulator() + + def _load(self, path:str): + raise NotImplementedError + + def x_true(self): + # From Data + raise NotImplementedError + + def y_true(self): + return self.simulator(self.theta_true(), self.x_true()) + + def prior(self, n_samples:int): + return self.prior_dist(size=(n_samples, self.n_dims)) + + def theta_true(self): + return get_item("data", "theta_true") + + def sigma_true(self): + return get_item("data", "sigma_true") + + def save(self, data, path:str): + raise NotImplementedError + + def load_prior(self, prior, prior_kwargs): + rng = np.random.default_rng(seed=42) + choices = { + "normal": rng.normal + } + + if prior not in choices.keys(): + raise NotImplementedError(f"{prior} is not an option for a prior, choose from {list(choices.keys())}") + if prior_kwargs is None: + prior_kwargs = {} + return lambda size: choices[prior](**prior_kwargs, size=size) From 91a5eb1ab2ee65ddf352125d11ed7a31360ebfd5 Mon Sep 17 00:00:00 2001 From: voetberg Date: Fri, 17 May 2024 15:49:22 -0500 Subject: [PATCH 4/5] functional metric for lc2st --- pyproject.toml | 2 +- src/client/client.py | 14 +++-- src/data/data.py | 69 +------------------------ src/data/h5_data.py | 3 +- src/metrics/local_two_sample.py | 92 ++++++++++++++------------------- src/models/sbi_model.py | 2 +- src/plots/local_two_sample.py | 8 +-- tests/conftest.py | 4 +- 8 files changed, 61 insertions(+), 133 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b636077..095de07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,9 +17,9 @@ getdist = "^1.4.7" h5py = "^3.10.0" numpy = "^1.26.4" matplotlib = "^3.8.3" -scikit-learn = "^1.4.2" tarp = "^0.1.1" deprecation = "^2.1.0" +scipy = "1.12.0" [tool.poetry.group.dev.dependencies] diff --git a/src/client/client.py b/src/client/client.py index 063b1d3..0ffe68f 100644 --- a/src/client/client.py +++ b/src/client/client.py @@ -97,9 +97,15 @@ def main(): plots = config.get_section("plots", raise_exception=False) for metrics_name, metrics_args in metrics.items(): - Metrics[metrics_name](model, data, **metrics_args)() + try: + Metrics[metrics_name](model, data, **metrics_args)() + except (NotImplementedError, RuntimeError) as error: + print(f"WARNING - skipping metric {metrics_name} due to error: {error}") for plot_name, plot_args in plots.items(): - Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)( - **plot_args - ) + try: + Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)( + **plot_args + ) + except (NotImplementedError, RuntimeError) as error: + print(f"WARNING - skipping plot {plot_name} due to error: {error}") diff --git a/src/data/data.py b/src/data/data.py index 8e52849..a022d6f 100644 --- a/src/data/data.py +++ b/src/data/data.py @@ -125,71 +125,4 @@ def load_prior(self, prior, prior_kwargs): return lambda size: choices[prior](**prior_kwargs, size=size) except KeyError as e: - raise RuntimeError(f"Data missing a prior specification - {e}") - -import importlib.util -import sys -import os -import numpy as np - -from utils.config import get_item -from utils.defaults import Defaults - -class Data: - def __init__(self, path:str, simulator_name: str, prior: str = "normal", prior_kwargs:dict=None): - self.data = self._load(path) - self.simulator = self._load_simulator(simulator_name) - self.prior_dist = self.load_prior(prior, prior_kwargs) - - self.n_dims = self.theta_true().shape[1] - - def _load_simulator(self, name): - if name is not None: - try: - simulator_path = os.environ[f"{Defaults['common']['sim_location']}:{name}"] - except KeyError as e: - raise RuntimeError(f"Simulator cannot be found using env var {e}. Hint: have you registered your simulation with utils.register_simulator?") - - new_class = os.path.dirname(simulator_path) - sys.path.insert(1, new_class) - - # TODO robust error checks - module_name = os.path.basename(simulator_path.rstrip('.py')) - m = importlib.import_module(module_name) - - simulator = getattr(m, name) - return simulator() - - def _load(self, path:str): - raise NotImplementedError - - def x_true(self): - # From Data - raise NotImplementedError - - def y_true(self): - return self.simulator(self.theta_true(), self.x_true()) - - def prior(self, n_samples:int): - return self.prior_dist(size=(n_samples, self.n_dims)) - - def theta_true(self): - return get_item("data", "theta_true") - - def sigma_true(self): - return get_item("data", "sigma_true") - - def save(self, data, path:str): - raise NotImplementedError - - def load_prior(self, prior, prior_kwargs): - rng = np.random.default_rng(seed=42) - choices = { - "normal": rng.normal - } - - if prior not in choices.keys(): - raise NotImplementedError(f"{prior} is not an option for a prior, choose from {list(choices.keys())}") - if prior_kwargs is None: - prior_kwargs = {} - return lambda size: choices[prior](**prior_kwargs, size=size) + raise RuntimeError(f"Data missing a prior specification - {e}") \ No newline at end of file diff --git a/src/data/h5_data.py b/src/data/h5_data.py index 80ddac0..c10b4a5 100644 --- a/src/data/h5_data.py +++ b/src/data/h5_data.py @@ -10,7 +10,8 @@ class H5Data(Data): def __init__(self, path: str, simulator: Callable): super().__init__(path, simulator) - + self.theta_true = self.get_theta_true() + def _load(self, path): assert path.split(".")[-1] == "h5", "File extension must be h5" loaded_data = {} diff --git a/src/metrics/local_two_sample.py b/src/metrics/local_two_sample.py index 530e482..ea1a97c 100644 --- a/src/metrics/local_two_sample.py +++ b/src/metrics/local_two_sample.py @@ -1,5 +1,4 @@ from typing import Any, Union -import torch import numpy as np from sklearn.model_selection import KFold @@ -14,35 +13,27 @@ def __init__(self, model: Any, data: Any, out_dir: str | None = None) -> None: super().__init__(model, data, out_dir) def _collect_data_params(self): - - samples_per_inference = get_item( - "metrics_common", "samples_per_inference", raise_exception=False - ) num_simulations = get_item( "metrics_common", "number_simulations", raise_exception=False ) # P is the prior and x_P is generated via the simulator from the parameters P. - self.p = self.data.prior(samples_per_inference) - self.q = np.zeros((num_simulations, samples_per_inference, self.data.n_dims)) - x_sample = np.zeros((num_simulations, self.data.x_true().shape[-1])) - self.x_evaluation = np.zeros_like(x_sample) + self.p = self.data.sample_prior(num_simulations) + self.q = np.zeros_like(self.p) + self.outcome_given_p = np.zeros((num_simulations, self.data.simulator.generate_context().shape[-1])) + self.outcome_given_q = np.zeros_like(self.outcome_given_p) + self.evaluation_context = np.zeros_like(self.outcome_given_p) - print(x_sample.shape) - x_true = self.data.x_true() for index, p in enumerate(self.p): - x_p = self.data.simulator(p, num_simulations) - print(x_p.shape) + context = self.data.simulator.generate_context() + self.outcome_given_p[index] = self.data.simulator.simulate(p, context) # Q is the approximate posterior amortized in x - self.q[index] = self.model.sample_posterior(samples_per_inference, y_true=x_p) - x_sample[index] = x_p - - self.x_evaluation[index] = x_true[self.data.rng.integers(0, len(x_true)),:] - - # x_Q is a shuffled version of x_P, used to generate independent samples from Q | x. - self.x_given_p = self.x_given_q = x_sample + q = self.model.sample_posterior(1, context).ravel() + self.q[index] = q + self.outcome_given_q[index] = self.data.simulator.simulate(q, context) + self.evaluation_context = np.array([self.data.simulator.generate_context() for _ in range(num_simulations)]) def train_linear_classifier(self, p, q, x_p, x_q, classifier:str, classifier_kwargs:dict={}): classifier_map = { @@ -69,10 +60,9 @@ def train_linear_classifier(self, p, q, x_p, x_q, classifier:str, classifier_kwa classifier.fit(X=features, y=labels) return classifier - def _eval_model(self, P, classifier): - - x_evaluate = np.concatenate([P, self.x_evaluation.repeat(len(P), 1)], axis=1) - probability = classifier.predict_proba(x_evaluate)[:, 0] + def _eval_model(self, P, evaluation_sample, classifier): + evaluation = np.concatenate([P, evaluation_sample], axis=1) + probability = classifier.predict_proba(evaluation)[:, 0] return probability def _scores(self, p, q, x_p, x_q, classifier, cross_evaluate: bool=True, classifier_kwargs=None): @@ -94,39 +84,37 @@ def _cross_eval_score(self, p, q, x_p, x_q, classifier, classifier_kwargs, n_cro probabilities = [] for train_index, val_index in cv_splits: # get train split - p_train, x_p_train = p[train_index], x_p[train_index] - q_train, x_q_train = q[train_index], x_q[train_index] - + p_train, x_p_train = p[train_index,:], x_p[train_index,:] + q_train, x_q_train = q[train_index,:], x_q[train_index,:] trained_nth_classifier = self.train_linear_classifier(p_train, q_train, x_p_train, x_q_train, classifier, classifier_kwargs) - p_evaluate = p[val_index] - probabilities.append(self._eval_model(p_evaluate, trained_nth_classifier)) + evaluation_data = np.zeros((len(val_index), self.evaluation_context.shape[-1])) + for index, p_validation in enumerate(p_evaluate): + evaluation_data[index] = self.data.simulator.simulate( + p_validation, self.evaluation_context[val_index][index] + ) + probabilities.append(self._eval_model(p_evaluate, evaluation_data, trained_nth_classifier)) return probabilities - @staticmethod - def permute_data(P, Q, seed=42): + def permute_data(self, P, Q): """Permute the concatenated data [P,Q] to create null-hyp samples. Args: P (torch.Tensor): data of shape (n_samples, dim) Q (torch.Tensor): data of shape (n_samples, dim) - seed (int, optional): random seed. Defaults to 42. """ - # set seed - torch.manual_seed(seed) # TODO Get seed - # check inputs - assert P.shape[0] == Q.shape[0] - n_samples = P.shape[0] - X = torch.cat([P, Q], dim=0) - X_perm = X[torch.randperm(n_samples * 2)] + X = np.concatenate([P, Q], axis=0) + X_perm = X[self.data.rng.permutation(np.arange(n_samples * 2))] return X_perm[:n_samples], X_perm[n_samples:] - def calculate(self, - linear_classifier:Union[str, list[str]]='MLP', - cross_evaluate:bool=True, - n_null_hypothesis_trials=100, - classifier_kwargs:Union[dict, list[dict]]=None): + def calculate( + self, + linear_classifier:Union[str, list[str]]='MLP', + cross_evaluate:bool=True, + n_null_hypothesis_trials=100, + classifier_kwargs:Union[dict, list[dict]]=None + ): if isinstance(linear_classifier, str): linear_classifier = [linear_classifier] @@ -139,19 +127,18 @@ def calculate(self, probabilities = self._scores( self.p, self.q, - self.x_given_p, - self.x_given_q, + self.outcome_given_p, + self.outcome_given_q, classifier=linear_classifier, cross_evaluate=cross_evaluate, classifier_kwargs=classifier_kwargs ) - null_hypothesis_probabilities = [] - for trial in range(n_null_hypothesis_trials): - joint_P_x = torch.cat([self.p, self.x_given_p], dim=1) - joint_Q_x = torch.cat([self.q, self.x_given_q], dim=1) - joint_P_x_perm, joint_Q_x_perm = LocalTwoSampleTest.permute_data( - joint_P_x, joint_Q_x, seed=self.seed + trial, + for _ in range(n_null_hypothesis_trials): + joint_P_x = np.concatenate([self.p, self.outcome_given_p], axis=1) + joint_Q_x = np.concatenate([self.q, self.outcome_given_q], axis=1) + joint_P_x_perm, joint_Q_x_perm = self.permute_data( + joint_P_x, joint_Q_x, ) p_null = joint_P_x_perm[:, : self.p.shape[-1]] p_given_x_null = joint_P_x_perm[:, self.p.shape[-1] :] @@ -177,7 +164,6 @@ def calculate(self, } return probabilities, null - def __call__(self, **kwds: Any) -> Any: self._collect_data_params() self.calculate(**kwds) diff --git a/src/models/sbi_model.py b/src/models/sbi_model.py index c544d01..a244da1 100644 --- a/src/models/sbi_model.py +++ b/src/models/sbi_model.py @@ -24,7 +24,7 @@ def sample_posterior(self, n_samples: int, y_true): # TODO typing def predict_posterior(self, data): posterior_samples = self.sample_posterior(data.y_true) posterior_predictive_samples = data.simulator( - data.theta_true(), posterior_samples + data.get_theta_true(), posterior_samples ) return posterior_predictive_samples diff --git a/src/plots/local_two_sample.py b/src/plots/local_two_sample.py index a3af6be..0dfa52a 100644 --- a/src/plots/local_two_sample.py +++ b/src/plots/local_two_sample.py @@ -1,8 +1,8 @@ -import os from typing import Optional import matplotlib.pyplot as plt import numpy as np -from matplotlib import rcParams +from matplotlib.colors import Normalize +from matplotlib.patches import Rectangle from plots.plot import Display from metrics.local_two_sample import LocalTwoSampleTest as l2st @@ -10,6 +10,9 @@ from utils.plotting_utils import get_hex_colors class LocalTwoSampleTest(Display): + + # https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L133 + def __init__(self, model, data, save:bool, show:bool, out_dir:Optional[str]=None): super().__init__(model, data, save, show, out_dir) @@ -19,7 +22,6 @@ def _plot_name(self): def _data_setup(self): self.percentiles = get_item("metrics_common", item='percentiles', raise_exception=False) self.region_colors = get_hex_colors(n_colors=len(self.percentiles)) - self.probability, self.null_hypothesis_probability = l2st.calculate() def _plot_settings(self): diff --git a/tests/conftest.py b/tests/conftest.py index 094fbb6..26b8af2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,8 +9,8 @@ class MockSimulator(Simulator): - def generate_context(self, n_samples: int) -> np.ndarray: - return np.linspace(0, 100, n_samples) + def generate_context(self, n_samples=None) -> np.ndarray: + return np.linspace(0, 100, 101) def simulate(self, theta: np.ndarray, context_samples: np.ndarray) -> np.ndarray: thetas = np.atleast_2d(theta) From a3b32cc108dd903fd3c4e8541f9efa6f30898a27 Mon Sep 17 00:00:00 2001 From: voetberg Date: Mon, 20 May 2024 14:34:33 -0500 Subject: [PATCH 5/5] plotting for local classifier 2 sample test --- src/metrics/local_two_sample.py | 37 +++--- src/plots/__init__.py | 2 +- src/plots/local_two_sample.py | 193 +++++++++++++++++++++----------- src/plots/plot.py | 11 +- src/utils/plotting_utils.py | 2 +- 5 files changed, 158 insertions(+), 87 deletions(-) diff --git a/src/metrics/local_two_sample.py b/src/metrics/local_two_sample.py index ea1a97c..e078670 100644 --- a/src/metrics/local_two_sample.py +++ b/src/metrics/local_two_sample.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any, Optional, Union import numpy as np from sklearn.model_selection import KFold @@ -9,19 +9,18 @@ from utils.config import get_item class LocalTwoSampleTest(Metric): - def __init__(self, model: Any, data: Any, out_dir: str | None = None) -> None: + def __init__(self, model: Any, data: Any, out_dir: str | None = None, num_simulations: Optional[int] = None) -> None: super().__init__(model, data, out_dir) - - def _collect_data_params(self): - num_simulations = get_item( + self.num_simulations = num_simulations if num_simulations is not None else get_item( "metrics_common", "number_simulations", raise_exception=False ) + def _collect_data_params(self): # P is the prior and x_P is generated via the simulator from the parameters P. - self.p = self.data.sample_prior(num_simulations) + self.p = self.data.sample_prior(self.num_simulations) self.q = np.zeros_like(self.p) - self.outcome_given_p = np.zeros((num_simulations, self.data.simulator.generate_context().shape[-1])) + self.outcome_given_p = np.zeros((self.num_simulations, self.data.simulator.generate_context().shape[-1])) self.outcome_given_q = np.zeros_like(self.outcome_given_p) self.evaluation_context = np.zeros_like(self.outcome_given_p) @@ -33,7 +32,7 @@ def _collect_data_params(self): self.q[index] = q self.outcome_given_q[index] = self.data.simulator.simulate(q, context) - self.evaluation_context = np.array([self.data.simulator.generate_context() for _ in range(num_simulations)]) + self.evaluation_context = np.array([self.data.simulator.generate_context() for _ in range(self.num_simulations)]) def train_linear_classifier(self, p, q, x_p, x_q, classifier:str, classifier_kwargs:dict={}): classifier_map = { @@ -77,23 +76,25 @@ def _scores(self, p, q, x_p, x_q, classifier, cross_evaluate: bool=True, classif return np.mean(model_probabilities, axis=0) def _cross_eval_score(self, p, q, x_p, x_q, classifier, classifier_kwargs, n_cross_folds=5): - kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42) # TODO get seed from config + kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42) # Getting the shape cv_splits = kf.split(p) - # train classifiers over cv-folds probabilities = [] - for train_index, val_index in cv_splits: + self.evaluation_data = np.zeros((n_cross_folds, len(next(cv_splits)[1]), self.evaluation_context.shape[-1])) + + kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42) + cv_splits = kf.split(p) + for cross_trial, (train_index, val_index) in enumerate(cv_splits): # get train split p_train, x_p_train = p[train_index,:], x_p[train_index,:] q_train, x_q_train = q[train_index,:], x_q[train_index,:] trained_nth_classifier = self.train_linear_classifier(p_train, q_train, x_p_train, x_q_train, classifier, classifier_kwargs) p_evaluate = p[val_index] - evaluation_data = np.zeros((len(val_index), self.evaluation_context.shape[-1])) for index, p_validation in enumerate(p_evaluate): - evaluation_data[index] = self.data.simulator.simulate( + self.evaluation_data[cross_trial][index] = self.data.simulator.simulate( p_validation, self.evaluation_context[val_index][index] ) - probabilities.append(self._eval_model(p_evaluate, evaluation_data, trained_nth_classifier)) + probabilities.append(self._eval_model(p_evaluate, self.evaluation_data[cross_trial], trained_nth_classifier)) return probabilities def permute_data(self, P, Q): @@ -107,7 +108,7 @@ def permute_data(self, P, Q): X = np.concatenate([P, Q], axis=0) X_perm = X[self.data.rng.permutation(np.arange(n_samples * 2))] return X_perm[:n_samples], X_perm[n_samples:] - + def calculate( self, linear_classifier:Union[str, list[str]]='MLP', @@ -165,6 +166,10 @@ def calculate( return probabilities, null def __call__(self, **kwds: Any) -> Any: - self._collect_data_params() + try: + self._collect_data_params() + except NotImplementedError: + pass + self.calculate(**kwds) self._finish() \ No newline at end of file diff --git a/src/plots/__init__.py b/src/plots/__init__.py index b1497ef..b186bc2 100644 --- a/src/plots/__init__.py +++ b/src/plots/__init__.py @@ -5,4 +5,4 @@ from plots.tarp import TARP _all = [CoverageFraction, CDFRanks, Ranks, LocalTwoSampleTest, TARP] -Metrics = {m.__name__: m for m in _all} \ No newline at end of file +Plots = {m.__name__: m for m in _all} \ No newline at end of file diff --git a/src/plots/local_two_sample.py b/src/plots/local_two_sample.py index 0dfa52a..0923763 100644 --- a/src/plots/local_two_sample.py +++ b/src/plots/local_two_sample.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Sequence, Union import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import Normalize @@ -13,39 +13,58 @@ class LocalTwoSampleTest(Display): # https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L133 - def __init__(self, model, data, save:bool, show:bool, out_dir:Optional[str]=None): + def __init__(self, + model, + data, + save:bool, + show:bool, + out_dir:Optional[str]=None, + percentiles: Optional[Sequence] = None, + parameter_names: Optional[Sequence] = None, + parameter_colors: Optional[Sequence]= None, + figure_size: Optional[Sequence] = None, + num_simulations: Optional[int] = None, + colorway: Optional[str]=None): super().__init__(model, data, save, show, out_dir) - + self.percentiles = percentiles if percentiles is not None else get_item("metrics_common", item='percentiles', raise_exception=False) + + self.param_names = parameter_names if parameter_names is not None else get_item("plots_common", item="parameter_labels", raise_exception=False) + self.param_colors = parameter_colors if parameter_colors is not None else get_item("plots_common", item="parameter_colors", raise_exception=False) + self.figure_size = figure_size if figure_size is not None else get_item("plots_common", item="figure_size", raise_exception=False) + + colorway = colorway if colorway is not None else get_item( + "plots_common", "default_colorway", raise_exception=False + ) + self.region_colors = get_hex_colors(n_colors=len(self.percentiles), colorway=colorway) + + num_simulations = num_simulations if num_simulations is not None else get_item( + "metrics_common", "number_simulations", raise_exception=False + ) + self.l2st = l2st(model, data, out_dir, num_simulations) + def _plot_name(self): return "local_C2ST.png" - def _data_setup(self): - self.percentiles = get_item("metrics_common", item='percentiles', raise_exception=False) - self.region_colors = get_hex_colors(n_colors=len(self.percentiles)) - self.probability, self.null_hypothesis_probability = l2st.calculate() - - def _plot_settings(self): - self.param_names = get_item("plots_common", item="parameter_labels", raise_exception=False) - self.param_colors = get_item("plots_common", item="parameter_colors", raise_exception=False) - self.figure_size = get_item("plots_common", item="figure_size", raise_exception=False) - def _make_pairplot_values(self, random_samples): - pp_vals = [np.mean(random_samples <= alpha) for alpha in self.cdf_alphas] + pp_vals = np.array([np.mean(random_samples <= alpha) for alpha in self.cdf_alphas]) return pp_vals def lc2st_pairplot(self, subplot, confidence_region_alpha=0.2): + null_cdf = self._make_pairplot_values([0.5] * len(self.probability)) subplot.plot( - self.cdf_alphas, self._make_pairplot_values([0.5] * len(self.probability)), "--", color="black", + self.cdf_alphas, null_cdf, "--", color="black", label="Theoretical Null CDF" ) - null_hypothesis_pairplot = np.zeros_like(self.null_hypothesis_probability) + null_hypothesis_pairplot = np.zeros((len(self.cdf_alphas), *null_cdf.shape)) + for t in range(len(self.null_hypothesis_probability)): null_hypothesis_pairplot[t] = self._make_pairplot_values(self.null_hypothesis_probability[t]) + for percentile, color in zip(self.percentiles, self.region_colors): - low_null = null_hypothesis_pairplot.quantile(percentile/100, axis=1) - up_null = null_hypothesis_pairplot.quantile((100-percentile)/100, axis=1) + low_null = np.quantile(null_hypothesis_pairplot, percentile/100, axis=1) + up_null = np.quantile(null_hypothesis_pairplot, (100-percentile)/100, axis=1) subplot.fill_between( self.cdf_alphas, @@ -53,101 +72,141 @@ def lc2st_pairplot(self, subplot, confidence_region_alpha=0.2): up_null, color=color, alpha=confidence_region_alpha, - label=f"{percentile}% confidence region", + label=f"{percentile}% Conf. region", ) for prob, label, color in zip(self.probability, self.param_names, self.param_colors): - pairplot_values = self._make_pairplot_values(self, prob) + pairplot_values = self._make_pairplot_values(prob) subplot.plot(self.cdf_alphas, pairplot_values, label=label, color=color) - def probability_intensity(self, subplot, dim, n_bins=20, vmin=0, vmax=1): - - if dim==1: - _, bins, patches = subplot.hist(df_probas.z, n_bins, density=True, color="green") - df_probas["bins"] = np.select( - [df_probas.z <= i for i in bins[1:]], list(range(n_bins)) - ) - # get mean predicted proba for each bin - weights = df_probas.groupby(["bins"]).mean().probas + def probability_intensity(self, subplot, plot_dims, features, n_bins=20): + evaluation_data = self.l2st.evaluation_data + + if len(evaluation_data.shape) >=3: # Used the kfold option + evaluation_data = evaluation_data.reshape(( + evaluation_data.shape[0]*evaluation_data.shape[1], + evaluation_data.shape[-1])) + self.probability = self.probability.ravel() - id = list(set(range(n_bins)) - set(df_probas.bins)) - patches = np.delete(patches, id) - bins = np.delete(bins, id) + if plot_dims==1: - norm = Normalize(vmin=vmin, vmax=vmax) + _, bins, patches = subplot.hist( + evaluation_data[:,features], n_bins, weights=self.probability, density=True, color=self.param_colors[features]) + + eval_bins = np.select( + [evaluation_data[:,features] <= i for i in bins[1:]], list(range(n_bins)) + ) + + # get mean predicted proba for each bin + weights = np.array([self.probability[eval_bins==i].mean() for i in np.unique(eval_bins)]) #df_probas.groupby(["bins"]).mean().probas + colors = plt.get_cmap(self.colorway) for w, p in zip(weights, patches): - p.set_facecolor(cmap(w)) - + p.set_facecolor(colors(w)) # color is mean predicted proba + else: - _, x, y = np.histogram2d(df_probas.z_1, df_probas.z_2, bins=n_bins) - df_probas["bins_x"] = np.select( - [df_probas.z_1 <= i for i in x[1:]], list(range(n_bins)) + + _, x_edges, y_edges, patches = subplot.hist2d( + evaluation_data[:,features[0]], + evaluation_data[:,features[1]], + n_bins, + density=True, color=self.param_colors[features[0]]) + + eval_bins_dim_1 = np.select( + [evaluation_data[:,features[0]] <= i for i in x_edges[1:]], list(range(n_bins)) ) - df_probas["bins_y"] = np.select( - [df_probas.z_2 <= i for i in y[1:]], list(range(n_bins)) + eval_bins_dim_2 = np.select( + [evaluation_data[:,features[1]] <= i for i in y_edges[1:]], list(range(n_bins)) ) - # get mean predicted proba for each bin - prob_mean = df_probas.groupby(["bins_x", "bins_y"]).mean().probas - weights = np.zeros((n_bins, n_bins)) + colors = plt.get_cmap(self.colorway) + + weights = np.empty((n_bins, n_bins)) for i in range(n_bins): for j in range(n_bins): try: - weights[i, j] = prob_mean.loc[i].loc[j] + weights[i, j] = self.probability[np.logical_and(eval_bins_dim_1==i, eval_bins_dim_2==j)].mean() except KeyError: - # if no sample in bin, set color to white - weights[i, j] = np.nan + pass - norm = Normalize(vmin=vmin, vmax=vmax) - for i in range(len(x) - 1): - for j in range(len(y) - 1): - facecolor = cmap(norm(weights.T[j, i])) + for i in range(len(x_edges) - 1): + for j in range(len(y_edges) - 1): + weight = weights[i,j] + facecolor = colors(weight) # if no sample in bin, set color to white - if weights.T[j, i] == np.nan: + if weight == np.nan: facecolor = "white" rect = Rectangle( - (x[i], y[j]), - x[i + 1] - x[i], - y[j + 1] - y[j], - facecolor=facecolor, # color is mean predicted proba + (x_edges[i], y_edges[j]), + x_edges[i + 1] - x_edges[i], + y_edges[j + 1] - y_edges[j], + facecolor=facecolor, edgecolor="none", ) subplot.add_patch(rect) + def _plot(self, use_intensity_plot:bool=True, n_alpha_samples:int=100, confidence_region_alpha:float=0.2, n_intensity_bins:int=20, - intensity_dimension:int=1, - intensity_range:tuple=(0,1), - y_label="", + intensity_dimension:int=2, + intensity_feature_index:Union[int, Sequence[int]]=[0,1], + linear_classifier:Union[str, list[str]]='MLP', + cross_evaluate:bool=True, + n_null_hypothesis_trials=100, + classifier_kwargs:Union[dict, list[dict]]=None, + y_label="Empirical CDF", x_label="", - title="" + title="Local Classifier 2-Sample Test" ): + if use_intensity_plot: + if intensity_dimension not in (1, 2): + raise NotImplementedError("LC2ST Intensity Plot only implemented in 1D and 2D") + + if intensity_dimension == 1: + try: + int(intensity_feature_index) + except TypeError: + raise ValueError(f"Cannot use {intensity_feature_index} to plot, please supply an integer value index.") + + else: + try: + assert len(intensity_feature_index) == intensity_dimension + int(intensity_feature_index[0]) + int(intensity_feature_index[1]) + except (AssertionError, TypeError): + raise ValueError(f"Cannot use {intensity_feature_index} to plot, please supply a list of 2 integer value indices.") + + self.l2st(**{ + "linear_classifier":linear_classifier, + "cross_evaluate": cross_evaluate, + "n_null_hypothesis_trials": n_null_hypothesis_trials, + "classifier_kwargs": classifier_kwargs}) + + self.probability, self.null_hypothesis_probability = self.l2st.output["lc2st_probabilities"], self.l2st.output["lc2st_null_hypothesis_probabilities"] + # Plots to make - # pp_plot_lc2st: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L49 # eval_space_with_proba_intensity: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L133 n_plots = 1 if not use_intensity_plot else 2 - if intensity_dimension and intensity_dimension not in (1,2): - raise NotImplementedError("LC2ST Intensity Plot only implemented in 1D and 2D") - - fig, subplots = plt.subplot(1, n_plots, figsize=self.figure_size) + figure_size = self.figure_size if n_plots==1 else (int(self.figure_size[0]*1.8),self.figure_size[1]) + fig, subplots = plt.subplots(1, n_plots, figsize=figure_size) self.cdf_alphas = np.linspace(0, 1, n_alpha_samples) - self.lc2st_pairplot(subplots[0], confidence_region_alpha=confidence_region_alpha) + + self.lc2st_pairplot(subplots[0] if n_plots == 2 else subplots, confidence_region_alpha=confidence_region_alpha) if use_intensity_plot: self.probability_intensity( subplots[1], intensity_dimension, n_bins=n_intensity_bins, - vmin=intensity_range[0], - vmax=intensity_range[1] + features=intensity_feature_index ) fig.legend() fig.supylabel(y_label) fig.supxlabel(x_label) - fig.set_title(title) + fig.suptitle(title) \ No newline at end of file diff --git a/src/plots/plot.py b/src/plots/plot.py index 0448800..e3ac508 100644 --- a/src/plots/plot.py +++ b/src/plots/plot.py @@ -27,7 +27,6 @@ def __init__( self.model = model self._common_settings() - self._plot_settings() self.plot_name = self._plot_name() def _plot_name(self): @@ -77,6 +76,14 @@ def _finish(self): plt.cla() def __call__(self, **plot_args) -> None: - self._data_setup() + try: + self._data_setup() + except NotImplementedError: + pass + try: + self._plot_settings() + except NotImplementedError: + pass + self._plot(**plot_args) self._finish() diff --git a/src/utils/plotting_utils.py b/src/utils/plotting_utils.py index 2f5ac26..dc138d6 100644 --- a/src/utils/plotting_utils.py +++ b/src/utils/plotting_utils.py @@ -2,7 +2,7 @@ import matplotlib as mpl def get_hex_colors(n_colors:int, colorway:str): - cmap = mpl.cm.get_cmap(colorway) + cmap = mpl.pyplot.get_cmap(colorway) hex_colors = [] arr=np.linspace(0, 1, n_colors) for hit in arr: