Skip to content

Commit

Permalink
results(davis): retrained aflow models #113 due to issue #116
Browse files Browse the repository at this point in the history
still need to train esm variants to complete #113 for davis
  • Loading branch information
jyaacoub committed Jul 9, 2024
1 parent 1361c7e commit a0e4405
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,4 @@ results/model_media/*/test_set_pred/*
results/model_media/test_set_pred

splits/**/*.csv
results/v113/model_media/*/train_log/*.json
40 changes: 39 additions & 1 deletion playground.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,42 @@
# %%
########################################################################
########################## VIOLIN PLOTTING #############################
########################################################################
import logging
from typing import OrderedDict

import seaborn as sns
from matplotlib import pyplot as plt
from statannotations.Annotator import Annotator

from src.analysis.figures import prepare_df, custom_fig, fig_combined

models = {
'DG': ('nomsa', 'binary', 'original', 'binary'),
# 'esm': ('ESM', 'binary', 'original', 'binary'), # esm model
'aflow': ('nomsa', 'aflow', 'original', 'binary'),
# 'gvpP': ('gvp', 'binary', 'original', 'binary'),
'gvpL': ('nomsa', 'binary', 'gvp', 'binary'),
# 'aflow_ring3': ('nomsa', 'aflow_ring3', 'original', 'binary'),
'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'),
# 'gvpL_aflow_rng3': ('nomsa', 'aflow_ring3', 'gvp', 'binary'),
#GVPL_ESMM_davis3D_nomsaF_aflowE_48B_0.00010636872718329864LR_0.23282479481785903D_2000E_gvpLF_binaryLE
# 'gvpl_esm_aflow': ('ESM', 'aflow', 'gvp', 'binary'),
}

df = prepare_df('/cluster/home/t122995uhn/projects/MutDTA/results/v113/model_media/model_stats.csv')
fig, axes = fig_combined(df, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" test set performance")
plt.xticks(rotation=45)

df = prepare_df('/cluster/home/t122995uhn/projects/MutDTA/results/v113/model_media/model_stats_val.csv')
fig, axes = fig_combined(df, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance")
plt.xticks(rotation=45)


# %%
from src.data_prep.init_dataset import create_datasets
from src import cfg
Expand Down Expand Up @@ -49,6 +88,5 @@
sfile = f"{src}/val{i}/XY.csv"
dfile = f"{dst}/val{i}.csv"
shutil.copyfile(sfile, dfile)


# %%
10 changes: 10 additions & 0 deletions results/v113/model_media/model_stats.csv
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,13 @@ GVPLM_davis2D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_
GVPLM_davis3D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7881126560725141,0.5117816919373207,0.5344788271118109,0.7824263605065905,0.4710016337041163,0.8845486761657555
GVPLM_davis0D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.758263866811987,0.4372814709152363,0.4796308343784581,0.8577483542500933,0.5066641469660601,0.9261470478547632
GVPLM_davis1D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.6851987657280194,0.293523499357519,0.346229131471335,1.0881947803478098,0.5547743755228379,1.0431657492209998
DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.6936121435080361,0.3264373697380284,0.3623624786171266,0.955541722461442,0.5866414782306556,0.9775181443131592
DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.6954896081622584,0.3276610988554868,0.3660103054199092,0.9501814404206184,0.5941079377785988,0.9747725070090038
DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.6958698762052185,0.3287337103560256,0.3663455614887089,0.9508110868966404,0.5924060765434714,0.9750954245081044
DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.6921726461681715,0.328433593374192,0.3581768615124882,0.9489810461762982,0.598616964334737,0.974156581960158
DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.691988107981202,0.3206628430008865,0.3595820594502304,0.9616201903122608,0.5948554296233561,0.9806223484666564
GVPLM_davis0D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7120142318237347,0.3830589592680732,0.400661326698907,0.9422431676298167,0.5708391885333287,0.9706921075345244
GVPLM_davis2D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.6758544466241947,0.3482148623183911,0.3313455573966994,0.9952058429887072,0.5672084597638211,0.9976000415941788
GVPLM_davis1D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7715096417472135,0.5070999916573449,0.505118493158149,0.8228862016569322,0.521357725135223,0.90713075223858
GVPLM_davis3D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7784890152525039,0.5548561075227736,0.519896821064786,0.7201384467294399,0.534564816284727,0.8486097140201967
GVPLM_davis4D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7591851271174638,0.4580465758811774,0.49007259211942494,0.8589132362281114,0.5267642440549613,0.9267757205646421
22 changes: 10 additions & 12 deletions results/v113/model_media/model_stats_val.csv
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
run,cindex,pearson,spearman,mse,mae,rmse
<<<<<<< Updated upstream
DGM_davis1D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8163523994132658,0.7137851124164986,0.5878867118585559,0.4173222599402789,0.35940127305806,0.6460048451368449
DGM_davis2D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8343753726176008,0.7720660560854502,0.618611791650495,0.3283613673160558,0.3143092008198009,0.5730282430352416
DGM_davis0D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8345962630012087,0.7841559792978283,0.6289360094161347,0.3611685706886787,0.3225311759640188,0.6009730199340722
DGM_davis4D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8557152702768875,0.7904879365029139,0.6656571912339156,0.3427956306740941,0.3299113169392162,0.585487515387044
DGM_davis3D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8349019526166668,0.7875041565317034,0.6443896656067043,0.3662570748291923,0.32390817027678465,0.605191766987285
=======
DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7542333456540382,0.4661222527360288,0.4622454723166427,0.5745961603253643,0.4490743226299249,0.7580212136380909
DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7889097125722226,0.6050398686544554,0.5194217790291483,0.3712609331151108,0.3486874123904465,0.6093118521045777
DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7619768709032059,0.5908949206748343,0.4620148211231741,0.3627365784893167,0.3464703883892624,0.6022761646365533
DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7518315515465352,0.4909815923227958,0.4818553258238582,0.5928763165737386,0.5397916083517371,0.769984620998198
DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7852144067428423,0.5455471328321575,0.5134570446387595,0.4461269801785316,0.3969394130741848,0.6679273764254102
GVPLM_davis3D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.8041134218658649,0.64988689371186,0.573599346178764,0.4519291998976836,0.3978824110179624,0.672256796096316
GVPLM_davis2D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7802128282015651,0.6222348456680393,0.5034858314455757,0.3639218510862791,0.367958717825743,0.6032593564017712
GVPLM_davis1D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7568603850512082,0.5231446926168597,0.4563947560300728,0.4067614037303542,0.3910246412845944,0.6377784911161195
GVPLM_davis0D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7314242613494631,0.520657218580497,0.4200042081423011,0.549750266646302,0.4693360195214201,0.7414514593999407
GVPLM_davis4D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.8317179299256213,0.7525555462006857,0.6231692649585702,0.4056508528712579,0.3161008745272529,0.6369072560987651
GVPLM_davis2D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.8288787269604654,0.7293571227266307,0.6105211859324264,0.3761462606386942,0.3491652655728998,0.6133076394752427
GVPLM_davis3D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.8164773292665711,0.7250154328888457,0.6154067444267105,0.4597009403992768,0.3855288712417378,0.6780124928047246
GVPLM_davis0D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.8055708777738954,0.7243877845881463,0.5809665306027836,0.4394688312618948,0.3957247064712851,0.6629244536611203
GVPLM_davis1D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.775744215597699,0.6842257371933662,0.5160465530857263,0.4672185947997873,0.3549506181382878,0.68353390171943
>>>>>>> Stashed changes
DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7523695779428693,0.4682407034650613,0.4649581513699286,0.6092224587753515,0.4834460555550886,0.7805270391058541
DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7762332212240666,0.5470881337730895,0.5112932834407308,0.4740730581497179,0.4375338560757152,0.688529634910305
DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7570127033757169,0.5114232635236531,0.4736185104310603,0.5234106486856192,0.4567883369677207,0.723471249384258
DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7801250458793092,0.5321007699350608,0.5131027258344268,0.5202149722288628,0.4510516971865327,0.7212592961126136
DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7450873842292931,0.4893503545803986,0.4778696104467196,0.6419698895173639,0.490655592843598,0.8012302350244678
GVPLM_davis0D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7676686087397216,0.6398302090119784,0.4874074774095251,0.4720718172086983,0.4122438906773314,0.687074826499049
GVPLM_davis2D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.8056817717122009,0.6600120446400118,0.5615058375759591,0.3972760533222704,0.3633167063701174,0.630298384356386
GVPLM_davis1D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7796812905632664,0.6418365067480541,0.508433155533638,0.4350043675747687,0.3998515603735166,0.6595486089552223
GVPLM_davis3D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.7920895987899996,0.6645194989409928,0.5585099170472267,0.469367754560377,0.4311579323672002,0.6851041924848928
GVPLM_davis4D_nomsaF_aflowE_128B_0.00014968791626986144LR_0.00039427600918916277D_2000E_gvpLF_binaryLE,0.829644425075257,0.7189969797010738,0.591263254476021,0.35142261273003306,0.3509021611774669,0.5928090862411212
20 changes: 20 additions & 0 deletions src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,26 @@
'num_GVPLayers': 3
}
},
'davis_esm':{
"model": cfg.MODEL_OPT.EDI,

"dataset": cfg.DATA_OPT.davis,
"feature_opt": cfg.PRO_FEAT_OPT.nomsa,
"edge_opt": cfg.PRO_EDGE_OPT.binary,
"lig_feat_opt": cfg.LIG_FEAT_OPT.original,
"lig_edge_opt": cfg.LIG_EDGE_OPT.binary,

'lr': 0.0001,
'batch_size': 48, # global batch size (local was 12)

'architecture_kwargs': {
'dropout': 0.4,
'dropout_prot': 0.0,
'output_dim': 128,
'pro_extra_fc_lyr': False,
# 'pro_emb_dim': 512 # just for reference since this is the default for EDI
}
},
'davis_gvpl_esm_aflow': {
"model": cfg.MODEL_OPT.GVPL_ESM,

Expand Down
2 changes: 1 addition & 1 deletion train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
cp_saver = CheckpointSaver(model=None,
save_path=None,
train_all=False, # forces full training
patience=150, min_delta=0.5)
patience=50, min_delta=0.2)

# %% Training loop
metrics = {}
Expand Down

0 comments on commit a0e4405

Please sign in to comment.