Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resolves #113 #117

Merged
merged 32 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
dfc29e0
feat(downloader): KLIFS pocket sequences #109
jyaacoub Jun 21, 2024
cf077ae
results(davis): gvpl_aflow with corrected hparams #90
jyaacoub Jun 21, 2024
1c9c441
feat(tcga): analysis script #95 #111
jyaacoub Jun 21, 2024
7027018
results(tuned): tuned davis_gvpl_esm_aflow and retuned kiba_GVPL_aflo…
jyaacoub Jun 25, 2024
76b8acf
fix(ESM_GVPL): safe_load_state_dict via inheritence of BaseModel #90
jyaacoub Jun 25, 2024
4152347
fix(ESM_GVPL): syntax, missing comma
jyaacoub Jun 26, 2024
891a39c
fix(tuned): davis_DG params + loader #90
jyaacoub Jun 26, 2024
cf4f2e9
refactor(playground): prediction with tuned models moved to its own s…
jyaacoub Jun 27, 2024
42542a0
results(DG,GVPL): updated training with unified test set #90
jyaacoub Jun 28, 2024
d9f8091
chore: update gitignore for new models folder
jyaacoub Jun 28, 2024
bd796de
results(GVPL_ESM): davis GVPL+esm performance #90
jyaacoub Jul 2, 2024
3b8b0a8
results(kiba): updated gvpL_aflow results #90
jyaacoub Jul 2, 2024
f442919
fix(prepare_df): parse for GVPL_ESM model results #90
jyaacoub Jul 3, 2024
9ac093f
feat(resplit): resplit stub for #113
jyaacoub Jul 3, 2024
2e61fd2
refactor(datasets): logging for #114
jyaacoub Jul 3, 2024
cdf930f
fix(split): explicit val_size for balanced_kfold_split
jyaacoub Jul 3, 2024
765caf2
fix(loader): init_dataset_object and splitting overlap issue #112 #113
jyaacoub Jul 3, 2024
0c43a7c
fix(loader): max_seq_len kwarg #112
jyaacoub Jul 4, 2024
c47be94
feat(resplit): for resplitting existing datasets into proper folds #1…
jyaacoub Jul 4, 2024
ef0106c
feat(resplit): extract csvs from "like_dataset" #112 #113
jyaacoub Jul 4, 2024
c4c7741
feat: davis splits #112 #113
jyaacoub Jul 4, 2024
099c3a3
fix(splitting): created davis splits #113
jyaacoub Jul 4, 2024
b032b5b
fix(config): new results dir for issue #113
jyaacoub Jul 4, 2024
e80e225
chore(gitignore): ignoring checkpoints for #113
jyaacoub Jul 5, 2024
c7fdc86
refactor(playground): #113
jyaacoub Jul 5, 2024
256563c
chore(pdbbind): created pdbbind test set #113
jyaacoub Jul 6, 2024
b15e83d
fix(init_dataset): adding `resplit` to create_datasets #113
jyaacoub Jul 8, 2024
20c9343
fix(datasets): paths for aflow files #116
jyaacoub Jul 8, 2024
3fd9367
results(davis_DGM): retrained davis_DGM on new splits #113 #112
jyaacoub Jul 8, 2024
69add71
feat(splits): created kiba and pdbind splits #113
jyaacoub Jul 8, 2024
1361c7e
results(davis_gvpl): retrained davis_gvpl on new splits #113 #112
jyaacoub Jul 8, 2024
a0e4405
results(davis): retrained aflow models #113 due to issue #116
jyaacoub Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,15 @@ slurm_out*/
*.swp
*.swo
/*.sh
results/model_checkpoints/ours/*.model_tmp
results/model_checkpoints/ours/*.model
results/**/model_checkpoints/*/*.model_tmp
results/**/model_checkpoints/*/*.model
*.swp
*.swo

results/model_media/*/train_log/*
results/model_media/*/train_set_pred/*
results/model_media/*/test_set_pred/*
results/model_media/test_set_pred
results/model_media/test_set_pred

splits/**/*.csv
results/v113/model_media/*/train_log/*.json
183 changes: 65 additions & 118 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#%%
# %%
########################################################################
########################## VIOLIN PLOTTING #############################
########################################################################
import logging
from typing import OrderedDict

Expand All @@ -9,137 +11,82 @@

from src.analysis.figures import prepare_df, custom_fig, fig_combined

df = prepare_df()
# %%
models = {
'DG': ('nomsa', 'binary', 'original', 'binary'),
# 'esm': ('ESM', 'binary', 'original', 'binary'), # esm model
'aflow': ('nomsa', 'aflow', 'original', 'binary'),
# 'aflow_ring3': ('nomsa', 'aflow_ring3', 'original', 'binary'),
# 'gvpP': ('gvp', 'binary', 'original', 'binary'),
# 'gvpL': ('nomsa', 'binary', 'gvp', 'binary'),
'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'),
'gvpL': ('nomsa', 'binary', 'gvp', 'binary'),
# 'aflow_ring3': ('nomsa', 'aflow_ring3', 'original', 'binary'),
'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'),
# 'gvpL_aflow_rng3': ('nomsa', 'aflow_ring3', 'gvp', 'binary'),
#GVPL_ESMM_davis3D_nomsaF_aflowE_48B_0.00010636872718329864LR_0.23282479481785903D_2000E_gvpLF_binaryLE
# 'gvpl_esm_aflow': ('ESM', 'aflow', 'gvp', 'binary'),
}

# %%
fig, axes = fig_combined(df, datasets=['davis','PDBbind'], fig_callable=custom_fig,
df = prepare_df('/cluster/home/t122995uhn/projects/MutDTA/results/v113/model_media/model_stats.csv')
fig, axes = fig_combined(df, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(8,5))
fig_scale=(10,5), add_stats=True, title_postfix=" test set performance")
plt.xticks(rotation=45)

df = prepare_df('/cluster/home/t122995uhn/projects/MutDTA/results/v113/model_media/model_stats_val.csv')
fig, axes = fig_combined(df, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance")
plt.xticks(rotation=45)

# %%
########################################################################
########################## PLATINUM ANALYSIS ###########################
########################################################################
import torch, os
import pandas as pd

# %%
from src.data_prep.init_dataset import create_datasets
from src import cfg
from src import TUNED_MODEL_CONFIGS

splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/davis/'
create_datasets(cfg.DATA_OPT.davis,
feat_opt=cfg.PRO_FEAT_OPT.nomsa,
edge_opt=cfg.PRO_EDGE_OPT.aflow,
ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp],
ligand_edges=cfg.LIG_EDGE_OPT.binary,
k_folds=5,
test_prots_csv=f'{splits}/test.csv',
val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)])

# %%
from src.utils.loader import Loader
from src.train_test.training import test
from src.analysis.figures import predictive_performance, tbl_stratified_dpkd_metrics, tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_in_binding

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

INFERENCE = True
VERBOSE = True
out_dir = f'{cfg.MEDIA_SAVE_DIR}/test_set_pred/'
os.makedirs(out_dir, exist_ok=True)
cp_dir = cfg.CHECKPOINT_SAVE_DIR
RAW_PLT_CSV=f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv"

#%% load up model:
for KEY, CONFIG in TUNED_MODEL_CONFIGS.items():
MODEL_KEY = lambda fold: Loader.get_model_key(CONFIG['model'], CONFIG['dataset'], CONFIG['feature_opt'], CONFIG['edge_opt'],
CONFIG['batch_size'], CONFIG['lr'], CONFIG['architecture_kwargs']['dropout'],
n_epochs=2000, fold=fold,
ligand_feature=CONFIG['lig_feat_opt'], ligand_edge=CONFIG['lig_edge_opt'])
print('\n\n'+ '## ' + KEY)
OUT_PLT = lambda i: f'{out_dir}/{MODEL_KEY(i)}_PLATINUM.csv'
db_p = f"{CONFIG['feature_opt']}_{CONFIG['edge_opt']}_{CONFIG['lig_feat_opt']}_{CONFIG['lig_edge_opt']}"

if CONFIG['dataset'] in ['kiba', 'davis']:
db_p = f"DavisKibaDataset/{CONFIG['dataset']}/{db_p}"
else:
db_p = f"{CONFIG['dataset']}Dataset/{db_p}"

train_p = lambda set: f"{cfg.DATA_ROOT}/{db_p}/{set}0/cleaned_XY.csv"

if not os.path.exists(OUT_PLT(0)) and INFERENCE:
print('running inference!')
cp = lambda fold: f"{cp_dir}/{MODEL_KEY(fold)}.model"

model = Loader.init_model(model=CONFIG["model"], pro_feature=CONFIG["feature_opt"],
pro_edge=CONFIG["edge_opt"],**CONFIG['architecture_kwargs'])

# load up platinum test db
loaders = Loader.load_DataLoaders(cfg.DATA_OPT.platinum,
pro_feature = CONFIG['feature_opt'],
edge_opt = CONFIG['edge_opt'],
ligand_feature = CONFIG['lig_feat_opt'],
ligand_edge = CONFIG['lig_edge_opt'],
datasets=['test'])

for i in range(5):
model.safe_load_state_dict(torch.load(cp(i), map_location=device))
model.to(device)
model.eval()

loss, pred, actual = test(model, loaders['test'], device, verbose=True)

# saving as csv with columns code, pred, actual
# get codes from test loader
codes, pid = [b['code'][0] for b in loaders['test']], [b['prot_id'][0] for b in loaders['test']]
df = pd.DataFrame({'prot_id': pid, 'pred': pred, 'actual': actual}, index=codes)
df.index.name = 'code'
df.to_csv(OUT_PLT(i))

# run platinum eval:
print('\n### 1. predictive performance')
mkdown = predictive_performance(OUT_PLT, train_p, verbose=VERBOSE, plot=False)
print('\n### 2 Mutation impact analysis')
print('\n#### 2.1 $\Delta pkd$ predictive performance')
mkdn = tbl_dpkd_metrics_overlap(OUT_PLT, train_p, verbose=VERBOSE, plot=False)
print('\n#### 2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)')
m = tbl_dpkd_metrics_in_binding(OUT_PLT, RAW_PLT_CSV, verbose=VERBOSE, plot=False)


db_aflow = Loader.load_dataset('../data/DavisKibaDataset/davis/nomsa_aflow_original_binary/full')
db = Loader.load_dataset('../data/DavisKibaDataset/davis/nomsa_binary_original_binary/full')

# %%
dfr = pd.read_csv(RAW_PLT_CSV, index_col=0)

# add in_binding info to df
def get_in_binding(df, dfr):
"""
df is the predicted csv with index as <raw_idx>_wt (or *_mt) where raw_idx
corresponds to an index in dfr which contains the raw data for platinum including
('mut.in_binding_site')
- 0: wildtype rows
- 1: close (<8 Ang)
- 2: Far (>8 Ang)
"""
pocket = dfr[dfr['mut.in_binding_site'] == 'YES'].index
pclass = []
for code in df.index:
if '_wt' in code:
pclass.append(0)
elif int(code.split('_')[0]) in pocket:
pclass.append(1)
else:
pclass.append(2)
df['pocket'] = pclass
return df

df = get_in_binding(pd.read_csv(OUT_PLT(0), index_col=0), dfr)
if VERBOSE:
cnts = df.pocket.value_counts()
cnts.index = ['wt', 'in pocket', 'not in pocket']
cnts.name = "counts"
print(cnts.to_markdown(), end="\n\n")

tbl_stratified_dpkd_metrics(OUT_PLT, NORMALIZE=True, n_models=5, df_transform=get_in_binding,
conditions=['(pocket == 0) | (pocket == 1)', '(pocket == 0) | (pocket == 2)'],
names=['in pocket', 'not in pocket'],
verbose=VERBOSE, plot=True, dfr=dfr)
# 5-fold cross validation + test set
import pandas as pd
from src import cfg
from src.train_test.splitting import balanced_kfold_split
from src.utils.loader import Loader
test_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind_test.csv')
test_prots = set(test_df.prot_id)
db = Loader.load_dataset(f'{cfg.DATA_ROOT}/PDBbindDataset/nomsa_binary_original_binary/full/')

train, val, test = balanced_kfold_split(db,
k_folds=5, test_split=0.1, val_split=0.1,
test_prots=test_prots, random_seed=0, verbose=True
)

#%%
db.save_subset_folds(train, 'train')
db.save_subset_folds(val, 'val')
db.save_subset(test, 'test')

#%%
import shutil, os

src = "/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_binary_original_binary/"
dst = "/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind"
os.makedirs(dst, exist_ok=True)

for i in range(5):
sfile = f"{src}/val{i}/XY.csv"
dfile = f"{dst}/val{i}.csv"
shutil.copyfile(sfile, dfile)

# %%
Loading
Loading