Skip to content

Commit

Permalink
revert(datasets): using codes is a bad idea for pdbbind alphaflow fil…
Browse files Browse the repository at this point in the history
…e names #125
  • Loading branch information
jyaacoub committed Jul 25, 2024
1 parent 60e8a70 commit bbeecea
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 54 deletions.
15 changes: 9 additions & 6 deletions playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from multiprocessing import Pool, cpu_count
from src.data_prep.datasets import BaseDataset

df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base.csv", index_col=0)
# df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base.csv", index_col=0)
df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base_filtered.csv", index_col=0)
logging.getLogger().setLevel(logging.DEBUG)

Expand Down Expand Up @@ -54,6 +54,8 @@ def process_protein_multiprocessing(args):
# return pid, af_seq
return pid, matching_code

if matching_code != code:
return None, matching_code
return None, None

#%% check_missing_confs method
Expand Down Expand Up @@ -131,17 +133,18 @@ def process_protein_multiprocessing(args):
########################################################################
########################## BUILD DATASETS ##############################
########################################################################
import os
from src.data_prep.init_dataset import create_datasets
from src import cfg
import logging
cfg.logger.setLevel(logging.DEBUG)

splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind/'
create_datasets([cfg.DATA_OPT.PDBbind, cfg.DATA_OPT.davis],
splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/davis/'
create_datasets(cfg.DATA_OPT.davis,
feat_opt=cfg.PRO_FEAT_OPT.nomsa,
edge_opt=cfg.PRO_EDGE_OPT.aflow,
ligand_features=cfg.LIG_FEAT_OPT.original, #[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp],
ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=False,
edge_opt=[cfg.PRO_EDGE_OPT.aflow, cfg.PRO_EDGE_OPT.binary],
ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp],
ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=True,
k_folds=5,
test_prots_csv=f'{splits}/test.csv',
val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)],
Expand Down
70 changes: 30 additions & 40 deletions src/data_prep/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,25 +315,18 @@ def process_protein_multiprocessing(args):
"""
Checks if protein has conf file and correct sequence, returns:
- None, None - if it has a conf file and is correct
- pid, None - is missing a conf file
- pid, code_of_correct_seq - has the correct number of conf files but is not the correct sequence.
- pid, None - is missing a conf file
- pid, matching_code - has the correct number of conf files but is not the correct sequence.
- None, matching_code - correct seq, # of confs, but under a different file name
"""
group_codes, code, pid, seq, af_conf_dir, is_pdbbind, files = args
pid, seq, af_conf_dir, is_pdbbind, files = args
MIN_MODEL_COUNT = 5

correct_seq = False
matching_code = None
af_confs = []
if is_pdbbind:
for c in group_codes:
af_fp = os.path.join(af_conf_dir, f'{c}.pdb')
if os.path.exists(af_fp):
af_confs = [af_fp]
matching_code = code
if Chain(af_fp).sequence == seq:
correct_seq = True
break

fp = os.path.join(af_conf_dir, f'{pid}.pdb')
if os.path.exists(fp):
af_confs = [os.path.join(af_conf_dir, f'{pid}.pdb')]
else:
af_confs = [os.path.join(af_conf_dir, f) for f in files if f.startswith(pid)]

Expand All @@ -345,20 +338,17 @@ def process_protein_multiprocessing(args):

if model_count < MIN_MODEL_COUNT:
return pid, None
elif not correct_seq: # final check
af_seq = Chain(af_confs[0]).sequence
if seq != af_seq:
logging.debug(f'Mismatched sequence for {pid}')
# if matching_code == code: # something wrong here -> incorrect seq but for the right code?
# return pid, af_seq
return pid, matching_code

af_seq = Chain(af_confs[0]).sequence
if seq != af_seq:
logging.debug(f'Mismatched sequence for {pid}')
return pid, af_seq

return None, None

