Skip to content

Commit

Permalink
results: 1. predictive performance #94
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed May 3, 2024
1 parent 9df48e7 commit 78ad610
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 47 deletions.
50 changes: 12 additions & 38 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,20 @@
# %%
from src.analysis.figures import tbl_dpkd_metrics_n_mut, tbl_stratified_dpkd_metrics
MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv"
TRAIN_DATA_P = lambda set: f'/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv'
NORMALIZE = True
from src.analysis.figures import tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_n_mut, tbl_dpkd_metrics_in_binding, predictive_performance
from src.analysis.metrics import get_metrics

_ = predictive_performance(compare_overlap=True, verbose=True, plot=True, NORMALIZE=False)

# %%
_ = predictive_performance(compare_overlap=True, verbose=True, plot=True, NORMALIZE=True)

# %%
tbl_dpkd_metrics_overlap()

#%%
# add in_binding info to df
def get_in_binding(df, dfr):
"""
df is the predicted csv with index as <raw_idx>_wt (or *_mt) where raw_idx
corresponds to an index in dfr which contains the raw data for platinum including
('mut.in_binding_site')
- 0: wildtype rows
- 1: in pocket
- 2: outside of pocket
"""
in_pocket = dfr[dfr['mut.in_binding_site'] == 'YES'].index
pclass = []
for code in df.index:
if '_wt' in code:
pclass.append(0)
elif int(code.split('_')[0]) in in_pocket:
pclass.append(1)
else:
pclass.append(2)

df['in_pocket'] = pclass
return df
tbl_dpkd_metrics_n_mut()

# get df_binding info from /cluster/home/t122995uhn/projects/data/PlatinumDataset/raw/platinum_flat_file.csv
conditions = ['(in_pocket == 0) | (in_pocket == 1)', '(in_pocket == 0) | (in_pocket == 2)']
names = ['mutation in pocket', 'mutation NOT in pocket']
#%%
import pandas as pd
dfr = pd.read_csv('/cluster/home/t122995uhn/projects/data/PlatinumDataset/raw/platinum_flat_file.csv', index_col=0)
dfp = pd.read_csv(MODEL(0), index_col=0)
tbl_dpkd_metrics_in_binding()

df = get_in_binding(dfp, dfr)
print(df.in_pocket.value_counts())

#%%
tbl_stratified_dpkd_metrics(MODEL, NORMALIZE, n_models=5, df_transform=get_in_binding,
conditions=conditions, names=names, verbose=True, plot=True, dfr=dfr)
# %%
#%
109 changes: 108 additions & 1 deletion src/analysis/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,65 @@ def prepare_df(csv_p:str=cfg.MODEL_STATS_CSV, old_csv_p:str=None) -> pd.DataFram
#######################################################################################################
##################################### MUTATION ANALYSIS RELATED FIGS: #################################
#######################################################################################################
def predictive_performance(
MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv",
TRAIN_DATA_P = lambda set: f'{cfg.DATA_ROOT}/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv',
NORMALIZE = True,
n_models=5,
compare_overlap=False,
verbose=True,
plot=False,
):
df_t = pd.Index.append(pd.read_csv(TRAIN_DATA_P('train'), index_col=0).index,
pd.read_csv(TRAIN_DATA_P('val'), index_col=0).index)
df_t = df_t.str.upper()

results_with_overlap = []
results_without_overlap = []

for i in range(n_models):
df = pd.read_csv(MODEL(i), index_col=0).dropna()
df['pdb'] = df['prot_id'].str.split('_').str[0]
if NORMALIZE:
mean_df = df[['actual','pred']].mean(axis=0, numeric_only=True)
std_df = df[['actual','pred']].std(axis=0, numeric_only=True)

df[['actual','pred']] = (df[['actual','pred']] - mean_df) / std_df # z-normalization

# with overlap
cindex, p_corr, s_corr, mse, mae, rmse = get_metrics(df['actual'], df['pred'])
results_with_overlap.append([cindex, p_corr[0], s_corr[0], mse, mae, rmse])

# without overlap
df_no_overlap = df[~(df['pdb'].isin(df_t))]
cindex, p_corr, s_corr, mse, mae, rmse = get_metrics(df_no_overlap['actual'], df_no_overlap['pred'])
results_without_overlap.append([cindex, p_corr[0], s_corr[0], mse, mae, rmse])

if i==0 and plot:
n_plots = int(compare_overlap)+1
fig = plt.figure(figsize=(14,5*n_plots))
axes = fig.subplots(n_plots,1)
ax = axes[0] if compare_overlap else axes

sns.histplot(df_no_overlap['actual'], kde=True, ax=ax, alpha=0.5, label='True pkd')
sns.histplot(df_no_overlap['pred'], kde=True, ax=ax, alpha=0.5, label='Predicted pkd', color='orange')
ax.set_title(f"{'Normalized 'if NORMALIZE else ''} pkd distribution")
ax.legend()

if compare_overlap:
sns.histplot(df_no_overlap['actual'], kde=True, ax=axes[1], alpha=0.5, label='True pkd')
sns.histplot(df_no_overlap['pred'], kde=True, ax=axes[1], alpha=0.5, label='Predicted pkd', color='orange')
axes[1].set_title(f"{'Normalized 'if NORMALIZE else ''} pkd distribution (no overlap)")
axes[1].legend()

if compare_overlap:
return generate_markdown([results_with_overlap, results_without_overlap], names=['with overlap', 'without overlap'],
cindex=True,verbose=verbose)

return generate_markdown([results_without_overlap], names=['mean $\pm$ se'], cindex=True, verbose=verbose)



def get_dpkd(df, pkd_col='pkd', normalize=False) -> np.ndarray:
"""
2. Mutation impact analysis - Delta pkd given df containing wt and mutated proteins and their pkd values
Expand Down Expand Up @@ -605,7 +664,7 @@ def tbl_stratified_dpkd_metrics(

def tbl_dpkd_metrics_overlap(
MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv",
TRAIN_DATA_P = lambda set: f'/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv',
TRAIN_DATA_P = lambda set: f'{cfg.DATA_ROOT}/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv',
NORMALIZE = True,
verbose=True,
plot=False,
Expand Down Expand Up @@ -663,6 +722,54 @@ def tbl_dpkd_metrics_n_mut(
return tbl_stratified_dpkd_metrics(MODEL, NORMALIZE, n_models, df_transform=get_mut_count,
conditions=conditions, names=names, verbose=verbose, plot=plot)

def tbl_dpkd_metrics_in_binding(
MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv",
RAW_PLT_CSV=f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv",
NORMALIZE = True,
n_models=5,
verbose=True,
plot=False,
):
"""Generates a table comapring the metrics for mutations in the pocket and not in the pocket"""
# add in_binding info to df
def get_in_binding(df, dfr):
"""
df is the predicted csv with index as <raw_idx>_wt (or *_mt) where raw_idx
corresponds to an index in dfr which contains the raw data for platinum including
('mut.in_binding_site')
- 0: wildtype rows
- 1: in pocket
- 2: outside of pocket
"""
in_pocket = dfr[dfr['mut.in_binding_site'] == 'YES'].index
pclass = []
for code in df.index:
if '_wt' in code:
pclass.append(0)
elif int(code.split('_')[0]) in in_pocket:
pclass.append(1)
else:
pclass.append(2)

df['in_pocket'] = pclass
return df

conditions = ['(in_pocket == 0) | (in_pocket == 1)', '(in_pocket == 0) | (in_pocket == 2)']
names = ['mutation in pocket', 'mutation NOT in pocket']

dfr = pd.read_csv(RAW_PLT_CSV, index_col=0)

dfp = pd.read_csv(MODEL(0), index_col=0)
df = get_in_binding(dfp, dfr)
if verbose:
cnts = df.in_pocket.value_counts()
cnts.index = ['wt', 'pckt', 'not pckt']
cnts.name = "counts"
print(cnts.to_markdown(), end="\n\n")

return tbl_stratified_dpkd_metrics(MODEL, NORMALIZE, n_models=n_models, df_transform=get_in_binding,
conditions=conditions, names=names, verbose=verbose, plot=plot, dfr=dfr)

### 3. significant mutations as a classification problem
def fig_sig_mutations_conf_matrix(true_dpkd, pred_dpkd, std=2, verbose=True, plot=True, show_plot=False, ax=None):
"""For 3. significant mutation impact analysis"""
Expand Down
22 changes: 14 additions & 8 deletions src/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict
from scipy.stats import ttest_ind
import pandas as pd
import numpy as np


def count_missing_res(pdb_file: str) -> Tuple[int,int]:
Expand Down Expand Up @@ -67,7 +68,7 @@ def get_mut_count(df):
df['n_mut'] = n_mut
return df

def generate_markdown(results, names=None, verbose=False, thresh_sig=False):
def generate_markdown(results, names=None, verbose=False, thresh_sig=False, cindex=False):
"""
generates a markdown given a list or single df containing metrics from get_metrics
Expand All @@ -81,29 +82,34 @@ def generate_markdown(results, names=None, verbose=False, thresh_sig=False):
n_groups = len(results)
names = names if names else [str(i) for i in range(n_groups)]
# Convert results to DataFrame
results_df = [None for i in range(n_groups)]
results_df = [None for _ in range(n_groups)]
md_table = None
cols = ['cindex'] if cindex else []
cols += ['pcorr', 'scorr', 'mse', 'mae', 'rmse']
for i, r in enumerate(results):
df = pd.DataFrame(r, columns=['pcorr', 'scorr', 'mse', 'mae', 'rmse'])
df = pd.DataFrame(r, columns=cols)

mean = df.mean()
std = df.std()
mean = df.mean(numeric_only=True)
std = df.std(numeric_only=True)
results_df[i] = df

# calculate standard error:
se = std / np.sqrt(len(df))

# formating for markdown table:
combined = mean.map(lambda x: f"{x:.3f}") + " $\pm$ " + std.map(lambda x: f"{x:.3f}")
combined = mean.map(lambda x: f"{x:.3f}") + " $\pm$ " + se.map(lambda x: f"{x:.3f}")
md_table = combined if md_table is None else pd.concat([md_table, combined], axis=1)

if n_groups == 2: # no support for sig if groups are more than 2
# T-tests for significance
# two-sided t-tests for significance
ttests = {col: ttest_ind(results_df[0][col], results_df[1][col]) for col in results_df[0].columns}
if thresh_sig:
sig = pd.Series({col: '*' if ttests[col].pvalue < 0.05 else '' for col in results_df[0].columns})
else:
sig =pd.Series({col: f"{ttests[col].pvalue:.4f}" for col in results_df[0].columns})

md_table = pd.concat([md_table, sig], axis=1)
md_table.columns = [*names, 'Sig']
md_table.columns = [*names, 'p-val']
else:
md_table.columns = names

Expand Down

0 comments on commit 78ad610

Please sign in to comment.