-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_mutagenesis.py
153 lines (126 loc) · 6.34 KB
/
run_mutagenesis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import argparse
parser = argparse.ArgumentParser(description='Runs Mutagenesis on an input PDB file and a given ligand SMILES.')
parser.add_argument('--ligand_smile', type=str, required=True, help='Ligand SMILES string.')
parser.add_argument('--ligand_smile_name', type=str, required=True, help='Ligand SMILES name, required for output path.')
parser.add_argument('--pdb_file', type=str, required=True, help='Path to the PDB file.')
parser.add_argument('--out_path', type=str, default='./',
help='Output directory path to save resulting mutagenesis numpy matrix with predicted pkd values')
parser.add_argument('--res_start', type=int, default=0, help='Start index for mutagenesis (zero-indexed).')
parser.add_argument('--res_end', type=int, default=float('inf'), help='End index for mutagenesis.')
parser.add_argument('--model_opt', type=str, default='davis_DG',
choices=['davis_DG', 'davis_gvpl', 'davis_esm',
'kiba_DG', 'kiba_esm', 'kiba_gvpl',
'PDBbind_DG', 'PDBbind_esm', 'PDBbind_gvpl'],
help='Model option.')
parser.add_argument('--fold', type=int, default=1,
help='Which model fold to use (there are 5 models for each option due to 5-fold CV).')
args = parser.parse_args()
# Assign variables
LIGAND_SMILE = args.ligand_smile
LIGAND_SMILE_NAME = args.ligand_smile_name
PDB_FILE = args.pdb_file
OUT_PATH = args.out_path
MODEL_OPT = args.model_opt
FOLD = args.fold
RES_START = args.res_start
RES_END = args.res_end
# Your code logic here
print("#"*50)
print(f"LIGAND_SMILE: {LIGAND_SMILE}")
print(f"LIGAND_SMILE_NAME: {LIGAND_SMILE_NAME}")
print(f"PDB_FILE: {PDB_FILE}")
print(f"OUT_PATH: {OUT_PATH}")
print(f"MODEL_OPT: {MODEL_OPT}")
print(f"FOLD: {FOLD}")
print(f"RES_START: {RES_START}")
print(f"RES_END: {RES_END}")
print("#"*50)
import os
import logging
logging.getLogger().setLevel(logging.DEBUG)
import numpy as np
import torch
import torch_geometric as torchg
from tqdm import tqdm
from src import cfg
from src import TUNED_MODEL_CONFIGS
from src.utils.loader import Loader
from src.utils.residue import ResInfo, Chain
from src.data_prep.feature_extraction.ligand import smile_to_graph
from src.data_prep.feature_extraction.protein import target_to_graph
from src.data_prep.feature_extraction.protein_edges import get_target_edge_weights
from src.utils.residue import ResInfo, Chain
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_PARAMS = TUNED_MODEL_CONFIGS[MODEL_OPT]
PDB_FILE_NAME = os.path.basename(PDB_FILE).split('.pdb')[0]
# Get initial pkd value:
def get_protein_features(pdb_file_path, cmap_thresh=8.0):
pdb = Chain(pdb_file_path)
pro_cmap = pdb.get_contact_map()
updated_seq, extra_feat, edge_idx = target_to_graph(target_sequence=pdb.sequence,
contact_map=pro_cmap,
threshold=cmap_thresh,
pro_feat=MODEL_PARAMS['feature_opt'])
pro_edge_weight = None
if MODEL_PARAMS['edge_opt'] in cfg.OPT_REQUIRES_CONF:
raise NotImplementedError(f"{MODEL_PARAMS['edge_opt']} is not supported since it requires "+\
"multiple conformation files to run and generate edges.")
else:
# includes edge_attr like ring3
pro_edge_weight = get_target_edge_weights(pdb_file_path, pdb.sequence,
edge_opt=MODEL_PARAMS['edge_opt'],
cmap=pro_cmap,
n_modes=5, n_cpu=4)
if len(pro_edge_weight.shape) == 2:
pro_edge_weight = torch.Tensor(pro_edge_weight[edge_idx[0], edge_idx[1]])
elif len(pro_edge_weight.shape) == 3: # has edge attr! (This is our GVPL features)
pro_edge_weight = torch.Tensor(pro_edge_weight[edge_idx[0], edge_idx[1], :])
pro_feat = torch.Tensor(extra_feat)
pro = torchg.data.Data(x=torch.Tensor(pro_feat),
edge_index=torch.LongTensor(edge_idx),
pro_seq=updated_seq, # Protein sequence for downstream esm model
prot_id=PDB_FILE_NAME,
edge_weight=pro_edge_weight)
return pro, pdb
################################################
# Loading the model and get original pkd value #
################################################
m = Loader.load_tuned_model(MODEL_OPT, fold=FOLD)
m.to(DEVICE)
m.eval()
# build ligand graph
mol_feat, mol_edge = smile_to_graph(LIGAND_SMILE, lig_feature=MODEL_PARAMS['lig_feat_opt'], lig_edge=MODEL_PARAMS['lig_edge_opt'])
lig = torchg.data.Data(x=torch.Tensor(mol_feat), edge_index=torch.LongTensor(mol_edge), lig_seq=LIGAND_SMILE)
# build protein graph
pro, pdb = get_protein_features(PDB_FILE)
original_pkd = m(pro.to(DEVICE), lig.to(DEVICE))
print("Original pkd:",original_pkd)
################################################
# Mutate and regenerate graphs #################
################################################
# zero indexed res range to mutate:
res_range = (max(RES_START, 0),
min(RES_END, len(pdb.sequence)))
from src.utils.mutate_model import run_modeller
amino_acids = ResInfo.amino_acids[:-1] # not including "X" - unknown
muta = np.zeros(shape=(len(amino_acids), len(pdb.sequence)))
with tqdm(range(*res_range), ncols=100, total=(res_range[1]-res_range[0]), desc='Mutating') as t:
for j in t:
for i, AA in enumerate(amino_acids):
if i%2 == 0:
t.set_postfix(res=j, AA=i+1)
if pdb.sequence[j] == AA:
muta[i,j] = original_pkd
continue
out_pdb_fp = run_modeller(PDB_FILE, j+1, ResInfo.code_to_pep[AA], "A")
pro, pdb = get_protein_features(PDB_FILE)
assert pro.pro_seq != pdb.sequence and pro.pro_seq[j] == AA, "ERROR in modeller"
muta[i,j] = m(pro.to(DEVICE), lig.to(DEVICE))
# delete after use
os.remove(out_pdb_fp)
# Save mutagenesis matrix
OUT_DIR = f'{OUT_PATH}/{LIGAND_SMILE_NAME}/{MODEL_OPT}'
os.makedirs(OUT_DIR)
OUT_FP = f"{OUT_DIR}/{res_range[0]}_{res_range[1]}.npy"
print("Saving mutagenesis numpy matrix to", OUT_FP)
np.save(OUT_FP, muta)