Skip to content

Commit

Permalink
[ENH] Add training of model with CN only
Browse files Browse the repository at this point in the history
  • Loading branch information
JGarciaCondado committed Nov 21, 2023
1 parent 5e4b55e commit a5b7bcd
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 13 deletions.
20 changes: 17 additions & 3 deletions src/ageml/modelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,18 +235,32 @@ def fit_age(self, X, y):
# Final model trained on all data
self.pipeline.fit(X, y)
self.pipelineFit = True
y_pred = self.pipeline.predict(X)
self.fit_age_bias(y, y_pred)

return pred_age, corrected_age

def predict_age(self, X):
def predict_age(self, X, y=None):
"""Predict age with fitted model.
Parameters:
-----------
X: 2D-Array with features; shape=(n,m)"""
X: 2D-Array with features; shape=(n,m)
y: 1D-Array with age; shape=n"""

# Check that model has previously been fit
if not self.pipelineFit:
raise ValueError("Must fit the pipline before calling predict.")
if y is not None and not self.age_biasFit:
raise ValueError("Must fit the age bias before calling predict with bias correction.")

# Predict age
y_pred = self.pipeline.predict(X)

# Apply age bias correction
if y is not None:
y_corrected = self.predict_age_bias(y, y_pred)
else:
y_corrected = y_pred

return self.pipeline.predict(X)
return y_pred, y_corrected
58 changes: 48 additions & 10 deletions src/ageml/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def __init__(self, args):
# Arguments with which to run modelling
self.args = args

# Flags
self.flags = {'CN': False}

# Set up directory for storage of results
self.setup()

Expand Down Expand Up @@ -134,13 +137,25 @@ def load_csv(self, file):
def load_data(self):
"""Load data from csv files."""

# Load data
# Load features
self.df_features = self.load_csv(self.args.features)
if 'age' not in self.df_features.columns:
raise KeyError("Features file must contain a column name 'age', or any other case-insensitive variation.")

# Load covariates
self.df_covariates = self.load_csv(self.args.covariates)

# Load factors
self.df_factors = self.load_csv(self.args.factors)

# Load clinical
self.df_clinical = self.load_csv(self.args.clinical)
if self.df_clinical is not None:
if 'cn' not in self.df_clinical.columns:
raise KeyError("Clinical file must contian a column name 'CN' or any other case-insensitive variation.")
else:
self.flags['CN'] = True
self.cn_subjects = self.df_clinical[self.df_clinical['cn']].index

# Remove subjects with missing features
self.subjects_missing_data = self.df_features[self.df_features.isnull().any(axis=1)].index.to_list()
Expand All @@ -162,11 +177,17 @@ def age_distribution(self):
def features_vs_age(self):
"""Use visualizer to explore relationship between features and age."""

# Select which dataframe to use
if self.flags['CN']:
df = self.df_features.loc[self.df_features.index.isin(self.cn_subjects)]
else:
df = self.df_features

# Select data to visualize
feature_names = [name for name in self.df_features.columns
if name != 'age']
X = self.df_features[feature_names].to_numpy()
Y = self.df_features['age'].to_numpy()
X = df[feature_names].to_numpy()
Y = df['age'].to_numpy()

# Use visualizer to show
self.visualizer.features_vs_age(X, Y, feature_names)
Expand All @@ -179,21 +200,38 @@ def model_age(self):
print('Training Age Model')
print(self.ageml.pipeline)

# Select which dataframe to use
if self.flags['CN']:
df_cn = self.df_features.loc[self.df_features.index.isin(self.cn_subjects)]
else:
df_cn = self.df_features

# Select data to model
feature_names = [name for name in self.df_features.columns
if name != 'age']
X = self.df_features[feature_names].to_numpy()
y = self.df_features['age'].to_numpy()
X_cn = df_cn[feature_names].to_numpy()
y_cn = df_cn['age'].to_numpy()

# Fit model and plot results
y_pred, y_corrected = self.ageml.fit_age(X, y)
self.visualizer.true_vs_pred_age(y, y_pred)
self.visualizer.age_bias_correction(y, y_pred, y_corrected)
y_cn_pred, y_cn_corrected = self.ageml.fit_age(X_cn, y_cn)
self.visualizer.true_vs_pred_age(y_cn, y_cn_pred)
self.visualizer.age_bias_correction(y_cn, y_cn_pred, y_cn_corrected)

# Save to dataframe and csv
data = np.stack((y, y_pred, y_corrected), axis=1)
data = np.stack((y_cn, y_cn_pred, y_cn_corrected), axis=1)
cols = ['Age', 'Predicted Age', 'Corrected Age']
self.df_age = pd.DataFrame(data, index=self.df_features.index, columns=cols)
self.df_age = pd.DataFrame(data, index=df_cn.index, columns=cols)

# Calculate for the rest of the subjects
if self.flags['CN']:
df_clin = self.df_features.loc[~self.df_features.index.isin(self.cn_subjects)]
X_clin = df_clin[feature_names].to_numpy()
y_clin = df_clin['age'].to_numpy()
y_clin_pred, y_clin_corrected = self.ageml.predict_age(X_clin, y_clin)
data = np.stack((y_clin, y_clin_pred, y_clin_corrected), axis=1)
self.df_age = pd.concat([self.df_age, pd.DataFrame(data, index=df_clin.index, columns=cols)])

# Save results
self.df_age.to_csv(os.path.join(self.dir_path, 'predicted_age.csv'))

@log
Expand Down

0 comments on commit a5b7bcd

Please sign in to comment.