From 6908fa5633bb77eae695a9fd0bf1afc44bf4574c Mon Sep 17 00:00:00 2001 From: JGarciaCondado Date: Wed, 3 Apr 2024 13:33:33 +0200 Subject: [PATCH] [ENH] Fix bug with figure directory creation and correct string in figures --- src/ageml/ui.py | 3 --- src/ageml/visualizer.py | 4 ++-- tests/test_ageml/test_ui.py | 6 ++++++ 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/ageml/ui.py b/src/ageml/ui.py index be784c3..d403f5a 100644 --- a/src/ageml/ui.py +++ b/src/ageml/ui.py @@ -125,9 +125,6 @@ def __init__(self, args): # Set up directory for storage of results self.setup() - # Initialise objects form library - self.set_visualizer(self.dir_path) - def setup(self): """Create main directory.""" diff --git a/src/ageml/visualizer.py b/src/ageml/visualizer.py index 36eeef0..cf4853b 100644 --- a/src/ageml/visualizer.py +++ b/src/ageml/visualizer.py @@ -158,7 +158,7 @@ def true_vs_pred_age(self, y_true, y_pred, tag: NameTag): # Plot true vs predicted age plt.scatter(y_true, y_pred) plt.plot(age_range, age_range, color="k", linestyle="dashed") - plt.title(f"Chronological vs Predicted Age \n [Covariate: {tag.covar}, System:{tag.system}") + plt.title(f"Chronological vs Predicted Age \n [Covariate: {tag.covar}, System: {tag.system}]") plt.xlabel("Chronological Age") plt.ylabel("Predicted Age") @@ -210,7 +210,7 @@ def age_bias_correction(self, y_true, y_pred, y_corrected, tag: NameTag): filename = (f"age_bias_correction" f"{'_' + tag.covar if tag.covar != '' else ''}" f"{'_' + tag.system if tag.system != '' else ''}.png") - plt.suptitle(f"[Covariate: {tag.covar}, System:{tag.system}\n", y=1.00) + plt.suptitle(f"[Covariate: {tag.covar}, System: {tag.system}]\n", y=1.00) plt.savefig(os.path.join(self.path_for_fig, filename)) plt.close() diff --git a/tests/test_ageml/test_ui.py b/tests/test_ageml/test_ui.py index cf58ee2..4ffefca 100644 --- a/tests/test_ageml/test_ui.py +++ b/tests/test_ageml/test_ui.py @@ -389,9 +389,15 @@ def test_load_data_different_indexes_warning(dummy_interface, features, clinical def test_age_distribution_warning(dummy_interface): + + # Create different distributions dist1 = np.random.normal(loc=50, scale=1, size=100) dist2 = np.random.normal(loc=0, scale=1, size=100) dists = {'dist1': dist1, 'dist2': dist2} + + # Set visualized in temporal directory + dummy_interface.set_visualizer(tempfile.mkdtemp()) + with pytest.warns(UserWarning) as warn_record: dummy_interface.age_distribution(dists) assert isinstance(warn_record.list[0].message, UserWarning)