diff --git a/playground.py b/playground.py index 60c24f1..1aa4c38 100644 --- a/playground.py +++ b/playground.py @@ -1,39 +1,151 @@ -# %% +#%% +import os +import logging +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor import pandas as pd -train_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/DavisKibaDataset/davis/nomsa_binary_original_binary/full/XY.csv') -test_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/davis/test.csv') -train_df = train_df[~train_df.prot_id.isin(set(test_df.prot_id))] - -# %% +from src.utils.residue import Chain +from multiprocessing import Pool, cpu_count +from src.data_prep.datasets import BaseDataset + +df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base.csv", index_col=0) +df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base_filtered.csv", index_col=0) +logging.getLogger().setLevel(logging.DEBUG) + +def process_protein_multiprocessing(args): + """ + Checks if protein has conf file and correct sequence, returns: + - None, None - if it has a conf file and is correct + - pid, None - is missing a conf file + - pid, seq - has a conf file but is not correct sequence. + """ + group_codes, code, pid, seq, af_conf_dir, is_pdbbind, files = args + MIN_MODEL_COUNT = 5 + + correct_seq = False + matching_code = None + af_confs = [] + if is_pdbbind: + for c in group_codes: + af_fp = os.path.join(af_conf_dir, f'{c}.pdb') + if os.path.exists(af_fp): + af_confs = [af_fp] + matching_code = code + if Chain(af_fp).sequence == seq: + correct_seq = True + break + + else: + af_confs = [os.path.join(af_conf_dir, f) for f in files if f.startswith(pid)] + + if len(af_confs) == 0: + return pid, None + + # either all models in one pdb file (alphaflow) or spread out across multiple files (AF2 msa subsampling) + model_count = len(af_confs) if len(af_confs) > 1 else 5# Chain.get_model_count(af_confs[0]) + + if model_count < MIN_MODEL_COUNT: + return pid, None + elif not correct_seq: # final check + af_seq = Chain(af_confs[0]).sequence + if seq != af_seq: + logging.debug(f'Mismatched sequence for {pid}') + # if matching_code == code: # something wrong here -> incorrect seq but for the right code? + # return pid, af_seq + return pid, matching_code + + return None, None + +#%% check_missing_confs method +af_conf_dir:str = '/cluster/home/t122995uhn/projects/data/pdbbind/alphaflow_io/out_pdb_MD-distilled/' +is_pdbbind=True + +df_unique:pd.DataFrame = df.drop_duplicates('prot_id') +df_pid_groups = df.groupby(['prot_id']).groups + +missing = set() +mismatched = {} +# total of 3728 unique proteins with alphaflow confs (named by pdb ID) +files = None +if not is_pdbbind: + files = [f for f in os.listdir(af_conf_dir) if f.endswith('.pdb')] + +with Pool(processes=cpu_count()) as pool: + tasks = [(df_pid_groups[pid], code, pid, seq, af_conf_dir, is_pdbbind, files) \ + for code, (pid, seq) in df_unique[['prot_id', 'prot_seq']].iterrows()] + + for pid, new_seq in tqdm(pool.imap_unordered(process_protein_multiprocessing, tasks), + desc='Filtering out proteins with missing PDB files for multiple confirmations', + total=len(tasks)): + if new_seq is not None: + mismatched[pid] = new_seq + elif pid is not None: # just pid -> missing af files + missing.add(pid) + +print(len(missing),len(mismatched)) + +#%% make subsitutions for rows +df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base.csv", index_col=0) +df_mismatched = pd.DataFrame.from_dict(mismatched, orient='index', columns=['code']) +df_mismatched_sub = df.loc[df_mismatched['code']][['prot_id', 'prot_seq']].reset_index() +df_mismatched = df_mismatched.merge(df_mismatched_sub, on='code') + +df_mismatched = df_mismatched.merge(df_mismatched_sub, on='code') +dff = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base_filtered.csv") +dffm = dff.merge(df_mismatched, on='code') +#%% +from src.data_prep.datasets import BaseDataset import pandas as pd +csv_p = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_binary_original_binary/full/XY.csv" -davis_test_df = pd.read_csv(f"/home/jean/projects/MutDTA/splits/davis/test.csv") -davis_test_df['gene'] = davis_test_df['prot_id'].str.split('(').str[0] - -#%% ONCO KB MERGE -onco_df = pd.read_csv("../data/oncoKB_DrugGenePairList.csv") -davis_join_onco = davis_test_df.merge(onco_df.drop_duplicates("gene"), on="gene", how="inner") +df = pd.read_csv(csv_p, index_col=0) # %% -onco_df = pd.read_csv("../data/oncoKB_DrugGenePairList.csv") -onco_df.merge(davis_test_df.drop_duplicates("gene"), on="gene", how="inner").value_counts("gene") - - - - - - - - +import os +from tqdm import tqdm +alphaflow_dir = "/cluster/home/t122995uhn/projects/data/pdbbind/alphaflow_io/out_pdb_MD-distilled/" +ln_dir = "/cluster/home/t122995uhn/projects/data/pdbbind/alphaflow_io/out_pid_ln/" + +os.makedirs(ln_dir, exist_ok=True) + +# First, check and remove any broken links in the destination directory +for link_file in tqdm(os.listdir(ln_dir), desc="Checking for broken links"): + ln_p = os.path.join(ln_dir, link_file) + if os.path.islink(ln_p) and not os.path.exists(ln_p): + print(f"Removing broken link: {ln_p}") + os.remove(ln_p) + + +# %% files are .pdb with 50 "models" in each +for file in tqdm(os.listdir(alphaflow_dir)): + if not file.endswith('.pdb'): + continue + + code, _ = os.path.splitext(file) + pid = df.loc[code].prot_id + src, dst = f"{alphaflow_dir}/{file}", f"{ln_dir}/{pid}.pdb" + if not os.path.exists(dst): + os.symlink(src,dst) -# %% -from src.train_test.splitting import resplit +#%% +######################################################################## +########################## BUILD DATASETS ############################## +######################################################################## +from src.data_prep.init_dataset import create_datasets from src import cfg +import logging +cfg.logger.setLevel(logging.DEBUG) -db_p = lambda x: f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/nomsa_{x}_gvp_binary' - -db = resplit(dataset=db_p('binary'), split_files=db_p('aflow'), use_train_set=True) - +splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind/' +create_datasets([cfg.DATA_OPT.PDBbind, cfg.DATA_OPT.davis], + feat_opt=cfg.PRO_FEAT_OPT.nomsa, + edge_opt=cfg.PRO_EDGE_OPT.aflow, + ligand_features=cfg.LIG_FEAT_OPT.original, #[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp], + ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=False, + k_folds=5, + test_prots_csv=f'{splits}/test.csv', + val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)], + data_root=os.path.abspath('../data/test/')) # %% @@ -69,62 +181,3 @@ models=models, metrics=['cindex', 'mse'], fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance") plt.xticks(rotation=45) - -#%% -######################################################################## -########################## BUILD DATASETS ############################## -######################################################################## -from src.data_prep.init_dataset import create_datasets -from src import cfg -import logging -cfg.logger.setLevel(logging.DEBUG) - -splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind/' -create_datasets(cfg.DATA_OPT.PDBbind, - feat_opt=cfg.PRO_FEAT_OPT.nomsa, - edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow], - ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp], - ligand_edges=cfg.LIG_EDGE_OPT.binary, - k_folds=5, - test_prots_csv=f'{splits}/test.csv', - val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)]) - -# %% -from src.utils.loader import Loader - -db_aflow = Loader.load_dataset('../data/DavisKibaDataset/davis/nomsa_aflow_original_binary/full') -db = Loader.load_dataset('../data/DavisKibaDataset/davis/nomsa_binary_original_binary/full') - -# %% -# 5-fold cross validation + test set -import pandas as pd -from src import cfg -from src.train_test.splitting import balanced_kfold_split -from src.utils.loader import Loader -test_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind_test.csv') -test_prots = set(test_df.prot_id) -db = Loader.load_dataset(f'{cfg.DATA_ROOT}/PDBbindDataset/nomsa_binary_original_binary/full/') - -train, val, test = balanced_kfold_split(db, - k_folds=5, test_split=0.1, val_split=0.1, - test_prots=test_prots, random_seed=0, verbose=True - ) - -#%% -db.save_subset_folds(train, 'train') -db.save_subset_folds(val, 'val') -db.save_subset(test, 'test') - -#%% -import shutil, os - -src = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_binary_original_binary/" -dst = "/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind" -os.makedirs(dst, exist_ok=True) - -for i in range(5): - sfile = f"{src}/val{i}/XY.csv" - dfile = f"{dst}/val{i}.csv" - shutil.copyfile(sfile, dfile) - -# %% diff --git a/src/data_prep/datasets.py b/src/data_prep/datasets.py index ee3e3f9..8e368d4 100644 --- a/src/data_prep/datasets.py +++ b/src/data_prep/datasets.py @@ -242,8 +242,6 @@ def get_unique_prots(df, keep_len=False) -> pd.DataFrame: def load(self): # loading cleaned XY.csv file - # if self.df is None: # WARNING: HOT FIX to be compatible with old datasets - # self.df = pd.read_csv(self.processed_paths[3], index_col=0) self.df = pd.read_csv(self.processed_paths[3], index_col=0) self._indices = self.df.index @@ -318,18 +316,20 @@ def process_protein_multiprocessing(args): Checks if protein has conf file and correct sequence, returns: - None, None - if it has a conf file and is correct - pid, None - is missing a conf file - - pid, seq - has the correct number of conf files but is not the correct sequence. + - pid, code_of_correct_seq - has the correct number of conf files but is not the correct sequence. """ - codes, pid, seq, af_conf_dir, is_pdbbind, files = args + group_codes, code, pid, seq, af_conf_dir, is_pdbbind, files = args MIN_MODEL_COUNT = 5 correct_seq = False + matching_code = None + af_confs = [] if is_pdbbind: - af_confs = [] - for code in codes: - af_fp = os.path.join(af_conf_dir, f'{code}.pdb') + for c in group_codes: + af_fp = os.path.join(af_conf_dir, f'{c}.pdb') if os.path.exists(af_fp): af_confs = [af_fp] + matching_code = code if Chain(af_fp).sequence == seq: correct_seq = True break @@ -349,10 +349,11 @@ def process_protein_multiprocessing(args): af_seq = Chain(af_confs[0]).sequence if seq != af_seq: logging.debug(f'Mismatched sequence for {pid}') - return pid, af_seq - + # if matching_code == code: # something wrong here -> incorrect seq but for the right code? + # return pid, af_seq + return pid, matching_code return None, None - + @staticmethod def check_missing_confs(df:pd.DataFrame, af_conf_dir:str, is_pdbbind=False): logging.debug(f'Getting af_confs from {af_conf_dir}') @@ -364,14 +365,13 @@ def check_missing_confs(df:pd.DataFrame, af_conf_dir:str, is_pdbbind=False): # total of 3728 unique proteins with alphaflow confs (named by pdb ID) files = None if not is_pdbbind: - logging.debug('Dataset is NOT PDBbind.') files = [f for f in os.listdir(af_conf_dir) if f.endswith('.pdb')] with Pool(processes=cpu_count()) as pool: - tasks = [(df_pid_groups[pid], pid, seq, af_conf_dir, is_pdbbind, files) \ - for _, (pid, seq) in df_unique[['prot_id', 'prot_seq']].iterrows()] + tasks = [(df_pid_groups[pid], code, pid, seq, af_conf_dir, is_pdbbind, files) \ + for code, (pid, seq) in df_unique[['prot_id', 'prot_seq']].iterrows()] - for pid, new_seq in tqdm(pool.imap_unordered(BaseDataset.process_protein_multiprocessing, tasks), + for pid, new_seq in tqdm(pool.imap_unordered(process_protein_multiprocessing, tasks), desc='Filtering out proteins with missing PDB files for multiple confirmations', total=len(tasks)): if new_seq is not None: @@ -413,12 +413,10 @@ def clean_XY(self, df:pd.DataFrame, max_seq_len=None): filtered_df = df if len(mismatched) > 0: - filtered_df = df[~df.prot_id.isin(mismatched)] + # adjust all in dataframe to the ones with the correct pid logging.warning(f'{len(mismatched)} mismatched pids') - # TODO: update sequences and cmaps so feature extraction goes smoothly - - + logging.debug(f'Number of codes: {len(filtered_df)}/{len(df)}') # we are done filtering if ligand doesnt need filtering