Skip to content

Commit

Permalink
refactor(datasets): return code to map to if mismatch seq #125
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed Jul 25, 2024
1 parent fb5bd35 commit 822e277
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 105 deletions.
227 changes: 140 additions & 87 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,151 @@
# %%
#%%
import os
import logging
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import pandas as pd
train_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/DavisKibaDataset/davis/nomsa_binary_original_binary/full/XY.csv')
test_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/davis/test.csv')
train_df = train_df[~train_df.prot_id.isin(set(test_df.prot_id))]

# %%
from src.utils.residue import Chain
from multiprocessing import Pool, cpu_count
from src.data_prep.datasets import BaseDataset

df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base.csv", index_col=0)
df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base_filtered.csv", index_col=0)
logging.getLogger().setLevel(logging.DEBUG)

def process_protein_multiprocessing(args):
"""
Checks if protein has conf file and correct sequence, returns:
- None, None - if it has a conf file and is correct
- pid, None - is missing a conf file
- pid, seq - has a conf file but is not correct sequence.
"""
group_codes, code, pid, seq, af_conf_dir, is_pdbbind, files = args
MIN_MODEL_COUNT = 5

correct_seq = False
matching_code = None
af_confs = []
if is_pdbbind:
for c in group_codes:
af_fp = os.path.join(af_conf_dir, f'{c}.pdb')
if os.path.exists(af_fp):
af_confs = [af_fp]
matching_code = code
if Chain(af_fp).sequence == seq:
correct_seq = True
break

else:
af_confs = [os.path.join(af_conf_dir, f) for f in files if f.startswith(pid)]

if len(af_confs) == 0:
return pid, None

# either all models in one pdb file (alphaflow) or spread out across multiple files (AF2 msa subsampling)
model_count = len(af_confs) if len(af_confs) > 1 else 5# Chain.get_model_count(af_confs[0])

if model_count < MIN_MODEL_COUNT:
return pid, None
elif not correct_seq: # final check
af_seq = Chain(af_confs[0]).sequence
if seq != af_seq:
logging.debug(f'Mismatched sequence for {pid}')
# if matching_code == code: # something wrong here -> incorrect seq but for the right code?
# return pid, af_seq
return pid, matching_code

return None, None

#%% check_missing_confs method
af_conf_dir:str = '/cluster/home/t122995uhn/projects/data/pdbbind/alphaflow_io/out_pdb_MD-distilled/'
is_pdbbind=True

df_unique:pd.DataFrame = df.drop_duplicates('prot_id')
df_pid_groups = df.groupby(['prot_id']).groups

missing = set()
mismatched = {}
# total of 3728 unique proteins with alphaflow confs (named by pdb ID)
files = None
if not is_pdbbind:
files = [f for f in os.listdir(af_conf_dir) if f.endswith('.pdb')]

with Pool(processes=cpu_count()) as pool:
tasks = [(df_pid_groups[pid], code, pid, seq, af_conf_dir, is_pdbbind, files) \
for code, (pid, seq) in df_unique[['prot_id', 'prot_seq']].iterrows()]

for pid, new_seq in tqdm(pool.imap_unordered(process_protein_multiprocessing, tasks),
desc='Filtering out proteins with missing PDB files for multiple confirmations',
total=len(tasks)):
if new_seq is not None:
mismatched[pid] = new_seq
elif pid is not None: # just pid -> missing af files
missing.add(pid)

print(len(missing),len(mismatched))

#%% make subsitutions for rows
df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base.csv", index_col=0)
df_mismatched = pd.DataFrame.from_dict(mismatched, orient='index', columns=['code'])
df_mismatched_sub = df.loc[df_mismatched['code']][['prot_id', 'prot_seq']].reset_index()
df_mismatched = df_mismatched.merge(df_mismatched_sub, on='code')

df_mismatched = df_mismatched.merge(df_mismatched_sub, on='code')
dff = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base_filtered.csv")
dffm = dff.merge(df_mismatched, on='code')
#%%
from src.data_prep.datasets import BaseDataset
import pandas as pd
csv_p = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_binary_original_binary/full/XY.csv"

davis_test_df = pd.read_csv(f"/home/jean/projects/MutDTA/splits/davis/test.csv")
davis_test_df['gene'] = davis_test_df['prot_id'].str.split('(').str[0]

