Skip to content

Commit

Permalink
Merge pull request #80 from deepskies/example_nb_n_bugfix
Browse files Browse the repository at this point in the history
Example nb n bugfix
  • Loading branch information
voetberg authored Jun 27, 2024
2 parents 4b6c349 + 8fea2de commit 4fd4f48
Show file tree
Hide file tree
Showing 6 changed files with 582 additions and 136 deletions.
681 changes: 558 additions & 123 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions src/deepdiagnostics/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,19 @@ def parser():
# List of metrics (cannot supply specific kwargs)
parser.add_argument(
"--metrics",
nargs="?",
default=list(Defaults["metrics"].keys()),
nargs="+",
default=[],
choices=Metrics.keys(),
help="List of metrics to run. To not run any, supply `--metrics `"
help="List of metrics to run."
)

# List of plots
parser.add_argument(
"--plots",
nargs="?",
default=list(Defaults["plots"].keys()),
nargs="+",
default=[],
choices=Plots.keys(),
help="List of plots to run. To not run any, supply `--plots `"
help="List of plots to run."

)

Expand Down Expand Up @@ -109,7 +109,7 @@ def main():
plots = config.get_section("plots", raise_exception=False)

for metrics_name, metrics_args in metrics.items():
Metrics[metrics_name](model, data, **metrics_args)()
Metrics[metrics_name](model, data, save=True)(**metrics_args)

for plot_name, plot_args in plots.items():
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(
Expand Down
8 changes: 7 additions & 1 deletion src/deepdiagnostics/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
from deepdiagnostics.metrics.coverage_fraction import CoverageFraction
from deepdiagnostics.metrics.local_two_sample import LocalTwoSampleTest as LC2ST

def void(*args, **kwargs):
def void2(*args, **kwargs):
return None
return void2


Metrics = {
"": lambda **kwargs: None,
"": void,
CoverageFraction.__name__: CoverageFraction,
AllSBC.__name__: AllSBC,
"LC2ST": LC2ST
Expand Down
8 changes: 7 additions & 1 deletion src/deepdiagnostics/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@
from deepdiagnostics.plots.parity import Parity
from deepdiagnostics.plots.predictive_prior_check import PriorPC

def void(*args, **kwargs):
def void2(*args, **kwargs):
return None
return void2


Plots = {
"": lambda **kwargs: None,
"": void,
CDFRanks.__name__: CDFRanks,
CoverageFraction.__name__: CoverageFraction,
Ranks.__name__: Ranks,
Expand Down
5 changes: 2 additions & 3 deletions src/deepdiagnostics/plots/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,9 @@ def __init__(
"plots_common", "default_colorway", raise_exception=False
)

if save:
self.out_dir = out_dir if out_dir is not None else get_item("common", "out_dir", raise_exception=False)
self.out_dir = out_dir if out_dir is not None else get_item("common", "out_dir", raise_exception=False)

if self.out_dir is not None:
if self.out_dir is not None and self.save:
if not os.path.exists(os.path.dirname(self.out_dir)):
os.makedirs(os.path.dirname(self.out_dir))

Expand Down
2 changes: 1 addition & 1 deletion src/deepdiagnostics/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"common": {
"out_dir": "./DeepDiagnosticsResources/results/",
"temp_config": "./DeepDiagnosticsResources/temp/temp_config.yml",
"sim_location": "deepdiagnosticsResources/simulators",
"sim_location": "./DeepDiagnosticsResources/simulators",
"random_seed": 42,
},
"model": {"model_engine": "SBIModel"},
Expand Down

0 comments on commit 4fd4f48

Please sign in to comment.