Skip to content

Commit

Permalink
[ENH] Add Classification functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
JGarciaCondado committed Nov 29, 2023
1 parent 7a00ea8 commit ebcb24e
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 10 deletions.
6 changes: 6 additions & 0 deletions src/ageml/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@
"the third column should be the predicted age, fourth age is corrected age and last \n"
"column is the delta. The first row should be the header for column names."
)

groups_long_description = (
"Clinical groups to do classification on (Required: run classification). \n"
"Two groups are required. (e.g. --groups cn ad)"
)

# UI information

emblem = """
Expand Down
112 changes: 112 additions & 0 deletions src/ageml/modelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
Classes:
--------
AgeML - able to fit age models and predict age.
Classifier - classifier of class labels based on deltas.
"""

import numpy as np
import scipy.stats as st

# Sklearn and Scipy do not automatically load submodules (avoids overheads)
from scipy import stats
Expand Down Expand Up @@ -264,3 +266,113 @@ def predict_age(self, X, y=None):
y_corrected = y_pred

return y_pred, y_corrected


class Classifier:

"""Classifier of class labels based on deltas.
This class allows the differentiation of two groups based
on differences in their deltas based on a logistic regresor.
Public methods:
---------------
set_model(self): Sets the model to use in the pipeline.
fit_model(self, X, y): Fit the model.
"""

def __init__(self):
"""Initialise variables."""

# Set required modelling parts
self.set_model()

# Set default parameters
# TODO: let user choose this
self.CV_split = 5
self.seed = 0
self.thr = 0.5
self.ci_val = 0.95

# Initialise flags
self.modelFit = False

def set_model(self):
"""Sets the model to use in the pipeline."""

self.model = linear_model.LogisticRegression()

def fit_model(self, X, y):
"""Fit the model.
Parameters
----------
X: 2D-Array with features; shape=(n,m)
y: 1D-Array with labbels; shape=n"""

# Arrays to store values
accs, aucs, spes, sens = [], [], [], []
y_preds = np.empty(shape=y.shape)

kf = model_selection.KFold(n_splits=self.CV_split, shuffle=True, random_state=self.seed)
for train_index, test_index in kf.split(X):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]

# Fit the model using the training data
self.model.fit(X_train, y_train)

# Use model to predict probability of tests
y_pred = self.model.predict_proba(X_test)[::, 1]
y_preds[test_index] = y_pred

# Calculate AUC of model
auc = metrics.roc_auc_score(y_test, y_pred)
aucs.append(auc)

# Calculate relevant metrics
acc = metrics.accuracy_score(y_test, y_pred > self.thr)
tn, fp, fn, tp = metrics.confusion_matrix(y_test, y_pred > self.thr).ravel()
specificity = tn / (tn + fp)
sensitivity = tp / (tp + fp)
accs.append(acc)
sens.append(sensitivity)
spes.append(specificity)

# Compute confidence intervals
ci_accs = st.t.interval(alpha=self.ci_val, df=len(accs) - 1, loc=np.mean(accs), scale=st.sem(accs))
ci_aucs = st.t.interval(alpha=self.ci_val, df=len(aucs) - 1, loc=np.mean(aucs), scale=st.sem(aucs))
ci_sens = st.t.interval(alpha=self.ci_val, df=len(sens) - 1, loc=np.mean(sens), scale=st.sem(sens))
ci_spes = st.t.interval(alpha=self.ci_val, df=len(spes) - 1, loc=np.mean(spes), scale=st.sem(spes))

# Print results
print('Summary metrics over all CV splits (95% CI)')
print('AUC: %.3f [%.3f-%.3f]' % (np.mean(aucs), ci_aucs[0], ci_aucs[1]))
print('Accuracy: %.3f [%.3f-%.3f]' % (np.mean(accs), ci_accs[0], ci_accs[1]))
print('Sensitivity: %.3f [%.3f-%.3f]' % (np.mean(sens), ci_sens[0], ci_sens[1]))
print('Specificity: %.3f [%.3f-%.3f]' % (np.mean(spes), ci_spes[0], ci_spes[1]))

# Final model trained on all data
self.model.fit(X, y)

# Set flag
self.modelFit = True

return y_preds

def predict(self, X):
"""Predict class labels with fitted model.
Parameters:
-----------
X: 2D-Array with features; shape=(n,m)"""

# Check that model has previously been fit
if not self.modelFit:
raise ValueError("Must fit the pipline before calling predict.")

# Predict class labels
y_pred = self.model.predict_proba(X)[::, 1]

return y_pred
92 changes: 87 additions & 5 deletions src/ageml/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
--------
Interface - reads, parses and executes user commands.
CLI - reads and parsers user commands via command line.
InteractiveCLI - reads and parsers user commands via command line via an interactive interface.
"""

import argparse
Expand All @@ -22,7 +23,7 @@
import ageml.messages as messages
from ageml.visualizer import Visualizer
from ageml.utils import create_directory, feature_extractor, significant_markers, convert, log
from ageml.modelling import AgeML
from ageml.modelling import AgeML, Classifier
from ageml.processing import find_correlations


Expand All @@ -46,6 +47,8 @@ class Interface:
set_model(self): Set model with parameters.
set_classifier(self): Set classifier with parameters.
check_file(self, file): Check that file exists.
load_csv(self, file): Use panda to load csv into dataframe.
Expand All @@ -64,6 +67,8 @@ class Interface:
deltas_by_group(self, df, labels): Calculate summary metrics of deltas by group.
classify(self, df1, df2, groups): Classify two groups based on deltas.
run_wrapper(self, run): Wrapper for running modelling with log.
run_age(self): Run basic age modelling.
Expand All @@ -90,6 +95,7 @@ def __init__(self, args):
# Initialise objects form library
self.set_visualizer()
self.set_model()
self.set_classifier()

def setup(self):
"""Create required directories and files to store results."""
Expand Down Expand Up @@ -127,6 +133,11 @@ def set_model(self):
self.args.seed,
)

def set_classifier(self):
"""Set classifier with parameters."""

self.classifier = Classifier()

def check_file(self, file):
"""Check that file exists."""
if not os.path.exists(file):
Expand Down Expand Up @@ -476,6 +487,33 @@ def deltas_by_group(self, df, labels):
# Use visualizer
self.visualizer.deltas_by_groups(deltas, labels)

def classify(self, df1, df2, groups):
"""Classify two groups based on deltas.
Parameters
----------
df1: dataframe with delta information; shape=(n,m)
df2: dataframe with delta information; shape=(n,m)
groups: list of labels for each dataframe; shape=(2,)"""

# Classification
print("-----------------------------------")
print("Classification between groups %s and %s" % (groups[0], groups[1]))

# Select delta information
deltas1 = df1["delta"].to_numpy()
deltas2 = df2["delta"].to_numpy()

# Create X and y for classification
X = np.concatenate((deltas1, deltas2)).reshape(-1, 1)
y = np.concatenate((np.zeros(deltas1.shape), np.ones(deltas2.shape)))

# Calculate classification
y_pred = self.classifier.fit_model(X, y)

# Visualize AUC
self.visualizer.classification_auc(y, y_pred, groups)

@log
def run_wrapper(self, run):
"""Wrapper for running modelling with log."""
Expand Down Expand Up @@ -578,7 +616,33 @@ def run_classification(self):
"""Run classification between two different clinical groups."""

print("Running classification...")
pass

# Load data
self.load_data(required=["clinical"])

# Run age if not ages found
if self.df_ages is None:
print("No age data detected...")
print("-----------------------------------")
self.run_age()
print("-----------------------------------")
print("Resuming clinical outcomes...")

# Check that arguments given for each group
if self.args.group1 is None or self.args.group2 is None:
raise ValueError("Must provide two groups to classify.")

# Check that those groups exist
groups = [self.args.group1.lower(), self.args.group2.lower()]
if groups[0] not in self.df_clinical.columns or groups[1] not in self.df_clinical.columns:
raise ValueError("Classes must be one of the following: %s" % self.df_clinical.columns.to_list())

# Obtain dataframes for each clinical group
df_group1 = self.df_ages.loc[self.df_clinical[groups[0]]]
df_group2 = self.df_ages.loc[self.df_clinical[groups[1]]]

# Classify between groups
self.classify(df_group1, df_group2, groups)


class CLI(Interface):
Expand Down Expand Up @@ -679,6 +743,10 @@ def configure_parser(self):
"--ages", metavar="FILE", help=messages.ages_long_description
)

self.parser.add_argument(
"--groups", metavar="GROUP", nargs=2, help=messages.groups_long_description
)

def configure_args(self, args):
"""Configure argumens with required fromatting for modelling.
Expand Down Expand Up @@ -732,6 +800,12 @@ def configure_args(self, args):
else:
args.model_params = {}

