-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(mutagenesis): sequence plotting and sample plot #136
- Loading branch information
Showing
4 changed files
with
195 additions
and
93 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,89 +1,34 @@ | ||
#%% | ||
from src import cfg | ||
|
||
#%% | ||
from src.utils.mutate_model import run_modeller | ||
|
||
run_modeller('/cluster/home/t122995uhn/projects/tmp/d1/O00141.pdb', 50, "MET", "A") | ||
|
||
#%% | ||
import pandas as pd | ||
|
||
def get_test_oncokbs(train_df=pd.read_csv('/cluster/home/t122995uhn/projects/data/test/PDBbindDataset/nomsa_binary_original_binary/full/cleaned_XY.csv'), | ||
oncokb_fp='/cluster/home/t122995uhn/projects/data/tcga/mart_export.tsv', | ||
biomart='/cluster/home/t122995uhn/projects/downloads/oncoKB_DrugGenePairList.csv'): | ||
#Get gene names for PDBbind | ||
dfbm = pd.read_csv(oncokb_fp, sep='\t') | ||
dfbm['PDB ID'] = dfbm['PDB ID'].str.lower() | ||
train_df.reset_index(names='idx',inplace=True) | ||
|
||
df_uni = train_df.merge(dfbm, how='inner', left_on='prot_id', right_on='UniProtKB/Swiss-Prot ID') | ||
df_pdb = train_df.merge(dfbm, how='inner', left_on='code', right_on='PDB ID') | ||
|
||
# identifying ovelap with oncokb | ||
# df_all will have duplicate entries for entries with multiple gene names... | ||
df_all = pd.concat([df_uni, df_pdb]).drop_duplicates(['idx', 'Gene name'])[['idx', 'code', 'Gene name']] | ||
|
||
dfkb = pd.read_csv(biomart) | ||
df_all_kb = df_all.merge(dfkb.drop_duplicates('gene'), left_on='Gene name', right_on='gene', how='inner') | ||
|
||
trained_genes = set(df_all_kb.gene) | ||
|
||
#Identify non-trained genes | ||
return dfkb[~dfkb['gene'].isin(trained_genes)], dfkb[dfkb['gene'].isin(trained_genes)], dfkb | ||
|
||
|
||
train_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/test/PDBbindDataset/nomsa_binary_original_binary/train0/cleaned_XY.csv') | ||
val_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/test/PDBbindDataset/nomsa_binary_original_binary/val0/cleaned_XY.csv') | ||
|
||
train_df = pd.concat([train_df, val_df]) | ||
|
||
get_test_oncokbs(train_df=train_df) | ||
|
||
|
||
|
||
|
||
|
||
#%% | ||
# %% | ||
######################################################################## | ||
########################## BUILD DATASETS ############################## | ||
########################## VIOLIN PLOTTING ############################# | ||
######################################################################## | ||
import os | ||
from src.data_prep.init_dataset import create_datasets | ||
from src import cfg | ||
import logging | ||
cfg.logger.setLevel(logging.DEBUG) | ||
|
||
splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/davis/' | ||
create_datasets([cfg.DATA_OPT.PDBbind, cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba], | ||
feat_opt=cfg.PRO_FEAT_OPT.nomsa, | ||
edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow], | ||
ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp], | ||
ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=False, | ||
k_folds=5, | ||
test_prots_csv=f'{splits}/test.csv', | ||
val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)],) | ||
# data_root=os.path.abspath('../data/test/')) | ||
|
||
# %% Copy splits to commit them: | ||
#from to: | ||
import shutil | ||
from_dir_p = '/cluster/home/t122995uhn/projects/data/v131/' | ||
to_dir_p = '/cluster/home/t122995uhn/projects/MutDTA/splits/' | ||
from_db = ['PDBbindDataset', 'DavisKibaDataset/kiba', 'DavisKibaDataset/davis'] | ||
to_db = ['pdbbind', 'kiba', 'davis'] | ||
|
||
from_db = [f'{from_dir_p}/{f}/nomsa_binary_original_binary/' for f in from_db] | ||
to_db = [f'{to_dir_p}/{f}' for f in to_db] | ||
|
||
for src, dst in zip(from_db, to_db): | ||
for x in ['train', 'val']: | ||
for i in range(5): | ||
print(f"{src}/{x}{i}/XY.csv", f"{dst}/{x}{i}.csv") | ||
shutil.copy(f"{src}/{x}{i}/XY.csv", f"{dst}/{x}{i}.csv") | ||
|
||
print(f"{src}/test/XY.csv", f"{dst}/test.csv") | ||
shutil.copy(f"{src}/test/XY.csv", f"{dst}/test.csv") | ||
|
||
|
||
# %% | ||
from matplotlib import pyplot as plt | ||
|
||
from src.analysis.figures import prepare_df, fig_combined, custom_fig | ||
|
||
dft = prepare_df('./results/v115/model_media/model_stats.csv') | ||
dfv = prepare_df('./results/v115/model_media/model_stats_val.csv') | ||
|
||
models = { | ||
'DG': ('nomsa', 'binary', 'original', 'binary'), | ||
'esm': ('ESM', 'binary', 'original', 'binary'), # esm model | ||
'aflow': ('nomsa', 'aflow', 'original', 'binary'), | ||
# 'gvpP': ('gvp', 'binary', 'original', 'binary'), | ||
'gvpL': ('nomsa', 'binary', 'gvp', 'binary'), | ||
# 'aflow_ring3': ('nomsa', 'aflow_ring3', 'original', 'binary'), | ||
'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'), | ||
# 'gvpL_aflow_rng3': ('nomsa', 'aflow_ring3', 'gvp', 'binary'), | ||
#GVPL_ESMM_davis3D_nomsaF_aflowE_48B_0.00010636872718329864LR_0.23282479481785903D_2000E_gvpLF_binaryLE | ||
# 'gvpl_esm_aflow': ('ESM', 'aflow', 'gvp', 'binary'), | ||
} | ||
|
||
fig, axes = fig_combined(dft, datasets=['davis'], fig_callable=custom_fig, | ||
models=models, metrics=['cindex', 'mse'], | ||
fig_scale=(10,5), add_stats=True, title_postfix=" test set performance", box=True, fold_labels=True) | ||
plt.xticks(rotation=45) | ||
|
||
fig, axes = fig_combined(dfv, datasets=['davis'], fig_callable=custom_fig, | ||
models=models, metrics=['cindex', 'mse'], | ||
fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance", box=True, fold_labels=True) | ||
plt.xticks(rotation=45) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import matplotlib.patheffects as PathEffects | ||
|
||
from src.utils.residue import ResInfo, Chain | ||
from src.utils.mutate_model import run_modeller | ||
|
||
|
||
def plot_sequence(muta, pep_opts, pro_seq): | ||
# Create the plot | ||
plt.figure(figsize=(len(pro_seq), 20)) # Adjust the figure size as needed | ||
plt.imshow(muta, aspect='auto', cmap='coolwarm', interpolation='none') | ||
|
||
# Set the x-ticks to correspond to positions in the protein sequence | ||
plt.xticks(ticks=np.arange(len(pro_seq)), labels=[ResInfo.code_to_pep[p] for p in pro_seq], rotation=45, fontsize=16) | ||
plt.yticks(ticks=np.arange(len(pep_opts)), labels=pep_opts, fontsize=16) | ||
plt.xlabel('Protein Sequence Position', fontsize=75) | ||
plt.ylabel('Peptide Options', fontsize=75) | ||
|
||
# Add text labels to each square | ||
for i in range(len(pep_opts)): | ||
for j in range(len(pro_seq)): | ||
text = plt.text(j, i, f'{ResInfo.pep_to_code[pep_opts[i]]}', ha='center', va='center', color='black', fontsize=8) | ||
# Add a white outline to the text | ||
text.set_path_effects([ | ||
PathEffects.Stroke(linewidth=1, foreground='white'), | ||
PathEffects.Normal() | ||
]) | ||
break | ||
|
||
|
||
# Adjust gridlines to be off-center, forming cell boundaries | ||
plt.gca().set_xticks(np.arange(-0.5, len(pro_seq), 1), minor=True) | ||
plt.gca().set_yticks(np.arange(-0.5, len(pep_opts), 1), minor=True) | ||
plt.grid(which='minor', color='gray', linestyle='-', linewidth=0.5) | ||
|
||
# Remove the major gridlines (optional, for clarity) | ||
plt.grid(which='major', color='none') | ||
|
||
# Add a colorbar to show the scale of mutation values | ||
plt.colorbar(label='Mutation Values') | ||
plt.title('Mutation Array Visualization', fontsize=100) | ||
plt.show() | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
pro_id = "P67870" | ||
pdb_file = f'/cluster/home/t122995uhn/projects/tmp/kiba/{pro_id}.pdb' | ||
|
||
|
||
plot_sequence(np.load('muta-100_200.npy'), pep_opts=list(ResInfo.pep_to_code.keys()), | ||
pro_seq=Chain(pdb_file).sequence) | ||
|
||
|
||
# %% | ||
import os | ||
import logging | ||
logging.getLogger().setLevel(logging.DEBUG) | ||
|
||
import numpy as np | ||
import torch | ||
import torch_geometric as torchg | ||
from tqdm import tqdm | ||
|
||
from src import cfg | ||
from src.utils.loader import Loader | ||
from src.utils.residue import ResInfo, Chain | ||
from src.data_prep.feature_extraction.ligand import smile_to_graph | ||
from src.data_prep.feature_extraction.protein import target_to_graph | ||
from src.utils.residue import ResInfo, Chain | ||
|
||
#%% | ||
DATA = cfg.DATA_OPT.kiba | ||
lig_feature = cfg.LIG_FEAT_OPT.original | ||
lig_edge = cfg.LIG_EDGE_OPT.binary | ||
pro_feature = cfg.PRO_FEAT_OPT.nomsa | ||
pro_edge = cfg.PRO_EDGE_OPT.binary | ||
|
||
lig_seq = "COC1=C(C=CC(=C1)C2=CC3=C(C=C2)C(=CC4=CC=CN4)C(=O)N3)O" #CHEMBL202930 | ||
|
||
# %% build ligand graph | ||
mol_feat, mol_edge = smile_to_graph(lig_seq, lig_feature=lig_feature, lig_edge=lig_edge) | ||
lig = torchg.data.Data(x=torch.Tensor(mol_feat), edge_index=torch.LongTensor(mol_edge),lig_seq=lig_seq) | ||
|
||
# %% Get initial pkd value: | ||
pro_id = "P67870" | ||
pdb_file = f'/cluster/home/t122995uhn/projects/tmp/kiba/{pro_id}.pdb' | ||
|
||
def get_protein_features(pro_id, pdb_file, DATA=DATA): | ||
pdb = Chain(pdb_file) | ||
pro_seq = pdb.sequence | ||
pro_cmap = pdb.get_contact_map() | ||
|
||
updated_seq, extra_feat, edge_idx = target_to_graph(target_sequence=pro_seq, | ||
contact_map=pro_cmap, | ||
threshold=8.0 if DATA is cfg.DATA_OPT.PDBbind else -0.5, | ||
pro_feat=pro_feature) | ||
pro_feat = torch.Tensor(extra_feat) | ||
|
||
pro = torchg.data.Data(x=torch.Tensor(pro_feat), | ||
edge_index=torch.LongTensor(edge_idx), | ||
pro_seq=updated_seq, # Protein sequence for downstream esm model | ||
prot_id=pro_id, | ||
edge_weight=None) | ||
return pro, pro_seq | ||
|
||
pro, pro_seq = get_protein_features(pro_id, pdb_file) | ||
|
||
# %% Loading the model | ||
m = Loader.load_tuned_model('davis_DG', fold=1) | ||
m.eval() | ||
original_pkd = m(pro, lig) | ||
print(original_pkd) | ||
|
||
# %% mutate and regenerate graphs | ||
muta = np.zeros(shape=(len(ResInfo.pep_to_code.keys()), len(pro_seq))) | ||
|
||
# zero indexed res range to mutate: | ||
res_range = (100, 200) | ||
res_range = (max(res_range[0], 0), | ||
min(res_range[1], len(pro_seq))) | ||
|
||
# %% | ||
from src.utils.mutate_model import run_modeller | ||
|
||
amino_acids = ResInfo.amino_acids[:-1] # not including "X" - unknown | ||
|
||
with tqdm(range(*res_range), ncols=80, total=(res_range[1]-res_range[0])) as t: | ||
for j in t: | ||
for i, AA in enumerate(amino_acids): | ||
if i%2 == 0: | ||
t.set_postfix(res=j, AA=i+1) | ||
|
||
if pro_seq[i] == AA: | ||
muta[i,j] = original_pkd | ||
continue | ||
|
||
pro_id = "P67870" | ||
pdb_file = f'/cluster/home/t122995uhn/projects/tmp/kiba/{pro_id}.pdb' | ||
out_pdb_fp = run_modeller(pdb_file, 1, ResInfo.code_to_pep[AA], "A") | ||
|
||
pro, _ = get_protein_features(pro_id, out_pdb_fp) | ||
muta[i,j] = m(pro, lig) | ||
|
||
# delete after use | ||
os.remove(out_pdb_fp) | ||
|
||
#%% | ||
np.save(f"muta-{res_range[0]}_{res_range[1]}.npy", muta) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters