Skip to content

Commit

Permalink
feat(mutagenesis): sequence plotting and sample plot #136
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed Aug 13, 2024
1 parent d8ba18e commit 45702a2
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 93 deletions.
117 changes: 31 additions & 86 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,34 @@
#%%
from src import cfg

#%%
from src.utils.mutate_model import run_modeller

run_modeller('/cluster/home/t122995uhn/projects/tmp/d1/O00141.pdb', 50, "MET", "A")

#%%
import pandas as pd

def get_test_oncokbs(train_df=pd.read_csv('/cluster/home/t122995uhn/projects/data/test/PDBbindDataset/nomsa_binary_original_binary/full/cleaned_XY.csv'),
oncokb_fp='/cluster/home/t122995uhn/projects/data/tcga/mart_export.tsv',
biomart='/cluster/home/t122995uhn/projects/downloads/oncoKB_DrugGenePairList.csv'):
#Get gene names for PDBbind
dfbm = pd.read_csv(oncokb_fp, sep='\t')
dfbm['PDB ID'] = dfbm['PDB ID'].str.lower()
train_df.reset_index(names='idx',inplace=True)

df_uni = train_df.merge(dfbm, how='inner', left_on='prot_id', right_on='UniProtKB/Swiss-Prot ID')
df_pdb = train_df.merge(dfbm, how='inner', left_on='code', right_on='PDB ID')

# identifying ovelap with oncokb
# df_all will have duplicate entries for entries with multiple gene names...
df_all = pd.concat([df_uni, df_pdb]).drop_duplicates(['idx', 'Gene name'])[['idx', 'code', 'Gene name']]

dfkb = pd.read_csv(biomart)
df_all_kb = df_all.merge(dfkb.drop_duplicates('gene'), left_on='Gene name', right_on='gene', how='inner')

trained_genes = set(df_all_kb.gene)

#Identify non-trained genes
return dfkb[~dfkb['gene'].isin(trained_genes)], dfkb[dfkb['gene'].isin(trained_genes)], dfkb


train_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/test/PDBbindDataset/nomsa_binary_original_binary/train0/cleaned_XY.csv')
val_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/test/PDBbindDataset/nomsa_binary_original_binary/val0/cleaned_XY.csv')

train_df = pd.concat([train_df, val_df])

get_test_oncokbs(train_df=train_df)





#%%
# %%
########################################################################
########################## BUILD DATASETS ##############################
########################## VIOLIN PLOTTING #############################
########################################################################
import os
from src.data_prep.init_dataset import create_datasets
from src import cfg
import logging
cfg.logger.setLevel(logging.DEBUG)

splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/davis/'
create_datasets([cfg.DATA_OPT.PDBbind, cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba],
feat_opt=cfg.PRO_FEAT_OPT.nomsa,
edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow],
ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp],
ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=False,
k_folds=5,
test_prots_csv=f'{splits}/test.csv',
val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)],)
# data_root=os.path.abspath('../data/test/'))

# %% Copy splits to commit them:
#from to:
import shutil
from_dir_p = '/cluster/home/t122995uhn/projects/data/v131/'
to_dir_p = '/cluster/home/t122995uhn/projects/MutDTA/splits/'
from_db = ['PDBbindDataset', 'DavisKibaDataset/kiba', 'DavisKibaDataset/davis']
to_db = ['pdbbind', 'kiba', 'davis']

from_db = [f'{from_dir_p}/{f}/nomsa_binary_original_binary/' for f in from_db]
to_db = [f'{to_dir_p}/{f}' for f in to_db]

for src, dst in zip(from_db, to_db):
for x in ['train', 'val']:
for i in range(5):
print(f"{src}/{x}{i}/XY.csv", f"{dst}/{x}{i}.csv")
shutil.copy(f"{src}/{x}{i}/XY.csv", f"{dst}/{x}{i}.csv")

print(f"{src}/test/XY.csv", f"{dst}/test.csv")
shutil.copy(f"{src}/test/XY.csv", f"{dst}/test.csv")


# %%
from matplotlib import pyplot as plt

from src.analysis.figures import prepare_df, fig_combined, custom_fig

dft = prepare_df('./results/v115/model_media/model_stats.csv')
dfv = prepare_df('./results/v115/model_media/model_stats_val.csv')

