Skip to content

Commit

Permalink
results: stratified in_pocket metrics #94
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed May 2, 2024
1 parent 57be597 commit 9df48e7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 5 deletions.
44 changes: 41 additions & 3 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,46 @@
# %%
from src.analysis.figures import tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_n_mut
from src.analysis.figures import tbl_dpkd_metrics_n_mut, tbl_stratified_dpkd_metrics
MODEL = lambda i: f"results/model_media/test_set_pred/GVPLM_PDBbind{i}D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE_PLATINUM.csv"
TRAIN_DATA_P = lambda set: f'/cluster/home/t122995uhn/projects/data/PDBbindDataset/nomsa_aflow_gvp_binary/{set}0/cleaned_XY.csv'
NORMALIZE = True

print('NUM MUTATIONS:')
mkdnm = tbl_dpkd_metrics_n_mut(MODEL, NORMALIZE, conditions=[1,2], plot=True)

#%%
# add in_binding info to df
def get_in_binding(df, dfr):
"""
df is the predicted csv with index as <raw_idx>_wt (or *_mt) where raw_idx
corresponds to an index in dfr which contains the raw data for platinum including
('mut.in_binding_site')
- 0: wildtype rows
- 1: in pocket
- 2: outside of pocket
"""
in_pocket = dfr[dfr['mut.in_binding_site'] == 'YES'].index
pclass = []
for code in df.index:
if '_wt' in code:
pclass.append(0)
elif int(code.split('_')[0]) in in_pocket:
pclass.append(1)
else:
pclass.append(2)

df['in_pocket'] = pclass
return df

# get df_binding info from /cluster/home/t122995uhn/projects/data/PlatinumDataset/raw/platinum_flat_file.csv
conditions = ['(in_pocket == 0) | (in_pocket == 1)', '(in_pocket == 0) | (in_pocket == 2)']
names = ['mutation in pocket', 'mutation NOT in pocket']
#%%
import pandas as pd
dfr = pd.read_csv('/cluster/home/t122995uhn/projects/data/PlatinumDataset/raw/platinum_flat_file.csv', index_col=0)
dfp = pd.read_csv(MODEL(0), index_col=0)

df = get_in_binding(dfp, dfr)
print(df.in_pocket.value_counts())

#%%
tbl_stratified_dpkd_metrics(MODEL, NORMALIZE, n_models=5, df_transform=get_in_binding,
conditions=conditions, names=names, verbose=True, plot=True, dfr=dfr)
# %%
7 changes: 5 additions & 2 deletions src/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_mut_count(df):
df['n_mut'] = n_mut
return df

def generate_markdown(results, names=None, verbose=False):
def generate_markdown(results, names=None, verbose=False, thresh_sig=False):
"""
generates a markdown given a list or single df containing metrics from get_metrics
Expand Down Expand Up @@ -97,7 +97,10 @@ def generate_markdown(results, names=None, verbose=False):
if n_groups == 2: # no support for sig if groups are more than 2
# T-tests for significance
ttests = {col: ttest_ind(results_df[0][col], results_df[1][col]) for col in results_df[0].columns}
sig = pd.Series({col: '*' if ttests[col].pvalue < 0.05 else '' for col in results_df[0].columns})
if thresh_sig:
sig = pd.Series({col: '*' if ttests[col].pvalue < 0.05 else '' for col in results_df[0].columns})
else:
sig =pd.Series({col: f"{ttests[col].pvalue:.4f}" for col in results_df[0].columns})

md_table = pd.concat([md_table, sig], axis=1)
md_table.columns = [*names, 'Sig']
Expand Down

0 comments on commit 9df48e7

Please sign in to comment.