Skip to content

Commit

Permalink
Merge pull request #118 from jyaacoub/v115-aflow_subset
Browse files Browse the repository at this point in the history
V115 aflow subset
  • Loading branch information
jyaacoub authored Jul 10, 2024
2 parents 4537407 + e9e6b58 commit 2e2449f
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,4 @@ results/model_media/*/test_set_pred/*
results/model_media/test_set_pred

splits/**/*.csv
results/v113/model_media/*/train_log/*.json
results/*/model_media/*/train_log/*.json
15 changes: 8 additions & 7 deletions playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from matplotlib import pyplot as plt
from statannotations.Annotator import Annotator

from src.analysis.figures import prepare_df, custom_fig, fig_combined
from src.analysis.figures import prepare_df, fig_combined, custom_fig

dft = prepare_df('./results/v115/model_media/model_stats.csv')
dfv = prepare_df('./results/v115/model_media/model_stats_val.csv')

models = {
'DG': ('nomsa', 'binary', 'original', 'binary'),
Expand All @@ -24,16 +27,14 @@
# 'gvpl_esm_aflow': ('ESM', 'aflow', 'gvp', 'binary'),
}

df = prepare_df('/cluster/home/t122995uhn/projects/MutDTA/results/v113/model_media/model_stats.csv')
fig, axes = fig_combined(df, datasets=['davis'], fig_callable=custom_fig,
fig, axes = fig_combined(dft, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" test set performance")
fig_scale=(10,5), add_stats=True, title_postfix=" test set performance", box=True, fold_labels=True)
plt.xticks(rotation=45)

df = prepare_df('/cluster/home/t122995uhn/projects/MutDTA/results/v113/model_media/model_stats_val.csv')
fig, axes = fig_combined(df, datasets=['davis'], fig_callable=custom_fig,
fig, axes = fig_combined(dfv, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance")
fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance", box=True, fold_labels=True)
plt.xticks(rotation=45)


Expand Down
21 changes: 21 additions & 0 deletions results/v115/model_media/model_stats.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
run,cindex,pearson,spearman,mse,mae,rmse
DGM_davis0D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.7604957959117081,0.5074166361960291,0.4880557945875629,0.7899633560405003,0.5045533987507759,0.8887988276547738
DGM_davis2D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.7766413950490045,0.4809036117056705,0.5211140501386226,0.8209668727190276,0.5056958898068159,0.906072222683726
DGM_davis3D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.7487323569555477,0.4373251804322002,0.4659790811104093,0.8641685761814115,0.5812529990799629,0.929606678214723
DGM_davis1D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.7882312387834606,0.5814898274925953,0.539606500881958,0.6954483888899063,0.4763443782989743,0.8339354824504749
DGM_davis4D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.770028362176337,0.485165974095078,0.5055126264655252,0.8178595017482873,0.5139869788455826,0.904355849070645
GVPLM_davis0D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7482588250040787,0.4191458736488045,0.4654648560339079,0.916438806542663,0.5086004597213724,0.9573081042917494
GVPLM_davis1D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.6468906559014417,0.2941056073867227,0.2779337470015688,1.0269896906785398,0.5577064489189487,1.0134049983489029
GVPLM_davis3D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7471535855123098,0.441710216665843,0.4613052477455492,0.8885834379930146,0.5305747291450008,0.9426470378635976
GVPLM_davis2D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7296485716445884,0.4054334215308577,0.4335391243636884,0.9631544322037824,0.5330878745876734,0.9814043163771914
GVPLM_davis4D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7766749700560678,0.49090463554510744,0.517758159045675,0.8430872429869354,0.5006691044680186,0.9181978234492475
DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.6936121435080361,0.3264373697380284,0.3623624786171266,0.955541722461442,0.5866414782306556,0.9775181443131592
DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.6954896081622584,0.3276610988554868,0.3660103054199092,0.9501814404206184,0.5941079377785988,0.9747725070090038
DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.6958698762052185,0.3287337103560256,0.3663455614887089,0.9508110868966404,0.5924060765434714,0.9750954245081044
DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.6921726461681715,0.328433593374192,0.3581768615124882,0.9489810461762982,0.598616964334737,0.974156581960158
DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.691988107981202,0.3206628430008865,0.3595820594502304,0.9616201903122608,0.5948554296233561,0.9806223484666564
GVPLM_davis0D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7120142318237347,0.3830589592680732,0.400661326698907,0.9422431676298167,0.5708391885333287,0.9706921075345244
GVPLM_davis2D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.6758544466241947,0.3482148623183911,0.3313455573966994,0.9952058429887072,0.5672084597638211,0.9976000415941788
GVPLM_davis1D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7715096417472135,0.5070999916573449,0.505118493158149,0.8228862016569322,0.521357725135223,0.90713075223858
GVPLM_davis3D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7784890152525039,0.5548561075227736,0.519896821064786,0.7201384467294399,0.534564816284727,0.8486097140201967
GVPLM_davis4D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7591851271174638,0.4580465758811774,0.49007259211942494,0.8589132362281114,0.5267642440549613,0.9267757205646421
21 changes: 21 additions & 0 deletions results/v115/model_media/model_stats_val.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
run,cindex,pearson,spearman,mse,mae,rmse
DGM_davis0D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8234270583289791,0.713459681179556,0.5820109820998994,0.3866681321000501,0.3584109226012612,0.621826448536929
DGM_davis2D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8237245738420456,0.6651808918664764,0.5909037786837793,0.3846215914244413,0.3463886419507059,0.6201786770152948
DGM_davis3D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.7906653075036156,0.6441159371983189,0.5554869523309671,0.4952404807604716,0.421062613040014,0.7037332454563104
DGM_davis1D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.7980215958223652,0.6639328311818086,0.5425833351376672,0.4060308122972151,0.3693855847505962,0.6372054710195253
DGM_davis4D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8156238576892975,0.6549994289099051,0.5710358975674463,0.4207196688631731,0.3768108406693877,0.6486290687775048
GVPLM_davis0D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.8008002702925682,0.6882770547003477,0.5444236226040993,0.4178925669282142,0.3592081186818142,0.6464461051999727
GVPLM_davis1D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7867644660140576,0.648368021836038,0.5216309719594355,0.4187849589651989,0.3629059747737996,0.6471359663665736
GVPLM_davis3D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.787998863216255,0.6496858069147042,0.5512058721043596,0.514104358724952,0.3910173775713428,0.7170107103279225
GVPLM_davis2D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7967836291150775,0.6451761145051786,0.5463163388068719,0.4095255233260228,0.3658814668230407,0.6399418124533064
GVPLM_davis4D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.8219808050688133,0.6626969205508393,0.5802516690923312,0.40973373000931657,0.367932745329649,0.6401044680435504
DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7523695779428693,0.4682407034650613,0.4649581513699286,0.6092224587753515,0.4834460555550886,0.7805270391058541
DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7762332212240666,0.5470881337730895,0.5112932834407308,0.4740730581497179,0.4375338560757152,0.688529634910305
DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7570127033757169,0.5114232635236531,0.4736185104310603,0.5234106486856192,0.4567883369677207,0.723471249384258
DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7801250458793092,0.5321007699350608,0.5131027258344268,0.5202149722288628,0.4510516971865327,0.7212592961126136
DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7450873842292931,0.4893503545803986,0.4778696104467196,0.6419698895173639,0.490655592843598,0.8012302350244678
GVPLM_davis0D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7676686087397216,0.6398302090119784,0.4874074774095251,0.4720718172086983,0.4122438906773314,0.687074826499049
GVPLM_davis2D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.8056817717122009,0.6600120446400118,0.5615058375759591,0.3972760533222704,0.3633167063701174,0.630298384356386
GVPLM_davis1D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7796812905632664,0.6418365067480541,0.508433155533638,0.4350043675747687,0.3998515603735166,0.6595486089552223
GVPLM_davis3D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7920895987899996,0.6645194989409928,0.5585099170472267,0.469367754560377,0.4311579323672002,0.6851041924848928
GVPLM_davis4D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.829644425075257,0.7189969797010738,0.591263254476021,0.35142261273003306,0.3509021611774669,0.5928090862411212
42 changes: 33 additions & 9 deletions src/analysis/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,13 @@ def fig_combined(df, datasets=['PDBbind','davis', 'kiba'], metrics=['cindex', 'm
fig_scale[1]*len(metrics)))
for i, dataset in enumerate(datasets):
for j, metric in enumerate(metrics):
# Set current subplot
if len(datasets) == 1 or len(metrics) == 1:
ax = axes[j] if len(datasets) == 1 else axes[i]
# Set current subplot
if len(datasets) == 1 and len(metrics) == 1:
ax = axes
elif len(datasets) == 1:
ax = axes[j]
elif len(metrics) == 1:
ax = axes[i]
else:
ax = axes[j, i]

Expand Down Expand Up @@ -382,7 +386,8 @@ def fig_combined(df, datasets=['PDBbind','davis', 'kiba'], metrics=['cindex', 'm
return fig, axes

def custom_fig(df, models:OrderedDict=None, sel_dataset='PDBbind', sel_col='cindex',
verbose=False, show=False, add_stats=True, ax=None):
verbose=False, show=False, add_stats=True, ax=None, box=False,
fold_points=True, fold_labels=False, alpha=0.7):

"""
Example usage with `fig_combined`.
Expand Down Expand Up @@ -432,15 +437,34 @@ def matched(df, tuple):
filtered_df = filtered_df[sum(filter_conditions) > 0]

# Group each model results
plot_data = OrderedDict()
all_data = OrderedDict()
for model, feat in models.items():
plot_data[model] = filtered_df[matched(filtered_df, feat)][sel_col]
if len(plot_data[model]) != 5:
logging.warning(f'Expected 5 results for {model} on {sel_dataset}, got {len(plot_data[model])}')
all_data[model] = filtered_df[matched(filtered_df, feat)][[sel_col, 'fold']]
if len(all_data[model]) != 5:
logging.warning(f'Expected 5 results for {model} on {sel_dataset}, got {len(all_data[model])}')

# plot violin plot with annotations
plot_data = OrderedDict({k: v[sel_col] for k, v in all_data.items()})
fold_data = OrderedDict({k: v['fold'] for k, v in all_data.items()})
folds = list(fold_data.values())
vals = list(plot_data.values())
ax = sns.violinplot(data=vals, ax=ax)
if box:
ax = sns.boxplot(data=vals, ax=ax, boxprops=dict(alpha=alpha))
else:
ax = sns.violinplot(data=vals, ax=ax)
for violin in ax.collections:
violin.set_alpha(alpha)

if fold_points or fold_labels:
sns.stripplot(data=vals, dodge=True, ax=ax, alpha=.8, linewidth=1)

if fold_labels:
adjs = -0.5 if box else -0.2
for i in range(len(models)):
for f, v in zip(folds[i], vals[i]):
ax.text(i+adjs, v, f,
horizontalalignment='left', size='medium', color='red')

ax.set_xticklabels(list(plot_data.keys()))
ax.set_ylabel(sel_col)
ax.set_xlabel('Model Type')
Expand Down
33 changes: 23 additions & 10 deletions src/train_test/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def balanced_kfold_split(dataset: BaseDataset |str,


@init_dataset_object(strict=True)
def resplit(dataset:str|BaseDataset, split_files:dict|str=None, **kwargs):
def resplit(dataset:str|BaseDataset, split_files:dict|str=None, use_train_set=False, **kwargs):
"""
1.Takes as input the target dataset path or dataset object, and a dict defining the 6 splits for all 5 folds +
1 test set.
Expand All @@ -308,24 +308,36 @@ def resplit(dataset:str|BaseDataset, split_files:dict|str=None, **kwargs):
Returns:
BaseDataset: dataset object for "full" dataset
"""
key_names = ['val0', 'val1', 'val2', 'val3', 'val4', 'test']
if use_train_set: key_names += ['train0', 'train1', 'train2', 'train3', 'train4']

# getting split files from directory
if isinstance(split_files, str):
csv_files = {}
for split in ['test'] + [f'val{i}' for i in range(5)]:
for split in key_names:
csv_files[split] = f'{split_files}/{split}/cleaned_XY.csv'
split_files = csv_files
print('Using split files from:', split_files)

assert 'test' in split_files, 'Missing test csv from split files.'

##### Validate split files and if they exist #####
missing = []
for k in key_names:
if k not in split_files:
missing.append(k)
assert len(missing) == 0, f'Missing split files for: {missing}'

# Check if split files exist and are in the correct format
if split_files is None:
raise ValueError('split_files must be provided')
if len(split_files) != 6:
raise ValueError('split_files must contain 6 files for the 5 folds and test set')
num_files = 5*(1+use_train_set) + 1
if len(split_files) != num_files:
raise ValueError(f'split_files must contain {num_files} files for the 5 folds +1 for the test set.')

for f in split_files.values():
if not os.path.exists(f):
raise ValueError(f'{f} does not exist')

##### perform the resplitting #####
# Getting indices for each split based on db.df
split_files = split_files.copy()
test_prots = set(pd.read_csv(split_files['test'])['prot_id'])
Expand All @@ -339,10 +351,11 @@ def resplit(dataset:str|BaseDataset, split_files:dict|str=None, **kwargs):
val_idxs = [i for i in range(len(dataset.df)) if dataset.df.iloc[i]['prot_id'] in prots]
dataset.save_subset(val_idxs, k)

# Build training set from all proteins not in the val/test set
idxs = set(val_idxs + test_idxs)
train_idxs = [i for i in range(len(dataset.df)) if i not in idxs]
dataset.save_subset(train_idxs, k.replace('val', 'train'))
if not use_train_set:
# Build training set from all proteins not in the val/test set
idxs = set(val_idxs + test_idxs)
train_idxs = [i for i in range(len(dataset.df)) if i not in idxs]
dataset.save_subset(train_idxs, k.replace('val', 'train'))

return dataset

2 changes: 1 addition & 1 deletion src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class LIG_FEAT_OPT(StringEnum):
DATA_ROOT = os.path.abspath('../data/')

# Model save paths
issue_number = 113 # 113 is for unifying all splits for cross validation so that we are more confident
issue_number = 115 # 113 is for unifying all splits for cross validation so that we are more confident
# when comparing results that they were trained in the same manner.
RESULTS_PATH = os.path.abspath(f'results/v{issue_number}/')
MEDIA_SAVE_DIR = f'{RESULTS_PATH}/model_media/'
Expand Down

0 comments on commit 2e2449f

Please sign in to comment.