models = {
'DG': ('nomsa', 'binary', 'original', 'binary'),
'esm': ('ESM', 'binary', 'original', 'binary'), # esm model
'aflow': ('nomsa', 'aflow', 'original', 'binary'),
# 'gvpP': ('gvp', 'binary', 'original', 'binary'),
'gvpL': ('nomsa', 'binary', 'gvp', 'binary'),
# 'aflow_ring3': ('nomsa', 'aflow_ring3', 'original', 'binary'),
'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'),
# 'gvpL_aflow_rng3': ('nomsa', 'aflow_ring3', 'gvp', 'binary'),
#GVPL_ESMM_davis3D_nomsaF_aflowE_48B_0.00010636872718329864LR_0.23282479481785903D_2000E_gvpLF_binaryLE
# 'gvpl_esm_aflow': ('ESM', 'aflow', 'gvp', 'binary'),
}

fig, axes = fig_combined(dft, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" test set performance", box=True, fold_labels=True)
plt.xticks(rotation=45)

fig, axes = fig_combined(dfv, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance", box=True, fold_labels=True)
plt.xticks(rotation=45)
153 changes: 153 additions & 0 deletions src/analysis/mutagenesis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects as PathEffects

from src.utils.residue import ResInfo, Chain
from src.utils.mutate_model import run_modeller


def plot_sequence(muta, pep_opts, pro_seq):
# Create the plot
plt.figure(figsize=(len(pro_seq), 20)) # Adjust the figure size as needed
plt.imshow(muta, aspect='auto', cmap='coolwarm', interpolation='none')

# Set the x-ticks to correspond to positions in the protein sequence
plt.xticks(ticks=np.arange(len(pro_seq)), labels=[ResInfo.code_to_pep[p] for p in pro_seq], rotation=45, fontsize=16)
plt.yticks(ticks=np.arange(len(pep_opts)), labels=pep_opts, fontsize=16)
plt.xlabel('Protein Sequence Position', fontsize=75)
plt.ylabel('Peptide Options', fontsize=75)

# Add text labels to each square
for i in range(len(pep_opts)):
for j in range(len(pro_seq)):
text = plt.text(j, i, f'{ResInfo.pep_to_code[pep_opts[i]]}', ha='center', va='center', color='black', fontsize=8)
# Add a white outline to the text
text.set_path_effects([
PathEffects.Stroke(linewidth=1, foreground='white'),
PathEffects.Normal()
])
break


# Adjust gridlines to be off-center, forming cell boundaries
plt.gca().set_xticks(np.arange(-0.5, len(pro_seq), 1), minor=True)
plt.gca().set_yticks(np.arange(-0.5, len(pep_opts), 1), minor=True)
plt.grid(which='minor', color='gray', linestyle='-', linewidth=0.5)

# Remove the major gridlines (optional, for clarity)
plt.grid(which='major', color='none')

# Add a colorbar to show the scale of mutation values
plt.colorbar(label='Mutation Values')
plt.title('Mutation Array Visualization', fontsize=100)
plt.show()



if __name__ == "__main__":
pro_id = "P67870"
pdb_file = f'/cluster/home/t122995uhn/projects/tmp/kiba/{pro_id}.pdb'


plot_sequence(np.load('muta-100_200.npy'), pep_opts=list(ResInfo.pep_to_code.keys()),
pro_seq=Chain(pdb_file).sequence)


# %%
import os
import logging
logging.getLogger().setLevel(logging.DEBUG)

import numpy as np
import torch
import torch_geometric as torchg
from tqdm import tqdm

from src import cfg
from src.utils.loader import Loader
from src.utils.residue import ResInfo, Chain
from src.data_prep.feature_extraction.ligand import smile_to_graph
from src.data_prep.feature_extraction.protein import target_to_graph
from src.utils.residue import ResInfo, Chain

#%%
DATA = cfg.DATA_OPT.kiba
lig_feature = cfg.LIG_FEAT_OPT.original
lig_edge = cfg.LIG_EDGE_OPT.binary
pro_feature = cfg.PRO_FEAT_OPT.nomsa
pro_edge = cfg.PRO_EDGE_OPT.binary

lig_seq = "COC1=C(C=CC(=C1)C2=CC3=C(C=C2)C(=CC4=CC=CN4)C(=O)N3)O" #CHEMBL202930

# %% build ligand graph
mol_feat, mol_edge = smile_to_graph(lig_seq, lig_feature=lig_feature, lig_edge=lig_edge)
lig = torchg.data.Data(x=torch.Tensor(mol_feat), edge_index=torch.LongTensor(mol_edge),lig_seq=lig_seq)

# %% Get initial pkd value:
pro_id = "P67870"
pdb_file = f'/cluster/home/t122995uhn/projects/tmp/kiba/{pro_id}.pdb'

def get_protein_features(pro_id, pdb_file, DATA=DATA):
pdb = Chain(pdb_file)
pro_seq = pdb.sequence
pro_cmap = pdb.get_contact_map()

updated_seq, extra_feat, edge_idx = target_to_graph(target_sequence=pro_seq,
contact_map=pro_cmap,
threshold=8.0 if DATA is cfg.DATA_OPT.PDBbind else -0.5,
pro_feat=pro_feature)
pro_feat = torch.Tensor(extra_feat)

pro = torchg.data.Data(x=torch.Tensor(pro_feat),
edge_index=torch.LongTensor(edge_idx),
pro_seq=updated_seq, # Protein sequence for downstream esm model
prot_id=pro_id,
edge_weight=None)
return pro, pro_seq

pro, pro_seq = get_protein_features(pro_id, pdb_file)

# %% Loading the model
m = Loader.load_tuned_model('davis_DG', fold=1)
m.eval()
original_pkd = m(pro, lig)
print(original_pkd)

# %% mutate and regenerate graphs
muta = np.zeros(shape=(len(ResInfo.pep_to_code.keys()), len(pro_seq)))

# zero indexed res range to mutate:
res_range = (100, 200)
res_range = (max(res_range[0], 0),
min(res_range[1], len(pro_seq)))

# %%
from src.utils.mutate_model import run_modeller

amino_acids = ResInfo.amino_acids[:-1] # not including "X" - unknown

with tqdm(range(*res_range), ncols=80, total=(res_range[1]-res_range[0])) as t:
for j in t:
for i, AA in enumerate(amino_acids):
if i%2 == 0:
t.set_postfix(res=j, AA=i+1)

if pro_seq[i] == AA:
muta[i,j] = original_pkd
continue

pro_id = "P67870"
pdb_file = f'/cluster/home/t122995uhn/projects/tmp/kiba/{pro_id}.pdb'
out_pdb_fp = run_modeller(pdb_file, 1, ResInfo.code_to_pep[AA], "A")

pro, _ = get_protein_features(pro_id, out_pdb_fp)
muta[i,j] = m(pro, lig)

# delete after use
os.remove(out_pdb_fp)

#%%
np.save(f"muta-{res_range[0]}_{res_range[1]}.npy", muta)



3 changes: 2 additions & 1 deletion src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class LIG_FEAT_OPT(StringEnum):
from pathlib import Path

# Model save paths
issue_number = 131
issue_number = None

DATA_BASENAME = f'data/{f"v{issue_number}" if issue_number else ""}'
RESULTS_PATH = os.path.abspath(f'results/{f"v{issue_number}/" if issue_number else ""}')
MEDIA_SAVE_DIR = f'{RESULTS_PATH}/model_media/'
Expand Down
15 changes: 9 additions & 6 deletions src/utils/mutate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,11 @@ def run_modeller(modelname:str, respos:int|str, restyp:str, chain:str, out_path:
raise FileExistsError('would overwrite existing file at "{modelname}.pdb"')

respos = str(respos)
log.verbose()
log.none()

TMP_FILE_PATH = modelname+restyp+respos+'.tmp'
OUT_FILE_PATH = f"{modelname}-{restyp}_{respos}.pdb"

# Set a different value for rand_seed to get a different final model
env = Environ(rand_seed=-49837)

Expand Down Expand Up @@ -120,9 +123,8 @@ def run_modeller(modelname:str, respos:int|str, restyp:str, chain:str, out_path:
#before proceeding, because not all sequence related information about MODEL
#is changed by this command (e.g., internal coordinates, charges, and atom
#types and radii are not updated).

mdl1.write(file=modelname+restyp+respos+'.tmp')
mdl1.read(file=modelname+restyp+respos+'.tmp')
mdl1.write(file=TMP_FILE_PATH)
mdl1.read(file=TMP_FILE_PATH)

#set up restraints before computing energy
#we do this a second time because the model has been written out and read in,
Expand Down Expand Up @@ -157,7 +159,8 @@ def run_modeller(modelname:str, respos:int|str, restyp:str, chain:str, out_path:
s.energy()

#give a proper name
mdl1.write(file=f"{modelname}-{restyp}_{respos}.pdb")
mdl1.write(file=OUT_FILE_PATH)

#delete the temporary file
os.remove(modelname+restyp+respos+'.tmp')
os.remove(TMP_FILE_PATH)
return OUT_FILE_PATH

0 comments on commit 45702a2

Please sign in to comment.