From 49902a38e771ebd5510cd7b9066a717439901bd0 Mon Sep 17 00:00:00 2001 From: itellaetxe Date: Wed, 9 Oct 2024 12:09:41 +0200 Subject: [PATCH] BUG: Fixed uneven dataframe creation Now common indices are computed as shown in issue #55, which is more robust --- src/ageml/ui.py | 31 +++---- tests/test_ageml/test_ui.py | 167 ++++++++++++++---------------------- 2 files changed, 80 insertions(+), 118 deletions(-) diff --git a/src/ageml/ui.py b/src/ageml/ui.py index 033649d..be39854 100644 --- a/src/ageml/ui.py +++ b/src/ageml/ui.py @@ -611,22 +611,23 @@ def remove_missing_data(self): warnings.warn(warn_message, category=UserWarning) dfs[label] = df.drop(missing_subjects) - # Check that all dataframes have the same subjects + # Compute the intersection of the indices of the dataframes + indices_collection = [set(df.index) for df in dfs.values()] + shared_idx = list(indices_collection[0].intersection(*indices_collection[1:])) + # Remove subjects not shared among dataframes, report subjects left for analysis, and set dataframes print("Removing subjects not shared among dataframes...") - for l1, df1 in dfs.items(): - for l2, df2 in dfs.items(): - if l1 != l2: - non_shared_idx = df1.index[~df1.index.isin(set(df2.index.to_list()))] - if non_shared_idx.__len__() != 0: - warn_message = "Subjects in dataframe %s not in dataframe %s: %s" % (l1, l2, non_shared_idx.to_list()) - print(warn_message) - warnings.warn(warn_message, category=UserWarning) - dfs[l1] = df1.drop(non_shared_idx) - - # Set dataframes - for label, df in dfs.items(): - print("Final number of subjects in dataframe %s: %d (%.2f %% of initial)" % (label, len(df), len(df) / init_count[label] * 100)) - setattr(self, f"df_{label}", df) + for label in dfs.keys(): + removed_subjects = set(dfs[label].index) - set(shared_idx) + warn_message = f"{len(removed_subjects)} subjects removed from {label} dataframe: {removed_subjects}" + dfs[label] = dfs[label].loc[shared_idx] + print(warn_message) + warnings.warn(warn_message, category=UserWarning) + setattr(self, f"df_{label}", dfs[label]) + msg = ( + f"Final number of subjects in dataframe {label}: {len(shared_idx)} " + f"({len(shared_idx) / init_count[label] * 100:.2f} % of initial)" + ) + print(msg) def load_data(self, required=None): """Load data from csv files. diff --git a/tests/test_ageml/test_ui.py b/tests/test_ageml/test_ui.py index 0bff5ef..96c92c6 100644 --- a/tests/test_ageml/test_ui.py +++ b/tests/test_ageml/test_ui.py @@ -216,10 +216,8 @@ def factors(): def covariates(): df = pd.DataFrame( { - "id": [1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], - "sex": [0, 1, 1, 1, 0, 0, 0, 1, 1, 0, - 0, 0, 1, 1, 0, 0, 0, 1, 1, 0], + "id": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], + "sex": [0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0], } ) df.set_index("id", inplace=True) @@ -235,12 +233,9 @@ def systems(): def clinical(): df = pd.DataFrame( { - "id": [1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], - "CN": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, - 1, 0, 1, 0, 1, 0, 1, 0, 1, 0], - "group1": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + "id": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], + "CN": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0], + "group1": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], } ) df.set_index("id", inplace=True) @@ -251,16 +246,11 @@ def clinical(): def ages(): df = pd.DataFrame( { - "id": [1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], - "age": [50, 55, 60, 65, 70, 75, 80, 85, 90, 57, - 53, 57, 61, 65, 69, 73, 77, 81, 85, 89], - "predicted_age_all": [55, 67, 57, 75, 85, 64, 87, 93, 49, 51, - 58, 73, 80, 89, 55, 67, 57, 75, 85, 64], - "corrected_age_all": [51, 58, 73, 80, 89, 67, 57, 75, 85, 64, - 87, 93, 49, 55, 67, 57, 75, 85, 64, 87], - "delta_all": [1, -2, 3, 0, -1, 2, 1, 0, -3, 1, - 2, 1, 0, -1, 2, 1, 0, -3, 1, 2], + "id": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], + "age": [50, 55, 60, 65, 70, 75, 80, 85, 90, 57, 53, 57, 61, 65, 69, 73, 77, 81, 85, 89], + "predicted_age_all": [55, 67, 57, 75, 85, 64, 87, 93, 49, 51, 58, 73, 80, 89, 55, 67, 57, 75, 85, 64], + "corrected_age_all": [51, 58, 73, 80, 89, 67, 57, 75, 85, 64, 87, 93, 49, 55, 67, 57, 75, 85, 64, 87], + "delta_all": [1, -2, 3, 0, -1, 2, 1, 0, -3, 1, 2, 1, 0, -1, 2, 1, 0, -3, 1, 2], } ) df.set_index("id", inplace=True) @@ -271,22 +261,14 @@ def ages(): def ages_multisystem(): df = pd.DataFrame( { - "id": [1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], - "age": [50, 55, 60, 65, 70, 75, 80, 85, 90, 57, - 53, 57, 61, 65, 69, 73, 77, 81, 85, 89], - "predicted_age_pottongosystem": [55, 67, 57, 75, 85, 64, 87, 93, 49, 51, - 58, 73, 80, 89, 55, 67, 57, 75, 85, 64], - "corrected_age_pottongosystem": [51, 58, 73, 80, 89, 67, 57, 75, 85, 64, - 87, 93, 49, 55, 67, 57, 75, 85, 64, 87], - "delta_pottongosystem": [1, -2, 3, 0, -1, 2, 1, 0, -3, 1, - 2, 1, 0, -1, 2, 1, 0, -3, 1, 2], - "predicted_age_mondongsystem": [55, 67, 57, 75, 85, 64, 87, 93, 49, 51, - 58, 73, 80, 89, 55, 67, 57, 75, 85, 64], - "corrected_age_mondongsystem": [51, 58, 73, 80, 89, 67, 57, 75, 85, 64, - 87, 93, 49, 55, 67, 57, 75, 85, 64, 87], - "delta_mondongsystem": [1, -2, 3, 0, -1, 2, 1, 0, -3, 1, - 2, 1, 0, -1, 2, 1, 0, -3, 1, 2], + "id": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], + "age": [50, 55, 60, 65, 70, 75, 80, 85, 90, 57, 53, 57, 61, 65, 69, 73, 77, 81, 85, 89], + "predicted_age_pottongosystem": [55, 67, 57, 75, 85, 64, 87, 93, 49, 51, 58, 73, 80, 89, 55, 67, 57, 75, 85, 64], + "corrected_age_pottongosystem": [51, 58, 73, 80, 89, 67, 57, 75, 85, 64, 87, 93, 49, 55, 67, 57, 75, 85, 64, 87], + "delta_pottongosystem": [1, -2, 3, 0, -1, 2, 1, 0, -3, 1, 2, 1, 0, -1, 2, 1, 0, -3, 1, 2], + "predicted_age_mondongsystem": [55, 67, 57, 75, 85, 64, 87, 93, 49, 51, 58, 73, 80, 89, 55, 67, 57, 75, 85, 64], + "corrected_age_mondongsystem": [51, 58, 73, 80, 89, 67, 57, 75, 85, 64, 87, 93, 49, 55, 67, 57, 75, 85, 64, 87], + "delta_mondongsystem": [1, -2, 3, 0, -1, 2, 1, 0, -3, 1, 2, 1, 0, -1, 2, 1, 0, -3, 1, 2], } ) df.set_index("id", inplace=True) @@ -433,7 +415,7 @@ def test_load_factors_not_float(dummy_interface, factors): def test_load_data_covariates_not_float(dummy_interface, covariates): # Change item to string - covariates.loc[2, 'sex'] = 'asdf' + covariates.loc[2, "sex"] = "asdf" covariates_path = create_csv(covariates, dummy_interface.dir_path) dummy_interface.args.covariates = covariates_path @@ -542,7 +524,6 @@ def test_load_data_clinical_empty_column(dummy_interface, clinical): def test_load_data_clinical_empty_row(dummy_interface, clinical): - # Make a row all False clinical.loc[2, :] = 0 clinical_path = create_csv(clinical, dummy_interface.dir_path) @@ -571,8 +552,10 @@ def test_load_data_nan_values_warning(dummy_interface, features): def test_load_data_different_indexes_warning(dummy_interface, features, clinical): # Drop subjects 2 and 3 from features - features.drop([2, 3], axis=0, inplace=True) - clinical.drop([4], axis=0, inplace=True) + drop_features = [2, 3] + drop_clinical = [4] + features.drop(drop_features, axis=0, inplace=True) + clinical.drop(drop_clinical, axis=0, inplace=True) features_path = create_csv(features, dummy_interface.dir_path) clinical_path = create_csv(clinical, dummy_interface.dir_path) dummy_interface.args.features = features_path @@ -581,23 +564,19 @@ def test_load_data_different_indexes_warning(dummy_interface, features, clinical with pytest.warns(UserWarning) as warn_record: dummy_interface.load_data() assert isinstance(warn_record.list[0].message, UserWarning) - expected = "Subjects in dataframe features not in dataframe clinical: [%d]" % (4) + expected = f"1 subjects removed from features dataframe: {set(drop_clinical)}" assert warn_record.list[0].message.args[0] == expected assert isinstance(warn_record.list[1].message, UserWarning) - expected = "Subjects in dataframe clinical not in dataframe features: [%d, %d]" % ( - 2, - 3, - ) + expected = f"2 subjects removed from clinical dataframe: {set(drop_features)}" assert warn_record.list[1].message.args[0] == expected def test_age_distribution_warning(dummy_interface): - # Create different distributions dist1 = np.random.normal(loc=50, scale=1, size=100) dist2 = np.random.normal(loc=0, scale=1, size=100) - dists = {'dist1': dist1, 'dist2': dist2} - + dists = {"dist1": dist1, "dist2": dist2} + # Set visualized in temporal directory dummy_interface.set_visualizer(tempfile.mkdtemp()) @@ -624,9 +603,7 @@ def test_run_age(dummy_interface, features): "features_vs_age_controls_all", "chronological_vs_pred_age_all_all", ] - svg_paths = [ - os.path.join(dummy_interface.dir_path, f"model_age/figures/{fig}.png") for fig in figs - ] + svg_paths = [os.path.join(dummy_interface.dir_path, f"model_age/figures/{fig}.png") for fig in figs] print(os.listdir(os.path.join(dummy_interface.dir_path, "model_age/figures"))) assert all([os.path.exists(svg_path) for svg_path in svg_paths]) @@ -662,9 +639,7 @@ def test_run_age_clinical(dummy_interface, features, clinical): "features_vs_age_controls_all", "chronological_vs_pred_age_all_all", ] - svg_paths = [ - os.path.join(dummy_interface.dir_path, f"model_age/figures/{fig}.png") for fig in figs - ] + svg_paths = [os.path.join(dummy_interface.dir_path, f"model_age/figures/{fig}.png") for fig in figs] assert all([os.path.exists(svg_path) for svg_path in svg_paths]) # Check for the existence of the log @@ -696,19 +671,20 @@ def test_run_age_cov(dummy_interface, features, covariates): assert os.path.exists(dummy_interface.dir_path) # Check for output figs - figs = ["age_bias_correction_sex_0_all", - "age_bias_correction_sex_1_all", - "chronological_vs_pred_age_sex_0_all", - "chronological_vs_pred_age_sex_1_all", - "age_distribution_controls", - "features_vs_age_controls_all"] + figs = [ + "age_bias_correction_sex_0_all", + "age_bias_correction_sex_1_all", + "chronological_vs_pred_age_sex_0_all", + "chronological_vs_pred_age_sex_1_all", + "age_distribution_controls", + "features_vs_age_controls_all", + ] # Print files in path svg_paths = [os.path.join(dummy_interface.dir_path, f"model_age/figures/{fig}.png") for fig in figs] assert all([os.path.exists(svg_path) for svg_path in svg_paths]) # Check for the existence of the output CSV - csv_path = os.path.join(dummy_interface.dir_path, - f"model_age/predicted_age_{dummy_interface.args.covar_name}.csv") + csv_path = os.path.join(dummy_interface.dir_path, f"model_age/predicted_age_{dummy_interface.args.covar_name}.csv") assert os.path.exists(csv_path) # Check that the output CSV has the right columns @@ -735,18 +711,19 @@ def test_run_age_cov_clinical(dummy_interface, features, covariates, clinical): assert os.path.exists(dummy_interface.dir_path) # Check for output figs - figs = ["age_bias_correction_sex_0_all", - "age_bias_correction_sex_1_all", - "chronological_vs_pred_age_sex_0_all", - "chronological_vs_pred_age_sex_1_all", - "age_distribution_controls", - "features_vs_age_controls_all"] + figs = [ + "age_bias_correction_sex_0_all", + "age_bias_correction_sex_1_all", + "chronological_vs_pred_age_sex_0_all", + "chronological_vs_pred_age_sex_1_all", + "age_distribution_controls", + "features_vs_age_controls_all", + ] svg_paths = [os.path.join(dummy_interface.dir_path, f"model_age/figures/{fig}.png") for fig in figs] assert all([os.path.exists(svg_path) for svg_path in svg_paths]) # Check for the existence of the output CSV - csv_path = os.path.join(dummy_interface.dir_path, - f"model_age/predicted_age_{dummy_interface.args.covar_name}.csv") + csv_path = os.path.join(dummy_interface.dir_path, f"model_age/predicted_age_{dummy_interface.args.covar_name}.csv") assert os.path.exists(csv_path) # Check that the output CSV has the right columns @@ -779,10 +756,9 @@ def test_run_age_systems(dummy_interface, systems, features): assert all([os.path.exists(svg_path) for svg_path in svg_paths]) # Check existence of output CSV - csv_path = os.path.join(dummy_interface.dir_path, - "model_age/predicted_age_multisystem.csv") + csv_path = os.path.join(dummy_interface.dir_path, "model_age/predicted_age_multisystem.csv") assert os.path.exists(csv_path) - + # Check that the output CSV has the right columns df = pd.read_csv(csv_path, header=0, index_col=0) assert all(any(word in s for s in df.columns) for word in ["age", "predicted_age", "corrected_age", "delta"]) @@ -816,10 +792,9 @@ def test_run_age_systems_clinical(dummy_interface, systems, features, clinical): assert all([os.path.exists(svg_path) for svg_path in svg_paths]) # Check existence of output CSV - csv_path = os.path.join(dummy_interface.dir_path, - "model_age/predicted_age_multisystem.csv") + csv_path = os.path.join(dummy_interface.dir_path, "model_age/predicted_age_multisystem.csv") assert os.path.exists(csv_path) - + # Check that the output CSV has the right columns df = pd.read_csv(csv_path, header=0, index_col=0) assert all(any(word in s for s in df.columns) for word in ["age", "predicted_age", "corrected_age", "delta"]) @@ -857,10 +832,9 @@ def test_run_age_cov_and_systems(dummy_interface, systems, features, covariates) assert all([os.path.exists(svg_path) for svg_path in svg_paths]) # Check existence of output CSV - csv_path = os.path.join(dummy_interface.dir_path, - f"model_age/predicted_age_{dummy_interface.args.covar_name}_multisystem.csv") + csv_path = os.path.join(dummy_interface.dir_path, f"model_age/predicted_age_{dummy_interface.args.covar_name}_multisystem.csv") assert os.path.exists(csv_path) - + # Check that the output CSV has the right columns df = pd.read_csv(csv_path, header=0, index_col=0) assert all(any(word in s for s in df.columns) for word in ["age", "predicted_age", "corrected_age", "delta"]) @@ -901,10 +875,9 @@ def test_run_age_cov_and_systems_clinical(dummy_interface, systems, features, co assert all([os.path.exists(svg_path) for svg_path in svg_paths]) # Check existence of output CSV - csv_path = os.path.join(dummy_interface.dir_path, - f"model_age/predicted_age_{dummy_interface.args.covar_name}_multisystem.csv") + csv_path = os.path.join(dummy_interface.dir_path, f"model_age/predicted_age_{dummy_interface.args.covar_name}_multisystem.csv") assert os.path.exists(csv_path) - + # Check that the output CSV has the right columns df = pd.read_csv(csv_path, header=0, index_col=0) assert all(any(word in s for s in df.columns) for word in ["age", "predicted_age", "corrected_age", "delta"]) @@ -928,9 +901,7 @@ def test_run_factor_correlation(dummy_interface, ages, factors, covariates): # Check for the existence of the output figures figs = ["factors_vs_deltas_cn"] - svg_paths = [ - os.path.join(dummy_interface.dir_path, f"factor_correlation/figures/{fig}.png") for fig in figs - ] + svg_paths = [os.path.join(dummy_interface.dir_path, f"factor_correlation/figures/{fig}.png") for fig in figs] assert all([os.path.exists(svg_path) for svg_path in svg_paths]) # Check for the existence of the log @@ -952,9 +923,7 @@ def test_run_factor_correlation_systems(dummy_interface, ages_multisystem, facto # Check for the existence of the output figures figs = [] figs.append("factors_vs_deltas_cn") - svg_paths = [ - os.path.join(dummy_interface.dir_path, f"factor_correlation/figures/{fig}.png") for fig in figs - ] + svg_paths = [os.path.join(dummy_interface.dir_path, f"factor_correlation/figures/{fig}.png") for fig in figs] assert all([os.path.exists(svg_path) for svg_path in svg_paths]) # Check for the existence of the log @@ -992,9 +961,7 @@ def test_run_clinical(dummy_interface, ages, clinical, covariates): # Check for the existence of the output figures figs = ["age_distribution_clinical_groups", "clinical_groups_box_plot_all"] - svg_paths = [ - os.path.join(dummy_interface.dir_path, f"clinical_groups/figures/{fig}.png") for fig in figs - ] + svg_paths = [os.path.join(dummy_interface.dir_path, f"clinical_groups/figures/{fig}.png") for fig in figs] assert all([os.path.exists(svg_path) for svg_path in svg_paths]) # Check for the existence of the log @@ -1018,9 +985,7 @@ def test_run_clinical_systems(dummy_interface, ages_multisystem, clinical): figs = ["age_distribution_clinical_groups"] for system in system_names: figs.append(f"clinical_groups_box_plot_{system}") - svg_paths = [ - os.path.join(dummy_interface.dir_path, f"clinical_groups/figures/{fig}.png") for fig in figs - ] + svg_paths = [os.path.join(dummy_interface.dir_path, f"clinical_groups/figures/{fig}.png") for fig in figs] assert all([os.path.exists(svg_path) for svg_path in svg_paths]) # Check for the existence of the log @@ -1043,9 +1008,7 @@ def test_run_classification(dummy_interface, ages, clinical): # Check for the existence of the output figures figs = ["roc_curve_cn_vs_group1_all"] - svg_paths = [ - os.path.join(dummy_interface.dir_path, f"clinical_classify/figures/{fig}.png") for fig in figs - ] + svg_paths = [os.path.join(dummy_interface.dir_path, f"clinical_classify/figures/{fig}.png") for fig in figs] assert all([os.path.exists(svg_path) for svg_path in svg_paths]) # Check for the existence of the log @@ -1073,9 +1036,7 @@ def test_run_classification_systems(dummy_interface, ages_multisystem, clinical) figs = [] for system in system_names: figs.append(f"roc_curve_{dummy_interface.args.group1}_vs_{dummy_interface.args.group2}_{system}") - svg_paths = [ - os.path.join(dummy_interface.dir_path, f"clinical_classify/figures/{fig}.png") for fig in figs - ] + svg_paths = [os.path.join(dummy_interface.dir_path, f"clinical_classify/figures/{fig}.png") for fig in figs] assert all([os.path.exists(svg_path) for svg_path in svg_paths]) # Check for the existence of the log @@ -1121,9 +1082,9 @@ def test_classification_few_subjects(dummy_interface, ages, clinical): clinical_path = create_csv(clinical, dummy_interface.dir_path) dummy_interface.args.ages = ages_path dummy_interface.args.clinical = clinical_path - dummy_interface.args.group1 = 'cn' - dummy_interface.args.group2 = 'group1' - + dummy_interface.args.group1 = "cn" + dummy_interface.args.group2 = "group1" + # Run classification and capture error with pytest.raises(ValueError) as exc_info: dummy_interface.run_classification()