-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Platinum analysis #94
Labels
Comments
1. predictive performanceraw:
z-normalized:
codedef predictive_performance(
MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv",
TRAIN_DATA_P = lambda set: f'{cfg.DATA_ROOT}/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv',
NORMALIZE = True,
n_models=5,
compare_overlap=False,
verbose=True,
plot=False,
):
df_t = pd.Index.append(pd.read_csv(TRAIN_DATA_P('train'), index_col=0).index,
pd.read_csv(TRAIN_DATA_P('val'), index_col=0).index)
df_t = df_t.str.upper()
results_with_overlap = []
results_without_overlap = []
for i in range(n_models):
df = pd.read_csv(MODEL(i), index_col=0).dropna()
df['pdb'] = df['prot_id'].str.split('_').str[0]
if NORMALIZE:
mean_df = df[['actual','pred']].mean(axis=0, numeric_only=True)
std_df = df[['actual','pred']].std(axis=0, numeric_only=True)
df[['actual','pred']] = (df[['actual','pred']] - mean_df) / std_df # z-normalization
if i==0: print(df)
# with overlap
cindex, p_corr, s_corr, mse, mae, rmse = get_metrics(df['actual'], df['pred'])
results_with_overlap.append([cindex, p_corr[0], s_corr[0], mse, mae, rmse])
# without overlap
df_no_overlap = df[~(df['pdb'].isin(df_t))]
cindex, p_corr, s_corr, mse, mae, rmse = get_metrics(df_no_overlap['actual'], df_no_overlap['pred'])
results_without_overlap.append([cindex, p_corr[0], s_corr[0], mse, mae, rmse])
if i==0 and plot:
n_plots = int(compare_overlap)+1
fig = plt.figure(figsize=(14,5*n_plots))
axes = fig.subplots(n_plots,1)
ax = axes[0] if compare_overlap else axes
sns.histplot(df_no_overlap['actual'], kde=True, ax=ax, alpha=0.5, label='True pkd')
sns.histplot(df_no_overlap['pred'], kde=True, ax=ax, alpha=0.5, label='Predicted pkd', color='orange')
ax.set_title(f"{'Normalized 'if NORMALIZE else ''} pkd distribution")
ax.legend()
if compare_overlap:
sns.histplot(df_no_overlap['actual'], kde=True, ax=axes[1], alpha=0.5, label='True pkd')
sns.histplot(df_no_overlap['pred'], kde=True, ax=axes[1], alpha=0.5, label='Predicted pkd', color='orange')
axes[1].set_title(f"{'Normalized 'if NORMALIZE else ''} pkd distribution (no overlap)")
axes[1].legend()
if compare_overlap:
return generate_markdown([results_with_overlap, results_without_overlap], names=['with overlap', 'without overlap'],
cindex=True,verbose=verbose)
return generate_markdown([results_without_overlap], names=['mean $\pm$ se'], cindex=True, verbose=verbose) |
2. Mutation impact analysisSame thing but looking at 2.1 delta pkd predictive performanceraw
|
metric | With Overlap | Without Overlap | Significance |
---|---|---|---|
pcorr | 0.176 |
0.037 |
* |
scorr | 0.099 |
0.046 |
|
mse | 1.505 |
1.303 |
* |
mae | 0.905 |
0.847 |
* |
rmse | 1.227 |
1.141 |
* |
Z-normalized
With Overlap | Without Overlap | Significance | |
---|---|---|---|
pcorr | 0.176 |
0.037 |
* |
scorr | 0.099 |
0.046 |
|
mse | 1.649 |
1.927 |
* |
mae | 0.899 |
1.014 |
* |
rmse | 1.284 |
1.387 |
* |
Code
from src.analysis.figures import tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_n_mut
MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv"
TRAIN_DATA_P = lambda set: f'/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv'
NORMALIZE = True
# %%
print('OVERLAP')
mkdo = tbl_dpkd_metrics_overlap(MODEL, TRAIN_DATA_P, NORMALIZE, plot=False)
print('NUM MUTATIONS:')
mkdnm = tbl_dpkd_metrics_n_mut(MODEL, NORMALIZE, plot=False)
2.2. Stratify by mutation count
2 classes "single mutation" vs "2+ mutations"
1 mutations | 2+ mutations | Sig | |
---|---|---|---|
pcorr | 0.076 |
0.336 |
* |
scorr | 0.053 |
0.252 |
* |
mse | 1.848 |
1.328 |
* |
mae | 0.961 |
0.833 |
* |
rmse | 1.359 |
1.152 |
* |
3 classes "single", "2", "3+"
1 mutations | 2 mutations | 3+ mutations | |
---|---|---|---|
pcorr | 0.076 |
0.207 |
0.509 |
scorr | 0.053 |
0.131 |
0.496 |
mse | 1.848 |
1.586 |
0.982 |
mae | 0.961 |
0.964 |
0.732 |
rmse | 1.359 |
1.259 |
0.989 |
CODE
from src.analysis.figures import tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_n_mut
MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv"
TRAIN_DATA_P = lambda set: f'/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv'
NORMALIZE = True
print('NUM MUTATIONS:')
mkdnm = tbl_dpkd_metrics_n_mut(MODEL, NORMALIZE, conditions=[1,2], plot=True)
2.3. Stratify by location of mutation
2.3.1. binding pocket vs not in binding pocket
mutation in pocket | mutation NOT in pocket | Sig | |
---|---|---|---|
pcorr | 0.180 |
0.110 |
* |
scorr | 0.108 |
0.050 |
|
mse | 1.640 |
1.781 |
* |
mae | 0.904 |
0.981 |
* |
rmse | 1.280 |
1.334 |
* |
code
# %%
import matplotlib.pyplot as plt
import seaborn as sns
from src.analysis.figures import get_dpkd
NORMALIZE = True
dfr = pd.read_csv(f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv", index_col=0)
df = pd.read_csv("/cluster/home/t122995uhn/projects/data/PlatinumDataset/nomsa_binary_original_binary/full/cleaned_XY.csv", index_col=0).dropna()
df = get_in_binding(df, dfr=dfr)
fig = plt.figure(figsize=(14,5))
ax = fig.subplots(1,1)
# must include 0 in both cases since they are the wildtype reference
true_dpkd1 = get_dpkd(df.query('(pocket == 0) | (pocket == 2)'), 'pkd', NORMALIZE)
sns.histplot(true_dpkd1, kde=True, ax=ax, alpha=0.6, color='orange', label='not in pocket', stat='proportion')
true_dpkd1 = get_dpkd(df.query('(pocket == 0) | (pocket == 1)'), 'pkd', NORMALIZE)
sns.histplot(true_dpkd1, kde=True, ax=ax, alpha=0.6, color=None, label='in pocket', stat='proportion')
ax.set_title(f"{'Normalized 'if NORMALIZE else ''}TRUE Δpkd distribution")
ax.set_xlabel('Δpkd')
ax.legend()
2.3.2. distance to ligand
- We will treat this as a classification problem (near is <4A)
counts | |
---|---|
wt | 981 |
near lig | 577 |
not near lig | 372 |
mutation near lig (<4A) | mutation not near lig (>4A) | p-val | |
---|---|---|---|
pcorr | 0.164 |
0.198 |
0.1405 |
scorr | 0.104 |
0.079 |
0.1844 |
mse | 1.672 |
1.604 |
0.1405 |
mae | 0.912 |
0.939 |
0.0408 |
rmse | 1.293 |
1.266 |
0.1381 |
Code
from src.analysis.figures import tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_n_mut, tbl_dpkd_metrics_in_binding, predictive_performance, tbl_stratified_dpkd_metrics
from src.analysis.metrics import get_metrics
from src import config as cfg
import pandas as pd
#%%
MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv"
RAW_PLT_CSV=f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv"
NORMALIZE = True
n_models=5
verbose=True
plot=True
dfr = pd.read_csv(RAW_PLT_CSV, index_col=0)
dfp = pd.read_csv(MODEL(0), index_col=0)
#%%
import seaborn as sns
sns.histplot(dfr['mut.distance_to_lig'])
#%%
# add in_binding info to df
thres = 4
def get_in_binding(df, dfr):
"""
df is the predicted csv with index as <raw_idx>_wt (or *_mt) where raw_idx
corresponds to an index in dfr which contains the raw data for platinum including
('mut.in_binding_site')
- 0: wildtype rows
- 1: close (<8 Ang)
- 2: Far (>8 Ang)
"""
near_lig = dfr[dfr['mut.distance_to_lig'] < thres].index
pclass = []
for code in df.index:
if '_wt' in code:
pclass.append(0)
elif int(code.split('_')[0]) in near_lig:
pclass.append(1)
else:
pclass.append(2)
df['near_lig'] = pclass
return df
conditions = ['(near_lig == 0) | (near_lig == 1)', '(near_lig == 0) | (near_lig == 2)']
names = [f'mutation near lig (<{thres}A)', f'mutation not near lig (>{thres}A)']
df = get_in_binding(dfp, dfr)
if verbose:
cnts = df.near_lig.value_counts()
cnts.index = ['wt', 'near lig', 'not near lig']
cnts.name = "counts"
print(cnts.to_markdown(), end="\n\n")
#%%
tbl_stratified_dpkd_metrics(MODEL, NORMALIZE, n_models=n_models, df_transform=get_in_binding,
conditions=conditions, names=names, verbose=verbose, plot=plot, dfr=dfr)
3. Significant Mutation impact analysisWith overlap best threshold is 0.1*STD:Without overlap best threshold is 0.3*STDcode#%%
from src.analysis.figures import get_dpkd, fig_sig_mutations_conf_matrix, generate_roc_curve
from src.analysis.metrics import get_metrics
import numpy as np
import pandas as pd
from scipy.stats import ttest_ind
MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv"
data_p = lambda set: f'/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv'
df_t = pd.Index.append(pd.read_csv(data_p('train'), index_col=0).index,
pd.read_csv(data_p('val'), index_col=0).index)
df_t = df_t.str.upper()
results_with_overlap = []
results_without_overlap = []
i=0
df = pd.read_csv(MODEL(i), index_col=0).dropna()
df['pdb'] = df['prot_id'].str.split('_').str[0]
df_no = df[~(df['pdb'].isin(df_t))]
#%%
true_dpkd = get_dpkd(df, pkd_col='actual')
pred_dpkd = get_dpkd(df, pkd_col='pred')
true_dpkd_no = get_dpkd(df_no, pkd_col='actual')
pred_dpkd_no = get_dpkd(df_no, pkd_col='pred')
# %%
# ROC
_, _, _, best_threshold = generate_roc_curve(true_dpkd, pred_dpkd, thres_range=(0,5), step=0.1)
_ = fig_sig_mutations_conf_matrix(true_dpkd, pred_dpkd, std=round(best_threshold, 3))
# %%
_, _, _, best_threshold = generate_roc_curve(true_dpkd_no, pred_dpkd_no, thres_range=(0,5), step=0.1)
_ = fig_sig_mutations_conf_matrix(true_dpkd_no, pred_dpkd_no, std=round(best_threshold, 3)) |
jyaacoub
added a commit
that referenced
this issue
May 1, 2024
jyaacoub
added a commit
that referenced
this issue
May 1, 2024
jyaacoub
added a commit
that referenced
this issue
May 2, 2024
jyaacoub
added a commit
that referenced
this issue
May 2, 2024
jyaacoub
added a commit
that referenced
this issue
May 2, 2024
jyaacoub
added a commit
that referenced
this issue
May 2, 2024
jyaacoub
added a commit
that referenced
this issue
May 4, 2024
jyaacoub
added a commit
that referenced
this issue
May 4, 2024
jyaacoub
added a commit
that referenced
this issue
May 8, 2024
Pretrained Davis resultsoutline
results:Code
# %%
import torch, os
import pandas as pd
from src import cfg
from src import TUNED_MODEL_CONFIGS
from src.utils.loader import Loader
from src.train_test.training import test
from src.analysis.figures import predictive_performance, tbl_stratified_dpkd_metrics, tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_in_binding
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
INFERENCE = False
VERBOSE = True
out_dir = f'{cfg.MEDIA_SAVE_DIR}/test_set_pred/'
os.makedirs(out_dir, exist_ok=True)
cp_dir = cfg.CHECKPOINT_SAVE_DIR
RAW_PLT_CSV=f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv"
#%% load up model:
for KEY, CONFIG in TUNED_MODEL_CONFIGS.items():
MODEL_KEY = lambda fold: Loader.get_model_key(CONFIG['model'], CONFIG['dataset'], CONFIG['feature_opt'], CONFIG['edge_opt'],
CONFIG['batch_size'], CONFIG['lr'], CONFIG['architecture_kwargs']['dropout'],
n_epochs=2000, fold=fold,
ligand_feature=CONFIG['lig_feat_opt'], ligand_edge=CONFIG['lig_edge_opt'])
print('\n\n'+ '## ' + KEY)
OUT_PLT = lambda i: f'{out_dir}/{MODEL_KEY(i)}_PLATINUM.csv'
db_p = f"{CONFIG['feature_opt']}_{CONFIG['edge_opt']}_{CONFIG['lig_feat_opt']}_{CONFIG['lig_edge_opt']}"
if CONFIG['dataset'] in ['kiba', 'davis']:
db_p = f"DavisKibaDataset/{CONFIG['dataset']}/{db_p}"
else:
db_p = f"{CONFIG['dataset']}Dataset/{db_p}"
train_p = lambda set: f"{cfg.DATA_ROOT}/{db_p}/{set}0/cleaned_XY.csv"
if not os.path.exists(OUT_PLT(0)) or INFERENCE:
print('running inference!')
cp = lambda fold: f"{cp_dir}/{MODEL_KEY(fold)}.model"
model = Loader.init_model(model=CONFIG["model"], pro_feature=CONFIG["feature_opt"],
pro_edge=CONFIG["edge_opt"],**CONFIG['architecture_kwargs'])
# load up platinum test db
loaders = Loader.load_DataLoaders(cfg.DATA_OPT.platinum,
pro_feature = CONFIG['feature_opt'],
edge_opt = CONFIG['edge_opt'],
ligand_feature = CONFIG['lig_feat_opt'],
ligand_edge = CONFIG['lig_edge_opt'],
datasets=['test'])
for i in range(5):
model.safe_load_state_dict(torch.load(cp(i), map_location=device))
model.to(device)
model.eval()
loss, pred, actual = test(model, loaders['test'], device, verbose=True)
# saving as csv with columns code, pred, actual
# get codes from test loader
codes, pid = [b['code'][0] for b in loaders['test']], [b['prot_id'][0] for b in loaders['test']]
df = pd.DataFrame({'prot_id': pid, 'pred': pred, 'actual': actual}, index=codes)
df.index.name = 'code'
df.to_csv(OUT_PLT(i))
# run platinum eval:
print('\n### 1. predictive performance')
mkdown = predictive_performance(OUT_PLT, train_p, verbose=VERBOSE, plot=False)
print('\n### 2 Mutation impact analysis')
print('\n#### 2.1 $\Delta pkd$ predictive performance')
mkdn = tbl_dpkd_metrics_overlap(OUT_PLT, train_p, verbose=VERBOSE, plot=False)
print('\n#### 2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)')
m = tbl_dpkd_metrics_in_binding(OUT_PLT, RAW_PLT_CSV, verbose=VERBOSE, plot=False) davis_gvpl_aflow1. predictive performance
2 Mutation impact analysis2.1
|
with overlap | without overlap | p-val | |
---|---|---|---|
pcorr | 0.005 |
0.005 |
1 |
scorr | 0.003 |
0.003 |
1 |
mse | 1.990 |
1.990 |
1 |
mae | 0.980 |
0.980 |
1 |
rmse | 1.411 |
1.411 |
1 |
2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)
counts | |
---|---|
wt | 981 |
pckt | 708 |
not pckt | 241 |
mutation in pocket | mutation NOT in pocket | p-val | |
---|---|---|---|
pcorr | 0.004 |
0.025 |
0.5777 |
scorr | -0.009 |
0.067 |
0.3433 |
mse | 1.992 |
1.950 |
0.5777 |
mae | 1.000 |
0.981 |
0.4374 |
rmse | 1.411 |
1.396 |
0.5578 |
davis_gvpl
1. predictive performance
mean predictive performance | |
---|---|
cindex | 0.469 |
pcorr | -0.058 |
scorr | -0.095 |
mse | 2.114 |
mae | 1.148 |
rmse | 1.454 |
2 Mutation impact analysis
2.1 $\Delta pkd$ predictive performance
with overlap | without overlap | p-val | |
---|---|---|---|
pcorr | 0.008 |
0.008 |
1 |
scorr | 0.023 |
0.023 |
1 |
mse | 1.983 |
1.983 |
1 |
mae | 0.974 |
0.974 |
1 |
rmse | 1.408 |
1.408 |
1 |
2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)
counts | |
---|---|
wt | 981 |
pckt | 725 |
not pckt | 256 |
mutation in pocket | mutation NOT in pocket | p-val | |
---|---|---|---|
pcorr | 0.012 |
-0.008 |
0.7244 |
scorr | 0.014 |
0.025 |
0.7289 |
mse | 1.977 |
2.016 |
0.7244 |
mae | 0.985 |
0.979 |
0.8043 |
rmse | 1.406 |
1.418 |
0.7599 |
davis_aflow
1. predictive performance
mean predictive performance | |
---|---|
cindex | 0.446 |
pcorr | -0.127 |
scorr | -0.172 |
mse | 2.253 |
mae | 1.236 |
rmse | 1.499 |
2 Mutation impact analysis
2.1 $\Delta pkd$ predictive performance
with overlap | without overlap | p-val | |
---|---|---|---|
pcorr | -0.007 |
-0.007 |
1 |
scorr | -0.019 |
-0.019 |
1 |
mse | 2.015 |
2.015 |
1 |
mae | 0.922 |
0.922 |
1 |
rmse | 1.419 |
1.419 |
1 |
2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)
counts | |
---|---|
wt | 981 |
pckt | 708 |
not pckt | 241 |
mutation in pocket | mutation NOT in pocket | p-val | |
---|---|---|---|
pcorr | -0.013 |
0.031 |
0.117 |
scorr | -0.039 |
0.063 |
0.0208 |
mse | 2.027 |
1.937 |
0.117 |
mae | 0.922 |
0.951 |
0.4358 |
rmse | 1.424 |
1.391 |
0.1167 |
PDBbind_gvpl_aflow
1. predictive performance
mean predictive performance | |
---|---|
cindex | 0.641 |
pcorr | 0.415 |
scorr | 0.411 |
mse | 0.947 |
mae | 0.754 |
rmse | 0.972 |
2 Mutation impact analysis
2.1 $\Delta pkd$ predictive performance
with overlap | without overlap | p-val | |
---|---|---|---|
pcorr | 0.176 |
0.037 |
0.0058 |
scorr | 0.099 |
0.046 |
0.0974 |
mse | 1.649 |
1.927 |
0.0058 |
mae | 0.899 |
1.014 |
0.0001 |
rmse | 1.284 |
1.387 |
0.005 |
2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)
counts | |
---|---|
wt | 981 |
pckt | 708 |
not pckt | 241 |
mutation in pocket | mutation NOT in pocket | p-val | |
---|---|---|---|
pcorr | 0.180 |
0.110 |
0.0374 |
scorr | 0.108 |
0.050 |
0.2533 |
mse | 1.640 |
1.781 |
0.0374 |
mae | 0.904 |
0.981 |
0.0003 |
rmse | 1.280 |
1.334 |
0.0365 |
jyaacoub
added a commit
that referenced
this issue
May 29, 2024
jyaacoub
added a commit
that referenced
this issue
May 30, 2024
jyaacoub
added a commit
that referenced
this issue
May 31, 2024
jyaacoub
added a commit
that referenced
this issue
Jun 4, 2024
jyaacoub
added a commit
that referenced
this issue
Jun 4, 2024
jyaacoub
added a commit
that referenced
this issue
Oct 3, 2024
Still need to move this to root dir and create a sample SBATCH script for running it.
Merged
jyaacoub
added a commit
that referenced
this issue
Oct 16, 2024
jyaacoub
added a commit
that referenced
this issue
Dec 4, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
GVPLM_PDBbind1D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE
The text was updated successfully, but these errors were encountered: