Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Platinum analysis #94

Closed
8 tasks done
jyaacoub opened this issue Apr 30, 2024 · 5 comments · Fixed by #140
Closed
8 tasks done

Platinum analysis #94

jyaacoub opened this issue Apr 30, 2024 · 5 comments · Fixed by #140
Labels

Comments

@jyaacoub
Copy link
Owner

jyaacoub commented Apr 30, 2024

  1. Build dataset
    1. Copy all aflow platinum conf files from #narval to #h4h ✅ 2024-04-29
    2. Init new platinum dataset with new confs ✅ 2024-04-29
create_datasets(cfg.DATA_OPT.platinum, 
                cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow, 
                ligand_features=cfg.LIG_FEAT_OPT.gvp,
                ligand_edges=cfg.LIG_EDGE_OPT.binary,
                k_folds=None, train_split=0, val_split=0)		
  1. Run inference
    1. Find trained weights for GVPL-aflow model on pdbbind (all 5 for each split to use as an ensemble) ✅ 2024-04-29
      • Model key is GVPLM_PDBbind1D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE
      • They are all on #h4h
    2. Save predictions of each model to a csv file ✅ 2024-04-29
loaders = Loader.load_DataLoaders(cfg.DATA_OPT.platinum,
                               cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow,
                               ligand_feature=cfg.LIG_FEAT_OPT.gvp, ligand_edge=cfg.LIG_EDGE_OPT.binary,
                               datasets=['test'])

#%%
model = Loader.init_model(cfg.MODEL_OPT.GVPL, cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow,
                          dropout=0.02414, output_dim=256)

#%%
cp_dir = "/cluster/home/t122995uhn/projects/MutDTA/results/model_checkpoints/ours"
MODEL_KEY = lambda fold: f"GVPLM_PDBbind{fold}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE"
cp = lambda fold: f"{cp_dir}/{MODEL_KEY(fold)}.model"

out_dir = f'{cfg.MEDIA_SAVE_DIR}/test_set_pred/'
os.makedirs(out_dir, exist_ok=True)