#%% ONCO KB MERGE
onco_df = pd.read_csv("../data/oncoKB_DrugGenePairList.csv")
davis_join_onco = davis_test_df.merge(onco_df.drop_duplicates("gene"), on="gene", how="inner")
df = pd.read_csv(csv_p, index_col=0)

# %%
onco_df = pd.read_csv("../data/oncoKB_DrugGenePairList.csv")
onco_df.merge(davis_test_df.drop_duplicates("gene"), on="gene", how="inner").value_counts("gene")








import os
from tqdm import tqdm
alphaflow_dir = "/cluster/home/t122995uhn/projects/data/pdbbind/alphaflow_io/out_pdb_MD-distilled/"
ln_dir = "/cluster/home/t122995uhn/projects/data/pdbbind/alphaflow_io/out_pid_ln/"

os.makedirs(ln_dir, exist_ok=True)

# First, check and remove any broken links in the destination directory
for link_file in tqdm(os.listdir(ln_dir), desc="Checking for broken links"):
ln_p = os.path.join(ln_dir, link_file)
if os.path.islink(ln_p) and not os.path.exists(ln_p):
print(f"Removing broken link: {ln_p}")
os.remove(ln_p)


# %% files are .pdb with 50 "models" in each
for file in tqdm(os.listdir(alphaflow_dir)):
if not file.endswith('.pdb'):
continue

code, _ = os.path.splitext(file)
pid = df.loc[code].prot_id
src, dst = f"{alphaflow_dir}/{file}", f"{ln_dir}/{pid}.pdb"
if not os.path.exists(dst):
os.symlink(src,dst)

# %%
from src.train_test.splitting import resplit
#%%
########################################################################
########################## BUILD DATASETS ##############################
########################################################################
from src.data_prep.init_dataset import create_datasets
from src import cfg
import logging
cfg.logger.setLevel(logging.DEBUG)

db_p = lambda x: f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/nomsa_{x}_gvp_binary'

db = resplit(dataset=db_p('binary'), split_files=db_p('aflow'), use_train_set=True)

splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind/'
create_datasets([cfg.DATA_OPT.PDBbind, cfg.DATA_OPT.davis],
feat_opt=cfg.PRO_FEAT_OPT.nomsa,
edge_opt=cfg.PRO_EDGE_OPT.aflow,
ligand_features=cfg.LIG_FEAT_OPT.original, #[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp],
ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=False,
k_folds=5,
test_prots_csv=f'{splits}/test.csv',
val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)],
data_root=os.path.abspath('../data/test/'))


# %%
Expand Down Expand Up @@ -69,62 +181,3 @@
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance")
plt.xticks(rotation=45)

#%%
########################################################################
########################## BUILD DATASETS ##############################
########################################################################
from src.data_prep.init_dataset import create_datasets
from src import cfg
import logging
cfg.logger.setLevel(logging.DEBUG)

splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind/'
create_datasets(cfg.DATA_OPT.PDBbind,
feat_opt=cfg.PRO_FEAT_OPT.nomsa,
edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow],
ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp],
ligand_edges=cfg.LIG_EDGE_OPT.binary,
k_folds=5,
test_prots_csv=f'{splits}/test.csv',
val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)])

# %%
from src.utils.loader import Loader

db_aflow = Loader.load_dataset('../data/DavisKibaDataset/davis/nomsa_aflow_original_binary/full')
db = Loader.load_dataset('../data/DavisKibaDataset/davis/nomsa_binary_original_binary/full')

# %%
# 5-fold cross validation + test set
import pandas as pd
from src import cfg
from src.train_test.splitting import balanced_kfold_split
from src.utils.loader import Loader
test_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind_test.csv')
test_prots = set(test_df.prot_id)
db = Loader.load_dataset(f'{cfg.DATA_ROOT}/PDBbindDataset/nomsa_binary_original_binary/full/')

train, val, test = balanced_kfold_split(db,
k_folds=5, test_split=0.1, val_split=0.1,
test_prots=test_prots, random_seed=0, verbose=True
)

#%%
db.save_subset_folds(train, 'train')
db.save_subset_folds(val, 'val')
db.save_subset(test, 'test')

#%%
import shutil, os

src = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_binary_original_binary/"
dst = "/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind"
os.makedirs(dst, exist_ok=True)

for i in range(5):
sfile = f"{src}/val{i}/XY.csv"
dfile = f"{dst}/val{i}.csv"
shutil.copyfile(sfile, dfile)

# %%
34 changes: 16 additions & 18 deletions src/data_prep/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,6 @@ def get_unique_prots(df, keep_len=False) -> pd.DataFrame:

def load(self):
# loading cleaned XY.csv file
# if self.df is None: # WARNING: HOT FIX to be compatible with old datasets
# self.df = pd.read_csv(self.processed_paths[3], index_col=0)
self.df = pd.read_csv(self.processed_paths[3], index_col=0)

self._indices = self.df.index
Expand Down Expand Up @@ -318,18 +316,20 @@ def process_protein_multiprocessing(args):
Checks if protein has conf file and correct sequence, returns:
- None, None - if it has a conf file and is correct
- pid, None - is missing a conf file
- pid, seq - has the correct number of conf files but is not the correct sequence.
- pid, code_of_correct_seq - has the correct number of conf files but is not the correct sequence.
"""
codes, pid, seq, af_conf_dir, is_pdbbind, files = args
group_codes, code, pid, seq, af_conf_dir, is_pdbbind, files = args
MIN_MODEL_COUNT = 5

correct_seq = False
matching_code = None
af_confs = []
if is_pdbbind:
af_confs = []
for code in codes:
af_fp = os.path.join(af_conf_dir, f'{code}.pdb')
for c in group_codes:
af_fp = os.path.join(af_conf_dir, f'{c}.pdb')
if os.path.exists(af_fp):
af_confs = [af_fp]
matching_code = code
if Chain(af_fp).sequence == seq:
correct_seq = True
break
Expand All @@ -349,10 +349,11 @@ def process_protein_multiprocessing(args):
af_seq = Chain(af_confs[0]).sequence
if seq != af_seq:
logging.debug(f'Mismatched sequence for {pid}')
return pid, af_seq

# if matching_code == code: # something wrong here -> incorrect seq but for the right code?
# return pid, af_seq
return pid, matching_code
return None, None

@staticmethod
def check_missing_confs(df:pd.DataFrame, af_conf_dir:str, is_pdbbind=False):
logging.debug(f'Getting af_confs from {af_conf_dir}')
Expand All @@ -364,14 +365,13 @@ def check_missing_confs(df:pd.DataFrame, af_conf_dir:str, is_pdbbind=False):
# total of 3728 unique proteins with alphaflow confs (named by pdb ID)
files = None
if not is_pdbbind:
logging.debug('Dataset is NOT PDBbind.')
files = [f for f in os.listdir(af_conf_dir) if f.endswith('.pdb')]

with Pool(processes=cpu_count()) as pool:
tasks = [(df_pid_groups[pid], pid, seq, af_conf_dir, is_pdbbind, files) \
for _, (pid, seq) in df_unique[['prot_id', 'prot_seq']].iterrows()]
tasks = [(df_pid_groups[pid], code, pid, seq, af_conf_dir, is_pdbbind, files) \
for code, (pid, seq) in df_unique[['prot_id', 'prot_seq']].iterrows()]

for pid, new_seq in tqdm(pool.imap_unordered(BaseDataset.process_protein_multiprocessing, tasks),
for pid, new_seq in tqdm(pool.imap_unordered(process_protein_multiprocessing, tasks),
desc='Filtering out proteins with missing PDB files for multiple confirmations',
total=len(tasks)):
if new_seq is not None:
Expand Down Expand Up @@ -413,12 +413,10 @@ def clean_XY(self, df:pd.DataFrame, max_seq_len=None):
filtered_df = df

if len(mismatched) > 0:
filtered_df = df[~df.prot_id.isin(mismatched)]
# adjust all in dataframe to the ones with the correct pid
logging.warning(f'{len(mismatched)} mismatched pids')
# TODO: update sequences and cmaps so feature extraction goes smoothly




logging.debug(f'Number of codes: {len(filtered_df)}/{len(df)}')

# we are done filtering if ligand doesnt need filtering
Expand Down

0 comments on commit 822e277

Please sign in to comment.