From 2c3953016fdc56693477237c41d4c3a17d4c65b3 Mon Sep 17 00:00:00 2001 From: voetberg Date: Mon, 24 Jun 2024 10:56:02 -0500 Subject: [PATCH] Merge resolve --- src/plots/__init__.py | 5 +++-- src/utils/defaults.py | 5 +++-- tests/test_plots.py | 19 +++++++++++++++---- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/plots/__init__.py b/src/plots/__init__.py index daa2baa..fb18498 100644 --- a/src/plots/__init__.py +++ b/src/plots/__init__.py @@ -6,6 +6,7 @@ from plots.predictive_posterior_check import PPC from plots.predictive_prior_check import PriorPC from plots.parity import Parity +from plots.predictive_prior_check import PriorPC Plots = { CDFRanks.__name__: CDFRanks, @@ -14,6 +15,6 @@ TARP.__name__: TARP, "LC2ST": LocalTwoSampleTest, PPC.__name__: PPC, - PriorPC.__name__: PriorPC, - "Parity": Parity + "Parity": Parity, + PriorPC.__name__: PriorPC } diff --git a/src/utils/defaults.py b/src/utils/defaults.py index e79a4c4..b132951 100644 --- a/src/utils/defaults.py +++ b/src/utils/defaults.py @@ -31,9 +31,10 @@ "coverage_sigma": 3 }, "LC2ST": {}, + "Parity":{}, "PPC": {}, - "PriorPC":{}, - "Parity":{} + "PriorPC":{} + }, "metrics_common": { "use_progress_bar": False, diff --git a/tests/test_plots.py b/tests/test_plots.py index 944c2ed..161786c 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -11,7 +11,7 @@ CoverageFraction, TARP, LocalTwoSampleTest, - PPC, + PPC, PriorPC, Parity ) @@ -80,10 +80,7 @@ def test_ppc(plot_config, mock_model, mock_data, mock_2d_data, result_output): plot = PPC(mock_model, mock_data, save=True, show=False) plot(**get_item("plots", "PPC", raise_exception=False)) assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}") -<<<<<<< HEAD -======= ->>>>>>> 9d57556 (Update existing plots and metrics to work with 2d #72) plot = PPC( mock_model, mock_2d_data, save=True, show=False, @@ -98,6 +95,20 @@ def test_prior_pc(plot_config, mock_model, mock_data): plot(**get_item("plots", "PriorPC", raise_exception=False)) assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}") + +def test_prior_pc(plot_config, mock_model, mock_2d_data, mock_data, result_output): + Config(plot_config) + plot = PriorPC(mock_model, mock_data, save=True, show=False) + plot(**get_item("plots", "PriorPC", raise_exception=False)) + assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}") + plot = PPC( + mock_model, + mock_2d_data, save=True, show=False, + out_dir=f"{result_output.strip('/')}/mock_2d/") + assert type(plot.data.simulator).__name__ == "Mock2DSimulator" + plot(**get_item("plots", "PPC", raise_exception=False)) + assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}") + def test_parity(plot_config, mock_model, mock_data): Config(plot_config) plot = Parity(mock_model, mock_data, save=True, show=False)