Skip to content

Commit

Permalink
[WIP+ENH] Working on gender separation, needs testing and harmonization.
Browse files Browse the repository at this point in the history
  • Loading branch information
itellaetxe committed Nov 29, 2023
1 parent 0acf844 commit 24e0815
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 16 deletions.
6 changes: 4 additions & 2 deletions src/ageml/__main__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import ageml
import sys


def main():
"""Choose between interactive command line or command line
"""Choose between interactive command line or command line
based on wether there are no flags when running script"""

if len(sys.argv) > 1:
ageml.ui.CLI()
else:
ageml.ui.InteractiveCLI()


if __name__ == '__main__':
main()
main()
32 changes: 24 additions & 8 deletions src/ageml/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def features_vs_age(self, dfs, labels: list = None, significance: float = 0.05,
# Use visualizer to show
self.visualizer.features_vs_age(X, y, corr, order, significant, feature_names, indices, labels)

def model_age(self, df, model):
def model_age(self, df, model, name: str = ""):
"""Use AgeML to fit age model with data.
Parameters
Expand All @@ -375,16 +375,19 @@ def model_age(self, df, model):

# Show training pipeline
print("-----------------------------------")
print("Training Age Model")
if name == "":
print("Training Age Model")
else:
print("Training Model for covariate %s" % name)
print(self.ageml.pipeline)

# Select data to model
X, y, _ = feature_extractor(df)

# Fit model and plot results
y_pred, y_corrected = model.fit_age(X, y)
self.visualizer.true_vs_pred_age(y, y_pred)
self.visualizer.age_bias_correction(y, y_pred, y_corrected)
self.visualizer.true_vs_pred_age(y, y_pred, name)
self.visualizer.age_bias_correction(y, y_pred, y_corrected, name)

# Calculate deltas
deltas = y_corrected - y
Expand Down Expand Up @@ -537,12 +540,21 @@ def run_age(self):
dfs_covars = [covar_df_dict[category] for category in categories]
# Relationship between features and age
self.features_vs_age(dfs_covars, labels=categories, name="covariates")

# Model age for each covariate. # TODO: Plot for each model? Or do the same as in features_vs_age?
self.covar_ageml = {}
dfs_ages_covar = {}
for category, df in zip(categories, dfs_covars):
model_name = f"{self.args.covar_name}_{category}"
self.covar_ageml[model_name], dfs_ages_covar[model_name] = self.model_age(df, self.ageml, category)

# Concatenate all dfs in dfs_ages_covar
df_ages_cn = pd.concat(dfs_ages_covar.values(), axis=0)
else:
# If no covariates found, do not separate data
self.features_vs_age(df_cn)

# Model age
self.ageml, df_ages_cn = self.model_age(df_cn, self.ageml)
# Model age
self.ageml, df_ages_cn = self.model_age(df_cn, self.ageml)

# Apply to clinical data
if self.flags["clinical"]:
Expand All @@ -553,7 +565,11 @@ def run_age(self):
self.df_ages = df_ages_cn

# Save dataframe
self.df_ages.to_csv(os.path.join(self.dir_path, "predicted_age.csv"))
if self.flags["covariates"]:
filename = f"predicted_age_{self.args.covar_name}.csv"
else:
filename = "predicted_age.csv"
self.df_ages.to_csv(os.path.join(self.dir_path, filename))

def run_lifestyle(self):
"""Run age modelling with lifestyle factors."""
Expand Down
23 changes: 17 additions & 6 deletions src/ageml/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from .utils import insert_newlines, create_directory

plt.rcParams.update({'font.size': 8})


class Visualizer:

Expand Down Expand Up @@ -119,7 +121,7 @@ def features_vs_age(self, X, Y, corr, order, markers, feature_names, idxs: list
plt.savefig(os.path.join(self.path_for_fig, "features_vs_age.svg"))
plt.close()

def true_vs_pred_age(self, y_true, y_pred):
def true_vs_pred_age(self, y_true, y_pred, name: str = ""):
"""Plot true age vs predicted age.
Parameters
Expand All @@ -133,12 +135,17 @@ def true_vs_pred_age(self, y_true, y_pred):
# Plot true vs predicted age
plt.scatter(y_true, y_pred)
plt.plot(age_range, age_range, color="k", linestyle="dashed")
plt.title(f"True vs Predicted Age {name}")
plt.xlabel("True Age")
plt.ylabel("Predicted Age")
plt.savefig(os.path.join(self.path_for_fig, "true_vs_pred_age.svg"))
if name == "":
filename = "true_vs_pred_age.svg"
else:
filename = f"true_vs_pred_age_{name}.svg"
plt.savefig(os.path.join(self.path_for_fig, filename))
plt.close()

def age_bias_correction(self, y_true, y_pred, y_corrected):
def age_bias_correction(self, y_true, y_pred, y_corrected, name: str = ""):
"""Plot before and after age bias correction procedure.
Parameters
Expand All @@ -159,7 +166,7 @@ def age_bias_correction(self, y_true, y_pred, y_corrected):
plt.plot(age_range, age_range, color="k", linestyle="dashed")
plt.plot(age_range, LR_age_bias.predict(age_range.reshape(-1, 1)), color="r")
plt.scatter(y_true, y_pred)
plt.title("Before age-bias correction")
plt.title(f"Before age-bias correction {name}")
plt.ylabel("Predicted Age")
plt.xlabel("True Age")

Expand All @@ -169,11 +176,15 @@ def age_bias_correction(self, y_true, y_pred, y_corrected):
plt.plot(age_range, age_range, color="k", linestyle="dashed")
plt.plot(age_range, LR_age_bias.predict(age_range.reshape(-1, 1)), color="r")
plt.scatter(y_true, y_corrected)
plt.title("After age-bias correction")
plt.title(f"After age-bias correction {name}")
plt.ylabel("Predicted Age")
plt.xlabel("True Age")
plt.tight_layout()
plt.savefig(os.path.join(self.path_for_fig, "age_bias_correction.svg"))
if name == "":
filename = "age_bias_correction.svg"
else:
filename = f"age_bias_correction_{name}.svg"
plt.savefig(os.path.join(self.path_for_fig, filename))
plt.close()

def factors_vs_deltas(self, corrs, groups, labels, markers):
Expand Down

0 comments on commit 24e0815

Please sign in to comment.