Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warning that simulator is missing #81

Merged
merged 2 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/deepdiagnostics/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from deepdiagnostics.models import ModelModules
from deepdiagnostics.metrics import Metrics
from deepdiagnostics.plots import Plots
from deepdiagnostics.utils.simulator_utils import SimulatorMissingError


def parser():
Expand Down Expand Up @@ -109,9 +110,15 @@ def main():
plots = config.get_section("plots", raise_exception=False)

for metrics_name, metrics_args in metrics.items():
Metrics[metrics_name](model, data, save=True)(**metrics_args)
try:
Metrics[metrics_name](model, data, save=True)(**metrics_args)
except SimulatorMissingError:
print(f"Cannot run {metrics_name} - simulator missing.")

for plot_name, plot_args in plots.items():
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(
**plot_args
)
try:
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(
**plot_args
)
except SimulatorMissingError:
print(f"Cannot run {plot_name} - simulator missing.")
8 changes: 6 additions & 2 deletions src/deepdiagnostics/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np

from deepdiagnostics.utils.config import get_item
from deepdiagnostics.utils.register import load_simulator
from deepdiagnostics.utils.simulator_utils import load_simulator

class Data:
"""
Expand Down Expand Up @@ -35,7 +35,11 @@ def __init__(
get_item("common", "random_seed", raise_exception=False)
)
self.data = self._load(path)
self.simulator = load_simulator(simulator_name, simulator_kwargs)
try:
self.simulator = load_simulator(simulator_name, simulator_kwargs)
except RuntimeError:
print("Warning: Simulator not loaded. Can only run non-generative metrics.")

self.prior_dist = self.load_prior(prior, prior_kwargs)
self.n_dims = self.get_theta_true().shape[1]
self.simulator_dimensions = simulation_dimensions if simulation_dimensions is not None else get_item("data", "simulator_dimensions", raise_exception=False)
Expand Down
3 changes: 1 addition & 2 deletions src/deepdiagnostics/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ def void2(*args, **kwargs):
return None
return void2


Metrics = {
"": void,
"": void,
CoverageFraction.__name__: CoverageFraction,
AllSBC.__name__: AllSBC,
"LC2ST": LC2ST
Expand Down
3 changes: 3 additions & 0 deletions src/deepdiagnostics/metrics/local_two_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sklearn.utils import shuffle

from deepdiagnostics.metrics.metric import Metric
from deepdiagnostics.utils.simulator_utils import SimulatorMissingError

class LocalTwoSampleTest(Metric):
"""
Expand Down Expand Up @@ -46,6 +47,8 @@ def __init__(
percentiles,
number_simulations
)
if not hasattr(self.data, "simulator"):
raise SimulatorMissingError("Missing a simulator to run LC2ST.")

def _collect_data_params(self):
# P is the prior and x_P is generated via the simulator from the parameters P.
Expand Down
2 changes: 1 addition & 1 deletion src/deepdiagnostics/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from deepdiagnostics.plots.parity import Parity
from deepdiagnostics.plots.predictive_prior_check import PriorPC


def void(*args, **kwargs):
def void2(*args, **kwargs):
return None
return void2


Plots = {
"": void,
CDFRanks.__name__: CDFRanks,
Expand Down
3 changes: 3 additions & 0 deletions src/deepdiagnostics/plots/predictive_posterior_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from deepdiagnostics.plots.plot import Display
from deepdiagnostics.utils.plotting_utils import get_hex_colors
from deepdiagnostics.utils.simulator_utils import SimulatorMissingError

class PPC(Display):
"""
Expand Down Expand Up @@ -33,6 +34,8 @@ def __init__(
colorway =None):

super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway)
if not hasattr(self.data, "simulator"):
raise SimulatorMissingError("Missing a simulator to run PPC.")

def plot_name(self):
return "predictive_posterior_check.png"
Expand Down
5 changes: 4 additions & 1 deletion src/deepdiagnostics/plots/predictive_prior_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from deepdiagnostics.plots.plot import Display
from deepdiagnostics.utils.simulator_utils import SimulatorMissingError

class PriorPC(Display):
"""
Expand Down Expand Up @@ -36,7 +37,9 @@ def __init__(
colorway = None):

super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway)

if not hasattr(self.data, "simulator"):
raise SimulatorMissingError("Missing a simulator to run PriorPC.")

if self.data.simulator_dimensions == 1:
self.plot_image = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,8 @@ def load_simulator(name, simulator_kwargs):
"Simulator improperly formed - requires a simulate method."
)

return simulator_instance
return simulator_instance


class SimulatorMissingError(Exception):
pass
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from deepdiagnostics.data.simulator import Simulator
from deepdiagnostics.models import SBIModel
from deepdiagnostics.utils.config import get_item
from deepdiagnostics.utils.register import register_simulator
from deepdiagnostics.utils.simulator_utils import register_simulator


class MockSimulator(Simulator):
Expand Down
22 changes: 22 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,25 @@ def test_main_missing_args(model_path):
process = subprocess.run(command)
exit_code = process.returncode
assert exit_code == 1


def test_missing_simulator(model_path, data_path):
command = [
"diagnose",
"--model_path",
model_path,
"--data_path",
data_path,
"--simulator",
"Not_A_Registered_Name",
"--plots",
"PPC",
"--metrics",
""
]
process = subprocess.run(command, capture_output=True)
exit_code = process.returncode
stdout = process.stdout.decode("utf-8")
assert exit_code == 0
plot_name = "PPC"
assert f"Cannot run {plot_name} - simulator missing." in stdout
Loading