-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #63 from jyaacoub/development
Fix SaProt loading and creation for training
- Loading branch information
Showing
6 changed files
with
184 additions
and
160 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,52 @@ | ||
# %% | ||
from src.data_analysis.figures import fig2_pro_feat, fig3_edge_feat, prepare_df, fig4_pro_feat_violin, fig5_edge_feat_violin | ||
|
||
df = prepare_df('results/model_media/model_stats.csv') | ||
data = 'kiba' | ||
verbose=False | ||
|
||
pdb_f = fig4_pro_feat_violin(df, sel_col='mse', sel_dataset='PDBbind', verbose=verbose) | ||
kiba_e = fig5_edge_feat_violin(df, sel_col='mse', sel_dataset='kiba', verbose=verbose) | ||
|
||
# %% simplified plot for debugging | ||
from seaborn import violinplot | ||
from matplotlib import pyplot as plt | ||
import pandas as pd | ||
from statannotations.Annotator import Annotator | ||
|
||
new_df = pd.DataFrame({'binary': kiba_e[0], 'simple': kiba_e[1], 'anm': kiba_e[2], 'af2': kiba_e[3]}) | ||
new_df | ||
|
||
#%% | ||
ax = violinplot(data=kiba_e) | ||
ax.set_xticklabels(['binary', 'simple', 'anm', 'a2']) | ||
|
||
pairs = [(0,1), (0,2), (0,3), (1,2), (1,3), (2,3)] | ||
annotator = Annotator(ax, pairs, data=kiba_e, verbose=verbose) | ||
annotator.configure(test='Mann-Whitney', text_format='star', loc='inside', | ||
hide_non_significant=not verbose) | ||
annotator.apply_and_annotate() | ||
|
||
import torch | ||
from src.utils import config as cfg | ||
from src.utils.loader import Loader | ||
|
||
model = Loader.init_model('SPD', pro_feature='nomsa', pro_edge='binary', dropout=0.4) | ||
|
||
# %% | ||
fig4_pro_feat_violin(df, sel_col='cindex', sel_dataset=data, verbose=verbose) | ||
fig4_pro_feat_violin(df, sel_col='mse', sel_dataset=data, verbose=verbose) | ||
|
||
#%% | ||
fig5_edge_feat_violin(df, sel_col='cindex', sel_dataset=data, verbose=verbose) | ||
fig5_edge_feat_violin(df, sel_col='mse', sel_dataset=data, verbose=verbose) | ||
|
||
d = Loader.load_dataset(data='davis', | ||
pro_feature='foldseek', | ||
edge_opt='binary', | ||
subset='val0') | ||
|
||
dl = Loader.load_DataLoaders(loaded_datasets={'val0':d}, batch_train=2)['val0'] | ||
# %% ESM emb #### | ||
# cls and sep tokens are added to the sequence by the tokenizer | ||
data = next(iter(dl)) | ||
|
||
out = model(data['protein'], data['ligand']) | ||
# seq_tok = model.esm_tok(data.pro_seq, | ||
# return_tensors='pt', | ||
# padding=True) # [B, L_max+2] | ||
# seq_tok['input_ids'] = seq_tok['input_ids'].to(data.x.device) | ||
# seq_tok['attention_mask'] = seq_tok['attention_mask'].to(data.x.device) | ||
|
||
# esm_emb = model.esm_mdl(**seq_tok).last_hidden_state # [B, L_max+2, emb_dim] | ||
|
||
# # mask tokens dont make it through to the final output | ||
# # thus the final output is the same length as if we were to run it through the original ESM | ||
|
||
# #%% removing <cls> token | ||
# esm_emb = esm_emb[:,1:,:] # [B, L_max+1, emb_dim] | ||
|
||
# # %% removing <sep>/<eos> and <pad> token by applying mask | ||
# # for saProt token 2 == <eos> | ||
# L_max = esm_emb.shape[1] # L_max+1 | ||
# mask = torch.arange(L_max)[None, :] < torch.tensor([len(seq)/2 #NOTE: this is the main difference from normal ESM since the input sequence includes SA tokens | ||
# for seq in data.pro_seq])[:, None] | ||
# mask = mask.flatten(0,1) # [B*L_max+1] | ||
|
||
# #%% flatten from [B, L_max+1, emb_dim] | ||
# esm_emb = esm_emb.flatten(0,1) # to [B*L_max+1, emb_seqdim] | ||
# esm_emb = esm_emb[mask] # [B*L, emb_dim] | ||
|
||
# #%% | ||
# if model.esm_only: | ||
# target_x = esm_emb # [B*L, emb_dim] | ||
# else: | ||
# # append esm embeddings to protein input | ||
# target_x = torch.cat((esm_emb, data.x), axis=1) | ||
# # ->> [B*L, emb_dim+feat_dim] | ||
# %% | ||
fig2_pro_feat(df, sel_col='pearson') | ||
fig3_edge_feat(df, exclude=['af2-anm'], sel_col='pearson') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.