From e0afb6cd1de96e935106ee349b6035aa7797c85f Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Mon, 16 Sep 2024 23:16:34 -0400 Subject: [PATCH 1/2] feat: created `run_mutagenesis.py` for generating mutagensis matrix --- inference.py | 2 +- run_mutagenesis.py | 153 ++++++++++++++++++++++++++ src/analysis/mutagenesis.py | 181 ------------------------------- src/analysis/mutagenesis_plot.py | 64 +++++++++++ src/utils/loader.py | 2 +- src/utils/mutate_model.py | 2 +- 6 files changed, 220 insertions(+), 184 deletions(-) create mode 100644 run_mutagenesis.py delete mode 100644 src/analysis/mutagenesis.py create mode 100644 src/analysis/mutagenesis_plot.py diff --git a/inference.py b/inference.py index ee797e8..a0cb570 100644 --- a/inference.py +++ b/inference.py @@ -45,7 +45,7 @@ from src.utils.loader import Loader logging.getLogger().setLevel(logging.DEBUG) -m = Loader.load_tuned_model('davis_esm', fold=1) +m, _ = Loader.load_tuned_model('davis_esm', fold=1) # %% m(pro, lig) diff --git a/run_mutagenesis.py b/run_mutagenesis.py new file mode 100644 index 0000000..5800003 --- /dev/null +++ b/run_mutagenesis.py @@ -0,0 +1,153 @@ +import argparse +parser = argparse.ArgumentParser(description='Runs Mutagenesis on an input PDB file and a given ligand SMILES.') +parser.add_argument('--ligand_smile', type=str, required=True, help='Ligand SMILES string.') +parser.add_argument('--ligand_smile_name', type=str, required=True, help='Ligand SMILES name, required for output path.') +parser.add_argument('--pdb_file', type=str, required=True, help='Path to the PDB file.') +parser.add_argument('--out_path', type=str, default='./', + help='Output directory path to save resulting mutagenesis numpy matrix with predicted pkd values') +parser.add_argument('--res_start', type=int, default=0, help='Start index for mutagenesis (zero-indexed).') +parser.add_argument('--res_end', type=int, default=float('inf'), help='End index for mutagenesis.') + +parser.add_argument('--model_opt', type=str, default='davis_DG', + choices=['davis_DG', 'davis_gvpl', 'davis_esm', + 'kiba_DG', 'kiba_esm', 'kiba_gvpl', + 'PDBbind_DG', 'PDBbind_esm', 'PDBbind_gvpl'], + help='Model option.') +parser.add_argument('--fold', type=int, default=1, + help='Which model fold to use (there are 5 models for each option due to 5-fold CV).') + +args = parser.parse_args() + +# Assign variables +LIGAND_SMILE = args.ligand_smile +LIGAND_SMILE_NAME = args.ligand_smile_name +PDB_FILE = args.pdb_file +OUT_PATH = args.out_path +MODEL_OPT = args.model_opt +FOLD = args.fold +RES_START = args.res_start +RES_END = args.res_end + +# Your code logic here +print("#"*50) +print(f"LIGAND_SMILE: {LIGAND_SMILE}") +print(f"LIGAND_SMILE_NAME: {LIGAND_SMILE_NAME}") +print(f"PDB_FILE: {PDB_FILE}") +print(f"OUT_PATH: {OUT_PATH}") +print(f"MODEL_OPT: {MODEL_OPT}") +print(f"FOLD: {FOLD}") +print(f"RES_START: {RES_START}") +print(f"RES_END: {RES_END}") +print("#"*50) + +import os +import logging +logging.getLogger().setLevel(logging.DEBUG) + +import numpy as np +import torch +import torch_geometric as torchg +from tqdm import tqdm + +from src import cfg +from src import TUNED_MODEL_CONFIGS + +from src.utils.loader import Loader +from src.utils.residue import ResInfo, Chain +from src.data_prep.feature_extraction.ligand import smile_to_graph +from src.data_prep.feature_extraction.protein import target_to_graph +from src.data_prep.feature_extraction.protein_edges import get_target_edge_weights +from src.utils.residue import ResInfo, Chain + +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +MODEL_PARAMS = TUNED_MODEL_CONFIGS[MODEL_OPT] +PDB_FILE_NAME = os.path.basename(PDB_FILE).split('.pdb')[0] + +# Get initial pkd value: +def get_protein_features(pdb_file_path, cmap_thresh=8.0): + pdb = Chain(pdb_file_path) + pro_cmap = pdb.get_contact_map() + + updated_seq, extra_feat, edge_idx = target_to_graph(target_sequence=pdb.sequence, + contact_map=pro_cmap, + threshold=cmap_thresh, + pro_feat=MODEL_PARAMS['feature_opt']) + pro_edge_weight = None + if MODEL_PARAMS['edge_opt'] in cfg.OPT_REQUIRES_CONF: + raise NotImplementedError(f"{MODEL_PARAMS['edge_opt']} is not supported since it requires "+\ + "multiple conformation files to run and generate edges.") + else: + # includes edge_attr like ring3 + pro_edge_weight = get_target_edge_weights(pdb_file_path, pdb.sequence, + edge_opt=MODEL_PARAMS['edge_opt'], + cmap=pro_cmap, + n_modes=5, n_cpu=4) + if len(pro_edge_weight.shape) == 2: + pro_edge_weight = torch.Tensor(pro_edge_weight[edge_idx[0], edge_idx[1]]) + elif len(pro_edge_weight.shape) == 3: # has edge attr! (This is our GVPL features) + pro_edge_weight = torch.Tensor(pro_edge_weight[edge_idx[0], edge_idx[1], :]) + + pro_feat = torch.Tensor(extra_feat) + + pro = torchg.data.Data(x=torch.Tensor(pro_feat), + edge_index=torch.LongTensor(edge_idx), + pro_seq=updated_seq, # Protein sequence for downstream esm model + prot_id=PDB_FILE_NAME, + edge_weight=pro_edge_weight) + return pro, pdb + +################################################ +# Loading the model and get original pkd value # +################################################ +m = Loader.load_tuned_model(MODEL_OPT, fold=FOLD) +m.to(DEVICE) +m.eval() + +# build ligand graph +mol_feat, mol_edge = smile_to_graph(LIGAND_SMILE, lig_feature=MODEL_PARAMS['lig_feat_opt'], lig_edge=MODEL_PARAMS['lig_edge_opt']) +lig = torchg.data.Data(x=torch.Tensor(mol_feat), edge_index=torch.LongTensor(mol_edge), lig_seq=LIGAND_SMILE) + +# build protein graph +pro, pdb = get_protein_features(PDB_FILE) + +original_pkd = m(pro.to(DEVICE), lig.to(DEVICE)) +print("Original pkd:",original_pkd) + + +################################################ +# Mutate and regenerate graphs ################# +################################################ +# zero indexed res range to mutate: +res_range = (max(RES_START, 0), + min(RES_END, len(pdb.sequence))) + +from src.utils.mutate_model import run_modeller +amino_acids = ResInfo.amino_acids[:-1] # not including "X" - unknown +muta = np.zeros(shape=(len(amino_acids), len(pdb.sequence))) + +with tqdm(range(*res_range), ncols=100, total=(res_range[1]-res_range[0]), desc='Mutating') as t: + for j in t: + for i, AA in enumerate(amino_acids): + if i%2 == 0: + t.set_postfix(res=j, AA=i+1) + + if pdb.sequence[j] == AA: + muta[i,j] = original_pkd + continue + out_pdb_fp = run_modeller(PDB_FILE, j+1, ResInfo.code_to_pep[AA], "A") + + pro, pdb = get_protein_features(PDB_FILE) + assert pro.pro_seq != pdb.sequence and pro.pro_seq[j] == AA, "ERROR in modeller" + + muta[i,j] = m(pro.to(DEVICE), lig.to(DEVICE)) + + # delete after use + os.remove(out_pdb_fp) + + +# Save mutagenesis matrix +OUT_DIR = f'{OUT_PATH}/{LIGAND_SMILE_NAME}/{MODEL_OPT}' +os.makedirs(OUT_DIR) +OUT_FP = f"{OUT_DIR}/{res_range[0]}_{res_range[1]}.npy" +print("Saving mutagenesis numpy matrix to", OUT_FP) +np.save(OUT_FP, muta) \ No newline at end of file diff --git a/src/analysis/mutagenesis.py b/src/analysis/mutagenesis.py deleted file mode 100644 index eb78c8a..0000000 --- a/src/analysis/mutagenesis.py +++ /dev/null @@ -1,181 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt -import matplotlib.patheffects as PathEffects - -from src.utils.residue import ResInfo, Chain -from src.utils.mutate_model import run_modeller -import copy - -def plot_sequence(muta, pep_opts, pro_seq, delta=False): - if delta: - muta = copy.deepcopy(muta) - original_pkd = None - for i, AA in enumerate(pep_opts): - if AA == pro_seq[0]: - original_pkd = muta[i,0] - - muta -= original_pkd - - - # Create the plot - plt.figure(figsize=(len(pro_seq), 20)) # Adjust the figure size as needed - plt.imshow(muta, aspect='auto', cmap='coolwarm', interpolation='none') - - # Set the x-ticks to correspond to positions in the protein sequence - plt.xticks(ticks=np.arange(len(pro_seq)), labels=pro_seq, fontsize=16) - plt.yticks(ticks=np.arange(len(pep_opts)), labels=pep_opts, fontsize=16) - plt.xlabel('Original Protein Sequence', fontsize=75) - plt.ylabel('Mutated to Amino Acid code', fontsize=75) - - # Add text labels to each square - for i in range(len(pep_opts)): - for j in range(len(pro_seq)): - text = plt.text(j, i, f'{pep_opts[i]}', ha='center', va='center', color='black', fontsize=12) - # Add a white outline to the text - text.set_path_effects([ - PathEffects.Stroke(linewidth=1, foreground='white'), - PathEffects.Normal() - ]) - - - # Adjust gridlines to be off-center, forming cell boundaries - plt.gca().set_xticks(np.arange(-0.5, len(pro_seq), 1), minor=True) - plt.gca().set_yticks(np.arange(-0.5, len(pep_opts), 1), minor=True) - plt.grid(which='minor', color='gray', linestyle='-', linewidth=0.5) - - # Remove the major gridlines (optional, for clarity) - plt.grid(which='major', color='none') - - # Add a colorbar to show the scale of mutation values - plt.colorbar(label='Mutation Values') - plt.title('Mutation Array Visualization', fontsize=100) - plt.show() - - - -if __name__ == "__main__": - # %% - import os - import logging - logging.getLogger().setLevel(logging.DEBUG) - - import numpy as np - import torch - import torch_geometric as torchg - from tqdm import tqdm - - from src import cfg - from src.utils.loader import Loader - from src.utils.residue import ResInfo, Chain - from src.data_prep.feature_extraction.ligand import smile_to_graph - from src.data_prep.feature_extraction.protein import target_to_graph - from src.utils.residue import ResInfo, Chain - - #%% - DATA = cfg.DATA_OPT.kiba - lig_feature = cfg.LIG_FEAT_OPT.original - lig_edge = cfg.LIG_EDGE_OPT.binary - pro_feature = cfg.PRO_FEAT_OPT.nomsa - pro_edge = cfg.PRO_EDGE_OPT.binary - - lig_seq = "COC1=C(C=CC(=C1)C2=CC3=C(C=C2)C(=CC4=CC=CN4)C(=O)N3)O" #CHEMBL202930 - - # %% build ligand graph - mol_feat, mol_edge = smile_to_graph(lig_seq, lig_feature=lig_feature, lig_edge=lig_edge) - lig = torchg.data.Data(x=torch.Tensor(mol_feat), edge_index=torch.LongTensor(mol_edge),lig_seq=lig_seq) - - # %% Get initial pkd value: - pro_id = "P67870" - pdb_file = f'/cluster/home/t122995uhn/projects/tmp/kiba/{pro_id}.pdb' - - def get_protein_features(pro_id, pdb_file, DATA=DATA): - pdb = Chain(pdb_file) - pro_seq = pdb.sequence - pro_cmap = pdb.get_contact_map() - - updated_seq, extra_feat, edge_idx = target_to_graph(target_sequence=pro_seq, - contact_map=pro_cmap, - threshold=8.0 if DATA is cfg.DATA_OPT.PDBbind else -0.5, - pro_feat=pro_feature) - pro_feat = torch.Tensor(extra_feat) - - pro = torchg.data.Data(x=torch.Tensor(pro_feat), - edge_index=torch.LongTensor(edge_idx), - pro_seq=updated_seq, # Protein sequence for downstream esm model - prot_id=pro_id, - edge_weight=None) - return pro, pro_seq - - pro, pro_seq = get_protein_features(pro_id, pdb_file) - - # %% Loading the model - m = Loader.load_tuned_model('davis_DG', fold=1) - m.eval() - original_pkd = m(pro, lig) - print(original_pkd) - - # %% mutate and regenerate graphs - - # zero indexed res range to mutate: - res_range = (0,250) - res_range = (max(res_range[0], 0), - min(res_range[1], len(pro_seq))) - - # %% - from src.utils.mutate_model import run_modeller - - amino_acids = ResInfo.amino_acids[:-1] # not including "X" - unknown - muta = np.zeros(shape=(len(amino_acids), len(pro_seq))) - - n_continued = 0 - with tqdm(range(*res_range), ncols=80, total=(res_range[1]-res_range[0])) as t: - for j in t: - for i, AA in enumerate(amino_acids): - if i%2 == 0: - t.set_postfix(res=j, AA=i+1) - - if pro_seq[j] == AA: - muta[i,j] = original_pkd - n_continued += 1 - continue - - pro_id = "P67870" - pdb_file = f'/cluster/home/t122995uhn/projects/tmp/kiba/{pro_id}.pdb' - out_pdb_fp = run_modeller(pdb_file, j+1, ResInfo.code_to_pep[AA], "A") - - pro, _ = get_protein_features(pro_id, out_pdb_fp) - assert pro.pro_seq != pro_seq and pro.pro_seq[j] == AA, "ERROR in modeller" - muta[i,j] = m(pro, lig) - - # delete after use - os.remove(out_pdb_fp) - print('n_continued:', n_continued) - #%% - np.save(f"muta-{res_range[0]}_{res_range[1]}.npy", muta) - - exit() - - # %% - import numpy as np - from src.analysis.mutagenesis import plot_sequence - from src.utils.residue import ResInfo, Chain - - pro_id = "P67870" - pdb_file = f'/cluster/home/t122995uhn/projects/tmp/kiba/{pro_id}.pdb' - res_range = (0,215) - muta = np.load(f'muta-{res_range[0]}_{res_range[1]}.npy') - - # pkd -> kd - # muta = 10**(-muta) - # %% - amino_acids = ResInfo.amino_acids[:-1] # not including "X" - unknown - - plot_sequence(muta[:,res_range[0]:res_range[1]], pep_opts=amino_acids, - pro_seq=Chain(pdb_file).sequence[res_range[0]:res_range[1]]) - - plot_sequence(muta[:,res_range[0]:res_range[1]], pep_opts=amino_acids, - pro_seq=Chain(pdb_file).sequence[res_range[0]:res_range[1]], delta=True) - - - - diff --git a/src/analysis/mutagenesis_plot.py b/src/analysis/mutagenesis_plot.py new file mode 100644 index 0000000..88e2e4c --- /dev/null +++ b/src/analysis/mutagenesis_plot.py @@ -0,0 +1,64 @@ +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.patheffects as PathEffects + +from src.utils.residue import ResInfo, Chain +from src.utils.mutate_model import run_modeller +import copy + +def plot_sequence(muta, pro_seq, pep_opts=ResInfo.amino_acids[:-1], delta=False): + if delta: + muta = copy.deepcopy(muta) + original_pkd = None + for i, AA in enumerate(pep_opts): + if AA == pro_seq[0]: + original_pkd = muta[i,0] + + muta -= original_pkd + + + # Create the plot + plt.figure(figsize=(len(pro_seq), 20)) # Adjust the figure size as needed + plt.imshow(muta, aspect='auto', cmap='coolwarm', interpolation='none') + + # Set the x-ticks to correspond to positions in the protein sequence + plt.xticks(ticks=np.arange(len(pro_seq)), labels=pro_seq, fontsize=16) + plt.yticks(ticks=np.arange(len(pep_opts)), labels=pep_opts, fontsize=16) + plt.xlabel('Original Protein Sequence', fontsize=75) + plt.ylabel('Mutated to Amino Acid code', fontsize=75) + + # Add text labels to each square + for i in range(len(pep_opts)): + for j in range(len(pro_seq)): + text = plt.text(j, i, f'{pep_opts[i]}', ha='center', va='center', color='black', fontsize=12) + # Add a white outline to the text + text.set_path_effects([ + PathEffects.Stroke(linewidth=1, foreground='white'), + PathEffects.Normal() + ]) + + + # Adjust gridlines to be off-center, forming cell boundaries + plt.gca().set_xticks(np.arange(-0.5, len(pro_seq), 1), minor=True) + plt.gca().set_yticks(np.arange(-0.5, len(pep_opts), 1), minor=True) + plt.grid(which='minor', color='gray', linestyle='-', linewidth=0.5) + + # Remove the major gridlines (optional, for clarity) + plt.grid(which='major', color='none') + + # Add a colorbar to show the scale of mutation values + plt.colorbar(label='Mutation Values') + plt.title('Mutation Array Visualization', fontsize=100) + plt.show() + + +if __name__ == "__main__": + + res_range = (0,215) + muta = np.load(f'muta-{res_range[0]}_{res_range[1]}.npy') + plot_sequence(muta[:,res_range[0]:res_range[1]], + pro_seq=Chain(pdb_file).sequence[res_range[0]:res_range[1]]) + + + + diff --git a/src/utils/loader.py b/src/utils/loader.py index 48ead68..367c49f 100644 --- a/src/utils/loader.py +++ b/src/utils/loader.py @@ -105,7 +105,7 @@ def reformat_kwargs(model_kwargs): logging.debug(f'loading: {model_p}') model = Loader.init_model(model=model_kwargs['model'], pro_feature=model_kwargs['pro_feature'], pro_edge=model_kwargs['edge'], **MODEL_TUNED_PARAMS['architecture_kwargs']) - return model + return model, model_kwargs @staticmethod @validate_args({'model': model_opt, 'edge': edge_opt, 'pro_feature': pro_feature_opt, diff --git a/src/utils/mutate_model.py b/src/utils/mutate_model.py index c6a2b6a..6cccf39 100644 --- a/src/utils/mutate_model.py +++ b/src/utils/mutate_model.py @@ -63,7 +63,7 @@ def run_modeller(modelname:str, respos:int|str, restyp:str, chain:str, out_path: respos = str(respos) log.none() - TMP_FILE_PATH = modelname+restyp+respos+'.tmp' + TMP_FILE_PATH = f"{modelname}-{restyp}_{respos}.tmp" OUT_FILE_PATH = f"{modelname}-{restyp}_{respos}.pdb" # Set a different value for rand_seed to get a different final model From 660d06a0ffa3b1ae6aa80ba361a84e7ee56e8097 Mon Sep 17 00:00:00 2001 From: Jean Charle Yaacoub <50300488+jyaacoub@users.noreply.github.com> Date: Mon, 16 Sep 2024 23:27:40 -0400 Subject: [PATCH 2/2] fix: mutagenesis_plot.py Error due to useless placeholder code --- src/analysis/mutagenesis_plot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/analysis/mutagenesis_plot.py b/src/analysis/mutagenesis_plot.py index 88e2e4c..1edf2d9 100644 --- a/src/analysis/mutagenesis_plot.py +++ b/src/analysis/mutagenesis_plot.py @@ -56,8 +56,8 @@ def plot_sequence(muta, pro_seq, pep_opts=ResInfo.amino_acids[:-1], delta=False) res_range = (0,215) muta = np.load(f'muta-{res_range[0]}_{res_range[1]}.npy') - plot_sequence(muta[:,res_range[0]:res_range[1]], - pro_seq=Chain(pdb_file).sequence[res_range[0]:res_range[1]]) + #plot_sequence(muta[:,res_range[0]:res_range[1]], + # pro_seq=Chain(pdb_file).sequence[res_range[0]:res_range[1]])