Skip to content

Commit

Permalink
Merge pull request #63 from voetberg/two_sample_test
Browse files Browse the repository at this point in the history
Plotting and Metric for Local classifier 2 sample test
  • Loading branch information
bnord authored May 21, 2024
2 parents c6076c7 + a3b32cc commit c889c23
Show file tree
Hide file tree
Showing 16 changed files with 459 additions and 211 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ numpy = "^1.26.4"
matplotlib = "^3.8.3"
tarp = "^0.1.1"
deprecation = "^2.1.0"
scipy = "1.12.0"


[tool.poetry.group.dev.dependencies]
Expand Down
14 changes: 10 additions & 4 deletions src/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,15 @@ def main():
plots = config.get_section("plots", raise_exception=False)

for metrics_name, metrics_args in metrics.items():
Metrics[metrics_name](model, data, **metrics_args)()
try:
Metrics[metrics_name](model, data, **metrics_args)()
except (NotImplementedError, RuntimeError) as error:
print(f"WARNING - skipping metric {metrics_name} due to error: {error}")

for plot_name, plot_args in plots.items():
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(
**plot_args
)
try:
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(
**plot_args
)
except (NotImplementedError, RuntimeError) as error:
print(f"WARNING - skipping plot {plot_name} due to error: {error}")
2 changes: 1 addition & 1 deletion src/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,4 @@ def load_prior(self, prior, prior_kwargs):
return lambda size: choices[prior](**prior_kwargs, size=size)

except KeyError as e:
raise RuntimeError(f"Data missing a prior specification - {e}")
raise RuntimeError(f"Data missing a prior specification - {e}")
3 changes: 2 additions & 1 deletion src/data/h5_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
class H5Data(Data):
def __init__(self, path: str, simulator: Callable):
super().__init__(path, simulator)

self.theta_true = self.get_theta_true()

def _load(self, path):
assert path.split(".")[-1] == "h5", "File extension must be h5"
loaded_data = {}
Expand Down
5 changes: 4 additions & 1 deletion src/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from metrics.all_sbc import AllSBC
from metrics.coverage_fraction import CoverageFraction
from metrics.local_two_sample import LocalTwoSampleTest

Metrics = {CoverageFraction.__name__: CoverageFraction, AllSBC.__name__: AllSBC}

_all = [CoverageFraction, AllSBC, LocalTwoSampleTest]
Metrics = {m.__name__: m for m in _all}
175 changes: 175 additions & 0 deletions src/metrics/local_two_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from typing import Any, Optional, Union
import numpy as np

from sklearn.model_selection import KFold
from sklearn.neural_network import MLPClassifier
from sklearn.utils import shuffle

from metrics.metric import Metric
from utils.config import get_item

class LocalTwoSampleTest(Metric):
def __init__(self, model: Any, data: Any, out_dir: str | None = None, num_simulations: Optional[int] = None) -> None:
super().__init__(model, data, out_dir)
self.num_simulations = num_simulations if num_simulations is not None else get_item(
"metrics_common", "number_simulations", raise_exception=False
)
def _collect_data_params(self):

# P is the prior and x_P is generated via the simulator from the parameters P.
self.p = self.data.sample_prior(self.num_simulations)
self.q = np.zeros_like(self.p)

self.outcome_given_p = np.zeros((self.num_simulations, self.data.simulator.generate_context().shape[-1]))
self.outcome_given_q = np.zeros_like(self.outcome_given_p)
self.evaluation_context = np.zeros_like(self.outcome_given_p)

for index, p in enumerate(self.p):
context = self.data.simulator.generate_context()
self.outcome_given_p[index] = self.data.simulator.simulate(p, context)
# Q is the approximate posterior amortized in x
q = self.model.sample_posterior(1, context).ravel()
self.q[index] = q
self.outcome_given_q[index] = self.data.simulator.simulate(q, context)

self.evaluation_context = np.array([self.data.simulator.generate_context() for _ in range(self.num_simulations)])

def train_linear_classifier(self, p, q, x_p, x_q, classifier:str, classifier_kwargs:dict={}):
classifier_map = {
"MLP":MLPClassifier
}
try:
classifier = classifier_map[classifier](**classifier_kwargs)
except KeyError:
raise NotImplementedError(
f"{classifier} not implemented, choose from {list(classifier_map.keys())}.")

joint_P_x = np.concatenate([p, x_p], axis=1)
joint_Q_x = np.concatenate([q, x_q], axis=1)

features = np.concatenate([joint_P_x, joint_Q_x], axis=0)
labels = np.concatenate(
[np.array([0] * len(joint_P_x)), np.array([1] * len(joint_Q_x))]
).ravel()

# shuffle features and labels
features, labels = shuffle(features, labels)

# train the classifier
classifier.fit(X=features, y=labels)
return classifier

def _eval_model(self, P, evaluation_sample, classifier):
evaluation = np.concatenate([P, evaluation_sample], axis=1)
probability = classifier.predict_proba(evaluation)[:, 0]
return probability

def _scores(self, p, q, x_p, x_q, classifier, cross_evaluate: bool=True, classifier_kwargs=None):
model_probabilities = []
for model, model_args in zip(classifier, classifier_kwargs):
if cross_evaluate:
model_probabilities.append(self._cross_eval_score(p, q, x_p, x_q, model, model_args))
else:
trained_model = self.train_linear_classifier(p, q, x_p, x_q, model, model_args)
model_probabilities.append(self._eval_model(P=p, classifier=trained_model))

return np.mean(model_probabilities, axis=0)

def _cross_eval_score(self, p, q, x_p, x_q, classifier, classifier_kwargs, n_cross_folds=5):
kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42) # Getting the shape
cv_splits = kf.split(p)
# train classifiers over cv-folds
probabilities = []
self.evaluation_data = np.zeros((n_cross_folds, len(next(cv_splits)[1]), self.evaluation_context.shape[-1]))

kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42)
cv_splits = kf.split(p)
for cross_trial, (train_index, val_index) in enumerate(cv_splits):
# get train split
p_train, x_p_train = p[train_index,:], x_p[train_index,:]
q_train, x_q_train = q[train_index,:], x_q[train_index,:]
trained_nth_classifier = self.train_linear_classifier(p_train, q_train, x_p_train, x_q_train, classifier, classifier_kwargs)
p_evaluate = p[val_index]
for index, p_validation in enumerate(p_evaluate):
self.evaluation_data[cross_trial][index] = self.data.simulator.simulate(
p_validation, self.evaluation_context[val_index][index]
)
probabilities.append(self._eval_model(p_evaluate, self.evaluation_data[cross_trial], trained_nth_classifier))
return probabilities

def permute_data(self, P, Q):
"""Permute the concatenated data [P,Q] to create null-hyp samples.
Args:
P (torch.Tensor): data of shape (n_samples, dim)
Q (torch.Tensor): data of shape (n_samples, dim)
"""
n_samples = P.shape[0]
X = np.concatenate([P, Q], axis=0)
X_perm = X[self.data.rng.permutation(np.arange(n_samples * 2))]
return X_perm[:n_samples], X_perm[n_samples:]

def calculate(
self,
linear_classifier:Union[str, list[str]]='MLP',
cross_evaluate:bool=True,
n_null_hypothesis_trials=100,
classifier_kwargs:Union[dict, list[dict]]=None
):

if isinstance(linear_classifier, str):
linear_classifier = [linear_classifier]

if classifier_kwargs is None:
classifier_kwargs = {}
if isinstance(classifier_kwargs, dict):
classifier_kwargs = [classifier_kwargs]

probabilities = self._scores(
self.p,
self.q,
self.outcome_given_p,
self.outcome_given_q,
classifier=linear_classifier,
cross_evaluate=cross_evaluate,
classifier_kwargs=classifier_kwargs
)
null_hypothesis_probabilities = []
for _ in range(n_null_hypothesis_trials):
joint_P_x = np.concatenate([self.p, self.outcome_given_p], axis=1)
joint_Q_x = np.concatenate([self.q, self.outcome_given_q], axis=1)
joint_P_x_perm, joint_Q_x_perm = self.permute_data(
joint_P_x, joint_Q_x,
)
p_null = joint_P_x_perm[:, : self.p.shape[-1]]
p_given_x_null = joint_P_x_perm[:, self.p.shape[-1] :]
q_null = joint_Q_x_perm[:, : self.q.shape[-1]]
q_given_x_null = joint_Q_x_perm[:, self.q.shape[-1] :]

null_result = self._scores(
p_null,
q_null,
p_given_x_null,
q_given_x_null,
classifier=linear_classifier,
cross_evaluate=cross_evaluate,
classifier_kwargs=classifier_kwargs
)

null_hypothesis_probabilities.append(null_result)

null = np.array(null_hypothesis_probabilities)
self.output = {
"lc2st_probabilities": probabilities,
"lc2st_null_hypothesis_probabilities": null
}
return probabilities, null

def __call__(self, **kwds: Any) -> Any:
try:
self._collect_data_params()
except NotImplementedError:
pass

self.calculate(**kwds)
self._finish()
4 changes: 3 additions & 1 deletion src/models/sbi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def sample_posterior(self, n_samples: int, y_true): # TODO typing
def predict_posterior(self, data):
posterior_samples = self.sample_posterior(data.y_true)
posterior_predictive_samples = data.simulator(
data.theta_true(), posterior_samples
data.get_theta_true(), posterior_samples
)
return posterior_predictive_samples


9 changes: 3 additions & 6 deletions src/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from plots.cdf_ranks import CDFRanks
from plots.coverage_fraction import CoverageFraction
from plots.ranks import Ranks
from plots.local_two_sample import LocalTwoSampleTest
from plots.tarp import TARP

Plots = {
CDFRanks.__name__: CDFRanks,
CoverageFraction.__name__: CoverageFraction,
Ranks.__name__: Ranks,
TARP.__name__: TARP,
}
_all = [CoverageFraction, CDFRanks, Ranks, LocalTwoSampleTest, TARP]
Plots = {m.__name__: m for m in _all}
Loading

0 comments on commit c889c23

Please sign in to comment.