diff --git a/playground.py b/playground.py index 60c24f1..d890a3c 100644 --- a/playground.py +++ b/playground.py @@ -1,39 +1,154 @@ -# %% +#%% +import os +import logging +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor import pandas as pd -train_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/DavisKibaDataset/davis/nomsa_binary_original_binary/full/XY.csv') -test_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/davis/test.csv') -train_df = train_df[~train_df.prot_id.isin(set(test_df.prot_id))] - -# %% +from src.utils.residue import Chain +from multiprocessing import Pool, cpu_count +from src.data_prep.datasets import BaseDataset + +# df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base.csv", index_col=0) +df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base_filtered.csv", index_col=0) +logging.getLogger().setLevel(logging.DEBUG) + +def process_protein_multiprocessing(args): + """ + Checks if protein has conf file and correct sequence, returns: + - None, None - if it has a conf file and is correct + - pid, None - is missing a conf file + - pid, seq - has a conf file but is not correct sequence. + """ + group_codes, code, pid, seq, af_conf_dir, is_pdbbind, files = args + MIN_MODEL_COUNT = 5 + + correct_seq = False + matching_code = None + af_confs = [] + if is_pdbbind: + for c in group_codes: + af_fp = os.path.join(af_conf_dir, f'{c}.pdb') + if os.path.exists(af_fp): + af_confs = [af_fp] + matching_code = code + if Chain(af_fp).sequence == seq: + correct_seq = True + break + + else: + af_confs = [os.path.join(af_conf_dir, f) for f in files if f.startswith(pid)] + + if len(af_confs) == 0: + return pid, None + + # either all models in one pdb file (alphaflow) or spread out across multiple files (AF2 msa subsampling) + model_count = len(af_confs) if len(af_confs) > 1 else 5# Chain.get_model_count(af_confs[0]) + + if model_count < MIN_MODEL_COUNT: + return pid, None + elif not correct_seq: # final check + af_seq = Chain(af_confs[0]).sequence + if seq != af_seq: + logging.debug(f'Mismatched sequence for {pid}') + # if matching_code == code: # something wrong here -> incorrect seq but for the right code? + # return pid, af_seq + return pid, matching_code + + if matching_code != code: + return None, matching_code + return None, None + +#%% check_missing_confs method +af_conf_dir:str = '/cluster/home/t122995uhn/projects/data/pdbbind/alphaflow_io/out_pdb_MD-distilled/' +is_pdbbind=True + +df_unique:pd.DataFrame = df.drop_duplicates('prot_id') +df_pid_groups = df.groupby(['prot_id']).groups + +missing = set() +mismatched = {} +# total of 3728 unique proteins with alphaflow confs (named by pdb ID) +files = None +if not is_pdbbind: + files = [f for f in os.listdir(af_conf_dir) if f.endswith('.pdb')] + +with Pool(processes=cpu_count()) as pool: + tasks = [(df_pid_groups[pid], code, pid, seq, af_conf_dir, is_pdbbind, files) \ + for code, (pid, seq) in df_unique[['prot_id', 'prot_seq']].iterrows()] + + for pid, new_seq in tqdm(pool.imap_unordered(process_protein_multiprocessing, tasks), + desc='Filtering out proteins with missing PDB files for multiple confirmations', + total=len(tasks)): + if new_seq is not None: + mismatched[pid] = new_seq + elif pid is not None: # just pid -> missing af files + missing.add(pid) + +print(len(missing),len(mismatched)) + +#%% make subsitutions for rows +df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base.csv", index_col=0) +df_mismatched = pd.DataFrame.from_dict(mismatched, orient='index', columns=['code']) +df_mismatched_sub = df.loc[df_mismatched['code']][['prot_id', 'prot_seq']].reset_index() +df_mismatched = df_mismatched.merge(df_mismatched_sub, on='code') + +df_mismatched = df_mismatched.merge(df_mismatched_sub, on='code') +dff = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base_filtered.csv") +dffm = dff.merge(df_mismatched, on='code') +#%% +from src.data_prep.datasets import BaseDataset import pandas as pd +csv_p = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_binary_original_binary/full/XY.csv" -davis_test_df = pd.read_csv(f"/home/jean/projects/MutDTA/splits/davis/test.csv") -davis_test_df['gene'] = davis_test_df['prot_id'].str.split('(').str[0] - -#%% ONCO KB MERGE -onco_df = pd.read_csv("../data/oncoKB_DrugGenePairList.csv") -davis_join_onco = davis_test_df.merge(onco_df.drop_duplicates("gene"), on="gene", how="inner") +df = pd.read_csv(csv_p, index_col=0) # %% -onco_df = pd.read_csv("../data/oncoKB_DrugGenePairList.csv") -onco_df.merge(davis_test_df.drop_duplicates("gene"), on="gene", how="inner").value_counts("gene") - - - - - - - - +import os +from tqdm import tqdm +alphaflow_dir = "/cluster/home/t122995uhn/projects/data/pdbbind/alphaflow_io/out_pdb_MD-distilled/" +ln_dir = "/cluster/home/t122995uhn/projects/data/pdbbind/alphaflow_io/out_pid_ln/" + +os.makedirs(ln_dir, exist_ok=True) + +# First, check and remove any broken links in the destination directory +for link_file in tqdm(os.listdir(ln_dir), desc="Checking for broken links"): + ln_p = os.path.join(ln_dir, link_file) + if os.path.islink(ln_p) and not os.path.exists(ln_p): + print(f"Removing broken link: {ln_p}") + os.remove(ln_p) + + +# %% files are .pdb with 50 "models" in each +for file in tqdm(os.listdir(alphaflow_dir)): + if not file.endswith('.pdb'): + continue + + code, _ = os.path.splitext(file) + pid = df.loc[code].prot_id + src, dst = f"{alphaflow_dir}/{file}", f"{ln_dir}/{pid}.pdb" + if not os.path.exists(dst): + os.symlink(src,dst) -# %% -from src.train_test.splitting import resplit +#%% +######################################################################## +########################## BUILD DATASETS ############################## +######################################################################## +import os +from src.data_prep.init_dataset import create_datasets from src import cfg +import logging +cfg.logger.setLevel(logging.DEBUG) -db_p = lambda x: f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/nomsa_{x}_gvp_binary' - -db = resplit(dataset=db_p('binary'), split_files=db_p('aflow'), use_train_set=True) - +splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/davis/' +create_datasets(cfg.DATA_OPT.davis, + feat_opt=cfg.PRO_FEAT_OPT.nomsa, + edge_opt=[cfg.PRO_EDGE_OPT.aflow, cfg.PRO_EDGE_OPT.binary], + ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp], + ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=True, + k_folds=5, + test_prots_csv=f'{splits}/test.csv', + val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)], + data_root=os.path.abspath('../data/test/')) # %% @@ -69,62 +184,3 @@ models=models, metrics=['cindex', 'mse'], fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance") plt.xticks(rotation=45) - -#%% -######################################################################## -########################## BUILD DATASETS ############################## -######################################################################## -from src.data_prep.init_dataset import create_datasets -from src import cfg -import logging -cfg.logger.setLevel(logging.DEBUG) - -splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind/' -create_datasets(cfg.DATA_OPT.PDBbind, - feat_opt=cfg.PRO_FEAT_OPT.nomsa, - edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow], - ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp], - ligand_edges=cfg.LIG_EDGE_OPT.binary, - k_folds=5, - test_prots_csv=f'{splits}/test.csv', - val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)]) - -# %% -from src.utils.loader import Loader - -db_aflow = Loader.load_dataset('../data/DavisKibaDataset/davis/nomsa_aflow_original_binary/full') -db = Loader.load_dataset('../data/DavisKibaDataset/davis/nomsa_binary_original_binary/full') - -# %% -# 5-fold cross validation + test set -import pandas as pd -from src import cfg -from src.train_test.splitting import balanced_kfold_split -from src.utils.loader import Loader -test_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind_test.csv') -test_prots = set(test_df.prot_id) -db = Loader.load_dataset(f'{cfg.DATA_ROOT}/PDBbindDataset/nomsa_binary_original_binary/full/') - -train, val, test = balanced_kfold_split(db, - k_folds=5, test_split=0.1, val_split=0.1, - test_prots=test_prots, random_seed=0, verbose=True - ) - -#%% -db.save_subset_folds(train, 'train') -db.save_subset_folds(val, 'val') -db.save_subset(test, 'test') - -#%% -import shutil, os - -src = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_binary_original_binary/" -dst = "/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind" -os.makedirs(dst, exist_ok=True) - -for i in range(5): - sfile = f"{src}/val{i}/XY.csv" - dfile = f"{dst}/val{i}.csv" - shutil.copyfile(sfile, dfile) - -# %% diff --git a/results/v113/model_media/model_stats.csv b/results/v113/model_media/model_stats.csv index 97b9b98..fa597ba 100644 --- a/results/v113/model_media/model_stats.csv +++ b/results/v113/model_media/model_stats.csv @@ -33,7 +33,7 @@ DGM_kiba3D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_o DGM_kiba1D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.6572547696170707,0.4910169881324348,0.4223030308489566,0.5056608650145961,0.4887140404038343,0.7110983511544631 DGM_kiba0D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7034265325632763,0.5925355722791278,0.5277625252545413,0.4342528236949791,0.4414575484107895,0.6589786215765874 DGM_kiba2D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7048145943290888,0.5949726563519002,0.5309972890766975,0.4319786589236721,0.4397173516202035,0.657250834098879 -DGM_kiba4D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.701385350951551,0.5879330884340543,0.5238489684293576,0.43711720760821254,0.44305434527103205,0.6611484005941575 +DGM_kiba4D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.701385350951551,0.5879330884340543,0.5238489684293576,0.4371172076082125,0.443054345271032,0.6611484005941575 DGM_PDBbind0D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6655853516013849,0.4884783169365537,0.4768266531852068,2.716448143336347,1.29524747639392,1.648165083763258 DGM_PDBbind1D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6616134720685233,0.4760514005042608,0.465349083935074,2.7772597147478297,1.309049959958126,1.6665112405104952 DGM_PDBbind2D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6547720831157129,0.4674109621848075,0.4473761042353416,2.911396055546887,1.350302681084595,1.7062813529857517 @@ -46,6 +46,16 @@ GVPLM_kiba4D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_200 GVPLM_kiba2D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6999996243009051,0.5844167249718527,0.5172933123278608,1.1060115689114345,0.8743866055307609,1.051670846278166 GVPLM_kiba4D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.7028775926753292,0.5511820968815552,0.5274587057439738,0.5191856014604727,0.464434987594827,0.7205453500373676 GVPLM_kiba3D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.6974533040075589,0.5366909353583093,0.5092679820627712,0.5267452909346243,0.4756001094118965,0.7257722031978245 -GVPLM_kiba2D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.6914933128392398,0.5535320315157304,0.49916906416324947,0.5167043183823063,0.46841519355949945,0.7188214787986696 +GVPLM_kiba2D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.6914933128392398,0.5535320315157304,0.4991690641632494,0.5167043183823063,0.4684151935594994,0.7188214787986696 GVPLM_kiba0D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.667835957156033,0.5118338017508158,0.4512406693094591,0.529732104745498,0.4987863530176246,0.7278269744558098 GVPLM_kiba1D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.673244203521782,0.5584520336614822,0.4601543459719329,0.4634166140494674,0.4859944296833685,0.6807470999199831 +DGM_PDBbind0D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.5,,,3.4182907626961336,1.4994075870215695,1.8488620182956144 +DGM_PDBbind1D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.5007066821937181,0.038214548483191,0.0037893479451263,3.4747609352089994,1.5185222442497732,1.864071064956752 +DGM_PDBbind2D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.5000473143129389,0.0065482511633214,0.0077853856938565,3.55397849703766,1.5411690941151477,1.8851998559934329 +DGM_PDBbind3D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.5,,,3.6010645245915747,1.553675811370748,1.8976471022272752 +DGM_PDBbind4D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.5,,,3.472537530411251,1.5178422226480432,1.8634745853945127 +GVPLM_PDBbind0D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.5675107531975828,0.2012549479389283,0.1972868580568649,3.521371768804649,1.5340581613048625,1.876531845934049 +GVPLM_PDBbind1D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.6351386534675363,0.4162576037033809,0.393776329694548,2.952582892735166,1.3614234164872832,1.7183081483643048 +GVPLM_PDBbind2D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.6073794252279185,0.300904510298725,0.3141960696982953,3.5527159664685373,1.497829202900039,1.8848649730069624 +GVPLM_PDBbind3D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.5544194125815396,0.1962498205943712,0.1602389373350445,3.559013089849939,1.5262607894024975,1.886534677616592 +GVPLM_PDBbind4D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.6222734472562935,0.38283753210340116,0.3603062880559139,3.0631035175351524,1.389417946315663,1.7501724250870692 diff --git a/results/v113/model_media/model_stats_val.csv b/results/v113/model_media/model_stats_val.csv index 3d6e7cb..c5c5f27 100644 --- a/results/v113/model_media/model_stats_val.csv +++ b/results/v113/model_media/model_stats_val.csv @@ -33,7 +33,7 @@ DGM_kiba3D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_o DGM_kiba1D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7169797416293802,0.6527045139380728,0.5592661699267665,0.3368578002681257,0.3921905937173818,0.5803945212251109 DGM_kiba0D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7128444726133566,0.6469405510839773,0.5484484334329713,0.336774733412883,0.4184534194805188,0.5803229561312244 DGM_kiba2D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7033453012745192,0.6072209463563496,0.5174763874796517,0.4390207180018329,0.4649270940615934,0.6625863853127628 -DGM_kiba4D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7023830999958538,0.639949133525287,0.5151496892582399,0.40105347469958885,0.4409018537119612,0.6332878292684843 +DGM_kiba4D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7023830999958538,0.639949133525287,0.5151496892582399,0.4010534746995888,0.4409018537119612,0.6332878292684843 DGM_PDBbind0D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.667185125919345,0.4916330475527956,0.4742033800137094,2.573515761007047,1.247796092498501,1.604218115159858 DGM_PDBbind1D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6918743579083576,0.5545547118073524,0.5483030524092783,2.511534059876986,1.2495642605282011,1.5847820228274254 DGM_PDBbind2D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6664571269002308,0.5086157319827234,0.4766146235124788,2.7894558492182133,1.3283703922279295,1.6701664136301548 @@ -49,3 +49,13 @@ GVPLM_kiba3D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_200 GVPLM_kiba2D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.7926024955332317,0.7374112888096318,0.7063431040772125,0.32468252989801,0.3579667507508575,0.5698092048203591 GVPLM_kiba0D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.6669854507035494,0.5525438589938947,0.4355859541059588,0.4115128362890195,0.4456888421443692,0.6414926626930503 GVPLM_kiba1D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.6989377849970573,0.6478510815311954,0.515693387777559,0.3561253610107396,0.4216625264430977,0.5967623991261007 +DGM_PDBbind0D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.5125715405607719,0.2793381742071084,0.1926536143404282,3.346598030093213,1.4805904418459894,1.829370938353732 +DGM_PDBbind1D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.5107971942443016,0.0312432003865885,0.0308374960271485,3.5631682917172403,1.5458264870189105,1.887635635316636 +DGM_PDBbind2D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.5277274707626068,0.358671071267003,0.2830406326274937,3.6322764505727583,1.5248606967243108,1.905853208033808 +DGM_PDBbind3D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.51955000029021,0.3102068840467967,0.2386819219389774,3.410729951552425,1.478064934361954,1.8468161661498488 +DGM_PDBbind4D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.5275270713061986,0.3631722071531414,0.2847957039355699,3.1869078443560386,1.4534702110049291,1.7851912626819677 +GVPLM_PDBbind0D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.5596497975350556,0.2543017795335043,0.1766358807164376,3.113015575178155,1.445382858916236,1.7643739896003214 +GVPLM_PDBbind1D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.6600679137627058,0.4711013339427787,0.4650543790596504,2.884138926557529,1.3550213658092014,1.698275279970102 +GVPLM_PDBbind2D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.6141731947137768,0.3644634107975009,0.3328003860926398,3.1989141378067982,1.408361079518094,1.788550848538223 +GVPLM_PDBbind3D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.5536881344717306,0.192581089791613,0.1606937190574807,3.4912679264408366,1.4810306765899366,1.8684934911422184 +GVPLM_PDBbind4D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.628428758706968,0.3807074065686098,0.3759327635419598,2.9551973425576046,1.3590244197523815,1.7190687428249065 diff --git a/src/data_prep/datasets.py b/src/data_prep/datasets.py index 498ead7..a10a307 100644 --- a/src/data_prep/datasets.py +++ b/src/data_prep/datasets.py @@ -14,6 +14,7 @@ import numpy as np import pandas as pd from tqdm import tqdm +from multiprocessing import Pool, cpu_count from src.data_prep.feature_extraction.gvp_feats import GVPFeaturesProtein, GVPFeaturesLigand from src.utils import config as cfg @@ -241,8 +242,6 @@ def get_unique_prots(df, keep_len=False) -> pd.DataFrame: def load(self): # loading cleaned XY.csv file - # if self.df is None: # WARNING: HOT FIX to be compatible with old datasets - # self.df = pd.read_csv(self.processed_paths[3], index_col=0) self.df = pd.read_csv(self.processed_paths[3], index_col=0) self._indices = self.df.index @@ -311,7 +310,68 @@ def load_subset(self, subset_name:str): self.subset = subset_name self.load() - def clean_XY(self, df:pd.DataFrame, max_seq_len=None): + @staticmethod + def process_protein_multiprocessing(args): + """ + Checks if protein has conf file and correct sequence, returns: + - None, None - if it has a conf file and is correct + - pid, None - is missing a conf file + - pid, matching_code - has the correct number of conf files but is not the correct sequence. + - None, matching_code - correct seq, # of confs, but under a different file name + """ + pid, seq, af_conf_dir, is_pdbbind, files = args + MIN_MODEL_COUNT = 5 + + af_confs = [] + if is_pdbbind: + fp = os.path.join(af_conf_dir, f'{pid}.pdb') + if os.path.exists(fp): + af_confs = [os.path.join(af_conf_dir, f'{pid}.pdb')] + else: + af_confs = [os.path.join(af_conf_dir, f) for f in files if f.startswith(pid)] + + if len(af_confs) == 0: + return pid, None + + # either all models in one pdb file (alphaflow) or spread out across multiple files (AF2 msa subsampling) + model_count = len(af_confs) if len(af_confs) > 1 else Chain.get_model_count(af_confs[0]) + + if model_count < MIN_MODEL_COUNT: + return pid, None + + af_seq = Chain(af_confs[0]).sequence + if seq != af_seq: + logging.debug(f'Mismatched sequence for {pid}') + return pid, af_seq + + return None, None + + @staticmethod + def check_missing_confs(df_unique:pd.DataFrame, af_conf_dir:str, is_pdbbind=False): + logging.debug(f'Getting af_confs from {af_conf_dir}') + + missing = set() + mismatched = {} + # total of 3728 unique proteins with alphaflow confs (named by pdb ID) + files = None + if not is_pdbbind: + files = [f for f in os.listdir(af_conf_dir) if f.endswith('.pdb')] + + with Pool(processes=cpu_count()) as pool: + tasks = [(pid, seq, af_conf_dir, is_pdbbind, files) \ + for _, (pid, seq) in df_unique[['prot_id', 'prot_seq']].iterrows()] + + for pid, correct_seq in tqdm(pool.imap_unordered(BaseDataset.process_protein_multiprocessing, tasks), + desc='Filtering out proteins with missing PDB files for multiple confirmations', + total=len(tasks)): + if correct_seq is not None: + mismatched[pid] = correct_seq + elif pid is not None: # just pid -> missing af files + missing.add(pid) + + return missing, mismatched + + def clean_XY(self, df:pd.DataFrame, max_seq_len=None): max_seq_len = self.max_seq_len # Filter proteins greater than max length @@ -331,47 +391,32 @@ def clean_XY(self, df:pd.DataFrame, max_seq_len=None): df.set_index(idx_name, inplace=True) # Filter out proteins that are missing pdbs for confirmations - missing_conf = set() + missing = set() + mismatched = {} if self.pro_edge_opt in cfg.OPT_REQUIRES_CONF: - files = [f for f in os.listdir(self.af_conf_dir) if f.endswith('.pdb')] - - for _, (pid, seq) in tqdm(df_unique[['prot_id', 'prot_seq']].iterrows(), - desc='Filtering out proteins with missing PDB files for multiple confirmations', - total=len(df_unique)): - - af_confs = [os.path.join(self.af_conf_dir, f) for f in files \ - if f.startswith(pid)] - - # for alphaflow af_confs will be a single file with multiple models - if self.alphaflow and len(af_confs) > 0: - model_count = Chain.get_model_count(af_confs[0]) - else: - model_count = len(af_confs) - - if model_count < 5: - missing_conf.add(pid) - logging.debug(f'missing conf for {pid} in {self.af_conf_dir}') - continue - - af_seq = Chain(af_confs[0]).sequence - if seq != af_seq: - logging.debug(f'Mismatched sequence for {pid}') - missing_conf.add(pid) - continue + missing, mismatched = self.check_missing_confs(df_unique, self.af_conf_dir, self.__class__ is PDBbindDataset) - if len(missing_conf) > 0: - filtered_df = df[~df.prot_id.isin(missing_conf)] - logging.warning(f'{len(missing_conf)} mismatched or missing pids') + if len(missing) > 0: + filtered_df = df[~df.prot_id.isin(missing)] + logging.warning(f'{len(missing)} missing pids') else: filtered_df = df + if len(mismatched) > 0: + filtered_df = filtered_df[~filtered_df.prot_id.isin(mismatched)] + # adjust all in dataframe to the ones with the correct pid + logging.warning(f'{len(mismatched)} mismatched pids') + + logging.debug(f'Number of codes: {len(filtered_df)}/{len(df)}') # we are done filtering if ligand doesnt need filtering if not (self.ligand_edge in cfg.OPT_REQUIRES_SDF or self.ligand_feature in cfg.OPT_REQUIRES_SDF): return filtered_df - + + ########### + # filter ligands # removing rows with ligands that have missing sdf files: unique_lig = filtered_df[['lig_id']].drop_duplicates() missing = set() @@ -526,13 +571,13 @@ def file_real(fp): self.df.to_csv(self.processed_paths[3]) logging.info('Created cleaned_XY.csv file') + ###### Get Protein Graphs ###### + processed_prots = self._create_protein_graphs(self.df, self.pro_feat_opt, self.pro_edge_opt) + ###### Get Ligand Graphs ###### processed_ligs = self._create_ligand_graphs(self.df, self.ligand_feature, self.ligand_edge) - ###### Get Protein Graphs ###### - processed_prots = self._create_protein_graphs(self.df, self.pro_feat_opt, self.pro_edge_opt) - ###### Save ###### logging.info('Saving...') torch.save(processed_prots, self.processed_paths[1]) @@ -569,12 +614,12 @@ def __init__(self, save_root=f'{cfg.DATA_ROOT}/PDBbindDataset', aln_dir=aln_dir, cmap_threshold=cmap_threshold, feature_opt=feature_opt, *args, **kwargs) - def af_conf_files(self, pid) -> list[str]|str: - if self.df is not None and pid in self.df.index: + def af_conf_files(self, pid, map_to_pid=True) -> list[str]|str: + if self.df is not None and pid in self.df.index and map_to_pid: pid = self.df.loc[pid]['prot_id'] if self.alphaflow: - fp = f'{self.af_conf_dir}/{pid}.pdb' + fp = os.path.join(self.af_conf_dir, f'{pid}.pdb') fp = fp if os.path.exists(fp) else None return fp @@ -905,12 +950,12 @@ def pre_process(self): no_aln = [c for c in codes if (not Processor.check_aln_lines(self.aln_p(c)))] # filters out those that do not have aln file - print(f'Number of codes with invalid aln files: {len(no_aln)} / {len(codes)}') + print(f'Number of codes with valid aln files: {len(codes)-len(no_aln)} / {len(codes)}') # Checking that contact maps are present for each code: # (Created by psconsc4) no_cmap = [c for c in codes if not os.path.isfile(self.cmap_p(c))] - print(f'Number of codes without cmap files: {len(no_cmap)} out of {len(codes)}') + print(f'Number of codes with cmap files: {len(codes)-len(no_cmap)} / {len(codes)}') # Checking that structure and af_confs files are present if required: no_confs = [] @@ -931,7 +976,7 @@ def pre_process(self): (len(self.af_conf_files(c)) < 5))] # only if not for foldseek # WARNING: TEMPORARY FIX FOR DAVIS (TESK1 highQ structure is mismatched...) - no_confs.append('TESK1') + if not self.alphaflow: no_confs.append('TESK1') logging.warning(f'Number of codes missing {"aflow" if self.alphaflow else "af2"} ' + \ f'conformations: {len(no_confs)} / {len(codes)}') diff --git a/src/utils/alphaflow.py b/src/utils/alphaflow.py index b9b1a19..481f43e 100644 --- a/src/utils/alphaflow.py +++ b/src/utils/alphaflow.py @@ -8,7 +8,7 @@ #%% from src.data_prep.datasets import BaseDataset import pandas as pd -csv_p = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_ring3_original_binary/full/XY.csv" +csv_p = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_binary_original_binary/full/XY.csv" df = pd.read_csv(csv_p, index_col=0) df_unique = BaseDataset.get_unique_prots(df) @@ -30,25 +30,27 @@ # %% files are .pdb with 50 "models" in each +created = {} for file in tqdm(os.listdir(alphaflow_dir)): if not file.endswith('.pdb'): continue code, _ = os.path.splitext(file) - pid = df_unique.loc[code].prot_id + pid = df.loc[code].prot_id src, dst = f"{alphaflow_dir}/{file}", f"{ln_dir}/{pid}.pdb" if not os.path.exists(dst): + created[src] = dst os.symlink(src,dst) # %% RUN RING3 -# %% Run RING3 on finished confirmations from AlphaFlow -from src.utils.residue import Ring3Runner +# # %% Run RING3 on finished confirmations from AlphaFlow +# from src.utils.residue import Ring3Runner -files = [os.path.join(ln_dir, f) for f in \ - os.listdir(ln_dir) if f.endswith('.pdb')] +# files = [os.path.join(ln_dir, f) for f in \ +# os.listdir(ln_dir) if f.endswith('.pdb')] -Ring3Runner.run_multiprocess(pdb_fps=files) +# Ring3Runner.run_multiprocess(pdb_fps=files) # %% checking the number of models in each file, flagging any issues: diff --git a/src/utils/config.py b/src/utils/config.py index 2d637de..a835978 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -117,15 +117,15 @@ class LIG_FEAT_OPT(StringEnum): elif 'graham' in DOMAIN_NAME: CLUSTER = 'graham' SLURM_CONSTRAINT = 'cascade,v100' - DATA_ROOT = os.path.abspath(Path.home() / 'scratch' / 'data' ) + DATA_ROOT = os.path.abspath(Path.home() / 'scratch' / 'data') elif 'cedar' in DOMAIN_NAME: CLUSTER = 'cedar' SLURM_GPU_NAME = 'v100l' - DATA_ROOT = os.path.abspath(Path.home() / 'scratch' / 'data' ) + DATA_ROOT = os.path.abspath(Path.home() / 'scratch' / 'data') elif 'narval' in DOMAIN_NAME: CLUSTER = 'narval' SLURM_GPU_NAME = 'a100' - DATA_ROOT = os.path.abspath(Path.home() / 'scratch' / 'data' ) + DATA_ROOT = os.path.abspath(Path.home() / 'scratch' / 'data') # bin paths FOLDSEEK_BIN = f'{Path.home()}/lib/foldseek/bin/foldseek'