Skip to content

Commit

Permalink
Merge pull request #63 from jyaacoub/development
Browse files Browse the repository at this point in the history
Fix SaProt loading and creation for training
  • Loading branch information
jyaacoub authored Nov 23, 2023
2 parents 9755564 + 10ec424 commit 5b0288c
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 160 deletions.
84 changes: 47 additions & 37 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,52 @@
# %%
from src.data_analysis.figures import fig2_pro_feat, fig3_edge_feat, prepare_df, fig4_pro_feat_violin, fig5_edge_feat_violin

df = prepare_df('results/model_media/model_stats.csv')
data = 'kiba'
verbose=False

pdb_f = fig4_pro_feat_violin(df, sel_col='mse', sel_dataset='PDBbind', verbose=verbose)
kiba_e = fig5_edge_feat_violin(df, sel_col='mse', sel_dataset='kiba', verbose=verbose)

# %% simplified plot for debugging
from seaborn import violinplot
from matplotlib import pyplot as plt
import pandas as pd
from statannotations.Annotator import Annotator

new_df = pd.DataFrame({'binary': kiba_e[0], 'simple': kiba_e[1], 'anm': kiba_e[2], 'af2': kiba_e[3]})
new_df

#%%
ax = violinplot(data=kiba_e)
ax.set_xticklabels(['binary', 'simple', 'anm', 'a2'])

pairs = [(0,1), (0,2), (0,3), (1,2), (1,3), (2,3)]
annotator = Annotator(ax, pairs, data=kiba_e, verbose=verbose)
annotator.configure(test='Mann-Whitney', text_format='star', loc='inside',
hide_non_significant=not verbose)
annotator.apply_and_annotate()

import torch
from src.utils import config as cfg
from src.utils.loader import Loader

model = Loader.init_model('SPD', pro_feature='nomsa', pro_edge='binary', dropout=0.4)

# %%
fig4_pro_feat_violin(df, sel_col='cindex', sel_dataset=data, verbose=verbose)
fig4_pro_feat_violin(df, sel_col='mse', sel_dataset=data, verbose=verbose)

#%%
fig5_edge_feat_violin(df, sel_col='cindex', sel_dataset=data, verbose=verbose)
fig5_edge_feat_violin(df, sel_col='mse', sel_dataset=data, verbose=verbose)

d = Loader.load_dataset(data='davis',
pro_feature='foldseek',
edge_opt='binary',
subset='val0')

dl = Loader.load_DataLoaders(loaded_datasets={'val0':d}, batch_train=2)['val0']
# %% ESM emb ####
# cls and sep tokens are added to the sequence by the tokenizer
data = next(iter(dl))

out = model(data['protein'], data['ligand'])
# seq_tok = model.esm_tok(data.pro_seq,
# return_tensors='pt',
# padding=True) # [B, L_max+2]
# seq_tok['input_ids'] = seq_tok['input_ids'].to(data.x.device)
# seq_tok['attention_mask'] = seq_tok['attention_mask'].to(data.x.device)

# esm_emb = model.esm_mdl(**seq_tok).last_hidden_state # [B, L_max+2, emb_dim]

# # mask tokens dont make it through to the final output
# # thus the final output is the same length as if we were to run it through the original ESM

# #%% removing <cls> token
# esm_emb = esm_emb[:,1:,:] # [B, L_max+1, emb_dim]

# # %% removing <sep>/<eos> and <pad> token by applying mask
# # for saProt token 2 == <eos>
# L_max = esm_emb.shape[1] # L_max+1
# mask = torch.arange(L_max)[None, :] < torch.tensor([len(seq)/2 #NOTE: this is the main difference from normal ESM since the input sequence includes SA tokens
# for seq in data.pro_seq])[:, None]
# mask = mask.flatten(0,1) # [B*L_max+1]

# #%% flatten from [B, L_max+1, emb_dim]
# esm_emb = esm_emb.flatten(0,1) # to [B*L_max+1, emb_seqdim]
# esm_emb = esm_emb[mask] # [B*L, emb_dim]

# #%%
# if model.esm_only:
# target_x = esm_emb # [B*L, emb_dim]
# else:
# # append esm embeddings to protein input
# target_x = torch.cat((esm_emb, data.x), axis=1)
# # ->> [B*L, emb_dim+feat_dim]
# %%
fig2_pro_feat(df, sel_col='pearson')
fig3_edge_feat(df, exclude=['af2-anm'], sel_col='pearson')
162 changes: 50 additions & 112 deletions src/data_analysis/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def fig1_pro_overlap(df, sel_col='cindex', verbose=False, show=True):

# overlap
col = group[group['overlap']][sel_col]
t_overlap_val = col.max() if sel_col == 'cindex' else col.min()
t_overlap_val = col.mean()
if np.isnan(t_overlap_val):
t_overlap_val = 0
t_overlap.append(t_overlap_val)

# no overlap
col = group[~group['overlap']][sel_col]
f_overlap_val = col.max() if sel_col == 'cindex' else col.min()
f_overlap_val = col.mean()
if np.isnan(f_overlap_val):
f_overlap_val = 0
f_overlap.append(f_overlap_val)
Expand Down Expand Up @@ -75,89 +75,61 @@ def fig1_pro_overlap(df, sel_col='cindex', verbose=False, show=True):
# Features -> nomsa, msa, shannon, and esm
def fig2_pro_feat(df, verbose=False, sel_col='cindex', exclude=[], show=True, add_labels=True):
# Extract relevant data
filtered_df = df[(df['edge'] == 'binary') & (~df['overlap']) & (df['lig_feat'].isna())] # NOTE
filtered_df = df[(df['edge'] == 'binary') & (~df['overlap'])
& (df['fold'] != '') & (df['lig_feat'].isna())]

# Initialize lists to store cindex values for each dataset type
nomsa = []
msa = []
shannon = []
esm = []
dataset_types = []

for dataset, group in filtered_df.groupby('data'):
if verbose: print(f"\nGroup Name: {dataset}")
if verbose: print(group[['cindex', 'mse', 'feat']])

# Extract max or min values based on sel_col
if sel_col in ['cindex', 'pearson', 'spearman']:
nomsa_v = group[group['feat'] == 'nomsa'][sel_col].max()
msa_v = group[group['feat'] == 'msa'][sel_col].max()
shannon_v = group[group['feat'] == 'shannon'][sel_col].max()
ESM_v = group[group['feat'] == 'ESM'][sel_col].max()
else:
nomsa_v = group[group['feat'] == 'nomsa'][sel_col].min()
msa_v = group[group['feat'] == 'msa'][sel_col].min()
shannon_v = group[group['feat'] == 'shannon'][sel_col].min()
ESM_v = group[group['feat'] == 'ESM'][sel_col].min()
# get only data, feat, and sel_col columns
plot_df = filtered_df[['data', 'feat', sel_col]]

# Append values or 0 if NaN
nomsa.append(nomsa_v if not np.isnan(nomsa_v) else 0)
msa.append(msa_v if not np.isnan(msa_v) else 0)
shannon.append(shannon_v if not np.isnan(shannon_v) else 0)
esm.append(ESM_v if not np.isnan(ESM_v) else 0)
dataset_types.append(dataset)

# Create a DataFrame for plotting
plot_data = pd.DataFrame({
'Dataset': dataset_types,
'nomsa': nomsa,
'msa': msa,
'shannon': shannon,
'esm': esm
})
for c in exclude:
plot_data.drop(c, axis=1, inplace=True)

# Melt the DataFrame for Seaborn barplot
melted_data = pd.melt(plot_data, id_vars=['Dataset'], var_name='Node feature',
value_name=sel_col)

hue_order = ['nomsa', 'msa', 'shannon', 'ESM']
for f in exclude:
plot_df = plot_df[plot_df['feat'] != f]
if f in hue_order:
hue_order.remove(f)

# Create a bar plot using Seaborn
plt.figure(figsize=(14, 7))
sns.set(style="darkgrid")
sns.set_context('poster')
ax = sns.barplot(x='Dataset', y=sel_col, hue='Node feature',
data=melted_data, palette='deep')
ax = sns.barplot(data=plot_df, x='data', y=sel_col, hue='feat', palette='deep', estimator=np.mean,
order=['PDBbind', 'davis', 'kiba'], hue_order=hue_order, errcolor='gray', errwidth=2)
sns.stripplot(data=plot_df, x='data', y=sel_col, hue='feat', palette='deep',
order=['PDBbind', 'davis', 'kiba'], hue_order=hue_order,
size=4, jitter=True, dodge=True, alpha=0.8, ax=ax)
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[len(hue_order):], labels[len(hue_order):],
title='', loc='upper right')

if add_labels:
for i in ax.containers:
ax.bar_label(i, fmt='%.3f', fontsize=13)
ax.bar_label(i, fmt='%.3f', fontsize=13, label_type='center')

# Set the title
ax.set_title(f'Node feature performance ({"concordance index" if sel_col == "cindex" else sel_col})')

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=labels, loc='upper right')

# Set the y-axis label and limit
ax.set_ylabel(sel_col)
if sel_col == 'cindex':
ax.set_ylim([0.5, 1]) # 0.5 is the worst cindex value

# Add statistical annotations

pairs=[("PDBbind", "kiba"), ("PDBbind", "davis"), ("davis", "kiba")]
annotator = Annotator(ax, pairs, data=df, x='data', y='cindex', order=['PDBbind', 'davis', 'kiba'],
annotator = Annotator(ax, pairs, data=plot_df,
x='data', y=sel_col, order=['PDBbind', 'davis', 'kiba'], #NOTE: this needs to be fixed
verbose=verbose)
annotator.configure(test='Mann-Whitney', text_format='star', loc='inside', hide_non_significant=True,
line_height=0.005, verbose=verbose)
annotator.apply_and_annotate()

# Show the plot
if show:
plt.show()

# reset stylesheet back to defaults
mpl.rcParams.update(mpl.rcParamsDefault)

return nomsa, msa, shannon, esm
return plot_df, filtered_df

# Figure 3 - Edge type cindex difference
# Edges -> binary, simple, anm, af2
Expand All @@ -167,79 +139,45 @@ def fig3_edge_feat(df, verbose=False, sel_col='cindex', exclude=[], show=True, a

# this will capture multiple models per dataset (different LR, batch size, etc)
# Taking the max cindex value for each dataset will give us the best model for each dataset
filtered_df = df[(df['feat'] == 'nomsa') & (~df['overlap'])]

# these groups are spaced by the data type, physically grouping bars of the same dataset together.
# Initialize lists to store cindex values for each dataset type
binary = []
simple = []
anm = []
af2 = []
af2_anm = []
dataset_types = []
filtered_df = df[(df['feat'] == 'nomsa') & (~df['overlap'])
& (df['fold'] != '') & (df['lig_feat'].isna())]
plot_df = filtered_df[['data', 'edge', sel_col]]

for dataset, group in filtered_df.groupby('data'):
if verbose: print('')
if verbose: print(group[['cindex', 'mse', 'overlap', 'data']])

value_dict = {}

for edge_value in ['binary', 'simple', 'anm', 'af2', 'af2-anm']:
filtered_group = group[group['edge'] == edge_value]

if sel_col == ['cindex', 'pearson', 'spearman']:
value = filtered_group[sel_col].max()
else:
value = filtered_group[sel_col].min()

value_dict[edge_value] = value if not np.isnan(value) else 0

binary.append(value_dict['binary'])
simple.append(value_dict['simple'])
anm.append(value_dict['anm'])
af2.append(value_dict['af2'])
af2_anm.append(value_dict['af2-anm'])
dataset_types.append(dataset)
hue_order = ['binary', 'simple', 'anm', 'af2', 'af2-anm']
for f in exclude:
plot_df = plot_df[plot_df['edge'] != f]
if f in hue_order:
hue_order.remove(f)

# Create a DataFrame for plotting
plot_data = pd.DataFrame({
'Dataset': dataset_types,
'binary': binary,
'simple': simple,
'anm': anm,
'af2': af2,
'af2-anm': af2_anm
})
for c in exclude:
plot_data.drop(c, axis=1, inplace=True)

# Melt the DataFrame for Seaborn barplot
melted_data = pd.melt(plot_data, id_vars=['Dataset'], var_name='Edge type',
value_name=sel_col)

# Create a bar plot using Seaborn
plt.figure(figsize=(14, 7))
sns.set(style="darkgrid")
sns.set_context('poster')
ax = sns.barplot(x='Dataset', y=sel_col, hue='Edge type',
data=melted_data, palette='deep')
ax = sns.barplot(data=plot_df, x='data', y=sel_col, hue='edge', palette='deep', estimator=np.mean,
order=['PDBbind', 'davis', 'kiba'], hue_order=hue_order, errcolor='gray', errwidth=2)
sns.stripplot(data=plot_df, x='data', y=sel_col, hue='edge', palette='deep',
order=['PDBbind', 'davis', 'kiba'], hue_order=hue_order,
size=4, jitter=True, dodge=True, alpha=0.8, ax=ax)
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[len(hue_order):], labels[len(hue_order):],
title='', loc='upper right')

if add_labels:
for i in ax.containers:
ax.bar_label(i, fmt='%.3f', fontsize=13)
ax.bar_label(i, fmt='%.3f', fontsize=13, label_type='center')

# Set the title
ax.set_title(f'Edge type performance ({"concordance index" if sel_col == "cindex" else sel_col})')

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=labels, loc='upper right')

# Set the y-axis label and limit
ax.set_ylabel(sel_col)
if sel_col == 'cindex':
ax.set_ylim([0.5, 1]) # 0.5 is the worst cindex value

# Add statistical annotations
pairs=[("PDBbind", "kiba"), ("PDBbind", "davis"), ("davis", "kiba")]
annotator = Annotator(ax, pairs, data=df, x='data', y='cindex', order=['PDBbind', 'davis', 'kiba'],
annotator = Annotator(ax, pairs, data=filtered_df, # Use the original filtered DataFrame
x='data', y=sel_col, order=['PDBbind', 'davis', 'kiba'],
verbose=verbose)
annotator.configure(test='Mann-Whitney', text_format='star', loc='inside', hide_non_significant=True,
line_height=0.005, verbose=verbose)
Expand All @@ -251,7 +189,7 @@ def fig3_edge_feat(df, verbose=False, sel_col='cindex', exclude=[], show=True, a
# reset stylesheet back to defaults
mpl.rcParams.update(mpl.rcParamsDefault)

return binary, simple, anm, af2, af2_anm
return filtered_df

# Figure 4: violin plot with error bars for Cross-validation results to show significance among pro feats
def fig4_pro_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex', exclude=[],
Expand Down
4 changes: 2 additions & 2 deletions src/data_analysis/stratify_protein.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def kinbase_to_df(fasta_fp:str=f'{cfg.DATA_ROOT}/misc/Human_kinase_domain.fasta'
for i in range(len(lines)):
line = lines[i]
if line[0] == '>': # header
seq = lines[i+1]
seq = lines[i+1].strip()
name = re.search(r'^>(.+?)_Hsap', line).group(1)
# all in the fasta has a protein family discriptor with at least 2 elements
protein_family = re.search(r'\((.*)\)', line).group(1)
Expand Down Expand Up @@ -74,7 +74,7 @@ def check_davis_names(davis_prots:dict, df:pd.DataFrame) -> list:

df = kinbase_to_df() if df is None else df

greek = {'alpha', 'beta', 'gamma', 'delta'} # for checking if protein name has greek letter
greek = {'alpha', 'beta', 'gamma', 'delta', 'epsilon'} # for checking if protein name has greek letter

found_prots = {}
for k in davis_prots.keys():
Expand Down
2 changes: 1 addition & 1 deletion src/data_processing/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __len__(self):
return len(self.df)

def __getitem__(self, idx) -> dict:
row = self.df.iloc[idx]
row = self.df.iloc[idx] #WARNING: idx must be a list in future versions on pandas since it is deprecated
code = row.name
prot_id = row['prot_id']
lig_seq = row['SMILE']
Expand Down
Loading

0 comments on commit 5b0288c

Please sign in to comment.