for i in range(5):
    model.safe_load_state_dict(torch.load(cp(i), map_location=device))
    model.to(device)
    model.eval()

    loss, pred, actual = test(model, loaders['test'], device, verbose=True)
    
    # saving as csv with columns code, pred, actual
    # get codes from test loader
    codes, pid = [b['code'][0] for b in loaders['test']], [b['prot_id'][0] for b in loaders['test']]
    df = pd.DataFrame({'prot_id': pid, 'pred': pred, 'actual': actual}, index=codes)
    df.index.name = 'code'
    df.to_csv(f'{out_dir}/{MODEL_KEY(i)}_PLATINUM.csv')
  1. Analysis ^3fe12b
    1. Check for overlap in pdbIDs and remove them
      • provide a plot of with and without performance?
    2. how can I analyze the performance of the model in its ability to detect dangerous mutations?
      1. Normal predictive performance analysis with cindex and MSE scores
        • Statistical t-test to determine how different the predicted and experimental distributions are from each other
      2. Mutation impact analysis:
        • calculate $\Delta pk_a$ (change in binding affinity for each protein-ligand pair). Then correlate the predicted $\Delta pk_a$ and the experimental $\Delta pk_a$
          • higher correlation indicates that the model effectively captures the impact of mutations
        • Split up scores depending on number of mutations for even more analysis (maybe it struggles with larger # of mutations)
      3. Identifying "significant" mutations
        • Based on the distribution of $\Delta pk_a$ scores classify all above [1 and 2] standard deviations to be "significant"?
        • Then Build a confusion matrix to analyze the true positive rate and true negative rate of the model in identifying these mutations.
jyaacoub added a commit that referenced this issue Apr 30, 2024
significant mutation analysis
@jyaacoub
Copy link
Owner Author

jyaacoub commented May 1, 2024

There is huge overlap between the PDBbind training data and the platinum dataset unfortunately...

Removing all exact instances of pdbids leaves us with 975 rows (967 if we drop nan):
image

However, if we consider both mutated and wildtype proteins as the same protein then we are left with 480 rows.
image

code
import pandas as pd

df = pd.read_csv("results/model_media/test_set_pred/GVPLM_PDBbind0D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv",
                 index_col=0)

# training set codes:
data_p = lambda set: f'/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv'
df_t = pd.Index.append(pd.read_csv(data_p('train'), index_col=0).index, 
                       pd.read_csv(data_p('val'), index_col=0).index)

df_t = df_t.str.upper()
df['pdb'] = df['prot_id'].str.split('_').str[0]

#%% remove training codes from df
# dont remove wt prots:
# df = df[~(df['pdb'].isin(df_t) & df.index.str.contains('_mt'))]
print(df)

# remove all 
df = df[~(df['pdb'].isin(df_t))]
print(df)

# %% treat mutated and wt proteins as the same 
wt_df = df[df.index.str.contains("_wt")]
mt_df = df[df.index.str.contains("_mt")]

missing_wt = delta_pkds = 0
for m in mt_df.index:
    i_wt = m.split('_')[0] + '_wt'
    if i_wt not in wt_df.index:
        missing_wt += 1
    else:
        delta_pkds += 1

print("missing wt:", missing_wt)
print("delta_pkds:", delta_pkds)

@jyaacoub
Copy link
Owner Author

jyaacoub commented May 1, 2024

1. predictive performance

raw:

with overlap without overlap p-val
cindex 0.674 $\pm$ 0.010 0.641 $\pm$ 0.011 0.0624
pcorr 0.530 $\pm$ 0.023 0.415 $\pm$ 0.028 0.0139
scorr 0.503 $\pm$ 0.030 0.411 $\pm$ 0.039 0.0973
mse 3.106 $\pm$ 0.121 3.467 $\pm$ 0.112 0.0598
mae 1.380 $\pm$ 0.026 1.467 $\pm$ 0.014 0.0198
rmse 1.761 $\pm$ 0.035 1.861 $\pm$ 0.030 0.062
distribution

image

z-normalized:

with overlap without overlap p-val
cindex 0.674 $\pm$ 0.010 0.641 $\pm$ 0.011 0.0624
pcorr 0.530 $\pm$ 0.023 0.415 $\pm$ 0.028 0.0139
scorr 0.503 $\pm$ 0.030 0.411 $\pm$ 0.039 0.0973
mse 0.939 $\pm$ 0.046 0.947 $\pm$ 0.034 0.9008
mae 0.758 $\pm$ 0.020 0.754 $\pm$ 0.009 0.8436
rmse 0.968 $\pm$ 0.024 0.972 $\pm$ 0.018 0.8879
distribution

image

code
def predictive_performance(
    MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv",
    TRAIN_DATA_P = lambda set: f'{cfg.DATA_ROOT}/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv',
    NORMALIZE = True,
    n_models=5,
    compare_overlap=False,
    verbose=True,
    plot=False,
    ):
    df_t = pd.Index.append(pd.read_csv(TRAIN_DATA_P('train'), index_col=0).index, 
                        pd.read_csv(TRAIN_DATA_P('val'), index_col=0).index)
    df_t = df_t.str.upper()

    results_with_overlap = []
    results_without_overlap = []

    for i in range(n_models):
        df = pd.read_csv(MODEL(i), index_col=0).dropna()
        df['pdb'] = df['prot_id'].str.split('_').str[0]
        if NORMALIZE:
            mean_df = df[['actual','pred']].mean(axis=0, numeric_only=True)
            std_df = df[['actual','pred']].std(axis=0, numeric_only=True)
            
            df[['actual','pred']] = (df[['actual','pred']] - mean_df) / std_df # z-normalization
        if i==0: print(df)

        # with overlap
        cindex, p_corr, s_corr, mse, mae, rmse = get_metrics(df['actual'], df['pred'])
        results_with_overlap.append([cindex, p_corr[0], s_corr[0], mse, mae, rmse])

        # without overlap
        df_no_overlap = df[~(df['pdb'].isin(df_t))]
        cindex, p_corr, s_corr, mse, mae, rmse = get_metrics(df_no_overlap['actual'], df_no_overlap['pred'])
        results_without_overlap.append([cindex, p_corr[0], s_corr[0], mse, mae, rmse])

        if i==0 and plot:
            n_plots = int(compare_overlap)+1
            fig = plt.figure(figsize=(14,5*n_plots))
            axes = fig.subplots(n_plots,1)
            ax = axes[0] if compare_overlap else axes
            
            sns.histplot(df_no_overlap['actual'], kde=True, ax=ax, alpha=0.5, label='True pkd')
            sns.histplot(df_no_overlap['pred'], kde=True, ax=ax, alpha=0.5, label='Predicted pkd', color='orange')
            ax.set_title(f"{'Normalized 'if NORMALIZE else ''} pkd distribution")
            ax.legend()
            
            if compare_overlap:
                sns.histplot(df_no_overlap['actual'], kde=True, ax=axes[1], alpha=0.5, label='True pkd')
                sns.histplot(df_no_overlap['pred'], kde=True, ax=axes[1], alpha=0.5, label='Predicted pkd', color='orange')
                axes[1].set_title(f"{'Normalized 'if NORMALIZE else ''}  pkd distribution (no overlap)")
                axes[1].legend()

    if compare_overlap:
        return generate_markdown([results_with_overlap, results_without_overlap], names=['with overlap', 'without overlap'], 
                             cindex=True,verbose=verbose)
    
    return generate_markdown([results_without_overlap], names=['mean $\pm$ se'], cindex=True, verbose=verbose)

@jyaacoub
Copy link
Owner Author

jyaacoub commented May 1, 2024

2. Mutation impact analysis

Same thing but looking at $\Delta pkd$ this time

2.1 delta pkd predictive performance

raw $\Delta pkd$

Model-0 Distribution

image

metric With Overlap Without Overlap Significance
pcorr 0.176 $\pm$ 0.026 0.037 $\pm$ 0.079 *
scorr 0.099 $\pm$ 0.019 0.046 $\pm$ 0.060
mse 1.505 $\pm$ 0.009 1.303 $\pm$ 0.006 *
mae 0.905 $\pm$ 0.003 0.847 $\pm$ 0.002 *
rmse 1.227 $\pm$ 0.004 1.141 $\pm$ 0.003 *

Z-normalized

Model-0 Distribution

image

With Overlap Without Overlap Significance
pcorr 0.176 $\pm$ 0.026 0.037 $\pm$ 0.079 *
scorr 0.099 $\pm$ 0.019 0.046 $\pm$ 0.060
mse 1.649 $\pm$ 0.053 1.927 $\pm$ 0.158 *
mae 0.899 $\pm$ 0.029 1.014 $\pm$ 0.023 *
rmse 1.284 $\pm$ 0.021 1.387 $\pm$ 0.057 *
Code
from src.analysis.figures import tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_n_mut
MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv"
TRAIN_DATA_P = lambda set: f'/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv'
NORMALIZE = True

# %%
print('OVERLAP')
mkdo = tbl_dpkd_metrics_overlap(MODEL, TRAIN_DATA_P, NORMALIZE, plot=False)

print('NUM MUTATIONS:')
mkdnm = tbl_dpkd_metrics_n_mut(MODEL, NORMALIZE, plot=False)

2.2. Stratify by mutation count

image

2 classes "single mutation" vs "2+ mutations"

histogram

image

1 mutations 2+ mutations Sig
pcorr 0.076 $\pm$ 0.019 0.336 $\pm$ 0.047 *
scorr 0.053 $\pm$ 0.015 0.252 $\pm$ 0.043 *
mse 1.848 $\pm$ 0.038 1.328 $\pm$ 0.095 *
mae 0.961 $\pm$ 0.031 0.833 $\pm$ 0.028 *
rmse 1.359 $\pm$ 0.014 1.152 $\pm$ 0.041 *

3 classes "single", "2", "3+"

histogram

image

1 mutations 2 mutations 3+ mutations
pcorr 0.076 $\pm$ 0.019 0.207 $\pm$ 0.055 0.509 $\pm$ 0.070
scorr 0.053 $\pm$ 0.015 0.131 $\pm$ 0.038 0.496 $\pm$ 0.078
mse 1.848 $\pm$ 0.038 1.586 $\pm$ 0.110 0.982 $\pm$ 0.140
mae 0.961 $\pm$ 0.031 0.964 $\pm$ 0.016 0.732 $\pm$ 0.022
rmse 1.359 $\pm$ 0.014 1.259 $\pm$ 0.043 0.989 $\pm$ 0.070
CODE
from src.analysis.figures import tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_n_mut
MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv"
TRAIN_DATA_P = lambda set: f'/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv'
NORMALIZE = True

print('NUM MUTATIONS:')
mkdnm = tbl_dpkd_metrics_n_mut(MODEL, NORMALIZE, conditions=[1,2], plot=True)

2.3. Stratify by location of mutation

2.3.1. binding pocket vs not in binding pocket

histogram

image

mutation in pocket mutation NOT in pocket Sig
pcorr 0.180 $\pm$ 0.029 0.110 $\pm$ 0.056 *
scorr 0.108 $\pm$ 0.022 0.050 $\pm$ 0.102
mse 1.640 $\pm$ 0.057 1.781 $\pm$ 0.113 *
mae 0.904 $\pm$ 0.021 0.981 $\pm$ 0.018 *
rmse 1.280 $\pm$ 0.022 1.334 $\pm$ 0.042 *
code

# %%
import matplotlib.pyplot as plt
import seaborn as sns
from src.analysis.figures import get_dpkd
NORMALIZE = True
dfr = pd.read_csv(f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv", index_col=0)
df = pd.read_csv("/cluster/home/t122995uhn/projects/data/PlatinumDataset/nomsa_binary_original_binary/full/cleaned_XY.csv", index_col=0).dropna()
df = get_in_binding(df, dfr=dfr)

fig = plt.figure(figsize=(14,5))
ax = fig.subplots(1,1)

# must include 0 in both cases since they are the wildtype reference 
true_dpkd1 = get_dpkd(df.query('(pocket == 0) | (pocket == 2)'), 'pkd', NORMALIZE)
sns.histplot(true_dpkd1, kde=True, ax=ax, alpha=0.6, color='orange', label='not in pocket', stat='proportion')
true_dpkd1 = get_dpkd(df.query('(pocket == 0) | (pocket == 1)'), 'pkd', NORMALIZE)
sns.histplot(true_dpkd1, kde=True, ax=ax, alpha=0.6, color=None, label='in pocket', stat='proportion')
ax.set_title(f"{'Normalized 'if NORMALIZE else ''}TRUE Δpkd distribution")
ax.set_xlabel('Δpkd')
ax.legend()

2.3.2. distance to ligand

image

  • We will treat this as a classification problem (near is <4A)
counts
wt 981
near lig 577
not near lig 372
mutation near lig (<4A) mutation not near lig (>4A) p-val
pcorr 0.164 $\pm$ 0.012 0.198 $\pm$ 0.017 0.1405
scorr 0.104 $\pm$ 0.009 0.079 $\pm$ 0.015 0.1844
mse 1.672 $\pm$ 0.023 1.604 $\pm$ 0.034 0.1405
mae 0.912 $\pm$ 0.010 0.939 $\pm$ 0.005 0.0408
rmse 1.293 $\pm$ 0.009 1.266 $\pm$ 0.013 0.1381
Distribution

image

Code

from src.analysis.figures import tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_n_mut, tbl_dpkd_metrics_in_binding, predictive_performance, tbl_stratified_dpkd_metrics
from src.analysis.metrics import get_metrics
from src import config as cfg
import pandas as pd

#%%
MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv"
RAW_PLT_CSV=f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv"
NORMALIZE = True
n_models=5
verbose=True
plot=True
dfr = pd.read_csv(RAW_PLT_CSV, index_col=0)
dfp = pd.read_csv(MODEL(0), index_col=0)

#%%
import seaborn as sns
sns.histplot(dfr['mut.distance_to_lig'])


#%%
# add in_binding info to df
thres = 4
def get_in_binding(df, dfr):
    """
    df is the predicted csv with index as <raw_idx>_wt (or *_mt) where raw_idx 
    corresponds to an index in dfr which contains the raw data for platinum including 
    ('mut.in_binding_site')
        - 0: wildtype rows
        - 1: close (<8 Ang)
        - 2: Far (>8 Ang)
    """
    near_lig = dfr[dfr['mut.distance_to_lig'] < thres].index   
    pclass = []
    for code in df.index:
        if '_wt' in code:
            pclass.append(0)
        elif int(code.split('_')[0]) in near_lig:
            pclass.append(1)
        else:
            pclass.append(2)
            
    df['near_lig'] = pclass
    return df

conditions = ['(near_lig == 0) | (near_lig == 1)', '(near_lig == 0) | (near_lig == 2)']
names = [f'mutation near lig (<{thres}A)', f'mutation not near lig (>{thres}A)']

df = get_in_binding(dfp, dfr)
if verbose: 
    cnts = df.near_lig.value_counts()
    cnts.index = ['wt', 'near lig', 'not near lig']
    cnts.name = "counts"
    print(cnts.to_markdown(), end="\n\n")

#%%
tbl_stratified_dpkd_metrics(MODEL, NORMALIZE, n_models=n_models, df_transform=get_in_binding,
                                    conditions=conditions, names=names, verbose=verbose, plot=plot, dfr=dfr)

@jyaacoub
Copy link
Owner Author

jyaacoub commented May 1, 2024

3. Significant Mutation impact analysis

With overlap best threshold is 0.1*STD:

Figures (ROC curve and sample confusion matrix)

image
image

Without overlap best threshold is 0.3*STD

Figures (ROC curve and sample confusion matrix)

image
image

code
#%%
from src.analysis.figures import get_dpkd, fig_sig_mutations_conf_matrix, generate_roc_curve
from src.analysis.metrics import get_metrics
import numpy as np
import pandas as pd
from scipy.stats import ttest_ind

MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv" 

data_p = lambda set: f'/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv'
df_t = pd.Index.append(pd.read_csv(data_p('train'), index_col=0).index, 
                       pd.read_csv(data_p('val'), index_col=0).index)
df_t = df_t.str.upper()

results_with_overlap = []
results_without_overlap = []
i=0

df = pd.read_csv(MODEL(i), index_col=0).dropna()
df['pdb'] = df['prot_id'].str.split('_').str[0]
df_no = df[~(df['pdb'].isin(df_t))]

#%%
true_dpkd = get_dpkd(df, pkd_col='actual')
pred_dpkd = get_dpkd(df, pkd_col='pred')
true_dpkd_no = get_dpkd(df_no, pkd_col='actual')
pred_dpkd_no = get_dpkd(df_no, pkd_col='pred')

# %%
# ROC 
_, _, _, best_threshold = generate_roc_curve(true_dpkd, pred_dpkd, thres_range=(0,5), step=0.1)
_ = fig_sig_mutations_conf_matrix(true_dpkd, pred_dpkd, std=round(best_threshold, 3))

# %%
_, _, _, best_threshold = generate_roc_curve(true_dpkd_no, pred_dpkd_no, thres_range=(0,5), step=0.1)
_ = fig_sig_mutations_conf_matrix(true_dpkd_no, pred_dpkd_no, std=round(best_threshold, 3))

jyaacoub added a commit that referenced this issue May 1, 2024
jyaacoub added a commit that referenced this issue May 2, 2024
jyaacoub added a commit that referenced this issue May 2, 2024
jyaacoub added a commit that referenced this issue May 2, 2024
jyaacoub added a commit that referenced this issue May 4, 2024
jyaacoub added a commit that referenced this issue May 4, 2024
@jyaacoub jyaacoub closed this as completed May 4, 2024
jyaacoub added a commit that referenced this issue May 8, 2024
Platinum analysis figures and TCGA init #94 and #95
@jyaacoub jyaacoub pinned this issue May 8, 2024
jyaacoub added a commit that referenced this issue May 8, 2024
ligand_name is only really used by platinum (#26, #27). Davis and kiba use CIDs and CHEMBL repectively.
jyaacoub added a commit that referenced this issue May 8, 2024
feat(download): cids + CHEMBL for sdf download via pubchem #94 #27
@jyaacoub
Copy link
Owner Author

jyaacoub commented May 14, 2024

Pretrained Davis results

outline

  • build dataset from aflow confirmations
  • Run hyperparameter tuning
  • Train davis
  • Run platinum evaluation

results:

Code

# %%
import torch, os
import pandas as pd

from src import cfg
from src import TUNED_MODEL_CONFIGS
from src.utils.loader import Loader
from src.train_test.training import test
from src.analysis.figures import predictive_performance, tbl_stratified_dpkd_metrics, tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_in_binding

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

INFERENCE = False
VERBOSE = True
out_dir = f'{cfg.MEDIA_SAVE_DIR}/test_set_pred/'
os.makedirs(out_dir, exist_ok=True)
cp_dir = cfg.CHECKPOINT_SAVE_DIR
RAW_PLT_CSV=f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv"

#%% load up model:
for KEY, CONFIG in TUNED_MODEL_CONFIGS.items():
    MODEL_KEY = lambda fold: Loader.get_model_key(CONFIG['model'], CONFIG['dataset'], CONFIG['feature_opt'], CONFIG['edge_opt'], 
                                    CONFIG['batch_size'], CONFIG['lr'], CONFIG['architecture_kwargs']['dropout'],
                                    n_epochs=2000, fold=fold, 
                                    ligand_feature=CONFIG['lig_feat_opt'], ligand_edge=CONFIG['lig_edge_opt'])
    print('\n\n'+ '## ' + KEY)
    OUT_PLT = lambda i: f'{out_dir}/{MODEL_KEY(i)}_PLATINUM.csv'
    db_p = f"{CONFIG['feature_opt']}_{CONFIG['edge_opt']}_{CONFIG['lig_feat_opt']}_{CONFIG['lig_edge_opt']}"
    
    if CONFIG['dataset'] in ['kiba', 'davis']:
        db_p = f"DavisKibaDataset/{CONFIG['dataset']}/{db_p}"
    else:
        db_p = f"{CONFIG['dataset']}Dataset/{db_p}"
        
    train_p = lambda set: f"{cfg.DATA_ROOT}/{db_p}/{set}0/cleaned_XY.csv"
    
    if not os.path.exists(OUT_PLT(0)) or INFERENCE:
        print('running inference!')
        cp = lambda fold: f"{cp_dir}/{MODEL_KEY(fold)}.model"
        
        model = Loader.init_model(model=CONFIG["model"], pro_feature=CONFIG["feature_opt"],
                                    pro_edge=CONFIG["edge_opt"],**CONFIG['architecture_kwargs'])

        # load up platinum test db
        loaders = Loader.load_DataLoaders(cfg.DATA_OPT.platinum,
                                    pro_feature    = CONFIG['feature_opt'], 
                                    edge_opt       = CONFIG['edge_opt'],
                                    ligand_feature = CONFIG['lig_feat_opt'], 
                                    ligand_edge    = CONFIG['lig_edge_opt'],
                                    datasets=['test'])

        for i in range(5):
            model.safe_load_state_dict(torch.load(cp(i), map_location=device))
            model.to(device)
            model.eval()

            loss, pred, actual = test(model, loaders['test'], device, verbose=True)
            
            # saving as csv with columns code, pred, actual
            # get codes from test loader
            codes, pid = [b['code'][0] for b in loaders['test']], [b['prot_id'][0] for b in loaders['test']]
            df = pd.DataFrame({'prot_id': pid, 'pred': pred, 'actual': actual}, index=codes)
            df.index.name = 'code'
            df.to_csv(OUT_PLT(i))

    # run platinum eval:
    print('\n### 1. predictive performance')
    mkdown = predictive_performance(OUT_PLT, train_p, verbose=VERBOSE, plot=False)
    print('\n### 2 Mutation impact analysis')
    print('\n#### 2.1 $\Delta pkd$ predictive performance')
    mkdn = tbl_dpkd_metrics_overlap(OUT_PLT, train_p, verbose=VERBOSE, plot=False)
    print('\n#### 2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)')
    m = tbl_dpkd_metrics_in_binding(OUT_PLT, RAW_PLT_CSV, verbose=VERBOSE, plot=False)

davis_gvpl_aflow

1. predictive performance

mean predictive performance
cindex 0.472 $\pm$ 0.026
pcorr -0.062 $\pm$ 0.050
scorr -0.082 $\pm$ 0.078
mse 2.123 $\pm$ 0.099
mae 1.122 $\pm$ 0.049
rmse 1.455 $\pm$ 0.034

2 Mutation impact analysis

2.1 $\Delta pkd$ predictive performance

with overlap without overlap p-val
pcorr 0.005 $\pm$ 0.012 0.005 $\pm$ 0.012 1
scorr 0.003 $\pm$ 0.025 0.003 $\pm$ 0.025 1
mse 1.990 $\pm$ 0.025 1.990 $\pm$ 0.025 1
mae 0.980 $\pm$ 0.013 0.980 $\pm$ 0.013 1
rmse 1.411 $\pm$ 0.009 1.411 $\pm$ 0.009 1

2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)

counts
wt 981
pckt 708
not pckt 241
mutation in pocket mutation NOT in pocket p-val
pcorr 0.004 $\pm$ 0.015 0.025 $\pm$ 0.032 0.5777
scorr -0.009 $\pm$ 0.026 0.067 $\pm$ 0.071 0.3433
mse 1.992 $\pm$ 0.030 1.950 $\pm$ 0.065 0.5777
mae 1.000 $\pm$ 0.009 0.981 $\pm$ 0.021 0.4374
rmse 1.411 $\pm$ 0.011 1.396 $\pm$ 0.023 0.5578

davis_gvpl

1. predictive performance

mean predictive performance
cindex 0.469 $\pm$ 0.007
pcorr -0.058 $\pm$ 0.020
scorr -0.095 $\pm$ 0.021
mse 2.114 $\pm$ 0.040
mae 1.148 $\pm$ 0.017
rmse 1.454 $\pm$ 0.014

2 Mutation impact analysis

2.1 $\Delta pkd$ predictive performance

with overlap without overlap p-val
pcorr 0.008 $\pm$ 0.014 0.008 $\pm$ 0.014 1
scorr 0.023 $\pm$ 0.011 0.023 $\pm$ 0.011 1
mse 1.983 $\pm$ 0.029 1.983 $\pm$ 0.029 1
mae 0.974 $\pm$ 0.011 0.974 $\pm$ 0.011 1
rmse 1.408 $\pm$ 0.010 1.408 $\pm$ 0.010 1

2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)

counts
wt 981
pckt 725
not pckt 256
mutation in pocket mutation NOT in pocket p-val
pcorr 0.012 $\pm$ 0.012 -0.008 $\pm$ 0.052 0.7244
scorr 0.014 $\pm$ 0.019 0.025 $\pm$ 0.024 0.7289
mse 1.977 $\pm$ 0.024 2.016 $\pm$ 0.105 0.7244
mae 0.985 $\pm$ 0.014 0.979 $\pm$ 0.023 0.8043
rmse 1.406 $\pm$ 0.009 1.418 $\pm$ 0.037 0.7599

davis_aflow

1. predictive performance

mean predictive performance
cindex 0.446 $\pm$ 0.018
pcorr -0.127 $\pm$ 0.059
scorr -0.172 $\pm$ 0.051
mse 2.253 $\pm$ 0.117
mae 1.236 $\pm$ 0.032
rmse 1.499 $\pm$ 0.039

2 Mutation impact analysis

2.1 $\Delta pkd$ predictive performance

with overlap without overlap p-val
pcorr -0.007 $\pm$ 0.010 -0.007 $\pm$ 0.010 1
scorr -0.019 $\pm$ 0.015 -0.019 $\pm$ 0.015 1
mse 2.015 $\pm$ 0.019 2.015 $\pm$ 0.019 1
mae 0.922 $\pm$ 0.024 0.922 $\pm$ 0.024 1
rmse 1.419 $\pm$ 0.007 1.419 $\pm$ 0.007 1

2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)

