Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing aflows #126

Merged
merged 6 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16,266 changes: 16,266 additions & 0 deletions df_base.csv
jyaacoub marked this conversation as resolved.
Show resolved Hide resolved

Large diffs are not rendered by default.

16,266 changes: 16,266 additions & 0 deletions df_base_filtered.csv

Large diffs are not rendered by default.

16,266 changes: 16,266 additions & 0 deletions df_iteractive.csv

Large diffs are not rendered by default.

227 changes: 140 additions & 87 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,151 @@
# %%
#%%
import os
import logging
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import pandas as pd
train_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/DavisKibaDataset/davis/nomsa_binary_original_binary/full/XY.csv')
test_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/davis/test.csv')
train_df = train_df[~train_df.prot_id.isin(set(test_df.prot_id))]

# %%
from src.utils.residue import Chain
from multiprocessing import Pool, cpu_count
from src.data_prep.datasets import BaseDataset

df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base.csv", index_col=0)
df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base_filtered.csv", index_col=0)
logging.getLogger().setLevel(logging.DEBUG)

def process_protein_multiprocessing(args):
"""
Checks if protein has conf file and correct sequence, returns:
- None, None - if it has a conf file and is correct
- pid, None - is missing a conf file
- pid, seq - has a conf file but is not correct sequence.
"""
group_codes, code, pid, seq, af_conf_dir, is_pdbbind, files = args
MIN_MODEL_COUNT = 5

correct_seq = False
matching_code = None
af_confs = []
if is_pdbbind:
for c in group_codes:
af_fp = os.path.join(af_conf_dir, f'{c}.pdb')
if os.path.exists(af_fp):
af_confs = [af_fp]
matching_code = code
if Chain(af_fp).sequence == seq:
correct_seq = True
break

else:
af_confs = [os.path.join(af_conf_dir, f) for f in files if f.startswith(pid)]

if len(af_confs) == 0:
return pid, None

# either all models in one pdb file (alphaflow) or spread out across multiple files (AF2 msa subsampling)
model_count = len(af_confs) if len(af_confs) > 1 else 5# Chain.get_model_count(af_confs[0])

if model_count < MIN_MODEL_COUNT:
return pid, None
elif not correct_seq: # final check
af_seq = Chain(af_confs[0]).sequence
if seq != af_seq:
logging.debug(f'Mismatched sequence for {pid}')
# if matching_code == code: # something wrong here -> incorrect seq but for the right code?
# return pid, af_seq
return pid, matching_code

return None, None

#%% check_missing_confs method
af_conf_dir:str = '/cluster/home/t122995uhn/projects/data/pdbbind/alphaflow_io/out_pdb_MD-distilled/'
is_pdbbind=True

df_unique:pd.DataFrame = df.drop_duplicates('prot_id')
df_pid_groups = df.groupby(['prot_id']).groups

missing = set()
mismatched = {}
# total of 3728 unique proteins with alphaflow confs (named by pdb ID)
files = None
if not is_pdbbind:
files = [f for f in os.listdir(af_conf_dir) if f.endswith('.pdb')]

with Pool(processes=cpu_count()) as pool:
tasks = [(df_pid_groups[pid], code, pid, seq, af_conf_dir, is_pdbbind, files) \
for code, (pid, seq) in df_unique[['prot_id', 'prot_seq']].iterrows()]

for pid, new_seq in tqdm(pool.imap_unordered(process_protein_multiprocessing, tasks),
desc='Filtering out proteins with missing PDB files for multiple confirmations',
total=len(tasks)):
if new_seq is not None:
mismatched[pid] = new_seq
elif pid is not None: # just pid -> missing af files
missing.add(pid)

print(len(missing),len(mismatched))

#%% make subsitutions for rows
df = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base.csv", index_col=0)
df_mismatched = pd.DataFrame.from_dict(mismatched, orient='index', columns=['code'])
df_mismatched_sub = df.loc[df_mismatched['code']][['prot_id', 'prot_seq']].reset_index()
df_mismatched = df_mismatched.merge(df_mismatched_sub, on='code')

df_mismatched = df_mismatched.merge(df_mismatched_sub, on='code')
dff = pd.read_csv("/cluster/home/t122995uhn/projects/MutDTA/df_base_filtered.csv")
dffm = dff.merge(df_mismatched, on='code')
#%%
from src.data_prep.datasets import BaseDataset
import pandas as pd
csv_p = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_binary_original_binary/full/XY.csv"

davis_test_df = pd.read_csv(f"/home/jean/projects/MutDTA/splits/davis/test.csv")
davis_test_df['gene'] = davis_test_df['prot_id'].str.split('(').str[0]

