diff --git a/playground.py b/playground.py index 1aa4c38..d890a3c 100644 --- a/playground.py +++ b/playground.py @@ -8,7 +8,7 @@ from multiprocessing import Pool, cpu_count from src.data_prep.datasets import BaseDataset -df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base.csv", index_col=0) +# df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base.csv", index_col=0) df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base_filtered.csv", index_col=0) logging.getLogger().setLevel(logging.DEBUG) @@ -54,6 +54,8 @@ def process_protein_multiprocessing(args): # return pid, af_seq return pid, matching_code + if matching_code != code: + return None, matching_code return None, None #%% check_missing_confs method @@ -131,17 +133,18 @@ def process_protein_multiprocessing(args): ######################################################################## ########################## BUILD DATASETS ############################## ######################################################################## +import os from src.data_prep.init_dataset import create_datasets from src import cfg import logging cfg.logger.setLevel(logging.DEBUG) -splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind/' -create_datasets([cfg.DATA_OPT.PDBbind, cfg.DATA_OPT.davis], +splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/davis/' +create_datasets(cfg.DATA_OPT.davis, feat_opt=cfg.PRO_FEAT_OPT.nomsa, - edge_opt=cfg.PRO_EDGE_OPT.aflow, - ligand_features=cfg.LIG_FEAT_OPT.original, #[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp], - ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=False, + edge_opt=[cfg.PRO_EDGE_OPT.aflow, cfg.PRO_EDGE_OPT.binary], + ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp], + ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=True, k_folds=5, test_prots_csv=f'{splits}/test.csv', val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)], diff --git a/src/data_prep/datasets.py b/src/data_prep/datasets.py index eab2a0d..a10a307 100644 --- a/src/data_prep/datasets.py +++ b/src/data_prep/datasets.py @@ -315,25 +315,18 @@ def process_protein_multiprocessing(args): """ Checks if protein has conf file and correct sequence, returns: - None, None - if it has a conf file and is correct - - pid, None - is missing a conf file - - pid, code_of_correct_seq - has the correct number of conf files but is not the correct sequence. + - pid, None - is missing a conf file + - pid, matching_code - has the correct number of conf files but is not the correct sequence. + - None, matching_code - correct seq, # of confs, but under a different file name """ - group_codes, code, pid, seq, af_conf_dir, is_pdbbind, files = args + pid, seq, af_conf_dir, is_pdbbind, files = args MIN_MODEL_COUNT = 5 - correct_seq = False - matching_code = None af_confs = [] if is_pdbbind: - for c in group_codes: - af_fp = os.path.join(af_conf_dir, f'{c}.pdb') - if os.path.exists(af_fp): - af_confs = [af_fp] - matching_code = code - if Chain(af_fp).sequence == seq: - correct_seq = True - break - + fp = os.path.join(af_conf_dir, f'{pid}.pdb') + if os.path.exists(fp): + af_confs = [os.path.join(af_conf_dir, f'{pid}.pdb')] else: af_confs = [os.path.join(af_conf_dir, f) for f in files if f.startswith(pid)] @@ -345,20 +338,17 @@ def process_protein_multiprocessing(args): if model_count < MIN_MODEL_COUNT: return pid, None - elif not correct_seq: # final check - af_seq = Chain(af_confs[0]).sequence - if seq != af_seq: - logging.debug(f'Mismatched sequence for {pid}') - # if matching_code == code: # something wrong here -> incorrect seq but for the right code? - # return pid, af_seq - return pid, matching_code + + af_seq = Chain(af_confs[0]).sequence + if seq != af_seq: + logging.debug(f'Mismatched sequence for {pid}') + return pid, af_seq + return None, None @staticmethod - def check_missing_confs(df:pd.DataFrame, af_conf_dir:str, is_pdbbind=False): + def check_missing_confs(df_unique:pd.DataFrame, af_conf_dir:str, is_pdbbind=False): logging.debug(f'Getting af_confs from {af_conf_dir}') - df_unique:pd.DataFrame = df.drop_duplicates('prot_id') - df_pid_groups = df.groupby(['prot_id']).groups missing = set() mismatched = {} @@ -368,14 +358,14 @@ def check_missing_confs(df:pd.DataFrame, af_conf_dir:str, is_pdbbind=False): files = [f for f in os.listdir(af_conf_dir) if f.endswith('.pdb')] with Pool(processes=cpu_count()) as pool: - tasks = [(df_pid_groups[pid], code, pid, seq, af_conf_dir, is_pdbbind, files) \ - for code, (pid, seq) in df_unique[['prot_id', 'prot_seq']].iterrows()] + tasks = [(pid, seq, af_conf_dir, is_pdbbind, files) \ + for _, (pid, seq) in df_unique[['prot_id', 'prot_seq']].iterrows()] - for pid, new_seq in tqdm(pool.imap_unordered(BaseDataset.process_protein_multiprocessing, tasks), + for pid, correct_seq in tqdm(pool.imap_unordered(BaseDataset.process_protein_multiprocessing, tasks), desc='Filtering out proteins with missing PDB files for multiple confirmations', total=len(tasks)): - if new_seq is not None: - mismatched[pid] = new_seq + if correct_seq is not None: + mismatched[pid] = correct_seq elif pid is not None: # just pid -> missing af files missing.add(pid) @@ -404,7 +394,7 @@ def clean_XY(self, df:pd.DataFrame, max_seq_len=None): missing = set() mismatched = {} if self.pro_edge_opt in cfg.OPT_REQUIRES_CONF: - missing, mismatched = self.check_missing_confs(df, self.af_conf_dir, self.__class__ is PDBbindDataset) + missing, mismatched = self.check_missing_confs(df_unique, self.af_conf_dir, self.__class__ is PDBbindDataset) if len(missing) > 0: filtered_df = df[~df.prot_id.isin(missing)] @@ -413,7 +403,7 @@ def clean_XY(self, df:pd.DataFrame, max_seq_len=None): filtered_df = df if len(mismatched) > 0: - filtered_df = df[~df.prot_id.isin(mismatched)] + filtered_df = filtered_df[~filtered_df.prot_id.isin(mismatched)] # adjust all in dataframe to the ones with the correct pid logging.warning(f'{len(mismatched)} mismatched pids') @@ -581,13 +571,13 @@ def file_real(fp): self.df.to_csv(self.processed_paths[3]) logging.info('Created cleaned_XY.csv file') + ###### Get Protein Graphs ###### + processed_prots = self._create_protein_graphs(self.df, self.pro_feat_opt, self.pro_edge_opt) + ###### Get Ligand Graphs ###### processed_ligs = self._create_ligand_graphs(self.df, self.ligand_feature, self.ligand_edge) - ###### Get Protein Graphs ###### - processed_prots = self._create_protein_graphs(self.df, self.pro_feat_opt, self.pro_edge_opt) - ###### Save ###### logging.info('Saving...') torch.save(processed_prots, self.processed_paths[1]) @@ -624,12 +614,12 @@ def __init__(self, save_root=f'{cfg.DATA_ROOT}/PDBbindDataset', aln_dir=aln_dir, cmap_threshold=cmap_threshold, feature_opt=feature_opt, *args, **kwargs) - def af_conf_files(self, pid) -> list[str]|str: - if self.df is not None and pid in self.df.index: + def af_conf_files(self, pid, map_to_pid=True) -> list[str]|str: + if self.df is not None and pid in self.df.index and map_to_pid: pid = self.df.loc[pid]['prot_id'] if self.alphaflow: - fp = f'{self.af_conf_dir}/{pid}.pdb' + fp = os.path.join(self.af_conf_dir, f'{pid}.pdb') fp = fp if os.path.exists(fp) else None return fp @@ -960,12 +950,12 @@ def pre_process(self): no_aln = [c for c in codes if (not Processor.check_aln_lines(self.aln_p(c)))] # filters out those that do not have aln file - print(f'Number of codes with invalid aln files: {len(no_aln)} / {len(codes)}') + print(f'Number of codes with valid aln files: {len(codes)-len(no_aln)} / {len(codes)}') # Checking that contact maps are present for each code: # (Created by psconsc4) no_cmap = [c for c in codes if not os.path.isfile(self.cmap_p(c))] - print(f'Number of codes without cmap files: {len(no_cmap)} out of {len(codes)}') + print(f'Number of codes with cmap files: {len(codes)-len(no_cmap)} / {len(codes)}') # Checking that structure and af_confs files are present if required: no_confs = [] @@ -986,7 +976,7 @@ def pre_process(self): (len(self.af_conf_files(c)) < 5))] # only if not for foldseek # WARNING: TEMPORARY FIX FOR DAVIS (TESK1 highQ structure is mismatched...) - no_confs.append('TESK1') + if not self.alphaflow: no_confs.append('TESK1') logging.warning(f'Number of codes missing {"aflow" if self.alphaflow else "af2"} ' + \ f'conformations: {len(no_confs)} / {len(codes)}') diff --git a/src/data_prep/init_dataset.py b/src/data_prep/init_dataset.py index b543f9b..7c3df45 100644 --- a/src/data_prep/init_dataset.py +++ b/src/data_prep/init_dataset.py @@ -89,7 +89,7 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis elif data == 'PDBbind': if 'af_conf_dir' not in kwargs: if EDGE in cfg.OPT_REQUIRES_AFLOW_CONF: - kwargs['af_conf_dir'] = f'{data_root}/pdbbind/alphaflow_io/out_pdb_MD-distilled/' + kwargs['af_conf_dir'] = f'{data_root}/pdbbind/alphaflow_io/out_pid_ln/' else: kwargs['af_conf_dir'] = f'{data_root}/pdbbind/pdbbind_af2_out/all_ln/' dataset = PDBbindDataset( diff --git a/src/utils/alphaflow.py b/src/utils/alphaflow.py index b9b1a19..481f43e 100644 --- a/src/utils/alphaflow.py +++ b/src/utils/alphaflow.py @@ -8,7 +8,7 @@ #%% from src.data_prep.datasets import BaseDataset import pandas as pd -csv_p = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_ring3_original_binary/full/XY.csv" +csv_p = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_binary_original_binary/full/XY.csv" df = pd.read_csv(csv_p, index_col=0) df_unique = BaseDataset.get_unique_prots(df) @@ -30,25 +30,27 @@ # %% files are .pdb with 50 "models" in each +created = {} for file in tqdm(os.listdir(alphaflow_dir)): if not file.endswith('.pdb'): continue code, _ = os.path.splitext(file) - pid = df_unique.loc[code].prot_id + pid = df.loc[code].prot_id src, dst = f"{alphaflow_dir}/{file}", f"{ln_dir}/{pid}.pdb" if not os.path.exists(dst): + created[src] = dst os.symlink(src,dst) # %% RUN RING3 -# %% Run RING3 on finished confirmations from AlphaFlow -from src.utils.residue import Ring3Runner +# # %% Run RING3 on finished confirmations from AlphaFlow +# from src.utils.residue import Ring3Runner -files = [os.path.join(ln_dir, f) for f in \ - os.listdir(ln_dir) if f.endswith('.pdb')] +# files = [os.path.join(ln_dir, f) for f in \ +# os.listdir(ln_dir) if f.endswith('.pdb')] -Ring3Runner.run_multiprocess(pdb_fps=files) +# Ring3Runner.run_multiprocess(pdb_fps=files) # %% checking the number of models in each file, flagging any issues: