From 29c519a993f0091c3a184c0273f9b2e97f677365 Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Thu, 23 May 2024 17:04:16 -0400 Subject: [PATCH 01/14] fix(datasets): filter missing sdfs for ligands #90 --- rayTrain_Tune.py | 28 ++++++++++++++++++++++++++-- src/data_prep/datasets.py | 22 ++++++++++++++++++++-- src/data_prep/downloaders.py | 14 ++++++++++---- src/utils/config.py | 1 + 4 files changed, 57 insertions(+), 8 deletions(-) diff --git a/rayTrain_Tune.py b/rayTrain_Tune.py index 89f8e81..3df554a 100644 --- a/rayTrain_Tune.py +++ b/rayTrain_Tune.py @@ -102,14 +102,14 @@ def train_func(config): # "output_dim": ray.tune.choice([128, 256, 512]), # } # } - # 'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'): +# 'gvpL': ('nomsa', 'aflow', 'gvp', 'binary') search_space = { ## constants: "epochs": 20, "model": cfg.MODEL_OPT.GVPL, "dataset": cfg.DATA_OPT.davis, "feature_opt": cfg.PRO_FEAT_OPT.nomsa, - "edge_opt": cfg.PRO_EDGE_OPT.aflow, + "edge_opt": cfg.PRO_EDGE_OPT.binary, "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, @@ -126,6 +126,30 @@ def train_func(config): "output_dim": ray.tune.choice([128, 256, 512]), } } + # 'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'): + # search_space = { + # ## constants: + # "epochs": 20, + # "model": cfg.MODEL_OPT.GVPL, + # "dataset": cfg.DATA_OPT.davis, + # "feature_opt": cfg.PRO_FEAT_OPT.nomsa, + # "edge_opt": cfg.PRO_EDGE_OPT.aflow, + # "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, + # "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, + + # "fold_selection": 0, + # "save_checkpoint": False, + + # ## hyperparameters to tune: + # "lr": ray.tune.loguniform(1e-5, 1e-3), + # "batch_size": ray.tune.choice([32, 64, 128]), # local batch size + + # # model architecture hyperparams + # "architecture_kwargs":{ + # "dropout": ray.tune.uniform(0.0, 0.5), + # "output_dim": ray.tune.choice([128, 256, 512]), + # } + # } # search space for GVPL_RNG MODEL: # search_space = { # ## constants: diff --git a/src/data_prep/datasets.py b/src/data_prep/datasets.py index da504a8..809f8d5 100644 --- a/src/data_prep/datasets.py +++ b/src/data_prep/datasets.py @@ -221,9 +221,9 @@ def get_unique_prots(df, keep_len=False) -> pd.DataFrame: # ensures that the right codes are always present in unique_pro df.sort_values(by='sort_order', inplace=True) else: - df['seq_len'] = df['prot_seq'].str.len() + df = df.assign(seq_len=df['prot_seq'].str.len()) df.sort_values(by='seq_len', ascending=False, inplace=True) - df['sort_order'] = [i for i in range(len(df))] + df = df.assign(sort_order=[i for i in range(len(df))]) # Get unique protid codes idx_name = df.index.name @@ -365,6 +365,24 @@ def clean_XY(self, df:pd.DataFrame, max_seq_len=None): logging.debug(f'Number of codes: {len(filtered_df)}/{len(df)}') + # we are done filtering if ligand doesnt need filtering + if not (self.ligand_edge in cfg.OPT_REQUIRES_SDF or + self.ligand_feature in cfg.OPT_REQUIRES_SDF): + return filtered_df + + # removing rows with ligands that have missing sdf files: + unique_lig = filtered_df[['lig_id']].drop_duplicates() + missing = set() + for code, (lig_id,) in tqdm(unique_lig.iterrows(), desc='dropping missing sdfs from df', + total=len(unique_lig)): + fp = self.sdf_p(code, lig_id=lig_id) + if (not os.path.isfile(fp) or + os.path.getsize(fp) <= 20): + missing.add(lig_id) + + logging.debug(f'{len(missing)}/{len(unique_lig)} missing ligands') + filtered_df = filtered_df[~filtered_df.lig_id.isin(missing)] + logging.debug(f'Number of codes after ligand filter: {len(filtered_df)}/{len(df)}') return filtered_df def _create_protein_graphs(self, df, node_feat, edge): diff --git a/src/data_prep/downloaders.py b/src/data_prep/downloaders.py index 6030781..0e97fd6 100644 --- a/src/data_prep/downloaders.py +++ b/src/data_prep/downloaders.py @@ -6,6 +6,7 @@ from tqdm import tqdm from concurrent.futures import ThreadPoolExecutor, as_completed +import logging class Downloader: @staticmethod @@ -81,7 +82,7 @@ def get_file_obj(ID: str, url=lambda x: f'https://files.rcsb.org/download/{x}.pd @staticmethod def download_single_file(id: str, save_path: Callable[[str], str], url: Callable[[str], str], - url_backup: Callable[[str], str], max_retries=4) -> tuple: + url_backup: Callable[[str], str], max_retries=3) -> tuple: """ Helper function to download a single file. """ @@ -106,9 +107,11 @@ def fetch_url(url): resp = fetch_url(url(id)) if resp.status_code >= 400 and url_backup: + logging.debug(f'{id}-{resp.status_code} {resp}') resp = fetch_url(url_backup(id)) if resp.status_code >= 400: + logging.debug(f'\tbkup{id}-{resp.status_code} {resp}') return id, resp.status_code else: with open(fp, 'w') as f: @@ -122,7 +125,8 @@ def download(IDs: Iterable[str], tqdm_desc='Downloading files', url_backup=None, # for if the first url fails tqdm_disable=False, - max_workers=None) -> dict: + max_workers=None, + **kwargs) -> dict: """ Generalized multithreaded download function for downloading any file type from any site. @@ -146,7 +150,8 @@ def download(IDs: Iterable[str], """ ID_status = {} with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = {executor.submit(Downloader.download_single_file, id, save_path, url, url_backup): id for id in IDs} + futures = {executor.submit(Downloader.download_single_file, id, save_path, + url, url_backup, **kwargs): id for id in IDs} for future in tqdm(as_completed(futures), desc=tqdm_desc, total=len(IDs), disable=tqdm_disable): id, status = future.result() ID_status[id] = status @@ -172,6 +177,7 @@ def download_predicted_PDBs(UniProtID: Iterable[str], save_dir='./') -> dict: @staticmethod def download_SDFs(ligand_ids: List[str], save_dir='./data/structures/ligands/', + max_workers=None, **kwargs) -> dict: """ Wrapper of `Downloader.download` for downloading SDF files. @@ -209,7 +215,7 @@ def download_SDFs(ligand_ids: List[str], save_path = lambda x: os.path.join(save_dir, f'{x}.sdf') url_backup=lambda x: url(x).split('?')[0] # fallback to 2d conformer structure return Downloader.download(ligand_ids, save_path=save_path, url=url, url_backup=url_backup, - tqdm_desc='Downloading ligand sdfs', **kwargs) + tqdm_desc='Downloading ligand sdfs', max_workers=max_workers, **kwargs) if __name__ == '__main__': # downloading pdbs from X.csv list diff --git a/src/utils/config.py b/src/utils/config.py index 9e914fa..11835b9 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -70,6 +70,7 @@ class PRO_FEAT_OPT(StringEnum): 'aflow_ring3']) OPT_REQUIRES_AFLOW_CONF = StringEnum('alphaflow_confs', ['aflow', 'aflow_ring3']) OPT_REQUIRES_RING3 = StringEnum('ring3', ['ring3', 'aflow_ring3']) +OPT_REQUIRES_SDF = StringEnum('lig_sdf', ['gvp']) # ligand options class LIG_EDGE_OPT(StringEnum): From e294eefb38ca004c4ee074a39d9bc276820ea95c Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Fri, 24 May 2024 11:20:38 -0400 Subject: [PATCH 02/14] fix(config): CC hhblits binary paths --- src/utils/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils/config.py b/src/utils/config.py index 11835b9..9da6561 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -128,8 +128,8 @@ class LIG_FEAT_OPT(StringEnum): UniRef_dir = '/cluster/projects/kumargroup/sequence_databases/UniRef30_2020_06/UniRef30_2020_06' hhsuite_bin_dir = '/cluster/tools/software/centos7/hhsuite/3.3.0/bin' else: - UniRef_dir = '' - hhsuite_bin_dir = '' + UniRef_dir = '/cvmfs/bio.data.computecanada.ca/content/databases/Core/alphafold2_dbs/2024_01/uniref30/UniRef30_2021_03' + hhsuite_bin_dir = '/cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v3/MPI/gcc12/openmpi4/hh-suite/3.3.0/bin' ########################### From fa63744a6d250851517dfea4031ef535e90d3cdc Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Tue, 28 May 2024 11:38:54 -0400 Subject: [PATCH 03/14] fix(raytune): +kwargs dep on model type --- rayTrain_Tune.py | 89 +++++----------------------------------- src/models/gvp_models.py | 4 +- 2 files changed, 14 insertions(+), 79 deletions(-) diff --git a/rayTrain_Tune.py b/rayTrain_Tune.py index 3df554a..50c3fb5 100644 --- a/rayTrain_Tune.py +++ b/rayTrain_Tune.py @@ -78,38 +78,15 @@ def train_func(config): print("Cuda support:", torch.cuda.is_available(),":", torch.cuda.device_count(), "devices") print("CUDA VERSION:", torch.__version__) - - # search_space = { - # ## constants: - # "epochs": 20, - # "model": cfg.MODEL_OPT.DG, - # "dataset": cfg.DATA_OPT.PDBbind, - # "feature_opt": cfg.PRO_FEAT_OPT.nomsa, - # "edge_opt": cfg.PRO_EDGE_OPT.aflow, - # "lig_feat_opt": cfg.LIG_FEAT_OPT.original, - # "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, - - # "fold_selection": 0, - # "save_checkpoint": False, - - # ## hyperparameters to tune: - # "lr": ray.tune.loguniform(1e-5, 1e-3), - # "batch_size": ray.tune.choice([32, 64, 128]), # local batch size - # # model architecture hyperparams - # "architecture_kwargs":{ - # "dropout": ray.tune.uniform(0.0, 0.5), - # "output_dim": ray.tune.choice([128, 256, 512]), - # } - # } -# 'gvpL': ('nomsa', 'aflow', 'gvp', 'binary') search_space = { ## constants: "epochs": 20, "model": cfg.MODEL_OPT.GVPL, - "dataset": cfg.DATA_OPT.davis, + + "dataset": cfg.DATA_OPT.kiba, "feature_opt": cfg.PRO_FEAT_OPT.nomsa, - "edge_opt": cfg.PRO_EDGE_OPT.binary, + "edge_opt": cfg.PRO_EDGE_OPT.aflow, "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, @@ -123,59 +100,15 @@ def train_func(config): # model architecture hyperparams "architecture_kwargs":{ "dropout": ray.tune.uniform(0.0, 0.5), - "output_dim": ray.tune.choice([128, 256, 512]), - } + "output_dim": ray.tune.choice([128, 256, 512]), + }, } - # 'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'): - # search_space = { - # ## constants: - # "epochs": 20, - # "model": cfg.MODEL_OPT.GVPL, - # "dataset": cfg.DATA_OPT.davis, - # "feature_opt": cfg.PRO_FEAT_OPT.nomsa, - # "edge_opt": cfg.PRO_EDGE_OPT.aflow, - # "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, - # "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, - - # "fold_selection": 0, - # "save_checkpoint": False, - - # ## hyperparameters to tune: - # "lr": ray.tune.loguniform(1e-5, 1e-3), - # "batch_size": ray.tune.choice([32, 64, 128]), # local batch size - - # # model architecture hyperparams - # "architecture_kwargs":{ - # "dropout": ray.tune.uniform(0.0, 0.5), - # "output_dim": ray.tune.choice([128, 256, 512]), - # } - # } - # search space for GVPL_RNG MODEL: -# search_space = { -# ## constants: -# "epochs": 20, -# "model": cfg.MODEL_OPT.GVPL_RNG, -# "dataset": cfg.DATA_OPT.PDBbind, -# "feature_opt": cfg.PRO_FEAT_OPT.nomsa, -# "edge_opt": cfg.PRO_EDGE_OPT.aflow_ring3, -# "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, -# "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, -# -# "fold_selection": 0, -# "save_checkpoint": False, -# -# ## hyperparameters to tune: -# "lr": ray.tune.loguniform(1e-5, 1e-3), -# "batch_size": ray.tune.choice([16,32,64]), # local batch size -# -# # model architecture hyperparams -# "architecture_kwargs":{ -# "dropout": ray.tune.uniform(0.0, 0.5), -# "pro_emb_dim": ray.tune.choice([64, 128, 256]), -# "output_dim": ray.tune.choice([128, 256, 512]), -# "nheads_pro": ray.tune.choice([3, 4, 5]), -# } -# } + arch_kwargs = search_space['architecture_kwargs'] + if search_space['model'] == cfg.MODEL_OPT.GVPL: + arch_kwargs["num_GVPLayers"]= ray.tune.choice([2, 3, 4]) + elif search_space['model'] == cfg.MODEL_OPT.GVPL_RNG: + arch_kwargs["pro_emb_dim"] = ray.tune.choice([64, 128, 256]) + arch_kwargs["nheads_pro"] = ray.tune.choice([3, 4, 5]) # each worker is a node from the ray cluster. # WARNING: SBATCH GPU directive should match num_workers*GPU_per_worker diff --git a/src/models/gvp_models.py b/src/models/gvp_models.py index 05487ae..5e45c8f 100644 --- a/src/models/gvp_models.py +++ b/src/models/gvp_models.py @@ -65,13 +65,15 @@ def __init__(self, num_features_pro=54, num_features_mol=78, output_dim=512, dropout=0.2, + num_GVPLayers=3, edge_weight_opt='binary', **kwargs): output_dim = int(output_dim) super(GVPLigand_DGPro, self).__init__(num_features_pro, num_features_mol, output_dim, dropout, edge_weight_opt) - self.gvp_ligand = GVPBranchLigand(final_out=output_dim, + self.gvp_ligand = GVPBranchLigand(num_layers=num_GVPLayers, + final_out=output_dim, drop_rate=dropout) self.dense_out = nn.Sequential( From 4bfbbbbbfe97158f18a1cb6ddaa083193521bb00 Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Tue, 28 May 2024 12:23:39 -0400 Subject: [PATCH 04/14] Merge branch 'development' of github.com:jyaacoub/MutDTA into development --- src/utils/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils/config.py b/src/utils/config.py index 11835b9..9da6561 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -128,8 +128,8 @@ class LIG_FEAT_OPT(StringEnum): UniRef_dir = '/cluster/projects/kumargroup/sequence_databases/UniRef30_2020_06/UniRef30_2020_06' hhsuite_bin_dir = '/cluster/tools/software/centos7/hhsuite/3.3.0/bin' else: - UniRef_dir = '' - hhsuite_bin_dir = '' + UniRef_dir = '/cvmfs/bio.data.computecanada.ca/content/databases/Core/alphafold2_dbs/2024_01/uniref30/UniRef30_2021_03' + hhsuite_bin_dir = '/cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v3/MPI/gcc12/openmpi4/hh-suite/3.3.0/bin' ########################### From 7d215099c2bf5e619ef8c54ff033c77b7e115eea Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Wed, 29 May 2024 10:29:17 -0400 Subject: [PATCH 05/14] results: davis GVPL-only #90 #94 --- results/model_media/model_stats.csv | 9 +++++++-- results/model_media/model_stats_val.csv | 7 ++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/results/model_media/model_stats.csv b/results/model_media/model_stats.csv index 4e3821a..d0c6d2c 100644 --- a/results/model_media/model_stats.csv +++ b/results/model_media/model_stats.csv @@ -172,13 +172,18 @@ GVPLM_PDBbind0D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE,0. GVPLM_PDBbind2D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE,0.7194491955332394,0.6365982015372369,0.6153139418604331,2.098544323479561,1.1343965978092618,1.4486353314342295 GVPLM_PDBbind4D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE,0.7203456545202407,0.6299801928832721,0.6142370717682227,2.207161887712393,1.160070003592779,1.485652007608913 GVPLM_PDBbind3D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE,0.709231747593396,0.5876231909596512,0.5904267672441162,2.3008469727213514,1.1705255782036554,1.5168543017446836 -DGM_PDBbind1D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6904769032358286,0.544938998054608,0.5407934266049562,2.5551568669266334,1.2339192264420646,1.5984858044182415 +DGM_PDBbind1D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6904769032358286,0.544938998054608,0.5407934266049562,2.555156866926634,1.2339192264420646,1.5984858044182415 DGM_PDBbind3D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.7015552107083158,0.6057001165731564,0.569729312002546,2.358233326230585,1.2048145031929016,1.5356540385876585 DGM_PDBbind4D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.7085153085744522,0.6208329976682688,0.5859156875580817,2.2240677925526744,1.1722256461779277,1.4913308796349234 DGM_PDBbind2D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6947751328102371,0.5816010010184015,0.5528265670239861,2.3897321791041333,1.2105133893376303,1.545875861479224 DGM_PDBbind0D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6855330290119371,0.5726521315119578,0.5325884299986027,2.4758763933999823,1.231696826806144,1.57349178370908 -GVPLM_davis0D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7581108896345126,0.48942831390709474,0.45529634961852566,0.5072034546494222,0.4161494350523831,0.7121821779919953 +GVPLM_davis0D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7581108896345126,0.4894283139070947,0.4552963496185256,0.5072034546494222,0.4161494350523831,0.7121821779919953 GVPLM_davis1D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.768273891734927,0.550820575655664,0.4743629409308578,0.4386155850598172,0.3857342382535988,0.6622805939024767 GVPLM_davis3D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7603887296892994,0.5280132204660057,0.4597706307310624,0.4791434307834897,0.3716412002266697,0.6922018714099881 GVPLM_davis2D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7694890904708525,0.5217685104759502,0.4753118765332927,0.4673188139402668,0.4174866665020387,0.6836072073495618 GVPLM_davis4D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.760520722940427,0.5078128106282308,0.4629361403886756,0.5235965111752061,0.3783202273117297,0.723599689866715 +GVPLM_davis3D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.756307395396296,0.5549704766818875,0.5027127535907578,0.5847335023746886,0.4515634416392335,0.7646786922457619 +GVPLM_davis4D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7610883812617106,0.6227046591753373,0.5053892668406388,0.5586851001177038,0.4105810923387186,0.7474524065903486 +GVPLM_davis0D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7852543960759865,0.6625415568442652,0.5483708380017869,0.4902795019912594,0.373903283016137,0.7001996158177034 +GVPLM_davis1D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7649810976587887,0.6493107697356035,0.5116033216497222,0.5055476857354936,0.379708959205043,0.7110187660923539 +GVPLM_davis2D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7788049850583856,0.6551546686297811,0.539826183289417,0.4900774065483533,0.381190583578702,0.7000552882082625 diff --git a/results/model_media/model_stats_val.csv b/results/model_media/model_stats_val.csv index d018c6b..9cd1b5e 100644 --- a/results/model_media/model_stats_val.csv +++ b/results/model_media/model_stats_val.csv @@ -156,8 +156,13 @@ DGM_PDBbind3D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_20 DGM_PDBbind4D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6980325998261337,0.5867948698567493,0.5613217730433377,2.504824019579536,1.2573465726293105,1.5826635838293417 DGM_PDBbind2D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6760856277267194,0.5188577783623333,0.5051689763782492,2.7333317052016373,1.2894853849669785,1.653279076623677 DGM_PDBbind0D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.678631965986838,0.5169422908392893,0.5136610285075189,2.659271637069346,1.2811931355709805,1.6307273337591868 -GVPLM_davis0D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7727048764494124,0.5829543940410362,0.46680064762373447,0.4654286995546293,0.3997238369103796,0.6822233501974476 +GVPLM_davis0D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7727048764494124,0.5829543940410362,0.4668006476237344,0.4654286995546293,0.3997238369103796,0.6822233501974476 GVPLM_davis1D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7767585519522199,0.5805240837724788,0.4946235296015073,0.4138118485637668,0.3817763907346182,0.6432820909708017 GVPLM_davis3D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.8202768328439823,0.6402255623870499,0.5435232833589796,0.3676188640274792,0.3097654533567193,0.6063158121206136 GVPLM_davis2D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7854232789039847,0.5996614387831698,0.527865354318397,0.4509265081187592,0.4061779738808213,0.67151061653466 GVPLM_davis4D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7849321063417397,0.6186046172952341,0.5194671423553721,0.4498645421549498,0.3525368623308257,0.670719421334249 +GVPLM_davis3D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7613364651895133,0.530509070466495,0.4887270646961001,0.5667063347251928,0.4394238040568357,0.752799000215325 +GVPLM_davis4D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7999980961404287,0.667340224773174,0.5456194297881923,0.4617459986364707,0.3392615584282555,0.6795189464882276 +GVPLM_davis0D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7961292743880335,0.675879144422058,0.5465922959565244,0.4938138142860839,0.3651142534725168,0.702718872868862 +GVPLM_davis1D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7982821251801983,0.6893850245133243,0.5572686590418361,0.4836518480565703,0.35890867349935723,0.6954508236076583 +GVPLM_davis2D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.8093119089319818,0.7167257951603837,0.543747905282994,0.3626874080110078,0.3138066983637325,0.6022353427116411 From ec3b8f4ac4e63a930695b9c2abb415725b309a89 Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Thu, 30 May 2024 14:13:41 -0400 Subject: [PATCH 06/14] refactor(datasets): + fallback to aflow conf if no pdb for mt #90 #94 --- src/analysis/figures.py | 24 +++++++++ src/data_prep/datasets.py | 111 +++++++++++++++++++++----------------- src/utils/residue.py | 2 + 3 files changed, 88 insertions(+), 49 deletions(-) diff --git a/src/analysis/figures.py b/src/analysis/figures.py index a0380f8..d40bc12 100644 --- a/src/analysis/figures.py +++ b/src/analysis/figures.py @@ -383,6 +383,30 @@ def fig_combined(df, datasets=['PDBbind','davis', 'kiba'], metrics=['cindex', 'm def custom_fig(df, models:OrderedDict=None, sel_dataset='PDBbind', sel_col='cindex', verbose=False, show=False, add_stats=True, ax=None): + + """ + Example usage with `fig_combined`. + ``` + from src.analysis.figures import custom_fig, prepare_df, fig_combined + + df = prepare_df() + + models = { + 'DG': ('nomsa', 'binary', 'original', 'binary'), + # 'DG-simple': ('nomsa', 'simple', 'original', 'binary'), + 'DG-anm': ('nomsa', 'anm', 'original', 'binary'), + 'DG-af2': ('nomsa', 'af2', 'original', 'binary'), + 'DG-ESM': ('ESM', 'binary', 'original', 'binary'), + # 'DG-saprot': ('foldseek', 'binary', 'original', 'binary'), + 'gvpP': ('gvp', 'binary', 'original', 'binary'), + 'gvpL-aflow': ('nomsa', 'aflow', 'gvp', 'binary'), + } + + fig_combined(df, datasets=['PDBbind'], metrics=['cindex', 'mse'], fig_scale=(10,5), + fig_callable=custom_fig, models=models, title_postfix=' test set performance', + add_stats=True) + ``` + """ if models is None: # example custom plot: # models to plot: # - Original model with (nomsa, binary) and (original, binary) features for protein and ligand respectively diff --git a/src/data_prep/datasets.py b/src/data_prep/datasets.py index 809f8d5..7e1ee81 100644 --- a/src/data_prep/datasets.py +++ b/src/data_prep/datasets.py @@ -1025,11 +1025,15 @@ def sdf_p(self, code, lig_id) -> str: """Needed for gvp ligand branch (uses coordinate info)""" return os.path.join(self.raw_paths[2], f'{lig_id}.sdf') - def pdb_p(self, code): - # code = code.split('_')[0] # removing additional information for mutations. - # no need to remove mutation information since that is curtial for the pdb id - return os.path.join(self.raw_paths[1], f'{code}.pdb') - + def pdb_p(self, pid, id_is_pdb=False): + if id_is_pdb: + fp = os.path.join(self.raw_paths[1], f'{pid}.pdb') + else: + fp = f'{self.af_conf_dir}/{pid}.pdb' + + fp = fp if os.path.exists(fp) else None + return fp + def cmap_p(self, prot_id): return os.path.join(self.raw_dir, 'contact_maps', f'{prot_id}.npy') @@ -1108,7 +1112,55 @@ def download(self): lambda x: os.path.join(self.raw_paths[2], f'{x}.sdf')) df_raw['smiles'] = df_raw['affin.lig_id'].map(smiles_dict) df_raw.to_csv(self.raw_paths[0]) + + def _get_prot_structure(self, muts, pdb_wt, pdb_mt, t_chain): + # Check if PDBs available (need at least 1 for sequence info) + missing_wt = pdb_wt == 'NO' + missing_mt = pdb_mt == 'NO' + assert not (missing_mt and missing_wt), f'missing pdbs for both mt and wt' + pdb_wt = pdb_mt if missing_wt else pdb_wt + pdb_mt = pdb_wt if missing_mt else pdb_mt + + # creating protein unique IDs for computational speed up by avoiding redundant compute + wt_id = f'{pdb_wt}_wt' + mt_id = f'{pdb_mt}_{"-".join(muts)}' + + chain_wt = Chain(self.pdb_p(pdb_wt, id_is_pdb=True), t_chain=t_chain) + chain_mt = Chain(self.pdb_p(pdb_mt, id_is_pdb=True), t_chain=t_chain) + + # Getting sequences: + if missing_wt: + mut_seq = chain_mt.sequence + ref_seq = chain_mt.get_mutated_seq(muts, reversed=True) + else: + # get mut_seq from wt to confirm that mapping the mutations works + mut_seq = chain_wt.get_mutated_seq(muts, reversed=False) + ref_seq = chain_wt.sequence + + if pdb_mt != pdb_wt and mut_seq != chain_mt.sequence: + # sequences dont match due to missing residues in either the wt or the mt pdb files (seperate files since pdb_mt != pdb_wt) + # we can just use the wildtype protein structure to avoid mismatches with graph network (still the same mutated sequence tho) + mt_id = f'{pdb_wt}_{"-".join(muts)}' + + # if we have aflow confs then we use those instead + fp = self.pdb_p(mt_id, id_is_pdb=False) + if fp is not None: + chain_mt = Chain(fp, model=0) # no t_chain for alphaflow confs since theres only one input sequence. + + # final check to make sure this aflow conf is correct for the mt sequence. + if mut_seq != chain_mt.sequence: + logging.warning(f'Mismatched AA: Using wt STRUCTURE ({pdb_wt}) for mutated {mt_id}') + chain_mt = chain_wt + # Getting and saving cmaps under the unique protein ID + if not os.path.isfile(self.cmap_p(wt_id)) or self.overwrite: + np.save(self.cmap_p(wt_id), chain_wt.get_contact_map()) + + if not os.path.isfile(self.cmap_p(mt_id)) or self.overwrite: + np.save(self.cmap_p(mt_id), chain_mt.get_contact_map()) + + return wt_id, ref_seq, mt_id, mut_seq + def pre_process(self): """ This method is used to create the processed data files for feature extraction. @@ -1134,7 +1186,7 @@ def pre_process(self): pd.DataFrame The XY.csv dataframe. """ - ### LOAD UP RAW CSV FILE + adjust values### + ### LOAD UP RAW CSV FILE + adjust values ### df_raw = pd.read_csv(self.raw_paths[0]) # fixing pkd values for binding affinity df_raw['affin.k_mt'] = df_raw['affin.k_mt'].str.extract(r'(\d+\.*\d+)', @@ -1157,53 +1209,14 @@ def pre_process(self): pdb_mt = row['mut.mt_pdb'] t_chain = row['affin.chain'] - # Check if PDBs available (need at least 1 for sequence info) - missing_wt = pdb_wt == 'NO' - missing_mt = pdb_mt == 'NO' - assert not (missing_mt and missing_wt), f'missing pdbs for both mt and wt on idx {i}' - pdb_wt = pdb_mt if missing_wt else pdb_wt - pdb_mt = pdb_wt if missing_mt else pdb_mt - try: - chain_wt = Chain(self.pdb_p(pdb_wt), t_chain=t_chain) - chain_mt = Chain(self.pdb_p(pdb_mt), t_chain=t_chain) - - # Getting sequences: - if missing_wt: - mut_seq = chain_wt.sequence - ref_seq = chain_wt.get_mutated_seq(muts, reversed=True) - else: - mut_seq = chain_wt.get_mutated_seq(muts, reversed=False) - ref_seq = chain_wt.sequence - - # creating protein unique IDs for computational speed up by avoiding redundant compute - wt_id = f'{pdb_wt}_wt' - mt_id = f'{pdb_mt}_{"-".join(muts)}' - if pdb_mt != pdb_wt and mut_seq != chain_mt.sequence: - # print(f'Mutated doesnt match with chain for {i}:{self.pdb_p(pdb_wt)} and {self.pdb_p(pdb_mt)}') - # using just the wildtype protein structure to avoid mismatches with graph network - mt_id = f'{pdb_wt}_{"-".join(muts)}' - chain_mt = chain_wt - - # Getting and saving cmaps under the unique protein ID - if not os.path.isfile(self.cmap_p(wt_id)): - np.save(self.cmap_p(wt_id), chain_wt.get_contact_map()) - - if not os.path.isfile(self.cmap_p(mt_id)): - np.save(self.cmap_p(mt_id), chain_mt.get_contact_map()) - + wt_id, ref_seq, mt_id, mut_seq = self._get_prot_structure(muts, pdb_wt, pdb_mt, t_chain) except Exception as e: raise Exception(f'Error with idx {i} on {pdb_wt} wt and {pdb_mt} mt.') from e - - # Saving sequence and additional relevant info - mt_pkd = row['affin.k_mt'] - wt_pkd = row['affin.k_wt'] - lig_id = row['affin.lig_id'] - smiles = row['smiles'] - + # Using index number for ID since pdb is not unique in this dataset. - prot_seq[f'{i}_mt'] = (mt_id, lig_id, mt_pkd, smiles, mut_seq) - prot_seq[f'{i}_wt'] = (wt_id, lig_id, wt_pkd, smiles, ref_seq) + prot_seq[f'{i}_mt'] = (mt_id, row['affin.lig_id'], row['affin.k_mt'], row['smiles'], mut_seq) + prot_seq[f'{i}_wt'] = (wt_id, row['affin.lig_id'], row['affin.k_wt'], row['smiles'], ref_seq) df = pd.DataFrame.from_dict(prot_seq, orient='index', columns=['prot_id', 'lig_id', diff --git a/src/utils/residue.py b/src/utils/residue.py index c4c9681..03d2037 100644 --- a/src/utils/residue.py +++ b/src/utils/residue.py @@ -182,6 +182,8 @@ def __init__(self, pdb_file:str, model:int=0, t_chain:str=None, # parse chain -> {: {: {: np.array([x,y,z], "name": )}}} self._chains = self._pdb_get_chains(pdb_file, model, self.grep_atoms) + if len(self._chains) == 0: + raise Exception(f'No chains parsed on {pdb_file}') # if t_chain is not specified then set it to be the largest chain self.t_chain = t_chain or max(self._chains, key=lambda x: len(self._chains[x])) From 25323c2129344777803ccc41e752e8f0912158de Mon Sep 17 00:00:00 2001 From: Jean Charle Yaacoub Date: Fri, 31 May 2024 06:48:59 -0700 Subject: [PATCH 07/14] results: davis_aflow from cedar #90 #94 --- results/model_media/model_stats.csv | 5 +++++ results/model_media/model_stats_val.csv | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/results/model_media/model_stats.csv b/results/model_media/model_stats.csv index d0c6d2c..a52ac48 100644 --- a/results/model_media/model_stats.csv +++ b/results/model_media/model_stats.csv @@ -187,3 +187,8 @@ GVPLM_davis4D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_ GVPLM_davis0D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7852543960759865,0.6625415568442652,0.5483708380017869,0.4902795019912594,0.373903283016137,0.7001996158177034 GVPLM_davis1D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7649810976587887,0.6493107697356035,0.5116033216497222,0.5055476857354936,0.379708959205043,0.7110187660923539 GVPLM_davis2D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7788049850583856,0.6551546686297811,0.539826183289417,0.4900774065483533,0.381190583578702,0.7000552882082625 +DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7269566564770428,0.4700493508446506,0.4015462422575715,0.478945718498584,0.4370024735832576,0.6920590426391263 +DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7636053478309109,0.4921759585470889,0.4648605313492071,0.4664726186122307,0.4114446920507094,0.6829880076635539 +DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7620051687791963,0.4911419064307628,0.4617636238896002,0.4676042055415506,0.406303394677743,0.683815914951934 +DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7505447112791103,0.470876632673457,0.4415803526952756,0.4865336122502623,0.4025646362178013,0.6975196142405332 +DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7627751294107745,0.49100000345979894,0.46303342294166805,0.46678858993885775,0.41589127070989734,0.6832192839336854 diff --git a/results/model_media/model_stats_val.csv b/results/model_media/model_stats_val.csv index 9cd1b5e..afbe8ff 100644 --- a/results/model_media/model_stats_val.csv +++ b/results/model_media/model_stats_val.csv @@ -166,3 +166,8 @@ GVPLM_davis4D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_ GVPLM_davis0D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7961292743880335,0.675879144422058,0.5465922959565244,0.4938138142860839,0.3651142534725168,0.702718872868862 GVPLM_davis1D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7982821251801983,0.6893850245133243,0.5572686590418361,0.4836518480565703,0.35890867349935723,0.6954508236076583 GVPLM_davis2D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.8093119089319818,0.7167257951603837,0.543747905282994,0.3626874080110078,0.3138066983637325,0.6022353427116411 +DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.717549316980205,0.5039155814516447,0.3941489326606527,0.4529655824807378,0.430028497115348,0.6730271781144784 +DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7656310259121064,0.4589840304826533,0.4566383520441913,0.5370825058314675,0.4312995326062656,0.7328591309600144 +DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7745741060459999,0.5391599122196629,0.5098816231812036,0.5021897590160201,0.4168515041762707,0.7086534830338592 +DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7980478317336308,0.5542099308186255,0.5089014580257839,0.4220264259213847,0.3624621766568135,0.6496356101087629 +DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7677158296355124,0.5054133287375089,0.4922434839494996,0.5150349706238619,0.4282377090806979,0.717659369494931 From f2483f7f44fac096e4fcf9c1f099853cb91a8c23 Mon Sep 17 00:00:00 2001 From: Jean Charle Yaacoub Date: Tue, 4 Jun 2024 06:50:54 -0700 Subject: [PATCH 08/14] results: updated davis_aflow --- results/model_media/model_stats.csv | 10 +++++----- results/model_media/model_stats_val.csv | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/results/model_media/model_stats.csv b/results/model_media/model_stats.csv index a52ac48..e72df15 100644 --- a/results/model_media/model_stats.csv +++ b/results/model_media/model_stats.csv @@ -187,8 +187,8 @@ GVPLM_davis4D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_ GVPLM_davis0D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7852543960759865,0.6625415568442652,0.5483708380017869,0.4902795019912594,0.373903283016137,0.7001996158177034 GVPLM_davis1D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7649810976587887,0.6493107697356035,0.5116033216497222,0.5055476857354936,0.379708959205043,0.7110187660923539 GVPLM_davis2D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7788049850583856,0.6551546686297811,0.539826183289417,0.4900774065483533,0.381190583578702,0.7000552882082625 -DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7269566564770428,0.4700493508446506,0.4015462422575715,0.478945718498584,0.4370024735832576,0.6920590426391263 -DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7636053478309109,0.4921759585470889,0.4648605313492071,0.4664726186122307,0.4114446920507094,0.6829880076635539 -DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7620051687791963,0.4911419064307628,0.4617636238896002,0.4676042055415506,0.406303394677743,0.683815914951934 -DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7505447112791103,0.470876632673457,0.4415803526952756,0.4865336122502623,0.4025646362178013,0.6975196142405332 -DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7627751294107745,0.49100000345979894,0.46303342294166805,0.46678858993885775,0.41589127070989734,0.6832192839336854 +DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7618267865956071,0.4923371930328827,0.4613806752914631,0.4662829663578006,0.411145990216076,0.6828491534429846 +DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7636053478309109,0.4921759568376547,0.4648783804454143,0.4664726196058547,0.4114446927293202,0.6829880083909634 +DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7627598258454263,0.4909617783259909,0.4629240453488026,0.4675438018075736,0.4069470615043133,0.6837717468626308 +DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7586087337447441,0.48792801067181446,0.45565623472699257,0.4699824575652958,0.406830881759383,0.6855526657852742 +DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7627751294107745,0.4910000034597989,0.463033422941668,0.4667885899388577,0.4158912707098973,0.6832192839336854 diff --git a/results/model_media/model_stats_val.csv b/results/model_media/model_stats_val.csv index afbe8ff..95dd6c1 100644 --- a/results/model_media/model_stats_val.csv +++ b/results/model_media/model_stats_val.csv @@ -164,10 +164,10 @@ GVPLM_davis4D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2 GVPLM_davis3D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7613364651895133,0.530509070466495,0.4887270646961001,0.5667063347251928,0.4394238040568357,0.752799000215325 GVPLM_davis4D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7999980961404287,0.667340224773174,0.5456194297881923,0.4617459986364707,0.3392615584282555,0.6795189464882276 GVPLM_davis0D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7961292743880335,0.675879144422058,0.5465922959565244,0.4938138142860839,0.3651142534725168,0.702718872868862 -GVPLM_davis1D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7982821251801983,0.6893850245133243,0.5572686590418361,0.4836518480565703,0.35890867349935723,0.6954508236076583 +GVPLM_davis1D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.7982821251801983,0.6893850245133243,0.5572686590418361,0.4836518480565703,0.3589086734993572,0.6954508236076583 GVPLM_davis2D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE,0.8093119089319818,0.7167257951603837,0.543747905282994,0.3626874080110078,0.3138066983637325,0.6022353427116411 -DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.717549316980205,0.5039155814516447,0.3941489326606527,0.4529655824807378,0.430028497115348,0.6730271781144784 -DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7656310259121064,0.4589840304826533,0.4566383520441913,0.5370825058314675,0.4312995326062656,0.7328591309600144 -DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7745741060459999,0.5391599122196629,0.5098816231812036,0.5021897590160201,0.4168515041762707,0.7086534830338592 -DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7980478317336308,0.5542099308186255,0.5089014580257839,0.4220264259213847,0.3624621766568135,0.6496356101087629 +DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7604358543507954,0.5258641768225495,0.4678115715796511,0.4397487084865327,0.4022678869222504,0.6631355129131095 +DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7656357706050457,0.4589840305150596,0.4566457476087189,0.5370825057205953,0.431299532494959,0.7328591308843708 +DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7739145158374762,0.5392158863378831,0.5089040317525807,0.5015169159434333,0.4169669235899112,0.7081785904300082 +DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.792242665227612,0.5533025236584668,0.49892823193027686,0.4181546354842427,0.3822523623535258,0.6466487728931701 DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7677158296355124,0.5054133287375089,0.4922434839494996,0.5150349706238619,0.4282377090806979,0.717659369494931 From 33de4dba78fea4fed8768f04885324a0c8b26c86 Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Tue, 4 Jun 2024 09:57:36 -0400 Subject: [PATCH 09/14] chore: clean up code + title for predictive performance #94 --- src/analysis/figures.py | 5 ++--- src/analysis/utils.py | 3 ++- src/data_prep/datasets.py | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/analysis/figures.py b/src/analysis/figures.py index d40bc12..6a76606 100644 --- a/src/analysis/figures.py +++ b/src/analysis/figures.py @@ -576,8 +576,8 @@ def predictive_performance( if compare_overlap: return generate_markdown([results_with_overlap, results_without_overlap], names=['with overlap', 'without overlap'], cindex=True,verbose=verbose) - - return generate_markdown([results_without_overlap], names=['mean $\pm$ se'], cindex=True, verbose=verbose) + # 'mean $\pm$ se' + return generate_markdown([results_without_overlap], names=['mean predictive performance'], cindex=True, verbose=verbose) def get_dpkd(df, pkd_col='pkd', normalize=False) -> np.ndarray: """ @@ -826,7 +826,6 @@ def fig_sig_mutations_conf_matrix(true_dpkd, pred_dpkd, std=2, verbose=True, plo print(f"True Negative Rate (TNR): {tnr:.2f}") return conf_matrix, tpr, tnr - def generate_roc_curve(true_dpkd, pred_dpkd, thres_range=(0,5), step=0.1): """3. significant mutation impact analysis""" diff --git a/src/analysis/utils.py b/src/analysis/utils.py index 34feb2e..9fbb9a7 100644 --- a/src/analysis/utils.py +++ b/src/analysis/utils.py @@ -81,7 +81,7 @@ def generate_markdown(results, names=None, verbose=False, thresh_sig=False, cind ``` """ n_groups = len(results) - names = names if names else [str(i) for i in range(n_groups)] + names = names if len(names)>0 else [str(i) for i in range(n_groups)] # Convert results to DataFrame results_df = [None for _ in range(n_groups)] md_table = None @@ -112,6 +112,7 @@ def generate_markdown(results, names=None, verbose=False, thresh_sig=False, cind md_table = pd.concat([md_table, sig], axis=1) md_table.columns = [*names, 'p-val'] else: + md_table = pd.DataFrame(md_table) md_table.columns = names md_output = md_table.to_markdown() diff --git a/src/data_prep/datasets.py b/src/data_prep/datasets.py index 7e1ee81..3d25b94 100644 --- a/src/data_prep/datasets.py +++ b/src/data_prep/datasets.py @@ -127,6 +127,7 @@ def __init__(self, save_root:str, data_root:str, aln_dir:str, # Validating subset subset = subset or 'full' save_root = os.path.join(save_root, f'{self.pro_feat_opt}_{self.pro_edge_opt}_{self.ligand_feature}_{self.ligand_edge}') # e.g.: path/to/root/nomsa_anm + self.save_path = save_root if self.verbose: print('save_root:', save_root) if subset != 'full': @@ -402,7 +403,7 @@ def _create_protein_graphs(self, df, node_feat, edge): total=len(unique_df)): if node_feat == cfg.PRO_FEAT_OPT.gvp: - # gvp has its own unique graph to support the architecture implementation. + # gvp has its own unique graph to support the architecture's implementation. coords = Chain(self.pdb_p(code), grep_atoms={'CA', 'N', 'C'}).getCoords(get_all=True) processed_prots[prot_id] = GVPFeaturesProtein().featurize_as_graph(code, coords, pro_seq) continue @@ -448,7 +449,7 @@ def _create_protein_graphs(self, df, node_feat, edge): if len(pro_edge_weight.shape) == 2: pro_edge_weight = torch.Tensor(pro_edge_weight[edge_idx[0], edge_idx[1]]) - elif len(pro_edge_weight.shape) == 3: # edge attr! + elif len(pro_edge_weight.shape) == 3: # has edge attr! pro_edge_weight = torch.Tensor(pro_edge_weight[edge_idx[0], edge_idx[1], :]) pro = torchg.data.Data(x=torch.Tensor(pro_feat), From 6ff304dd6fb90c8360fa84700c25ba42b8e28e42 Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Tue, 4 Jun 2024 10:00:29 -0400 Subject: [PATCH 10/14] feat(__init__.py): tuned model configs --- rayTrain_Tune.py | 2 +- src/__init__.py | 80 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/rayTrain_Tune.py b/rayTrain_Tune.py index 50c3fb5..059fa05 100644 --- a/rayTrain_Tune.py +++ b/rayTrain_Tune.py @@ -86,7 +86,7 @@ def train_func(config): "dataset": cfg.DATA_OPT.kiba, "feature_opt": cfg.PRO_FEAT_OPT.nomsa, - "edge_opt": cfg.PRO_EDGE_OPT.aflow, + "edge_opt": cfg.PRO_EDGE_OPT.binary, "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, diff --git a/src/__init__.py b/src/__init__.py index ee5fbcf..f94f4f4 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1 +1,79 @@ -from src.utils import config \ No newline at end of file +from src.utils import config +from src.utils import config as cfg + + +TUNED_MODEL_CONFIGS = { + #GVPLM_davis0D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE + 'davis_gvpl_aflow': { + "model": cfg.MODEL_OPT.GVPL, + + "dataset": cfg.DATA_OPT.davis, + "feature_opt": cfg.PRO_FEAT_OPT.nomsa, + "edge_opt": cfg.PRO_EDGE_OPT.aflow, + "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, + "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, + + "lr": 0.0001360163557088453, + "batch_size": 128, # local batch size + + "architecture_kwargs":{ + "dropout": 0.027175922988649594, + "output_dim": 128, + } + }, + #GVPLM_davis3D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE + 'davis_gvpl': { + "model": cfg.MODEL_OPT.GVPL, + + "dataset": cfg.DATA_OPT.davis, + "feature_opt": cfg.PRO_FEAT_OPT.nomsa, + "edge_opt": cfg.PRO_EDGE_OPT.binary, + "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, + "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, + + 'lr': 0.00020535607176845963, + 'batch_size': 128, + 'architecture_kwargs': { + 'dropout': 0.08845592454543601, + 'output_dim': 512 + } + }, + + 'davis_aflow':{ # not trained yet... + "model": cfg.MODEL_OPT.DG, + + "dataset": cfg.DATA_OPT.davis, + "feature_opt": cfg.PRO_FEAT_OPT.nomsa, + "edge_opt": cfg.PRO_EDGE_OPT.aflow, + "lig_feat_opt": cfg.LIG_FEAT_OPT.original, + "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, + + + 'lr': 0.0008279387625584954, + 'batch_size': 128, + + 'architecture_kwargs': { + 'dropout': 0.3480347297724069, + 'output_dim': 256 + } + }, + + #GVPLM_PDBbind0D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE + 'PDBbind_gvpl_aflow':{ + "model": cfg.MODEL_OPT.GVPL, + + "dataset": cfg.DATA_OPT.PDBbind, + "feature_opt": cfg.PRO_FEAT_OPT.nomsa, + "edge_opt": cfg.PRO_EDGE_OPT.aflow, + "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, + "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, + + 'lr': 0.00022659, + 'batch_size': 128, + + 'architecture_kwargs': { + 'dropout': 0.02414, + 'output_dim': 256 + } + }, +} \ No newline at end of file From d2f023a9000d42e368e6f837b9f6b08d1da226ec Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Tue, 4 Jun 2024 10:01:20 -0400 Subject: [PATCH 11/14] feat: full code for platinum inference + analysis #94 --- playground.py | 269 ++++++++++++++++++------------------- src/analysis/platinum.py | 47 +++++++ src/train_test/training.py | 2 +- 3 files changed, 182 insertions(+), 136 deletions(-) create mode 100644 src/analysis/platinum.py diff --git a/playground.py b/playground.py index 1c2ce53..df440d2 100644 --- a/playground.py +++ b/playground.py @@ -1,148 +1,147 @@ -#%% 1.Gather data for davis,kiba and pdbbind datasets -import os -import pandas as pd -import matplotlib.pyplot as plt -from src.analysis.utils import combine_dataset_pids -from src import config as cfg -df_prots = combine_dataset_pids(dbs=[cfg.DATA_OPT.davis, cfg.DATA_OPT.PDBbind], # just davis and pdbbind for now - subset='test') - - -#%% 2. Load TCGA data -df_tcga = pd.read_csv('../downloads/TCGA_ALL.maf', sep='\t') - -#%% 3. Pre filtering -df_tcga = df_tcga[df_tcga['Variant_Classification'] == 'Missense_Mutation'] -df_tcga['seq_len'] = pd.to_numeric(df_tcga['Protein_position'].str.split('/').str[1]) -df_tcga = df_tcga[df_tcga['seq_len'] < 5000] -df_tcga['seq_len'].plot.hist(bins=100, title="sequence length histogram capped at 5K") -plt.show() -df_tcga = df_tcga[df_tcga['seq_len'] < 1200] -df_tcga['seq_len'].plot.hist(bins=100, title="sequence length after capped at 1.2K") - -#%% 4. Merging df_prots with TCGA -df_tcga['uniprot'] = df_tcga['SWISSPROT'].str.split('.').str[0] - -dfm = df_tcga.merge(df_prots[df_prots.db != 'davis'], - left_on='uniprot', right_on='prot_id', how='inner') - -# for davis we have to merge on HUGO_SYMBOLS -dfm_davis = df_tcga.merge(df_prots[df_prots.db == 'davis'], - left_on='Hugo_Symbol', right_on='prot_id', how='inner') - -dfm = pd.concat([dfm,dfm_davis], axis=0) - -del dfm_davis # to save mem - -# %% 5. Post filtering step -# 5.1. Filter for only those sequences with matching sequence length (to get rid of nonmatched isoforms) -# seq_len_x is from tcga, seq_len_y is from our dataset -tmp = len(dfm) -# allow for some error due to missing amino acids from pdb file in PDBbind dataset -# - assumption here is that isoforms will differ by more than 50 amino acids -dfm = dfm[(dfm.seq_len_y <= dfm.seq_len_x) & (dfm.seq_len_x<= dfm.seq_len_y+50)] -print(f"Filter #1 (seq_len) : {tmp:5d} - {tmp-len(dfm):5d} = {len(dfm):5d}") - -# 5.2. Filter out those that dont have the same reference seq according to the "Protein_position" and "Amino_acids" col - -# Extract mutation location and reference amino acid from 'Protein_position' and 'Amino_acids' columns -dfm['mt_loc'] = pd.to_numeric(dfm['Protein_position'].str.split('/').str[0]) -dfm = dfm[dfm['mt_loc'] < dfm['seq_len_y']] -dfm[['ref_AA', 'mt_AA']] = dfm['Amino_acids'].str.split('/', expand=True) - -dfm['db_AA'] = dfm.apply(lambda row: row['prot_seq'][row['mt_loc']-1], axis=1) - -# Filter #2: Match proteins with the same reference amino acid at the mutation location -tmp = len(dfm) -dfm = dfm[dfm['db_AA'] == dfm['ref_AA']] -print(f"Filter #2 (ref_AA match): {tmp:5d} - {tmp-len(dfm):5d} = {len(dfm):5d}") -print('\n',dfm.db.value_counts()) - - -# %% final seq len distribution -n_bins = 25 -lengths = dfm.seq_len_x -fig, ax = plt.subplots(1, 1, figsize=(10, 5)) - -# Plot histogram -n, bins, patches = ax.hist(lengths, bins=n_bins, color='blue', alpha=0.7) -ax.set_title('TCGA final filtering for db matches') +#%% +# %% +import logging +from typing import OrderedDict -# Add counts to each bin -for count, x, patch in zip(n, bins, patches): - ax.text(x + 0.5, count, str(int(count)), ha='center', va='bottom') +import seaborn as sns +from matplotlib import pyplot as plt +from statannotations.Annotator import Annotator -ax.set_xlabel('Sequence Length') -ax.set_ylabel('Frequency') +from src.analysis.figures import prepare_df, custom_fig, fig_combined -plt.tight_layout() -plt.show() +df = prepare_df() +# %% +models = { + 'DG': ('nomsa', 'binary', 'original', 'binary'), + 'aflow': ('nomsa', 'aflow', 'original', 'binary'), + # 'aflow_ring3': ('nomsa', 'aflow_ring3', 'original', 'binary'), + # 'gvpP': ('gvp', 'binary', 'original', 'binary'), + # 'gvpL': ('nomsa', 'binary', 'gvp', 'binary'), + 'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'), + 'gvpL': ('nomsa', 'binary', 'gvp', 'binary'), + # 'gvpL_aflow_rng3': ('nomsa', 'aflow_ring3', 'gvp', 'binary'), +} -# %% Getting updated sequences -def apply_mut(row): - ref_seq = list(row['prot_seq']) - ref_seq[row['mt_loc']-1] = row['mt_AA'] - return ''.join(ref_seq) +# %% +fig, axes = fig_combined(df, datasets=['davis','PDBbind'], fig_callable=custom_fig, + models=models, metrics=['cindex', 'mse'], + fig_scale=(8,5)) +plt.xticks(rotation=45) -dfm['mt_seq'] = dfm.apply(apply_mut, axis=1) +# %% # %% -dfm.to_csv("/cluster/home/t122995uhn/projects/data/tcga/tcga_maf_davis_pdbbind.csv") -# %% -from src.utils.seq_alignment import MSARunner -from tqdm import tqdm +######################################################################## +########################## PLATINUM ANALYSIS ########################### +######################################################################## +import torch, os import pandas as pd -import os - -DATA_DIR = '/cluster/home/t122995uhn/projects/data/tcga' -CSV = f'{DATA_DIR}/tcga_maf_davis_pdbbind.csv' -N_CPUS= 6 -NUM_ARRAYS = 10 -array_idx = 0#${SLURM_ARRAY_TASK_ID} -df = pd.read_csv(CSV, index_col=0) -df.sort_values(by='seq_len_y', inplace=True) - - -# %% -for DB in df.db.unique(): - print('DB', DB) - RAW_DIR = f'{DATA_DIR}/{DB}' - # should already be unique if these are proteins mapped form tcga! - unique_df = df[df['db'] == DB] - ########################## Get job partition - partition_size = len(unique_df) / NUM_ARRAYS - start, end = int(array_idx*partition_size), int((array_idx+1)*partition_size) - - unique_df = unique_df[start:end] - - #################################### create fastas - fa_dir = os.path.join(RAW_DIR, f'{DB}_fa') - fasta_fp = lambda idx,pid: os.path.join(fa_dir, f"{idx}-{pid}.fasta") - os.makedirs(fa_dir, exist_ok=True) - for idx, (prot_id, pro_seq) in tqdm( - unique_df[['prot_id', 'prot_seq']].iterrows(), - desc='Creating fastas', - total=len(unique_df)): - with open(fasta_fp(idx,prot_id), "w") as f: - f.write(f">{prot_id},{idx},{DB}\n{pro_seq}") - - ##################################### Run hhblits - aln_dir = os.path.join(RAW_DIR, f'{DB}_aln') - aln_fp = lambda idx,pid: os.path.join(aln_dir, f"{idx}-{pid}.a3m") - os.makedirs(aln_dir, exist_ok=True) - - # finally running - for idx, (prot_id, pro_seq) in tqdm( - unique_df[['prot_id', 'mt_seq']].iterrows(), - desc='Running hhblits', - total=len(unique_df)): - in_fp = fasta_fp(idx,prot_id) - out_fp = aln_fp(idx,prot_id) +from src import cfg +from src import TUNED_MODEL_CONFIGS +from src.utils.loader import Loader +from src.train_test.training import test +from src.analysis.figures import predictive_performance, tbl_stratified_dpkd_metrics, tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_in_binding + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +INFERENCE = True +VERBOSE = True +out_dir = f'{cfg.MEDIA_SAVE_DIR}/test_set_pred/' +os.makedirs(out_dir, exist_ok=True) +cp_dir = cfg.CHECKPOINT_SAVE_DIR +RAW_PLT_CSV=f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv" + +#%% load up model: +for KEY, CONFIG in TUNED_MODEL_CONFIGS.items(): + MODEL_KEY = lambda fold: Loader.get_model_key(CONFIG['model'], CONFIG['dataset'], CONFIG['feature_opt'], CONFIG['edge_opt'], + CONFIG['batch_size'], CONFIG['lr'], CONFIG['architecture_kwargs']['dropout'], + n_epochs=2000, fold=fold, + ligand_feature=CONFIG['lig_feat_opt'], ligand_edge=CONFIG['lig_edge_opt']) + print('\n\n'+ '## ' + KEY) + OUT_PLT = lambda i: f'{out_dir}/{MODEL_KEY(i)}_PLATINUM.csv' + db_p = f"{CONFIG['feature_opt']}_{CONFIG['edge_opt']}_{CONFIG['lig_feat_opt']}_{CONFIG['lig_edge_opt']}" + + if CONFIG['dataset'] in ['kiba', 'davis']: + db_p = f"DavisKibaDataset/{CONFIG['dataset']}/{db_p}" + else: + db_p = f"{CONFIG['dataset']}Dataset/{db_p}" - if not os.path.isfile(out_fp): - print(MSARunner.hhblits(in_fp, out_fp, n_cpus=N_CPUS, return_cmd=True)) - break + train_p = lambda set: f"{cfg.DATA_ROOT}/{db_p}/{set}0/cleaned_XY.csv" + + if not os.path.exists(OUT_PLT(0)) and INFERENCE: + print('running inference!') + cp = lambda fold: f"{cp_dir}/{MODEL_KEY(fold)}.model" + + model = Loader.init_model(model=CONFIG["model"], pro_feature=CONFIG["feature_opt"], + pro_edge=CONFIG["edge_opt"],**CONFIG['architecture_kwargs']) + + # load up platinum test db + loaders = Loader.load_DataLoaders(cfg.DATA_OPT.platinum, + pro_feature = CONFIG['feature_opt'], + edge_opt = CONFIG['edge_opt'], + ligand_feature = CONFIG['lig_feat_opt'], + ligand_edge = CONFIG['lig_edge_opt'], + datasets=['test']) + + for i in range(5): + model.safe_load_state_dict(torch.load(cp(i), map_location=device)) + model.to(device) + model.eval() + + loss, pred, actual = test(model, loaders['test'], device, verbose=True) + + # saving as csv with columns code, pred, actual + # get codes from test loader + codes, pid = [b['code'][0] for b in loaders['test']], [b['prot_id'][0] for b in loaders['test']] + df = pd.DataFrame({'prot_id': pid, 'pred': pred, 'actual': actual}, index=codes) + df.index.name = 'code' + df.to_csv(OUT_PLT(i)) + + # run platinum eval: + print('\n### 1. predictive performance') + mkdown = predictive_performance(OUT_PLT, train_p, verbose=VERBOSE, plot=False) + print('\n### 2 Mutation impact analysis') + print('\n#### 2.1 $\Delta pkd$ predictive performance') + mkdn = tbl_dpkd_metrics_overlap(OUT_PLT, train_p, verbose=VERBOSE, plot=False) + print('\n#### 2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)') + m = tbl_dpkd_metrics_in_binding(OUT_PLT, RAW_PLT_CSV, verbose=VERBOSE, plot=False) # %% +dfr = pd.read_csv(RAW_PLT_CSV, index_col=0) + +# add in_binding info to df +def get_in_binding(df, dfr): + """ + df is the predicted csv with index as _wt (or *_mt) where raw_idx + corresponds to an index in dfr which contains the raw data for platinum including + ('mut.in_binding_site') + - 0: wildtype rows + - 1: close (<8 Ang) + - 2: Far (>8 Ang) + """ + pocket = dfr[dfr['mut.in_binding_site'] == 'YES'].index + pclass = [] + for code in df.index: + if '_wt' in code: + pclass.append(0) + elif int(code.split('_')[0]) in pocket: + pclass.append(1) + else: + pclass.append(2) + df['pocket'] = pclass + return df + +df = get_in_binding(pd.read_csv(OUT_PLT(0), index_col=0), dfr) +if VERBOSE: + cnts = df.pocket.value_counts() + cnts.index = ['wt', 'in pocket', 'not in pocket'] + cnts.name = "counts" + print(cnts.to_markdown(), end="\n\n") + +tbl_stratified_dpkd_metrics(OUT_PLT, NORMALIZE=True, n_models=5, df_transform=get_in_binding, + conditions=['(pocket == 0) | (pocket == 1)', '(pocket == 0) | (pocket == 2)'], + names=['in pocket', 'not in pocket'], + verbose=VERBOSE, plot=True, dfr=dfr) + diff --git a/src/analysis/platinum.py b/src/analysis/platinum.py new file mode 100644 index 0000000..54e8963 --- /dev/null +++ b/src/analysis/platinum.py @@ -0,0 +1,47 @@ +import torch, os +import pandas as pd + +from src import cfg +from src import TUNED_MODEL_CONFIGS +from src.utils.loader import Loader +from src.train_test.training import test +device = torch.cuda.device(0) if torch.cuda.is_available() else torch.device('cpu') + +CONFIG = TUNED_MODEL_CONFIGS['davis_gvpl_aflow'] +#%% load up model: +cp_dir = cfg.CHECKPOINT_SAVE_DIR + +MODEL_KEY = lambda fold: Loader.get_model_key(CONFIG['model'], CONFIG['dataset'], CONFIG['feature_opt'], CONFIG['edge_opt'], + CONFIG['batch_size'], CONFIG['lr'], CONFIG['architecture_kwargs']['dropout'], + n_epochs=2000, fold=fold, + ligand_feature=CONFIG['lig_feat_opt'], ligand_edge=CONFIG['lig_edge_opt']) +cp = lambda fold: f"{cp_dir}/{MODEL_KEY(fold)}.model" + +out_dir = f'{cfg.MEDIA_SAVE_DIR}/test_set_pred/' +os.makedirs(out_dir, exist_ok=True) + +model = Loader.init_model(model=CONFIG["model"], pro_feature=CONFIG["feature_opt"], + pro_edge=CONFIG["edge_opt"],**CONFIG['architecture_kwargs']) + +#%% +# load up platinum test db +loaders = Loader.load_DataLoaders(cfg.DATA_OPT.platinum, + pro_feature = CONFIG['feature_opt'], + edge_opt = CONFIG['edge_opt'], + ligand_feature = CONFIG['lig_feat_opt'], + ligand_edge = CONFIG['lig_edge_opt'], + datasets=['test']) + +for i in range(5): + model.safe_load_state_dict(torch.load(cp(i), map_location=device)) + model.to(device) + model.eval() + + loss, pred, actual = test(model, loaders['test'], device, verbose=True) + + # saving as csv with columns code, pred, actual + # get codes from test loader + codes, pid = [b['code'][0] for b in loaders['test']], [b['prot_id'][0] for b in loaders['test']] + df = pd.DataFrame({'prot_id': pid, 'pred': pred, 'actual': actual}, index=codes) + df.index.name = 'code' + df.to_csv(f'{out_dir}/{MODEL_KEY(i)}_PLATINUM.csv') \ No newline at end of file diff --git a/src/train_test/training.py b/src/train_test/training.py index c0ba01d..2bff0f5 100644 --- a/src/train_test/training.py +++ b/src/train_test/training.py @@ -165,7 +165,7 @@ def train(model: BaseModel, train_loader:DataLoader, val_loader:DataLoader, # gamma = (lr_e/lr_0)**(step_size/epochs) # calculate gamma based on final lr chosen. SCHEDULER = ReduceLROnPlateau(OPTIMIZER, mode='min', patience=saver.patience-1, threshold=saver.min_delta*0.1, - min_lr=5e-5, factor=0.8, + min_lr=5e-7, factor=0.8, verbose=True) logs = {'train_loss': [], 'val_loss': []} From df4f13caa9c872175be867f0857973c9a7d7aa2e Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Tue, 4 Jun 2024 10:43:19 -0400 Subject: [PATCH 12/14] fix: figures logging + gitignore test sets --- .gitignore | 1 + playground.py | 2 -- src/analysis/figures.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 997e9c5..75ad2bf 100644 --- a/.gitignore +++ b/.gitignore @@ -222,3 +222,4 @@ results/model_checkpoints/ours/*.model results/model_media/*/train_log/* results/model_media/*/train_set_pred/* results/model_media/*/test_set_pred/* +results/model_media/test_set_pred \ No newline at end of file diff --git a/playground.py b/playground.py index df440d2..937a6ad 100644 --- a/playground.py +++ b/playground.py @@ -28,8 +28,6 @@ fig_scale=(8,5)) plt.xticks(rotation=45) -# %% - # %% ######################################################################## diff --git a/src/analysis/figures.py b/src/analysis/figures.py index 6a76606..ce835c0 100644 --- a/src/analysis/figures.py +++ b/src/analysis/figures.py @@ -436,7 +436,7 @@ def matched(df, tuple): for model, feat in models.items(): plot_data[model] = filtered_df[matched(filtered_df, feat)][sel_col] if len(plot_data[model]) != 5: - logging.warning(f'Expected 5 results for {model}, got {len(plot_data[model])}') + logging.warning(f'Expected 5 results for {model} on {sel_dataset}, got {len(plot_data[model])}') # plot violin plot with annotations vals = list(plot_data.values()) From e2de41ab1221aac4b815d02eefdcdc0e96685335 Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Tue, 4 Jun 2024 10:43:53 -0400 Subject: [PATCH 13/14] feat: + kiba_gvpl tuned model #90 --- rayTrain_Tune.py | 4 ++-- src/__init__.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/rayTrain_Tune.py b/rayTrain_Tune.py index 059fa05..bcee304 100644 --- a/rayTrain_Tune.py +++ b/rayTrain_Tune.py @@ -86,8 +86,8 @@ def train_func(config): "dataset": cfg.DATA_OPT.kiba, "feature_opt": cfg.PRO_FEAT_OPT.nomsa, - "edge_opt": cfg.PRO_EDGE_OPT.binary, - "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, + "edge_opt": cfg.PRO_EDGE_OPT.aflow, + "lig_feat_opt": cfg.LIG_FEAT_OPT.original, "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, "fold_selection": 0, diff --git a/src/__init__.py b/src/__init__.py index f94f4f4..b19b632 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -57,7 +57,33 @@ 'output_dim': 256 } }, + ##################################################### + ############## kiba ################################# + ##################################################### + 'kiba_gvpl': { + "model": cfg.MODEL_OPT.GVPL, + + "dataset": cfg.DATA_OPT.kiba, + "feature_opt": cfg.PRO_FEAT_OPT.nomsa, + "edge_opt": cfg.PRO_EDGE_OPT.binary, + "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, + "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, + + + 'lr': 0.00003372637625954074, + 'batch_size': 32, + + 'architecture_kwargs': { + 'dropout': 0.09399264336737133, + 'output_dim': 512, + 'num_GVPLayers': 4 + } + } + + ##################################################### + ########### PDBbind ################################# + ##################################################### #GVPLM_PDBbind0D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE 'PDBbind_gvpl_aflow':{ "model": cfg.MODEL_OPT.GVPL, From 57dcb487b7a27df0dbf561ee736c7112a107ba41 Mon Sep 17 00:00:00 2001 From: Jean Charle Yaacoub <50300488+jyaacoub@users.noreply.github.com> Date: Tue, 4 Jun 2024 11:07:38 -0400 Subject: [PATCH 14/14] Update __init__.py --- src/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/__init__.py b/src/__init__.py index b19b632..1bb1995 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -78,7 +78,7 @@ 'output_dim': 512, 'num_GVPLayers': 4 } - } + }, ##################################################### @@ -102,4 +102,4 @@ 'output_dim': 256 } }, -} \ No newline at end of file +}