#%% ONCO KB MERGE
onco_df = pd.read_csv("../data/oncoKB_DrugGenePairList.csv")
davis_join_onco = davis_test_df.merge(onco_df.drop_duplicates("gene"), on="gene", how="inner")
df = pd.read_csv(csv_p, index_col=0)

# %%
onco_df = pd.read_csv("../data/oncoKB_DrugGenePairList.csv")
onco_df.merge(davis_test_df.drop_duplicates("gene"), on="gene", how="inner").value_counts("gene")








import os
from tqdm import tqdm
alphaflow_dir = "/cluster/home/t122995uhn/projects/data/pdbbind/alphaflow_io/out_pdb_MD-distilled/"
ln_dir = "/cluster/home/t122995uhn/projects/data/pdbbind/alphaflow_io/out_pid_ln/"

os.makedirs(ln_dir, exist_ok=True)

# First, check and remove any broken links in the destination directory
for link_file in tqdm(os.listdir(ln_dir), desc="Checking for broken links"):
ln_p = os.path.join(ln_dir, link_file)
if os.path.islink(ln_p) and not os.path.exists(ln_p):
print(f"Removing broken link: {ln_p}")
os.remove(ln_p)


# %% files are .pdb with 50 "models" in each
for file in tqdm(os.listdir(alphaflow_dir)):
if not file.endswith('.pdb'):
continue

code, _ = os.path.splitext(file)
pid = df.loc[code].prot_id
src, dst = f"{alphaflow_dir}/{file}", f"{ln_dir}/{pid}.pdb"
if not os.path.exists(dst):
os.symlink(src,dst)

# %%
from src.train_test.splitting import resplit
#%%
########################################################################
########################## BUILD DATASETS ##############################
########################################################################
from src.data_prep.init_dataset import create_datasets
from src import cfg
import logging
cfg.logger.setLevel(logging.DEBUG)

db_p = lambda x: f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/nomsa_{x}_gvp_binary'

db = resplit(dataset=db_p('binary'), split_files=db_p('aflow'), use_train_set=True)

splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind/'
create_datasets([cfg.DATA_OPT.PDBbind, cfg.DATA_OPT.davis],
feat_opt=cfg.PRO_FEAT_OPT.nomsa,
edge_opt=cfg.PRO_EDGE_OPT.aflow,
ligand_features=cfg.LIG_FEAT_OPT.original, #[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp],
ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=False,
k_folds=5,
test_prots_csv=f'{splits}/test.csv',
val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)],
data_root=os.path.abspath('../data/test/'))


# %%
Expand Down Expand Up @@ -69,62 +181,3 @@
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance")
plt.xticks(rotation=45)

#%%
########################################################################
########################## BUILD DATASETS ##############################
########################################################################
from src.data_prep.init_dataset import create_datasets
from src import cfg
import logging
cfg.logger.setLevel(logging.DEBUG)

splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind/'
create_datasets(cfg.DATA_OPT.PDBbind,
feat_opt=cfg.PRO_FEAT_OPT.nomsa,
edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow],
ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp],
ligand_edges=cfg.LIG_EDGE_OPT.binary,
k_folds=5,
test_prots_csv=f'{splits}/test.csv',
val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)])

# %%
from src.utils.loader import Loader

db_aflow = Loader.load_dataset('../data/DavisKibaDataset/davis/nomsa_aflow_original_binary/full')
db = Loader.load_dataset('../data/DavisKibaDataset/davis/nomsa_binary_original_binary/full')

# %%
# 5-fold cross validation + test set
import pandas as pd
from src import cfg
from src.train_test.splitting import balanced_kfold_split
from src.utils.loader import Loader
test_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind_test.csv')
test_prots = set(test_df.prot_id)
db = Loader.load_dataset(f'{cfg.DATA_ROOT}/PDBbindDataset/nomsa_binary_original_binary/full/')

train, val, test = balanced_kfold_split(db,
k_folds=5, test_split=0.1, val_split=0.1,
test_prots=test_prots, random_seed=0, verbose=True
)

#%%
db.save_subset_folds(train, 'train')
db.save_subset_folds(val, 'val')
db.save_subset(test, 'test')

#%%
import shutil, os

src = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_binary_original_binary/"
dst = "/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind"
os.makedirs(dst, exist_ok=True)

for i in range(5):
sfile = f"{src}/val{i}/XY.csv"
dfile = f"{dst}/val{i}.csv"
shutil.copyfile(sfile, dfile)

# %%
Loading
Loading