counts
wt 981
pckt 708
not pckt 241
mutation in pocket mutation NOT in pocket p-val
pcorr -0.013 $\pm$ 0.009 0.031 $\pm$ 0.024 0.117
scorr -0.039 $\pm$ 0.018 0.063 $\pm$ 0.030 0.0208
mse 2.027 $\pm$ 0.018 1.937 $\pm$ 0.048 0.117
mae 0.922 $\pm$ 0.023 0.951 $\pm$ 0.027 0.4358
rmse 1.424 $\pm$ 0.006 1.391 $\pm$ 0.017 0.1167

PDBbind_gvpl_aflow

1. predictive performance

mean predictive performance
cindex 0.641 $\pm$ 0.011
pcorr 0.415 $\pm$ 0.028
scorr 0.411 $\pm$ 0.039
mse 0.947 $\pm$ 0.034
mae 0.754 $\pm$ 0.009
rmse 0.972 $\pm$ 0.018

2 Mutation impact analysis

2.1 $\Delta pkd$ predictive performance

with overlap without overlap p-val
pcorr 0.176 $\pm$ 0.012 0.037 $\pm$ 0.035 0.0058
scorr 0.099 $\pm$ 0.009 0.046 $\pm$ 0.027 0.0974
mse 1.649 $\pm$ 0.024 1.927 $\pm$ 0.071 0.0058
mae 0.899 $\pm$ 0.013 1.014 $\pm$ 0.010 0.0001
rmse 1.284 $\pm$ 0.009 1.387 $\pm$ 0.025 0.005

2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)

counts
wt 981
pckt 708
not pckt 241
mutation in pocket mutation NOT in pocket p-val
pcorr 0.180 $\pm$ 0.013 0.110 $\pm$ 0.025 0.0374
scorr 0.108 $\pm$ 0.010 0.050 $\pm$ 0.046 0.2533
mse 1.640 $\pm$ 0.026 1.781 $\pm$ 0.050 0.0374
mae 0.904 $\pm$ 0.010 0.981 $\pm$ 0.008 0.0003
rmse 1.280 $\pm$ 0.010 1.334 $\pm$ 0.019 0.0365

@jyaacoub jyaacoub reopened this May 14, 2024
jyaacoub added a commit that referenced this issue May 14, 2024
@jyaacoub jyaacoub mentioned this issue May 16, 2024
5 tasks
jyaacoub added a commit that referenced this issue May 29, 2024
jyaacoub added a commit that referenced this issue May 31, 2024
jyaacoub added a commit that referenced this issue Oct 3, 2024
Still need to move this to root dir and create a sample SBATCH script for running it.
@jyaacoub jyaacoub linked a pull request Oct 3, 2024 that will close this issue
jyaacoub added a commit that referenced this issue Oct 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant