Skip to content

Commit

Permalink
Merge pull request #137 from jyaacoub/136-mutagenesis-modeller
Browse files Browse the repository at this point in the history
mutagenesis
  • Loading branch information
jyaacoub committed Sep 15, 2024
2 parents b53a0ff + bb12c28 commit af725c6
Show file tree
Hide file tree
Showing 5 changed files with 416 additions and 80 deletions.
109 changes: 31 additions & 78 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,34 @@
#%%
import pandas as pd

def get_test_oncokbs(train_df=pd.read_csv('/cluster/home/t122995uhn/projects/data/test/PDBbindDataset/nomsa_binary_original_binary/full/cleaned_XY.csv'),
oncokb_fp='/cluster/home/t122995uhn/projects/data/tcga/mart_export.tsv',
biomart='/cluster/home/t122995uhn/projects/downloads/oncoKB_DrugGenePairList.csv'):
#Get gene names for PDBbind
dfbm = pd.read_csv(oncokb_fp, sep='\t')
dfbm['PDB ID'] = dfbm['PDB ID'].str.lower()
train_df.reset_index(names='idx',inplace=True)

df_uni = train_df.merge(dfbm, how='inner', left_on='prot_id', right_on='UniProtKB/Swiss-Prot ID')
df_pdb = train_df.merge(dfbm, how='inner', left_on='code', right_on='PDB ID')

# identifying ovelap with oncokb
# df_all will have duplicate entries for entries with multiple gene names...
df_all = pd.concat([df_uni, df_pdb]).drop_duplicates(['idx', 'Gene name'])[['idx', 'code', 'Gene name']]

dfkb = pd.read_csv(biomart)
df_all_kb = df_all.merge(dfkb.drop_duplicates('gene'), left_on='Gene name', right_on='gene', how='inner')

trained_genes = set(df_all_kb.gene)

#Identify non-trained genes
return dfkb[~dfkb['gene'].isin(trained_genes)], dfkb[dfkb['gene'].isin(trained_genes)], dfkb


train_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/test/PDBbindDataset/nomsa_binary_original_binary/train0/cleaned_XY.csv')
val_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/test/PDBbindDataset/nomsa_binary_original_binary/val0/cleaned_XY.csv')

train_df = pd.concat([train_df, val_df])

get_test_oncokbs(train_df=train_df)





#%%
# %%
########################################################################
########################## BUILD DATASETS ##############################
########################## VIOLIN PLOTTING #############################
########################################################################
import os
from src.data_prep.init_dataset import create_datasets
from src import cfg
import logging
cfg.logger.setLevel(logging.DEBUG)

splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/davis/'
create_datasets([cfg.DATA_OPT.PDBbind, cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba],
feat_opt=cfg.PRO_FEAT_OPT.nomsa,
edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow],
ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp],
ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=False,
k_folds=5,
test_prots_csv=f'{splits}/test.csv',
val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)],)
# data_root=os.path.abspath('../data/test/'))

# %% Copy splits to commit them:
#from to:
import shutil
from_dir_p = '/cluster/home/t122995uhn/projects/data/v131/'
to_dir_p = '/cluster/home/t122995uhn/projects/MutDTA/splits/'
from_db = ['PDBbindDataset', 'DavisKibaDataset/kiba', 'DavisKibaDataset/davis']
to_db = ['pdbbind', 'kiba', 'davis']

from_db = [f'{from_dir_p}/{f}/nomsa_binary_original_binary/' for f in from_db]
to_db = [f'{to_dir_p}/{f}' for f in to_db]

for src, dst in zip(from_db, to_db):
for x in ['train', 'val']:
for i in range(5):
print(f"{src}/{x}{i}/XY.csv", f"{dst}/{x}{i}.csv")
shutil.copy(f"{src}/{x}{i}/XY.csv", f"{dst}/{x}{i}.csv")

print(f"{src}/test/XY.csv", f"{dst}/test.csv")
shutil.copy(f"{src}/test/XY.csv", f"{dst}/test.csv")


# %%
from matplotlib import pyplot as plt

from src.analysis.figures import prepare_df, fig_combined, custom_fig

dft = prepare_df('./results/v115/model_media/model_stats.csv')
dfv = prepare_df('./results/v115/model_media/model_stats_val.csv')

models = {
'DG': ('nomsa', 'binary', 'original', 'binary'),
'esm': ('ESM', 'binary', 'original', 'binary'), # esm model
'aflow': ('nomsa', 'aflow', 'original', 'binary'),
# 'gvpP': ('gvp', 'binary', 'original', 'binary'),
'gvpL': ('nomsa', 'binary', 'gvp', 'binary'),
# 'aflow_ring3': ('nomsa', 'aflow_ring3', 'original', 'binary'),
'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'),
# 'gvpL_aflow_rng3': ('nomsa', 'aflow_ring3', 'gvp', 'binary'),
#GVPL_ESMM_davis3D_nomsaF_aflowE_48B_0.00010636872718329864LR_0.23282479481785903D_2000E_gvpLF_binaryLE
# 'gvpl_esm_aflow': ('ESM', 'aflow', 'gvp', 'binary'),
}

fig, axes = fig_combined(dft, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" test set performance", box=True, fold_labels=True)
plt.xticks(rotation=45)

fig, axes = fig_combined(dfv, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance", box=True, fold_labels=True)
plt.xticks(rotation=45)
35 changes: 35 additions & 0 deletions src/analysis/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,3 +972,38 @@ def generate_roc_curve(true_dpkd, pred_dpkd, thres_range=(0,5), step=0.1):
# performance drop values:
grp = plot_df.groupby(['data', 'overlap']).mean()
grp[grp.index.get_level_values(1)].values - grp[~grp.index.get_level_values(1)].values

#%%
########################################################################
########################## VIOLIN PLOTTING #############################
########################################################################
import logging
from matplotlib import pyplot as plt

from src.analysis.figures import prepare_df, fig_combined, custom_fig

dft = prepare_df('./results/v115/model_media/model_stats.csv')
dfv = prepare_df('./results/v115/model_media/model_stats_val.csv')

models = {
'DG': ('nomsa', 'binary', 'original', 'binary'),
# 'esm': ('ESM', 'binary', 'original', 'binary'), # esm model
'aflow': ('nomsa', 'aflow', 'original', 'binary'),
# 'gvpP': ('gvp', 'binary', 'original', 'binary'),
# 'gvpL': ('nomsa', 'binary', 'gvp', 'binary'),
# 'aflow_ring3': ('nomsa', 'aflow_ring3', 'original', 'binary'),
# 'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'),
# 'gvpL_aflow_rng3': ('nomsa', 'aflow_ring3', 'gvp', 'binary'),
#GVPL_ESMM_davis3D_nomsaF_aflowE_48B_0.00010636872718329864LR_0.23282479481785903D_2000E_gvpLF_binaryLE
# 'gvpl_esm_aflow': ('ESM', 'aflow', 'gvp', 'binary'),
}

fig, axes = fig_combined(dft, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" test set performance", box=True, fold_labels=True)
plt.xticks(rotation=45)

fig, axes = fig_combined(dfv, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance", box=True, fold_labels=True)
plt.xticks(rotation=45)
181 changes: 181 additions & 0 deletions src/analysis/mutagenesis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects as PathEffects

from src.utils.residue import ResInfo, Chain
from src.utils.mutate_model import run_modeller
import copy

def plot_sequence(muta, pep_opts, pro_seq, delta=False):
if delta:
muta = copy.deepcopy(muta)
original_pkd = None
for i, AA in enumerate(pep_opts):
if AA == pro_seq[0]:
original_pkd = muta[i,0]

muta -= original_pkd


# Create the plot
plt.figure(figsize=(len(pro_seq), 20)) # Adjust the figure size as needed
plt.imshow(muta, aspect='auto', cmap='coolwarm', interpolation='none')

# Set the x-ticks to correspond to positions in the protein sequence
plt.xticks(ticks=np.arange(len(pro_seq)), labels=pro_seq, fontsize=16)
plt.yticks(ticks=np.arange(len(pep_opts)), labels=pep_opts, fontsize=16)
plt.xlabel('Original Protein Sequence', fontsize=75)
plt.ylabel('Mutated to Amino Acid code', fontsize=75)

# Add text labels to each square
for i in range(len(pep_opts)):
for j in range(len(pro_seq)):
text = plt.text(j, i, f'{pep_opts[i]}', ha='center', va='center', color='black', fontsize=12)
# Add a white outline to the text
text.set_path_effects([
PathEffects.Stroke(linewidth=1, foreground='white'),
PathEffects.Normal()
])


# Adjust gridlines to be off-center, forming cell boundaries
plt.gca().set_xticks(np.arange(-0.5, len(pro_seq), 1), minor=True)
plt.gca().set_yticks(np.arange(-0.5, len(pep_opts), 1), minor=True)
plt.grid(which='minor', color='gray', linestyle='-', linewidth=0.5)

# Remove the major gridlines (optional, for clarity)
plt.grid(which='major', color='none')

# Add a colorbar to show the scale of mutation values
plt.colorbar(label='Mutation Values')
plt.title('Mutation Array Visualization', fontsize=100)
plt.show()



if __name__ == "__main__":
# %%
import os
import logging
logging.getLogger().setLevel(logging.DEBUG)

import numpy as np
import torch
import torch_geometric as torchg
from tqdm import tqdm

from src import cfg
from src.utils.loader import Loader
from src.utils.residue import ResInfo, Chain
from src.data_prep.feature_extraction.ligand import smile_to_graph
from src.data_prep.feature_extraction.protein import target_to_graph
from src.utils.residue import ResInfo, Chain

#%%
DATA = cfg.DATA_OPT.kiba
lig_feature = cfg.LIG_FEAT_OPT.original
lig_edge = cfg.LIG_EDGE_OPT.binary
pro_feature = cfg.PRO_FEAT_OPT.nomsa
pro_edge = cfg.PRO_EDGE_OPT.binary

lig_seq = "COC1=C(C=CC(=C1)C2=CC3=C(C=C2)C(=CC4=CC=CN4)C(=O)N3)O" #CHEMBL202930

# %% build ligand graph
mol_feat, mol_edge = smile_to_graph(lig_seq, lig_feature=lig_feature, lig_edge=lig_edge)
lig = torchg.data.Data(x=torch.Tensor(mol_feat), edge_index=torch.LongTensor(mol_edge),lig_seq=lig_seq)

# %% Get initial pkd value:
pro_id = "P67870"
pdb_file = f'/cluster/home/t122995uhn/projects/tmp/kiba/{pro_id}.pdb'

def get_protein_features(pro_id, pdb_file, DATA=DATA):
pdb = Chain(pdb_file)
pro_seq = pdb.sequence
pro_cmap = pdb.get_contact_map()

updated_seq, extra_feat, edge_idx = target_to_graph(target_sequence=pro_seq,
contact_map=pro_cmap,
threshold=8.0 if DATA is cfg.DATA_OPT.PDBbind else -0.5,
pro_feat=pro_feature)
pro_feat = torch.Tensor(extra_feat)

pro = torchg.data.Data(x=torch.Tensor(pro_feat),
edge_index=torch.LongTensor(edge_idx),
pro_seq=updated_seq, # Protein sequence for downstream esm model
prot_id=pro_id,
edge_weight=None)
return pro, pro_seq

pro, pro_seq = get_protein_features(pro_id, pdb_file)

# %% Loading the model
m = Loader.load_tuned_model('davis_DG', fold=1)
m.eval()
original_pkd = m(pro, lig)
print(original_pkd)

# %% mutate and regenerate graphs

# zero indexed res range to mutate:
res_range = (0,250)
res_range = (max(res_range[0], 0),
min(res_range[1], len(pro_seq)))

# %%
from src.utils.mutate_model import run_modeller

amino_acids = ResInfo.amino_acids[:-1] # not including "X" - unknown
muta = np.zeros(shape=(len(amino_acids), len(pro_seq)))

n_continued = 0
with tqdm(range(*res_range), ncols=80, total=(res_range[1]-res_range[0])) as t:
for j in t:
for i, AA in enumerate(amino_acids):
if i%2 == 0:
t.set_postfix(res=j, AA=i+1)

if pro_seq[j] == AA:
muta[i,j] = original_pkd
n_continued += 1
continue

pro_id = "P67870"
pdb_file = f'/cluster/home/t122995uhn/projects/tmp/kiba/{pro_id}.pdb'
out_pdb_fp = run_modeller(pdb_file, j+1, ResInfo.code_to_pep[AA], "A")

pro, _ = get_protein_features(pro_id, out_pdb_fp)
assert pro.pro_seq != pro_seq and pro.pro_seq[j] == AA, "ERROR in modeller"
muta[i,j] = m(pro, lig)

# delete after use
os.remove(out_pdb_fp)
print('n_continued:', n_continued)
#%%
np.save(f"muta-{res_range[0]}_{res_range[1]}.npy", muta)

exit()

# %%
import numpy as np
from src.analysis.mutagenesis import plot_sequence
from src.utils.residue import ResInfo, Chain

pro_id = "P67870"
pdb_file = f'/cluster/home/t122995uhn/projects/tmp/kiba/{pro_id}.pdb'
res_range = (0,215)
muta = np.load(f'muta-{res_range[0]}_{res_range[1]}.npy')

# pkd -> kd
# muta = 10**(-muta)
# %%
amino_acids = ResInfo.amino_acids[:-1] # not including "X" - unknown

plot_sequence(muta[:,res_range[0]:res_range[1]], pep_opts=amino_acids,
pro_seq=Chain(pdb_file).sequence[res_range[0]:res_range[1]])

plot_sequence(muta[:,res_range[0]:res_range[1]], pep_opts=amino_acids,
pro_seq=Chain(pdb_file).sequence[res_range[0]:res_range[1]], delta=True)




5 changes: 3 additions & 2 deletions src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class LIG_FEAT_OPT(StringEnum):
from pathlib import Path

# Model save paths
issue_number = 131
issue_number = None

DATA_BASENAME = f'data/{f"v{issue_number}" if issue_number else ""}'
RESULTS_PATH = os.path.abspath(f'results/{f"v{issue_number}/" if issue_number else ""}')
MEDIA_SAVE_DIR = f'{RESULTS_PATH}/model_media/'
Expand All @@ -110,7 +111,7 @@ class LIG_FEAT_OPT(StringEnum):
SLURM_ACCOUNT = None
SLURM_GPU_NAME = 'v100'

if 'uhnh4h' in DOMAIN_NAME:
if ('uhnh4h' in DOMAIN_NAME or 'h4h' in DOMAIN_NAME):
CLUSTER = 'h4h'
SLURM_PARTITION = 'gpu'
SLURM_CONSTRAINT = 'gpu32g'
Expand Down
Loading

0 comments on commit af725c6

Please sign in to comment.