@staticmethod
def check_missing_confs(df:pd.DataFrame, af_conf_dir:str, is_pdbbind=False):
def check_missing_confs(df_unique:pd.DataFrame, af_conf_dir:str, is_pdbbind=False):
logging.debug(f'Getting af_confs from {af_conf_dir}')
df_unique:pd.DataFrame = df.drop_duplicates('prot_id')
df_pid_groups = df.groupby(['prot_id']).groups

missing = set()
mismatched = {}
Expand All @@ -368,14 +358,14 @@ def check_missing_confs(df:pd.DataFrame, af_conf_dir:str, is_pdbbind=False):
files = [f for f in os.listdir(af_conf_dir) if f.endswith('.pdb')]

with Pool(processes=cpu_count()) as pool:
tasks = [(df_pid_groups[pid], code, pid, seq, af_conf_dir, is_pdbbind, files) \
for code, (pid, seq) in df_unique[['prot_id', 'prot_seq']].iterrows()]
tasks = [(pid, seq, af_conf_dir, is_pdbbind, files) \
for _, (pid, seq) in df_unique[['prot_id', 'prot_seq']].iterrows()]

for pid, new_seq in tqdm(pool.imap_unordered(BaseDataset.process_protein_multiprocessing, tasks),
for pid, correct_seq in tqdm(pool.imap_unordered(BaseDataset.process_protein_multiprocessing, tasks),
desc='Filtering out proteins with missing PDB files for multiple confirmations',
total=len(tasks)):
if new_seq is not None:
mismatched[pid] = new_seq
if correct_seq is not None:
mismatched[pid] = correct_seq
elif pid is not None: # just pid -> missing af files
missing.add(pid)

Expand Down Expand Up @@ -404,7 +394,7 @@ def clean_XY(self, df:pd.DataFrame, max_seq_len=None):
missing = set()
mismatched = {}
if self.pro_edge_opt in cfg.OPT_REQUIRES_CONF:
missing, mismatched = self.check_missing_confs(df, self.af_conf_dir, self.__class__ is PDBbindDataset)
missing, mismatched = self.check_missing_confs(df_unique, self.af_conf_dir, self.__class__ is PDBbindDataset)

if len(missing) > 0:
filtered_df = df[~df.prot_id.isin(missing)]
Expand All @@ -413,7 +403,7 @@ def clean_XY(self, df:pd.DataFrame, max_seq_len=None):
filtered_df = df

if len(mismatched) > 0:
filtered_df = df[~df.prot_id.isin(mismatched)]
filtered_df = filtered_df[~filtered_df.prot_id.isin(mismatched)]
# adjust all in dataframe to the ones with the correct pid
logging.warning(f'{len(mismatched)} mismatched pids')

Expand Down Expand Up @@ -581,13 +571,13 @@ def file_real(fp):
self.df.to_csv(self.processed_paths[3])
logging.info('Created cleaned_XY.csv file')

###### Get Protein Graphs ######
processed_prots = self._create_protein_graphs(self.df, self.pro_feat_opt, self.pro_edge_opt)


###### Get Ligand Graphs ######
processed_ligs = self._create_ligand_graphs(self.df, self.ligand_feature, self.ligand_edge)

###### Get Protein Graphs ######
processed_prots = self._create_protein_graphs(self.df, self.pro_feat_opt, self.pro_edge_opt)

###### Save ######
logging.info('Saving...')
torch.save(processed_prots, self.processed_paths[1])
Expand Down Expand Up @@ -624,12 +614,12 @@ def __init__(self, save_root=f'{cfg.DATA_ROOT}/PDBbindDataset',
aln_dir=aln_dir, cmap_threshold=cmap_threshold,
feature_opt=feature_opt, *args, **kwargs)

def af_conf_files(self, pid) -> list[str]|str:
if self.df is not None and pid in self.df.index:
def af_conf_files(self, pid, map_to_pid=True) -> list[str]|str:
if self.df is not None and pid in self.df.index and map_to_pid:
pid = self.df.loc[pid]['prot_id']

if self.alphaflow:
fp = f'{self.af_conf_dir}/{pid}.pdb'
fp = os.path.join(self.af_conf_dir, f'{pid}.pdb')
fp = fp if os.path.exists(fp) else None
return fp

Expand Down Expand Up @@ -960,12 +950,12 @@ def pre_process(self):
no_aln = [c for c in codes if (not Processor.check_aln_lines(self.aln_p(c)))]

# filters out those that do not have aln file
print(f'Number of codes with invalid aln files: {len(no_aln)} / {len(codes)}')
print(f'Number of codes with valid aln files: {len(codes)-len(no_aln)} / {len(codes)}')

# Checking that contact maps are present for each code:
# (Created by psconsc4)
no_cmap = [c for c in codes if not os.path.isfile(self.cmap_p(c))]
print(f'Number of codes without cmap files: {len(no_cmap)} out of {len(codes)}')
print(f'Number of codes with cmap files: {len(codes)-len(no_cmap)} / {len(codes)}')

# Checking that structure and af_confs files are present if required:
no_confs = []
Expand All @@ -986,7 +976,7 @@ def pre_process(self):
(len(self.af_conf_files(c)) < 5))] # only if not for foldseek

# WARNING: TEMPORARY FIX FOR DAVIS (TESK1 highQ structure is mismatched...)
no_confs.append('TESK1')
if not self.alphaflow: no_confs.append('TESK1')

logging.warning(f'Number of codes missing {"aflow" if self.alphaflow else "af2"} ' + \
f'conformations: {len(no_confs)} / {len(codes)}')
Expand Down
2 changes: 1 addition & 1 deletion src/data_prep/init_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis
elif data == 'PDBbind':
if 'af_conf_dir' not in kwargs:
if EDGE in cfg.OPT_REQUIRES_AFLOW_CONF:
kwargs['af_conf_dir'] = f'{data_root}/pdbbind/alphaflow_io/out_pdb_MD-distilled/'
kwargs['af_conf_dir'] = f'{data_root}/pdbbind/alphaflow_io/out_pid_ln/'
else:
kwargs['af_conf_dir'] = f'{data_root}/pdbbind/pdbbind_af2_out/all_ln/'
dataset = PDBbindDataset(
Expand Down
16 changes: 9 additions & 7 deletions src/utils/alphaflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#%%
from src.data_prep.datasets import BaseDataset
import pandas as pd
csv_p = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_ring3_original_binary/full/XY.csv"
csv_p = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_binary_original_binary/full/XY.csv"

df = pd.read_csv(csv_p, index_col=0)
df_unique = BaseDataset.get_unique_prots(df)
Expand All @@ -30,25 +30,27 @@


# %% files are .pdb with 50 "models" in each
created = {}
for file in tqdm(os.listdir(alphaflow_dir)):
if not file.endswith('.pdb'):
continue

code, _ = os.path.splitext(file)
pid = df_unique.loc[code].prot_id
pid = df.loc[code].prot_id
src, dst = f"{alphaflow_dir}/{file}", f"{ln_dir}/{pid}.pdb"
if not os.path.exists(dst):
created[src] = dst
os.symlink(src,dst)


# %% RUN RING3
# %% Run RING3 on finished confirmations from AlphaFlow
from src.utils.residue import Ring3Runner
# # %% Run RING3 on finished confirmations from AlphaFlow
# from src.utils.residue import Ring3Runner

files = [os.path.join(ln_dir, f) for f in \
os.listdir(ln_dir) if f.endswith('.pdb')]
# files = [os.path.join(ln_dir, f) for f in \
# os.listdir(ln_dir) if f.endswith('.pdb')]

Ring3Runner.run_multiprocess(pdb_fps=files)
# Ring3Runner.run_multiprocess(pdb_fps=files)


# %% checking the number of models in each file, flagging any issues:
Expand Down

0 comments on commit bbeecea

Please sign in to comment.