Skip to content

Commit

Permalink
BUG: Fixed uneven dataframe creation
Browse files Browse the repository at this point in the history
Now common indices are computed as shown in issue #55, which is more robust
  • Loading branch information
itellaetxe committed Oct 9, 2024
1 parent fece052 commit 49902a3
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 118 deletions.
31 changes: 16 additions & 15 deletions src/ageml/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,22 +611,23 @@ def remove_missing_data(self):
warnings.warn(warn_message, category=UserWarning)
dfs[label] = df.drop(missing_subjects)

# Check that all dataframes have the same subjects
# Compute the intersection of the indices of the dataframes
indices_collection = [set(df.index) for df in dfs.values()]
shared_idx = list(indices_collection[0].intersection(*indices_collection[1:]))
# Remove subjects not shared among dataframes, report subjects left for analysis, and set dataframes
print("Removing subjects not shared among dataframes...")
for l1, df1 in dfs.items():
for l2, df2 in dfs.items():
if l1 != l2:
non_shared_idx = df1.index[~df1.index.isin(set(df2.index.to_list()))]
if non_shared_idx.__len__() != 0:
warn_message = "Subjects in dataframe %s not in dataframe %s: %s" % (l1, l2, non_shared_idx.to_list())
print(warn_message)
warnings.warn(warn_message, category=UserWarning)
dfs[l1] = df1.drop(non_shared_idx)

# Set dataframes
for label, df in dfs.items():
print("Final number of subjects in dataframe %s: %d (%.2f %% of initial)" % (label, len(df), len(df) / init_count[label] * 100))
setattr(self, f"df_{label}", df)
for label in dfs.keys():
removed_subjects = set(dfs[label].index) - set(shared_idx)
warn_message = f"{len(removed_subjects)} subjects removed from {label} dataframe: {removed_subjects}"
dfs[label] = dfs[label].loc[shared_idx]
print(warn_message)
warnings.warn(warn_message, category=UserWarning)
setattr(self, f"df_{label}", dfs[label])
msg = (
f"Final number of subjects in dataframe {label}: {len(shared_idx)} "
f"({len(shared_idx) / init_count[label] * 100:.2f} % of initial)"
)
print(msg)

def load_data(self, required=None):
"""Load data from csv files.
Expand Down
167 changes: 64 additions & 103 deletions tests/test_ageml/test_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,8 @@ def factors():
def covariates():
df = pd.DataFrame(
{
"id": [1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
"sex": [0, 1, 1, 1, 0, 0, 0, 1, 1, 0,
0, 0, 1, 1, 0, 0, 0, 1, 1, 0],
"id": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
"sex": [0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0],
}
)
df.set_index("id", inplace=True)
Expand All @@ -235,12 +233,9 @@ def systems():
def clinical():
df = pd.DataFrame(
{
"id": [1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
"CN": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
"group1": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
"id": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
"CN": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
"group1": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
}
)
df.set_index("id", inplace=True)
Expand All @@ -251,16 +246,11 @@ def clinical():
def ages():
df = pd.DataFrame(
{
"id": [1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
"age": [50, 55, 60, 65, 70, 75, 80, 85, 90, 57,
53, 57, 61, 65, 69, 73, 77, 81, 85, 89],
"predicted_age_all": [55, 67, 57, 75, 85, 64, 87, 93, 49, 51,
58, 73, 80, 89, 55, 67, 57, 75, 85, 64],
"corrected_age_all": [51, 58, 73, 80, 89, 67, 57, 75, 85, 64,
87, 93, 49, 55, 67, 57, 75, 85, 64, 87],
"delta_all": [1, -2, 3, 0, -1, 2, 1, 0, -3, 1,
2, 1, 0, -1, 2, 1, 0, -3, 1, 2],
"id": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
"age": [50, 55, 60, 65, 70, 75, 80, 85, 90, 57, 53, 57, 61, 65, 69, 73, 77, 81, 85, 89],
"predicted_age_all": [55, 67, 57, 75, 85, 64, 87, 93, 49, 51, 58, 73, 80, 89, 55, 67, 57, 75, 85, 64],
"corrected_age_all": [51, 58, 73, 80, 89, 67, 57, 75, 85, 64, 87, 93, 49, 55, 67, 57, 75, 85, 64, 87],
"delta_all": [1, -2, 3, 0, -1, 2, 1, 0, -3, 1, 2, 1, 0, -1, 2, 1, 0, -3, 1, 2],
}
)
df.set_index("id", inplace=True)
Expand All @@ -271,22 +261,14 @@ def ages():
def ages_multisystem():
df = pd.DataFrame(
{
"id": [1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
"age": [50, 55, 60, 65, 70, 75, 80, 85, 90, 57,
53, 57, 61, 65, 69, 73, 77, 81, 85, 89],
"predicted_age_pottongosystem": [55, 67, 57, 75, 85, 64, 87, 93, 49, 51,
58, 73, 80, 89, 55, 67, 57, 75, 85, 64],
"corrected_age_pottongosystem": [51, 58, 73, 80, 89, 67, 57, 75, 85, 64,
87, 93, 49, 55, 67, 57, 75, 85, 64, 87],
"delta_pottongosystem": [1, -2, 3, 0, -1, 2, 1, 0, -3, 1,
2, 1, 0, -1, 2, 1, 0, -3, 1, 2],
"predicted_age_mondongsystem": [55, 67, 57, 75, 85, 64, 87, 93, 49, 51,
58, 73, 80, 89, 55, 67, 57, 75, 85, 64],
"corrected_age_mondongsystem": [51, 58, 73, 80, 89, 67, 57, 75, 85, 64,
87, 93, 49, 55, 67, 57, 75, 85, 64, 87],
"delta_mondongsystem": [1, -2, 3, 0, -1, 2, 1, 0, -3, 1,
2, 1, 0, -1, 2, 1, 0, -3, 1, 2],
"id": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
"age": [50, 55, 60, 65, 70, 75, 80, 85, 90, 57, 53, 57, 61, 65, 69, 73, 77, 81, 85, 89],
"predicted_age_pottongosystem": [55, 67, 57, 75, 85, 64, 87, 93, 49, 51, 58, 73, 80, 89, 55, 67, 57, 75, 85, 64],
"corrected_age_pottongosystem": [51, 58, 73, 80, 89, 67, 57, 75, 85, 64, 87, 93, 49, 55, 67, 57, 75, 85, 64, 87],
"delta_pottongosystem": [1, -2, 3, 0, -1, 2, 1, 0, -3, 1, 2, 1, 0, -1, 2, 1, 0, -3, 1, 2],
"predicted_age_mondongsystem": [55, 67, 57, 75, 85, 64, 87, 93, 49, 51, 58, 73, 80, 89, 55, 67, 57, 75, 85, 64],
"corrected_age_mondongsystem": [51, 58, 73, 80, 89, 67, 57, 75, 85, 64, 87, 93, 49, 55, 67, 57, 75, 85, 64, 87],
"delta_mondongsystem": [1, -2, 3, 0, -1, 2, 1, 0, -3, 1, 2, 1, 0, -1, 2, 1, 0, -3, 1, 2],
}
)
df.set_index("id", inplace=True)
Expand Down Expand Up @@ -433,7 +415,7 @@ def test_load_factors_not_float(dummy_interface, factors):

def test_load_data_covariates_not_float(dummy_interface, covariates):
# Change item to string
covariates.loc[2, 'sex'] = 'asdf'
covariates.loc[2, "sex"] = "asdf"
covariates_path = create_csv(covariates, dummy_interface.dir_path)
dummy_interface.args.covariates = covariates_path

Expand Down Expand Up @@ -542,7 +524,6 @@ def test_load_data_clinical_empty_column(dummy_interface, clinical):


def test_load_data_clinical_empty_row(dummy_interface, clinical):

# Make a row all False
clinical.loc[2, :] = 0
clinical_path = create_csv(clinical, dummy_interface.dir_path)
Expand Down Expand Up @@ -571,8 +552,10 @@ def test_load_data_nan_values_warning(dummy_interface, features):

def test_load_data_different_indexes_warning(dummy_interface, features, clinical):
# Drop subjects 2 and 3 from features
features.drop([2, 3], axis=0, inplace=True)
clinical.drop([4], axis=0, inplace=True)
drop_features = [2, 3]
drop_clinical = [4]
features.drop(drop_features, axis=0, inplace=True)
clinical.drop(drop_clinical, axis=0, inplace=True)
features_path = create_csv(features, dummy_interface.dir_path)
clinical_path = create_csv(clinical, dummy_interface.dir_path)
dummy_interface.args.features = features_path
Expand All @@ -581,23 +564,19 @@ def test_load_data_different_indexes_warning(dummy_interface, features, clinical
with pytest.warns(UserWarning) as warn_record:
dummy_interface.load_data()
assert isinstance(warn_record.list[0].message, UserWarning)
expected = "Subjects in dataframe features not in dataframe clinical: [%d]" % (4)
expected = f"1 subjects removed from features dataframe: {set(drop_clinical)}"
assert warn_record.list[0].message.args[0] == expected
assert isinstance(warn_record.list[1].message, UserWarning)
expected = "Subjects in dataframe clinical not in dataframe features: [%d, %d]" % (
2,
3,
)
expected = f"2 subjects removed from clinical dataframe: {set(drop_features)}"
assert warn_record.list[1].message.args[0] == expected


def test_age_distribution_warning(dummy_interface):

# Create different distributions
dist1 = np.random.normal(loc=50, scale=1, size=100)
dist2 = np.random.normal(loc=0, scale=1, size=100)
dists = {'dist1': dist1, 'dist2': dist2}
dists = {"dist1": dist1, "dist2": dist2}

# Set visualized in temporal directory
dummy_interface.set_visualizer(tempfile.mkdtemp())

Expand All @@ -624,9 +603,7 @@ def test_run_age(dummy_interface, features):
"features_vs_age_controls_all",
"chronological_vs_pred_age_all_all",
]
svg_paths = [
os.path.join(dummy_interface.dir_path, f"model_age/figures/{fig}.png") for fig in figs
]
svg_paths = [os.path.join(dummy_interface.dir_path, f"model_age/figures/{fig}.png") for fig in figs]
print(os.listdir(os.path.join(dummy_interface.dir_path, "model_age/figures")))
assert all([os.path.exists(svg_path) for svg_path in svg_paths])

Expand Down Expand Up @@ -662,9 +639,7 @@ def test_run_age_clinical(dummy_interface, features, clinical):
"features_vs_age_controls_all",
"chronological_vs_pred_age_all_all",
]
svg_paths = [
os.path.join(dummy_interface.dir_path, f"model_age/figures/{fig}.png") for fig in figs
]
svg_paths = [os.path.join(dummy_interface.dir_path, f"model_age/figures/{fig}.png") for fig in figs]
assert all([os.path.exists(svg_path) for svg_path in svg_paths])

# Check for the existence of the log
Expand Down Expand Up @@ -696,19 +671,20 @@ def test_run_age_cov(dummy_interface, features, covariates):
assert os.path.exists(dummy_interface.dir_path)

# Check for output figs
figs = ["age_bias_correction_sex_0_all",
"age_bias_correction_sex_1_all",
"chronological_vs_pred_age_sex_0_all",
"chronological_vs_pred_age_sex_1_all",
"age_distribution_controls",
"features_vs_age_controls_all"]
figs = [
"age_bias_correction_sex_0_all",
"age_bias_correction_sex_1_all",
"chronological_vs_pred_age_sex_0_all",
"chronological_vs_pred_age_sex_1_all",
"age_distribution_controls",
"features_vs_age_controls_all",
]
# Print files in path
svg_paths = [os.path.join(dummy_interface.dir_path, f"model_age/figures/{fig}.png") for fig in figs]
assert all([os.path.exists(svg_path) for svg_path in svg_paths])

# Check for the existence of the output CSV
csv_path = os.path.join(dummy_interface.dir_path,
f"model_age/predicted_age_{dummy_interface.args.covar_name}.csv")
csv_path = os.path.join(dummy_interface.dir_path, f"model_age/predicted_age_{dummy_interface.args.covar_name}.csv")
assert os.path.exists(csv_path)

# Check that the output CSV has the right columns
Expand All @@ -735,18 +711,19 @@ def test_run_age_cov_clinical(dummy_interface, features, covariates, clinical):
assert os.path.exists(dummy_interface.dir_path)

# Check for output figs
figs = ["age_bias_correction_sex_0_all",
"age_bias_correction_sex_1_all",
"chronological_vs_pred_age_sex_0_all",
"chronological_vs_pred_age_sex_1_all",
"age_distribution_controls",
"features_vs_age_controls_all"]
figs = [
"age_bias_correction_sex_0_all",
"age_bias_correction_sex_1_all",
"chronological_vs_pred_age_sex_0_all",
"chronological_vs_pred_age_sex_1_all",
"age_distribution_controls",
"features_vs_age_controls_all",
]
svg_paths = [os.path.join(dummy_interface.dir_path, f"model_age/figures/{fig}.png") for fig in figs]
assert all([os.path.exists(svg_path) for svg_path in svg_paths])

# Check for the existence of the output CSV
csv_path = os.path.join(dummy_interface.dir_path,
f"model_age/predicted_age_{dummy_interface.args.covar_name}.csv")
csv_path = os.path.join(dummy_interface.dir_path, f"model_age/predicted_age_{dummy_interface.args.covar_name}.csv")
assert os.path.exists(csv_path)

# Check that the output CSV has the right columns
Expand Down Expand Up @@ -779,10 +756,9 @@ def test_run_age_systems(dummy_interface, systems, features):
assert all([os.path.exists(svg_path) for svg_path in svg_paths])

# Check existence of output CSV
csv_path = os.path.join(dummy_interface.dir_path,
"model_age/predicted_age_multisystem.csv")
csv_path = os.path.join(dummy_interface.dir_path, "model_age/predicted_age_multisystem.csv")
assert os.path.exists(csv_path)

# Check that the output CSV has the right columns
df = pd.read_csv(csv_path, header=0, index_col=0)
assert all(any(word in s for s in df.columns) for word in ["age", "predicted_age", "corrected_age", "delta"])
Expand Down Expand Up @@ -816,10 +792,9 @@ def test_run_age_systems_clinical(dummy_interface, systems, features, clinical):
assert all([os.path.exists(svg_path) for svg_path in svg_paths])

# Check existence of output CSV
csv_path = os.path.join(dummy_interface.dir_path,
"model_age/predicted_age_multisystem.csv")
csv_path = os.path.join(dummy_interface.dir_path, "model_age/predicted_age_multisystem.csv")
assert os.path.exists(csv_path)

# Check that the output CSV has the right columns
df = pd.read_csv(csv_path, header=0, index_col=0)
assert all(any(word in s for s in df.columns) for word in ["age", "predicted_age", "corrected_age", "delta"])
Expand Down Expand Up @@ -857,10 +832,9 @@ def test_run_age_cov_and_systems(dummy_interface, systems, features, covariates)
assert all([os.path.exists(svg_path) for svg_path in svg_paths])

# Check existence of output CSV
csv_path = os.path.join(dummy_interface.dir_path,
f"model_age/predicted_age_{dummy_interface.args.covar_name}_multisystem.csv")
csv_path = os.path.join(dummy_interface.dir_path, f"model_age/predicted_age_{dummy_interface.args.covar_name}_multisystem.csv")
assert os.path.exists(csv_path)

# Check that the output CSV has the right columns
df = pd.read_csv(csv_path, header=0, index_col=0)
assert all(any(word in s for s in df.columns) for word in ["age", "predicted_age", "corrected_age", "delta"])
Expand Down Expand Up @@ -901,10 +875,9 @@ def test_run_age_cov_and_systems_clinical(dummy_interface, systems, features, co
assert all([os.path.exists(svg_path) for svg_path in svg_paths])

# Check existence of output CSV
csv_path = os.path.join(dummy_interface.dir_path,
f"model_age/predicted_age_{dummy_interface.args.covar_name}_multisystem.csv")
csv_path = os.path.join(dummy_interface.dir_path, f"model_age/predicted_age_{dummy_interface.args.covar_name}_multisystem.csv")
assert os.path.exists(csv_path)

# Check that the output CSV has the right columns
df = pd.read_csv(csv_path, header=0, index_col=0)
assert all(any(word in s for s in df.columns) for word in ["age", "predicted_age", "corrected_age", "delta"])
Expand All @@ -928,9 +901,7 @@ def test_run_factor_correlation(dummy_interface, ages, factors, covariates):

# Check for the existence of the output figures
figs = ["factors_vs_deltas_cn"]
svg_paths = [
os.path.join(dummy_interface.dir_path, f"factor_correlation/figures/{fig}.png") for fig in figs
]
svg_paths = [os.path.join(dummy_interface.dir_path, f"factor_correlation/figures/{fig}.png") for fig in figs]
assert all([os.path.exists(svg_path) for svg_path in svg_paths])

# Check for the existence of the log
Expand All @@ -952,9 +923,7 @@ def test_run_factor_correlation_systems(dummy_interface, ages_multisystem, facto
# Check for the existence of the output figures
figs = []
figs.append("factors_vs_deltas_cn")
svg_paths = [
os.path.join(dummy_interface.dir_path, f"factor_correlation/figures/{fig}.png") for fig in figs
]
svg_paths = [os.path.join(dummy_interface.dir_path, f"factor_correlation/figures/{fig}.png") for fig in figs]
assert all([os.path.exists(svg_path) for svg_path in svg_paths])

# Check for the existence of the log
Expand Down Expand Up @@ -992,9 +961,7 @@ def test_run_clinical(dummy_interface, ages, clinical, covariates):

# Check for the existence of the output figures
figs = ["age_distribution_clinical_groups", "clinical_groups_box_plot_all"]
svg_paths = [
os.path.join(dummy_interface.dir_path, f"clinical_groups/figures/{fig}.png") for fig in figs
]
svg_paths = [os.path.join(dummy_interface.dir_path, f"clinical_groups/figures/{fig}.png") for fig in figs]
assert all([os.path.exists(svg_path) for svg_path in svg_paths])

# Check for the existence of the log
Expand All @@ -1018,9 +985,7 @@ def test_run_clinical_systems(dummy_interface, ages_multisystem, clinical):
figs = ["age_distribution_clinical_groups"]
for system in system_names:
figs.append(f"clinical_groups_box_plot_{system}")
svg_paths = [
os.path.join(dummy_interface.dir_path, f"clinical_groups/figures/{fig}.png") for fig in figs
]
svg_paths = [os.path.join(dummy_interface.dir_path, f"clinical_groups/figures/{fig}.png") for fig in figs]
assert all([os.path.exists(svg_path) for svg_path in svg_paths])

# Check for the existence of the log
Expand All @@ -1043,9 +1008,7 @@ def test_run_classification(dummy_interface, ages, clinical):

# Check for the existence of the output figures
figs = ["roc_curve_cn_vs_group1_all"]
svg_paths = [
os.path.join(dummy_interface.dir_path, f"clinical_classify/figures/{fig}.png") for fig in figs
]
svg_paths = [os.path.join(dummy_interface.dir_path, f"clinical_classify/figures/{fig}.png") for fig in figs]
assert all([os.path.exists(svg_path) for svg_path in svg_paths])

# Check for the existence of the log
Expand Down Expand Up @@ -1073,9 +1036,7 @@ def test_run_classification_systems(dummy_interface, ages_multisystem, clinical)
figs = []
for system in system_names:
figs.append(f"roc_curve_{dummy_interface.args.group1}_vs_{dummy_interface.args.group2}_{system}")
svg_paths = [
os.path.join(dummy_interface.dir_path, f"clinical_classify/figures/{fig}.png") for fig in figs
]
svg_paths = [os.path.join(dummy_interface.dir_path, f"clinical_classify/figures/{fig}.png") for fig in figs]
assert all([os.path.exists(svg_path) for svg_path in svg_paths])

# Check for the existence of the log
Expand Down Expand Up @@ -1121,9 +1082,9 @@ def test_classification_few_subjects(dummy_interface, ages, clinical):
clinical_path = create_csv(clinical, dummy_interface.dir_path)
dummy_interface.args.ages = ages_path
dummy_interface.args.clinical = clinical_path
dummy_interface.args.group1 = 'cn'
dummy_interface.args.group2 = 'group1'
dummy_interface.args.group1 = "cn"
dummy_interface.args.group2 = "group1"

# Run classification and capture error
with pytest.raises(ValueError) as exc_info:
dummy_interface.run_classification()
Expand Down

0 comments on commit 49902a3

Please sign in to comment.