# Set groups
if args.groups is not None:
args.group1, args.group2 = args.groups
else:
args.group1, args.group2 = None, None

return args


Expand Down Expand Up @@ -1080,9 +1154,15 @@ def run_command(self):
# Split into items and remove command
self.line = self.line.split()[1:]

# Check that only one argument input
if len(self.line) != 1:
error = "Must provide one argument only."
# Check that at least one argument given
if len(self.line) < 1:
error = "Must provide at least one argument."
return error
elif len(self.line) > 1 and self.line[0] in ['age', 'lifestyle', 'clinical']:
error = "Too many arguments given for run type %s" % self.line[0]
return error
elif len(self.line) != 3 and self.line[0] in ['classification']:
error = "For run type %s two arguments should be given" % self.line[0]
return error

# Run specificed modelling
Expand All @@ -1095,6 +1175,8 @@ def run_command(self):
self.run = self.run_clinical
elif case == "classification":
self.run = self.run_classification
self.args.group1 = self.line[1]
self.args.group2 = self.line[2]
else:
error = "Choose a valid run type: age, lifestyle, clinical, classification"

Expand Down
27 changes: 27 additions & 0 deletions src/ageml/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os

from sklearn.linear_model import LinearRegression
from sklearn.metrics import roc_curve, roc_auc_score

from .utils import insert_newlines, create_directory

Expand All @@ -38,7 +39,11 @@ class Visualizer:
age_bias_correction(self, y_true, y_pred, y_corrected): Plot before and after age bias correction procedure.
factors_vs_deltas(self, corrs, groups, labels, markers): Plot bar graph for correlation between factors and deltas.
deltas_by_groups(self, deltas, labels): Plot box plot for deltas in each group.
classification_auc(self, y, y_pred, groups): Plot ROC curve.
"""

def __init__(self, out_dir):
Expand Down Expand Up @@ -222,3 +227,25 @@ def deltas_by_groups(self, deltas, labels):
plt.ylabel("Delta")
plt.savefig(os.path.join(self.path_for_fig, "clinical_groups_box_plot.svg"))
plt.close()

def classification_auc(self, y, y_pred, groups):
"""Plot ROC curve.
Parameters
----------
y: 1D-Array with true labels; shape=n
y_pred: 1D-Array with predicted labels; shape=n"""

# Compute ROC curve and AUC
fpr, tpr, _ = roc_curve(y, y_pred)
auc = roc_auc_score(y, y_pred)

# Plot ROC curve
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC curve %s vs %s' % (groups[0], groups[1]))
plt.legend(loc="lower right")
plt.savefig(os.path.join(self.path_for_fig, "classification_auc_%s_vs_%s.svg" % (groups[0], groups[1])))
plt.close()
3 changes: 2 additions & 1 deletion tests/test_ageml/test_modelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def test_set_pipeline_none_model():
# TODO: test: metrics, summary_metrics, fit_age_bias, predict_age_bias, fit_age, predict_age
# TODO: check all errors raised


def test_fit_age():
pass

# TODO: test classifier fit_age
5 changes: 1 addition & 4 deletions tests/test_ageml/test_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,10 +686,7 @@ def test_run_command_interactiveCLI(dummy_cli):
# Test no input or mutiple arguments
dummy_cli.line = "r"
error = dummy_cli.run_command()
assert error == "Must provide one argument only."
dummy_cli.line = "r type1 type1"
error = dummy_cli.run_command()
assert error == "Must provide one argument only."
assert error == "Must provide at least one argument."

# Test passing invalid run type
dummy_cli.line = "r type1"
Expand Down

0 comments on commit ebcb24e

Please sign in to comment.