diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..e51c7cd --- /dev/null +++ b/inference.py @@ -0,0 +1,84 @@ +# %% +from src.data_prep.feature_extraction.ligand import smile_to_graph +from src.data_prep.feature_extraction.protein import target_to_graph +from src.data_prep.feature_extraction.protein_edges import get_target_edge_weights +from src import cfg +import torch +import torch_geometric as torchg +import numpy as np + +DATA = cfg.DATA_OPT.davis +lig_feature = cfg.LIG_FEAT_OPT.original +lig_edge = cfg.LIG_EDGE_OPT.binary +pro_feature = cfg.PRO_FEAT_OPT.nomsa +pro_edge = cfg.PRO_EDGE_OPT.binary + +lig_seq = "CN(C)CC=CC(=O)NC1=C(C=C2C(=C1)C(=NC=N2)NC3=CC(=C(C=C3)F)Cl)OC4CCOC4" + +#%% build ligand graph +mol_feat, mol_edge = smile_to_graph(lig_seq, lig_feature=lig_feature, lig_edge=lig_edge) +lig = torchg.data.Data(x=torch.Tensor(mol_feat), edge_index=torch.LongTensor(mol_edge), + lig_seq=lig_seq) + +#%% build protein graph +# predicted using - https://zhanggroup.org/NeBcon/ +prot_id = 'EGFR(L858R)' +pro_seq = 'MRPSGTAGAALLALLAALCPASRALEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCWGAGEENCQKLTKIICAQQCSGRCRGKSPSDCCHNQCAAGCTGPRESDCLVCRKFRDEATCKDTCPPLMLYNPTTYQMDVNPEGKYSFGATCVKKCPRNYVVTDHGSCVRACGADSYEMEEDGVRKCKKCEGPCRKVCNGIGIGEFKDSLSINATNIKHFKNCTSISGDLHILPVAFRGDSFTHTPPLDPQELDILKTVKEITGFLLIQAWPENRTDLHAFENLEIIRGRTKQHGQFSLAVVSLNITSLGLRSLKEISDGDVIISGNKNLCYANTINWKKLFGTSGQKTKIISNRGENSCKATGQVCHALCSPEGCWGPEPRDCVSCRNVSRGRECVDKCNLLEGEPREFVENSECIQCHPECLPQAMNITCTGRGPDNCIQCAHYIDGPHCVKTCPAGVMGENNTLVWKYADAGHVCHLCHPNCTYGCTGPGLEGCPTNGPKIPSIATGMVGALLLLLVVALGIGLFMRRRHIVRKRTLRRLLQERELVEPLTPSGEAPNQALLRILKETEFKKIKVLGSGAFGTVYKGLWIPEGEKVKIPVAIKELREATSPKANKEILDEAYVMASVDNPHVCRLLGICLTSTVQLITQLMPFGCLLDYVREHKDNIGSQYLLNWCVQIAKGMNYLEDRRLVHRDLAARNVLVKTPQHVKITDFGLAKLLGAEEKEYHAEGGKVPIKWMALESILHRIYTHQSDVWSYGVTVWELMTFGSKPYDGIPASEISSILEKGERLPQPPICTIDVYMIMVKCWMIDADSRPKFRELIIEFSKMARDPQRYLVIQGDERMHLPSPTDSNFYRALMDEEDMDDVVDADEYLIPQQGFFSSPSTSRTPLLSSLSATSNNSTVACIDRNGLQSCPIKEDSFLQRYSSDPTGALTEDSIDDTFLPVPEYINQSVPKRPAGSVQNPVYHNQPLNPAPSRDPHYQDPHSTAVGNPEYLNTVQPTCVNSTFDSPAHWAQKGSHQISLDNPDYQQDFFPKEAKPNGIFKGSTAENAEYLRVAPQSSEFIGA' +cmap_p = f'/cluster/home/t122995uhn/projects/data/davis/pconsc4/{prot_id}.npy' + +pro_cmap = np.load(cmap_p) +# updated_seq is for updated foldseek 3di combined seq +updated_seq, extra_feat, edge_idx = target_to_graph(target_sequence=pro_seq, + contact_map=pro_cmap, + threshold=8.0 if DATA is cfg.DATA_OPT.PDBbind else -0.5, + pro_feat=pro_feature) +pro_feat = torch.Tensor(extra_feat) + + + +pro = torchg.data.Data(x=torch.Tensor(pro_feat), + edge_index=torch.LongTensor(edge_idx), + pro_seq=updated_seq, # Protein sequence for downstream esm model + prot_id=prot_id, + edge_weight=None) + +#%% Loading the model +from src.utils.loader import Loader +from src import TUNED_MODEL_CONFIGS +import os + +def reformat_kwargs(model_kwargs): + return { + 'model': model_kwargs['model'], + 'data': model_kwargs['dataset'], + 'pro_feature': model_kwargs['feature_opt'], + 'edge': model_kwargs['edge_opt'], + 'batch_size': model_kwargs['batch_size'], + 'lr': model_kwargs['lr'], + 'dropout': model_kwargs['architecture_kwargs']['dropout'], + 'n_epochs': model_kwargs.get('n_epochs', 2000), # Assuming a default value for n_epochs + 'pro_overlap': model_kwargs.get('pro_overlap', False), # Assuming a default or None + 'fold': model_kwargs.get('fold', 0), # Assuming a default or None + 'ligand_feature': model_kwargs['lig_feat_opt'], + 'ligand_edge': model_kwargs['lig_edge_opt'] + } + + +model_kwargs = reformat_kwargs(TUNED_MODEL_CONFIGS['davis_esm']) + +MODEL_KEY = Loader.get_model_key(**model_kwargs) + +model_p_tmp = f'{cfg.MODEL_SAVE_DIR}/{MODEL_KEY}.model_tmp' +model_p = f'{cfg.MODEL_SAVE_DIR}/{MODEL_KEY}.model' + +# MODEL_KEY = 'DDP-' + MODEL_KEY # distributed model +model_p = model_p if os.path.isfile(model_p) else model_p_tmp +assert os.path.isfile(model_p), f"MISSING MODEL CHECKPOINT {model_p}" + +print(model_p) +# %% +args = model_kwargs +model = Loader.init_model(model=args['model'], pro_feature=args['pro_feature'], + pro_edge=args['edge'], **args['architecture_kwargs']) + + diff --git a/playground.py b/playground.py index 51c359f..a24f292 100644 --- a/playground.py +++ b/playground.py @@ -1,46 +1,47 @@ -# %% oncokb proteins +# %% +import pandas as pd +train_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/DavisKibaDataset/davis/nomsa_binary_original_binary/full/XY.csv') +test_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/davis/test.csv') +train_df = train_df[~train_df.prot_id.isin(set(test_df.prot_id))] + +# %% import pandas as pd -kb_df = pd.read_csv('/cluster/home/t122995uhn/projects/downloads/oncoKB_DrugGenePairList.csv') -kb_prots = set(kb_df.gene) +davis_test_df = pd.read_csv(f"/home/jean/projects/MutDTA/splits/davis/test.csv") +davis_test_df['gene'] = davis_test_df['prot_id'].str.split('(').str[0] + +#%% ONCO KB MERGE +onco_df = pd.read_csv("../data/oncoKB_DrugGenePairList.csv") +davis_join_onco = davis_test_df.merge(onco_df.drop_duplicates("gene"), on="gene", how="inner") + +# %% +onco_df = pd.read_csv("../data/oncoKB_DrugGenePairList.csv") +onco_df.merge(davis_test_df.drop_duplicates("gene"), on="gene", how="inner").value_counts("gene") + + + -davis_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/davis/test.csv') -davis_df['gene'] = davis_df.prot_id.str.split('(').str[0] -kiba_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/kiba/test.csv') -pdb_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind/test.csv') -davis_df['db'] = 'davis' -kiba_df['db'] = 'kiba' -pdb_df['db'] = 'pdbbind' -#%% -all_df = pd.concat([davis_df, kiba_df, pdb_df], axis=0) -new_order = ['db'] + [x for x in all_df.columns if x != 'db'] -all_df = all_df[new_order].drop(['seq_len', - 'gene_matched_on_pdb_id', - 'gene_matched_on_uniprot_id'], axis=1) -all_df.to_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/all_tests.csv') -kb_overlap_test = all_df[all_df.gene.isin(kb_prots)] +# %% +from src.train_test.splitting import resplit +from src import cfg + +db_p = lambda x: f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/nomsa_{x}_gvp_binary' + +db = resplit(dataset=db_p('binary'), split_files=db_p('aflow'), use_train_set=True) -kb_overlap_test.to_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/all_tests_oncokb.csv') -['BRAF', 'ERBB2', 'FGFR2', 'FGFR3', 'KIT', 'PDGFRA', 'PIK3CA', - 'RAF1', 'CHEK1', 'CHEK2', 'FGFR1', 'MAP2K1', 'MAP2K2', 'MTOR', - 'EZH2', 'KDM6A', 'HRAS', 'KRAS', 'IDH1', 'PTEN', 'ESR1', 'BRIP1'] # %% ######################################################################## ########################## VIOLIN PLOTTING ############################# ######################################################################## import logging -from typing import OrderedDict - -import seaborn as sns from matplotlib import pyplot as plt -from statannotations.Annotator import Annotator from src.analysis.figures import prepare_df, fig_combined, custom_fig @@ -49,7 +50,7 @@ models = { 'DG': ('nomsa', 'binary', 'original', 'binary'), - # 'esm': ('ESM', 'binary', 'original', 'binary'), # esm model + 'esm': ('ESM', 'binary', 'original', 'binary'), # esm model 'aflow': ('nomsa', 'aflow', 'original', 'binary'), # 'gvpP': ('gvp', 'binary', 'original', 'binary'), 'gvpL': ('nomsa', 'binary', 'gvp', 'binary'), @@ -70,15 +71,19 @@ fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance", box=True, fold_labels=True) plt.xticks(rotation=45) - -# %% +#%% +######################################################################## +########################## BUILD DATASETS ############################## +######################################################################## from src.data_prep.init_dataset import create_datasets from src import cfg +import logging +cfg.logger.setLevel(logging.DEBUG) -splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/davis/' -create_datasets(cfg.DATA_OPT.davis, +splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind/' +create_datasets(cfg.DATA_OPT.PDBbind, feat_opt=cfg.PRO_FEAT_OPT.nomsa, - edge_opt=cfg.PRO_EDGE_OPT.aflow, + edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow], ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp], ligand_edges=cfg.LIG_EDGE_OPT.binary, k_folds=5, diff --git a/results/v113/model_media/model_stats.csv b/results/v113/model_media/model_stats.csv index d5f0d02..97b9b98 100644 --- a/results/v113/model_media/model_stats.csv +++ b/results/v113/model_media/model_stats.csv @@ -24,7 +24,28 @@ EDIM_davis1D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.821300 EDIM_davis2D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.8225446485053062,0.6603442335278025,0.5949532834745241,0.6383847882308629,0.4154985330988705,0.7989898548985856 EDIM_davis3D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.8257508664241492,0.7310753856354779,0.5889781203307795,0.4852039936128347,0.3546071329782175,0.6965658573407361 EDIM_davis4D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.8404380957025455,0.7536920423101249,0.6146193037809486,0.4475845677765498,0.344083449582884,0.6690176139508958 -DGM_kiba0D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.5790729147254525,0.3237567466038684,0.220483622988913,0.5872366004827765,0.5416789341999021,0.7663136436752098 -DGM_kiba1D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6991041045925556,0.5827312646091369,0.5234334020017861,0.4359633916482334,0.4614902562500025,0.6602752393117838 -DGM_kiba3D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6703138545001528,0.5130147763454005,0.4537929112353873,0.4829267623645413,0.484057142425249,0.6949293218483023 +DGM_kiba0D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7098487012268593,0.5707564567872483,0.5408755032739221,0.4702066637233398,0.456037305701769,0.68571616848616 +DGM_kiba1D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6991040727775232,0.5827312636226913,0.5234333190865307,0.4359633921504585,0.4614902574862928,0.6602752396920988 +DGM_kiba3D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.673975414774567,0.5176426001573622,0.4632216123889441,0.4814432352161062,0.4857926130830525,0.6938611065740076 DGM_kiba2D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.672607400192465,0.5094659678690476,0.4582769804170869,0.5474786138565632,0.5343923583458045,0.7399179777898109 +DGM_kiba4D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.725334039945205,0.5904997735444316,0.5827307698737493,0.4852863028425619,0.452352737630416,0.6966249369944791 +DGM_kiba3D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.6779600069482259,0.5366735812473975,0.4684844875359095,0.4911992003209447,0.4879214073204387,0.7008560482159976 +DGM_kiba1D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.6572547696170707,0.4910169881324348,0.4223030308489566,0.5056608650145961,0.4887140404038343,0.7110983511544631 +DGM_kiba0D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7034265325632763,0.5925355722791278,0.5277625252545413,0.4342528236949791,0.4414575484107895,0.6589786215765874 +DGM_kiba2D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7048145943290888,0.5949726563519002,0.5309972890766975,0.4319786589236721,0.4397173516202035,0.657250834098879 +DGM_kiba4D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.701385350951551,0.5879330884340543,0.5238489684293576,0.43711720760821254,0.44305434527103205,0.6611484005941575 +DGM_PDBbind0D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6655853516013849,0.4884783169365537,0.4768266531852068,2.716448143336347,1.29524747639392,1.648165083763258 +DGM_PDBbind1D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6616134720685233,0.4760514005042608,0.465349083935074,2.7772597147478297,1.309049959958126,1.6665112405104952 +DGM_PDBbind2D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6547720831157129,0.4674109621848075,0.4473761042353416,2.911396055546887,1.350302681084595,1.7062813529857517 +DGM_PDBbind3D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6590676461310532,0.4730640252598409,0.4579028540331872,2.813666484139417,1.3159578356701942,1.6773987254494436 +DGM_PDBbind4D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6560204986341623,0.4656808967835565,0.4491845247551415,2.830198214336796,1.3234003991911518,1.6823192961910638 +GVPLM_kiba0D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6916246228601947,0.5614767553441977,0.4996457822085293,1.395078889806946,1.0094453810491737,1.181134577347961 +GVPLM_kiba3D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6563609963973004,0.5357178296100711,0.4197069130416901,1.1612746598503312,0.8684839913064618,1.0776245449368398 +GVPLM_kiba1D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6676856814770858,0.5583291036072489,0.443468456589727,1.3486743771266174,0.9835523785170164,1.161324406497434 +GVPLM_kiba4D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.706312738166959,0.618586606318348,0.5346567955843519,1.2316886588884002,0.945708981855823,1.1098146957435733 +GVPLM_kiba2D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6999996243009051,0.5844167249718527,0.5172933123278608,1.1060115689114345,0.8743866055307609,1.051670846278166 +GVPLM_kiba4D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.7028775926753292,0.5511820968815552,0.5274587057439738,0.5191856014604727,0.464434987594827,0.7205453500373676 +GVPLM_kiba3D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.6974533040075589,0.5366909353583093,0.5092679820627712,0.5267452909346243,0.4756001094118965,0.7257722031978245 +GVPLM_kiba2D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.6914933128392398,0.5535320315157304,0.49916906416324947,0.5167043183823063,0.46841519355949945,0.7188214787986696 +GVPLM_kiba0D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.667835957156033,0.5118338017508158,0.4512406693094591,0.529732104745498,0.4987863530176246,0.7278269744558098 +GVPLM_kiba1D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.673244203521782,0.5584520336614822,0.4601543459719329,0.4634166140494674,0.4859944296833685,0.6807470999199831 diff --git a/results/v113/model_media/model_stats_val.csv b/results/v113/model_media/model_stats_val.csv index 04c1400..3d6e7cb 100644 --- a/results/v113/model_media/model_stats_val.csv +++ b/results/v113/model_media/model_stats_val.csv @@ -24,7 +24,28 @@ EDIM_davis1D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.816451 EDIM_davis2D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.8216236400111289,0.7501182762791265,0.5974487525198227,0.350761726413648,0.3393165324142272,0.5922514047375895 EDIM_davis3D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.860215850035831,0.8257879154012255,0.6885140061655473,0.3049388665705435,0.3008343178958179,0.5522127004792116 EDIM_davis4D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.869855884432354,0.8198549396094414,0.6874469120379113,0.300964541267322,0.2915365434585408,0.5486023525900359 -DGM_kiba0D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.588709694581878,0.3937358060010086,0.244256676114009,0.5859551462931225,0.5731926417196901,0.7654770710433608 -DGM_kiba1D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7435320925039093,0.6398140200524538,0.6268500070894489,0.4107691435515396,0.4127436749121904,0.6409127425410885 -DGM_kiba3D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6724758528185321,0.5028962013125743,0.4551672280896423,0.4654384967530096,0.4689483909867084,0.682230530504909 +DGM_kiba0D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6963043086640728,0.6044845580110094,0.5044084790491152,0.4332235417293724,0.462965043539043,0.6581971906118806 +DGM_kiba1D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7435321292069627,0.6398140176186407,0.6268500067874493,0.4107691456559463,0.4127436667380659,0.6409127441828149 +DGM_kiba3D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6738115108510626,0.5062866384479864,0.4584816022310084,0.4642658758560831,0.4735249047745725,0.6813705862862611 DGM_kiba2D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.730532492692012,0.6193619256451531,0.586863514576392,0.512769142470763,0.5601343077809119,0.7160790057464071 +DGM_kiba4D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7306965607165764,0.6725079045145821,0.5879382934587275,0.3752289881284924,0.3928903582254524,0.6125593751861875 +DGM_kiba3D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7337489077017221,0.6038998080790196,0.5957554788128485,0.4110459962390765,0.4166844796933485,0.6411286892965222 +DGM_kiba1D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7169797416293802,0.6527045139380728,0.5592661699267665,0.3368578002681257,0.3921905937173818,0.5803945212251109 +DGM_kiba0D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7128444726133566,0.6469405510839773,0.5484484334329713,0.336774733412883,0.4184534194805188,0.5803229561312244 +DGM_kiba2D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7033453012745192,0.6072209463563496,0.5174763874796517,0.4390207180018329,0.4649270940615934,0.6625863853127628 +DGM_kiba4D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7023830999958538,0.639949133525287,0.5151496892582399,0.40105347469958885,0.4409018537119612,0.6332878292684843 +DGM_PDBbind0D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.667185125919345,0.4916330475527956,0.4742033800137094,2.573515761007047,1.247796092498501,1.604218115159858 +DGM_PDBbind1D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6918743579083576,0.5545547118073524,0.5483030524092783,2.511534059876986,1.2495642605282011,1.5847820228274254 +DGM_PDBbind2D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6664571269002308,0.5086157319827234,0.4766146235124788,2.7894558492182133,1.3283703922279295,1.6701664136301548 +DGM_PDBbind3D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6611219451160404,0.5089255012900505,0.4600979435412906,2.783708713595093,1.3037198611195124,1.6684449986724446 +DGM_PDBbind4D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6534984333863754,0.4396685778489655,0.4407730853652159,2.576893253351179,1.2700988720333766,1.605270461122106 +GVPLM_kiba0D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6733090385648975,0.5445900684839073,0.4521990415567309,1.2826114472884644,0.9411753130130744,1.1325243694015878 +GVPLM_kiba3D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6662243896458074,0.4993110854957338,0.4421583024717247,1.2516487322969612,0.9116905590693156,1.1187710812748786 +GVPLM_kiba1D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7040956155872347,0.5619996771694697,0.5255565499914184,1.6135555044624046,1.09047767273585,1.2702580464072664 +GVPLM_kiba4D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7177100067210386,0.6378039890592351,0.5549280580594628,1.2182558614793937,0.9438271032744332,1.1037462849221251 +GVPLM_kiba2D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7581840462826792,0.6821950063412816,0.643349960216416,0.8554240645192899,0.7355428603612219,0.9248913798491636 +GVPLM_kiba4D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.6989113278386156,0.6451542686142409,0.5034844342783885,0.4038230338246696,0.4156683528230258,0.6354707183062566 +GVPLM_kiba3D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.7422723180937305,0.5910071794468137,0.6044807220207702,0.4584843911137911,0.4115796661883425,0.677114754760071 +GVPLM_kiba2D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.7926024955332317,0.7374112888096318,0.7063431040772125,0.32468252989801,0.3579667507508575,0.5698092048203591 +GVPLM_kiba0D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.6669854507035494,0.5525438589938947,0.4355859541059588,0.4115128362890195,0.4456888421443692,0.6414926626930503 +GVPLM_kiba1D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.6989377849970573,0.6478510815311954,0.515693387777559,0.3561253610107396,0.4216625264430977,0.5967623991261007 diff --git a/src/__init__.py b/src/__init__.py index 21b5535..370ef00 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -173,7 +173,6 @@ 'pro_emb_dim': 512 # just for reference since this is the default for EDI } }, - 'kiba_gvpl_aflow': { "model": cfg.MODEL_OPT.GVPL, @@ -214,6 +213,62 @@ ##################################################### ########### PDBbind ################################# ##################################################### + #DGM_PDBbind0D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E + 'PDBbind_DG': { + "model": cfg.MODEL_OPT.DG, + + "dataset": cfg.DATA_OPT.PDBbind, + "feature_opt": cfg.PRO_FEAT_OPT.nomsa, + "edge_opt": cfg.PRO_EDGE_OPT.binary, + "lig_feat_opt": cfg.LIG_FEAT_OPT.original, + "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, + + 'lr': 0.0001, + 'batch_size': 64, + + 'architecture_kwargs': { + 'dropout': 0.4, + 'output_dim': 128, + } + }, + 'PDBbind_aflow':{ + "model": cfg.MODEL_OPT.DG, + + "dataset": cfg.DATA_OPT.PDBbind, + "feature_opt": cfg.PRO_FEAT_OPT.nomsa, + "edge_opt": cfg.PRO_EDGE_OPT.aflow, + "lig_feat_opt": cfg.LIG_FEAT_OPT.original, + "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, + + 'lr': 0.0009185598967356679, + 'batch_size': 128, + + 'architecture_kwargs': { + 'dropout': 0.22880989869337157, + 'output_dim': 256 + } + }, + #EDIM_PDBbind1D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E + 'PDBbind_esm':{ + "model": cfg.MODEL_OPT.EDI, + + "dataset": cfg.DATA_OPT.PDBbind, + "feature_opt": cfg.PRO_FEAT_OPT.nomsa, + "edge_opt": cfg.PRO_EDGE_OPT.binary, + "lig_feat_opt": cfg.LIG_FEAT_OPT.original, + "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, + + 'lr': 0.0001, + 'batch_size': 48, # global batch size (local was 12) + + 'architecture_kwargs': { + 'dropout': 0.4, + 'dropout_prot': 0.0, + 'output_dim': 128, + 'pro_extra_fc_lyr': False, + 'pro_emb_dim': 512 # just for reference since this is the default for EDI + } + }, #GVPLM_PDBbind0D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE 'PDBbind_gvpl_aflow':{ "model": cfg.MODEL_OPT.GVPL, @@ -250,21 +305,4 @@ 'output_dim': 512 } }, - 'PDBbind_aflow':{ - "model": cfg.MODEL_OPT.DG, - - "dataset": cfg.DATA_OPT.PDBbind, - "feature_opt": cfg.PRO_FEAT_OPT.nomsa, - "edge_opt": cfg.PRO_EDGE_OPT.aflow, - "lig_feat_opt": cfg.LIG_FEAT_OPT.original, - "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, - - 'lr': 0.0009185598967356679, - 'batch_size': 128, - - 'architecture_kwargs': { - 'dropout': 0.22880989869337157, - 'output_dim': 256 - } - }, } diff --git a/src/data_prep/datasets.py b/src/data_prep/datasets.py index 2dc9d18..498ead7 100644 --- a/src/data_prep/datasets.py +++ b/src/data_prep/datasets.py @@ -654,7 +654,7 @@ def pre_process(self): # Get binding data: df_binding = PDBbindProcessor.get_binding_data(self.raw_paths[0]) # _data.2020 df_binding.drop(columns=['resolution', 'release_year'], inplace=True) - df_binding.rename({'lig_name':'lig_id'}, inplace=True) + df_binding.rename({'lig_name':'lig_id'}, inplace=True, axis=1) pdb_codes = df_binding.index # pdbcodes ############## validating codes ############# @@ -933,14 +933,15 @@ def pre_process(self): # WARNING: TEMPORARY FIX FOR DAVIS (TESK1 highQ structure is mismatched...) no_confs.append('TESK1') - print(f'Number of codes missing af2 configurations: {len(no_confs)} / {len(codes)}') + logging.warning(f'Number of codes missing {"aflow" if self.alphaflow else "af2"} ' + \ + f'conformations: {len(no_confs)} / {len(codes)}') invalid_codes = set(no_aln + no_cmap + no_confs) # filtering out invalid codes and storing their index vals. lig_r = [r for i,r in enumerate(lig_r) if codes[prot_c[i]] not in invalid_codes] prot_c = [c for c in prot_c if codes[c] not in invalid_codes] - assert len(prot_c) > 10, f"Not enough proteins in dataset, {len(prot_c)} total." + assert len(prot_c) > 10, f"Not enough proteins in dataset, {len(prot_c)} total from {self.af_conf_dir}" # creating binding dataframe: # code,SMILE,pkd,prot_seq diff --git a/src/data_prep/init_dataset.py b/src/data_prep/init_dataset.py index 586954e..7c3df45 100644 --- a/src/data_prep/init_dataset.py +++ b/src/data_prep/init_dataset.py @@ -62,7 +62,7 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis for data, FEATURE, EDGE, ligand_feature, ligand_edge in itertools.product( data_opt, feat_opt, edge_opt, ligand_features, ligand_edges): - print('\n', data, FEATURE, EDGE) + print('\n', data, FEATURE, EDGE, ligand_feature, ligand_edge) if data in ['davis', 'kiba']: if FEATURE == 'msa': # position frequency matrix creation -> important for msa feature diff --git a/src/models/gvp_models.py b/src/models/gvp_models.py index f2a65aa..be8eb04 100644 --- a/src/models/gvp_models.py +++ b/src/models/gvp_models.py @@ -14,25 +14,43 @@ from src.models.ring3 import Ring3Branch -class GVPLigand_DGPro(DGraphDTA): +class GVPLigand_DGPro(BaseModel): """ DG model with GVP Ligand branch + + model = GVPLigand_DGPro(num_features_pro=num_feat_pro, + dropout=dropout, + edge_weight_opt=pro_edge, + **kwargs) """ def __init__(self, num_features_pro=54, - num_features_mol=78, output_dim=512, dropout=0.2, num_GVPLayers=3, edge_weight_opt='binary', **kwargs): output_dim = int(output_dim) - super(GVPLigand_DGPro, self).__init__(num_features_pro, - num_features_mol, output_dim, - dropout, edge_weight_opt) + super(GVPLigand_DGPro, self).__init__(pro_feat=None, + edge_weight_opt=edge_weight_opt) self.gvp_ligand = GVPBranchLigand(num_layers=num_GVPLayers, final_out=output_dim, drop_rate=dropout) + # protein branch: + emb_feat= 54 # to ensure constant embedding size regardless of input size (for fair comparison) + self.pro_conv1 = GCNConv(num_features_pro, emb_feat) + self.pro_conv2 = GCNConv(emb_feat, emb_feat * 2) + self.pro_conv3 = GCNConv(emb_feat * 2, emb_feat * 4) + self.pro_fc = nn.Sequential( + nn.Linear(emb_feat * 4, 1024), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(1024, output_dim), + nn.Dropout(dropout) + ) + self.relu = nn.ReLU() + + # concat branch to feedforward network self.dense_out = nn.Sequential( nn.Linear(2*output_dim, 1024), nn.Dropout(dropout), @@ -51,6 +69,30 @@ def __init__(self, num_features_pro=54, def forward_mol(self, data): return self.gvp_ligand(data) + def forward_pro(self, data): + # get protein input + target_x, ei, target_batch = data.x, data.edge_index, data.batch + # if edge_weight doesnt exist no error is thrown it just passes it as None + ew = data.edge_weight if self.edge_weight else None + + xt = self.pro_conv1(target_x, ei, ew) + xt = self.relu(xt) + + # target_edge_index, _ = dropout_adj(target_edge_index, training=self.training) + xt = self.pro_conv2(xt, ei, ew) + xt = self.relu(xt) + + # target_edge_index, _ = dropout_adj(target_edge_index, training=self.training) + xt = self.pro_conv3(xt, ei, ew) + xt = self.relu(xt) + + # xt = self.pro_conv4(xt, target_edge_index) + # xt = self.relu(xt) + xt = gep(xt, target_batch) # global pooling + + # FFNN + return self.pro_fc(xt) + def forward(self, data_pro, data_mol): xm = self.forward_mol(data_mol) xp = self.forward_pro(data_pro) diff --git a/src/models/prior_work.py b/src/models/prior_work.py index 4ce32d1..68ea377 100644 --- a/src/models/prior_work.py +++ b/src/models/prior_work.py @@ -68,7 +68,8 @@ def forward_pro(self, data): xt = gep(xt, target_batch) # global pooling # flatten - xt = self.relu(self.pro_fc_g1(xt)) + xt = self.pro_fc_g1(xt) + xt = self.relu(xt) xt = self.dropout(xt) xt = self.pro_fc_g2(xt) xt = self.dropout(xt) diff --git a/src/models/utils.py b/src/models/utils.py index 924ed48..103cee1 100644 --- a/src/models/utils.py +++ b/src/models/utils.py @@ -156,7 +156,6 @@ def __init__(self, in_dims, out_dims, h_dim=None, self.ws = nn.Linear(self.si, self.so) self.scalar_act, self.vector_act = activations - self.dummy_param = nn.Parameter(torch.empty(0)) def forward(self, x): ''' @@ -187,7 +186,7 @@ def forward(self, x): s = self.ws(x) if self.vo: # vector dim is zero v = torch.zeros(s.shape[0], self.vo, 3, - device=self.dummy_param.device) + device=x.device) if self.scalar_act: s = self.scalar_act(s) @@ -201,17 +200,15 @@ class _VDropout(nn.Module): def __init__(self, drop_rate): super(_VDropout, self).__init__() self.drop_rate = drop_rate - self.dummy_param = nn.Parameter(torch.empty(0)) def forward(self, x): ''' :param x: `torch.Tensor` corresponding to vector channels ''' - device = self.dummy_param.device if not self.training: return x mask = torch.bernoulli( - (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device) + (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=x.device) ).unsqueeze(-1) x = mask * x / (1 - self.drop_rate) return x diff --git a/src/train_test/distributed.py b/src/train_test/distributed.py index cc6a871..ed42957 100644 --- a/src/train_test/distributed.py +++ b/src/train_test/distributed.py @@ -143,7 +143,7 @@ def dtrain(args, unknown_args): cp_saver = CheckpointSaver(model=model, save_path=f'{cfg.MODEL_SAVE_DIR}/{MODEL_KEY}.model', train_all=False, - patience=50, min_delta=(0.2 if DATA == cfg.DATA_OPT.PDBbind else 0.05), + patience=100, min_delta=(0.2 if DATA == cfg.DATA_OPT.PDBbind else 0.05), dist_rank=args.rank) # load ckpnt if os.path.exists(cp_saver.save_path + '_tmp') and args.rank == 0: diff --git a/src/train_test/training.py b/src/train_test/training.py index 2bff0f5..9b41267 100644 --- a/src/train_test/training.py +++ b/src/train_test/training.py @@ -16,7 +16,7 @@ class CheckpointSaver: # Adapted from https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch - def __init__(self, model:BaseModel, save_path=None, train_all=True, patience=30, + def __init__(self, model:BaseModel, save_path=None, train_all=True, patience=100, min_delta=0.2, debug=False, dist_rank:int=None): """ Early stopping and checkpoint saving class. @@ -163,9 +163,10 @@ def train(model: BaseModel, train_loader:DataLoader, val_loader:DataLoader, CRITERION = torch.nn.MSELoss() OPTIMIZER = torch.optim.Adam(model.parameters(), lr=lr_0, **kwargs) # gamma = (lr_e/lr_0)**(step_size/epochs) # calculate gamma based on final lr chosen. - SCHEDULER = ReduceLROnPlateau(OPTIMIZER, mode='min', patience=saver.patience-1, - threshold=saver.min_delta*0.1, - min_lr=5e-7, factor=0.8, + SCHEDULER = ReduceLROnPlateau(OPTIMIZER, mode='min', + patience=int(saver.patience*0.9), + threshold=saver.min_delta*0.1, # more sensitive than early stopper + min_lr=5e-7, factor=0.5, verbose=True) logs = {'train_loss': [], 'val_loss': []} diff --git a/src/utils/config.py b/src/utils/config.py index 51c8390..3da43f1 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -84,7 +84,7 @@ class LIG_FEAT_OPT(StringEnum): ############################# # save paths ############################# -DATA_ROOT = os.path.abspath('../data/') +from pathlib import Path # Model save paths issue_number = 115 # 113 is for unifying all splits for cross validation so that we are more confident @@ -113,18 +113,21 @@ class LIG_FEAT_OPT(StringEnum): SLURM_PARTITION = 'gpu' SLURM_CONSTRAINT = 'gpu32g' SLURM_ACCOUNT = 'kumargroup_gpu' + DATA_ROOT = os.path.abspath('../data/') elif 'graham' in DOMAIN_NAME: CLUSTER = 'graham' SLURM_CONSTRAINT = 'cascade,v100' + DATA_ROOT = os.path.abspath(Path.home() / 'scratch' / 'data' ) elif 'cedar' in DOMAIN_NAME: CLUSTER = 'cedar' SLURM_GPU_NAME = 'v100l' + DATA_ROOT = os.path.abspath(Path.home() / 'scratch' / 'data' ) elif 'narval' in DOMAIN_NAME: CLUSTER = 'narval' SLURM_GPU_NAME = 'a100' + DATA_ROOT = os.path.abspath(Path.home() / 'scratch' / 'data' ) # bin paths -from pathlib import Path FOLDSEEK_BIN = f'{Path.home()}/lib/foldseek/bin/foldseek' MMSEQ2_BIN = f'{Path.home()}/lib/mmseqs/bin/mmseqs' RING3_BIN = f'{Path.home()}/lib/ring-3.0.0/ring/bin/ring' diff --git a/train_test.py b/train_test.py index e13a49c..8bc9bba 100644 --- a/train_test.py +++ b/train_test.py @@ -46,7 +46,7 @@ cp_saver = CheckpointSaver(model=None, save_path=None, train_all=False, # forces full training - patience=50) + patience=100) # %% Training loop metrics = {}