Skip to content

Commit

Permalink
Merge pull request #138 from jyaacoub/136-mutagenesis-modeller
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub authored Sep 17, 2024
2 parents ef9ac0f + 660d06a commit 85aa42b
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 184 deletions.
2 changes: 1 addition & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from src.utils.loader import Loader
logging.getLogger().setLevel(logging.DEBUG)

m = Loader.load_tuned_model('davis_esm', fold=1)
m, _ = Loader.load_tuned_model('davis_esm', fold=1)

# %%
m(pro, lig)
Expand Down
153 changes: 153 additions & 0 deletions run_mutagenesis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import argparse
parser = argparse.ArgumentParser(description='Runs Mutagenesis on an input PDB file and a given ligand SMILES.')
parser.add_argument('--ligand_smile', type=str, required=True, help='Ligand SMILES string.')
parser.add_argument('--ligand_smile_name', type=str, required=True, help='Ligand SMILES name, required for output path.')
parser.add_argument('--pdb_file', type=str, required=True, help='Path to the PDB file.')
parser.add_argument('--out_path', type=str, default='./',
help='Output directory path to save resulting mutagenesis numpy matrix with predicted pkd values')
parser.add_argument('--res_start', type=int, default=0, help='Start index for mutagenesis (zero-indexed).')
parser.add_argument('--res_end', type=int, default=float('inf'), help='End index for mutagenesis.')

parser.add_argument('--model_opt', type=str, default='davis_DG',
choices=['davis_DG', 'davis_gvpl', 'davis_esm',
'kiba_DG', 'kiba_esm', 'kiba_gvpl',
'PDBbind_DG', 'PDBbind_esm', 'PDBbind_gvpl'],
help='Model option.')
parser.add_argument('--fold', type=int, default=1,
help='Which model fold to use (there are 5 models for each option due to 5-fold CV).')

args = parser.parse_args()

# Assign variables
LIGAND_SMILE = args.ligand_smile
LIGAND_SMILE_NAME = args.ligand_smile_name
PDB_FILE = args.pdb_file
OUT_PATH = args.out_path
MODEL_OPT = args.model_opt
FOLD = args.fold
RES_START = args.res_start
RES_END = args.res_end

# Your code logic here
print("#"*50)
print(f"LIGAND_SMILE: {LIGAND_SMILE}")
print(f"LIGAND_SMILE_NAME: {LIGAND_SMILE_NAME}")
print(f"PDB_FILE: {PDB_FILE}")
print(f"OUT_PATH: {OUT_PATH}")
print(f"MODEL_OPT: {MODEL_OPT}")
print(f"FOLD: {FOLD}")
print(f"RES_START: {RES_START}")
print(f"RES_END: {RES_END}")
print("#"*50)

import os
import logging
logging.getLogger().setLevel(logging.DEBUG)

import numpy as np
import torch
import torch_geometric as torchg
from tqdm import tqdm

from src import cfg
from src import TUNED_MODEL_CONFIGS

from src.utils.loader import Loader
from src.utils.residue import ResInfo, Chain
from src.data_prep.feature_extraction.ligand import smile_to_graph
from src.data_prep.feature_extraction.protein import target_to_graph
from src.data_prep.feature_extraction.protein_edges import get_target_edge_weights
from src.utils.residue import ResInfo, Chain

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_PARAMS = TUNED_MODEL_CONFIGS[MODEL_OPT]
PDB_FILE_NAME = os.path.basename(PDB_FILE).split('.pdb')[0]

# Get initial pkd value:
def get_protein_features(pdb_file_path, cmap_thresh=8.0):
pdb = Chain(pdb_file_path)
pro_cmap = pdb.get_contact_map()

updated_seq, extra_feat, edge_idx = target_to_graph(target_sequence=pdb.sequence,
contact_map=pro_cmap,
threshold=cmap_thresh,
pro_feat=MODEL_PARAMS['feature_opt'])
pro_edge_weight = None
if MODEL_PARAMS['edge_opt'] in cfg.OPT_REQUIRES_CONF:
raise NotImplementedError(f"{MODEL_PARAMS['edge_opt']} is not supported since it requires "+\
"multiple conformation files to run and generate edges.")
else:
# includes edge_attr like ring3
pro_edge_weight = get_target_edge_weights(pdb_file_path, pdb.sequence,
edge_opt=MODEL_PARAMS['edge_opt'],
cmap=pro_cmap,
n_modes=5, n_cpu=4)
if len(pro_edge_weight.shape) == 2:
pro_edge_weight = torch.Tensor(pro_edge_weight[edge_idx[0], edge_idx[1]])
elif len(pro_edge_weight.shape) == 3: # has edge attr! (This is our GVPL features)
pro_edge_weight = torch.Tensor(pro_edge_weight[edge_idx[0], edge_idx[1], :])

pro_feat = torch.Tensor(extra_feat)

pro = torchg.data.Data(x=torch.Tensor(pro_feat),
edge_index=torch.LongTensor(edge_idx),
pro_seq=updated_seq, # Protein sequence for downstream esm model
prot_id=PDB_FILE_NAME,
edge_weight=pro_edge_weight)
return pro, pdb

################################################
# Loading the model and get original pkd value #
################################################
m = Loader.load_tuned_model(MODEL_OPT, fold=FOLD)
m.to(DEVICE)
m.eval()

# build ligand graph
mol_feat, mol_edge = smile_to_graph(LIGAND_SMILE, lig_feature=MODEL_PARAMS['lig_feat_opt'], lig_edge=MODEL_PARAMS['lig_edge_opt'])
lig = torchg.data.Data(x=torch.Tensor(mol_feat), edge_index=torch.LongTensor(mol_edge), lig_seq=LIGAND_SMILE)

# build protein graph
pro, pdb = get_protein_features(PDB_FILE)

original_pkd = m(pro.to(DEVICE), lig.to(DEVICE))
print("Original pkd:",original_pkd)


################################################
# Mutate and regenerate graphs #################
################################################
# zero indexed res range to mutate:
res_range = (max(RES_START, 0),
min(RES_END, len(pdb.sequence)))

from src.utils.mutate_model import run_modeller
amino_acids = ResInfo.amino_acids[:-1] # not including "X" - unknown
muta = np.zeros(shape=(len(amino_acids), len(pdb.sequence)))

with tqdm(range(*res_range), ncols=100, total=(res_range[1]-res_range[0]), desc='Mutating') as t:
for j in t:
for i, AA in enumerate(amino_acids):
if i%2 == 0:
t.set_postfix(res=j, AA=i+1)

if pdb.sequence[j] == AA:
muta[i,j] = original_pkd
continue
out_pdb_fp = run_modeller(PDB_FILE, j+1, ResInfo.code_to_pep[AA], "A")

pro, pdb = get_protein_features(PDB_FILE)
assert pro.pro_seq != pdb.sequence and pro.pro_seq[j] == AA, "ERROR in modeller"

muta[i,j] = m(pro.to(DEVICE), lig.to(DEVICE))

# delete after use
os.remove(out_pdb_fp)


# Save mutagenesis matrix
OUT_DIR = f'{OUT_PATH}/{LIGAND_SMILE_NAME}/{MODEL_OPT}'
os.makedirs(OUT_DIR)
OUT_FP = f"{OUT_DIR}/{res_range[0]}_{res_range[1]}.npy"
print("Saving mutagenesis numpy matrix to", OUT_FP)
np.save(OUT_FP, muta)
181 changes: 0 additions & 181 deletions src/analysis/mutagenesis.py

This file was deleted.

Loading

0 comments on commit 85aa42b

Please sign in to comment.