diff --git a/.gitignore b/.gitignore index 3d39d50..bff6b5a 100644 --- a/.gitignore +++ b/.gitignore @@ -226,3 +226,5 @@ results/model_media/test_set_pred splits/**/*.csv results/*/model_media/*/train_log/*.json +results/model_checkpoints.tar.gz +*.out diff --git a/SBATCH/run_mutagenesis.sh b/SBATCH/run_mutagenesis.sh new file mode 100644 index 0000000..153796e --- /dev/null +++ b/SBATCH/run_mutagenesis.sh @@ -0,0 +1,39 @@ +#!/bin/bash +#SBATCH -t 10:00 +#SBATCH --job-name=run_mutagenesis +#SBATCH --mem=10G + +#SBATCH --gpus-per-node=a100:1 +#SBATCH --cpus-per-task=4 + +#SBATCH --output=./%x_%a.out +#SBATCH --array=0 + +# runs across all folds for a model +# should produce a matrix for each fold + +# Then to get most accurate mutagenesis you can average these matrices +# and visualize them with src.analysis.mutagenesis_plot.plot_sequence + +# Modeller is needed for this to run... (see: Generic install - https://salilab.org/modeller/10.5/release.html#unix) +export PYTHONPATH="${PYTHONPATH}:/home/jyaacoub/bin/modeller10.5/lib/x86_64-intel8/python3.3:/home/jyaacoub/bin/modeller10.5/lib/x86_64-intel8:/home/jyaacoub/bin/modeller10.5/modlib" +export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/home/jyaacoub/bin/modeller10.5/lib/x86_64-intel8" + +cd /home/jyaacoub/projects/def-sushant/jyaacoub/MutDTA +source .venv/bin/activate + +# NOTE: To get SMILE from a .mol2 or .sdf file you can use RDKIT: +# +# from rdkit import Chem +# mol2smile = lambda x: Chem.MolToSmiles(Chem.MolFromMol2File(x), isomericSmiles=False) +# mol2smile("/home/jyaacoub/scratch/pdbbind_demo/1a30/1a30_ligand.mol2") + +python -u run_mutagenesis.py \ + --ligand_smile "CC(C)CC(NC(=O)C(CC(=O)[O-])NC(=O)C([NH3+])CCC(=O)[O-])C(=O)[O-]" \ + --ligand_smile_name "1a30_ligand" \ + --pdb_file "/home/jyaacoub/projects/def-sushant/jyaacoub/data/kiba/alphaflow_io/out_pdb_MD-distilled/P67870.pdb" \ + --out_path "/home/jyaacoub/scratch/mutagenesis_tests/" \ + --res_start 0 \ + --res_end 5 \ + --model_opt davis_DG \ + --fold ${SLURM_ARRAY_TASK_ID} diff --git a/run_mutagenesis.py b/run_mutagenesis.py index 5800003..f07c07c 100644 --- a/run_mutagenesis.py +++ b/run_mutagenesis.py @@ -12,7 +12,7 @@ choices=['davis_DG', 'davis_gvpl', 'davis_esm', 'kiba_DG', 'kiba_esm', 'kiba_gvpl', 'PDBbind_DG', 'PDBbind_esm', 'PDBbind_gvpl'], - help='Model option.') + help='Model option. See MutDTA/src/__init__.py for details.') parser.add_argument('--fold', type=int, default=1, help='Which model fold to use (there are 5 models for each option due to 5-fold CV).') @@ -82,10 +82,11 @@ def get_protein_features(pdb_file_path, cmap_thresh=8.0): edge_opt=MODEL_PARAMS['edge_opt'], cmap=pro_cmap, n_modes=5, n_cpu=4) - if len(pro_edge_weight.shape) == 2: - pro_edge_weight = torch.Tensor(pro_edge_weight[edge_idx[0], edge_idx[1]]) - elif len(pro_edge_weight.shape) == 3: # has edge attr! (This is our GVPL features) - pro_edge_weight = torch.Tensor(pro_edge_weight[edge_idx[0], edge_idx[1], :]) + if pro_edge_weight: + if len(pro_edge_weight.shape) == 2: + pro_edge_weight = torch.Tensor(pro_edge_weight[edge_idx[0], edge_idx[1]]) + elif len(pro_edge_weight.shape) == 3: # has edge attr! (This is our GVPL features) + pro_edge_weight = torch.Tensor(pro_edge_weight[edge_idx[0], edge_idx[1], :]) pro_feat = torch.Tensor(extra_feat) @@ -96,10 +97,10 @@ def get_protein_features(pdb_file_path, cmap_thresh=8.0): edge_weight=pro_edge_weight) return pro, pdb -################################################ -# Loading the model and get original pkd value # -################################################ -m = Loader.load_tuned_model(MODEL_OPT, fold=FOLD) +################################################## +### Loading the model and get original pkd value # +################################################## +m, _ = Loader.load_tuned_model(MODEL_OPT, fold=FOLD) m.to(DEVICE) m.eval() @@ -108,22 +109,23 @@ def get_protein_features(pdb_file_path, cmap_thresh=8.0): lig = torchg.data.Data(x=torch.Tensor(mol_feat), edge_index=torch.LongTensor(mol_edge), lig_seq=LIGAND_SMILE) # build protein graph -pro, pdb = get_protein_features(PDB_FILE) +pro, pdb_original = get_protein_features(PDB_FILE) +original_seq = pdb_original.sequence original_pkd = m(pro.to(DEVICE), lig.to(DEVICE)) -print("Original pkd:",original_pkd) +print("Original pkd:", original_pkd) -################################################ -# Mutate and regenerate graphs ################# -################################################ +################################################## +### Mutate and regenerate graphs ################# +################################################## # zero indexed res range to mutate: res_range = (max(RES_START, 0), - min(RES_END, len(pdb.sequence))) + min(RES_END, len(original_seq))) from src.utils.mutate_model import run_modeller amino_acids = ResInfo.amino_acids[:-1] # not including "X" - unknown -muta = np.zeros(shape=(len(amino_acids), len(pdb.sequence))) +muta = np.zeros(shape=(len(amino_acids), len(original_seq))) with tqdm(range(*res_range), ncols=100, total=(res_range[1]-res_range[0]), desc='Mutating') as t: for j in t: @@ -131,13 +133,14 @@ def get_protein_features(pdb_file_path, cmap_thresh=8.0): if i%2 == 0: t.set_postfix(res=j, AA=i+1) - if pdb.sequence[j] == AA: + if original_seq[j] == AA: # skip same AA modifications muta[i,j] = original_pkd continue out_pdb_fp = run_modeller(PDB_FILE, j+1, ResInfo.code_to_pep[AA], "A") - pro, pdb = get_protein_features(PDB_FILE) - assert pro.pro_seq != pdb.sequence and pro.pro_seq[j] == AA, "ERROR in modeller" + pro, _ = get_protein_features(out_pdb_fp) + assert pro.pro_seq != original_seq and pro.pro_seq[j] == AA, \ + f"ERROR in modeller, {pro.pro_seq} == {original_seq} \nor {pro.pro_seq[j]} != {AA}" muta[i,j] = m(pro.to(DEVICE), lig.to(DEVICE)) diff --git a/src/__init__.py b/src/__init__.py index 370ef00..449b9f6 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -3,7 +3,7 @@ TUNED_MODEL_CONFIGS = { - #DGM_davis0D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E + #DGM_davis0D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E.model 'davis_DG':{ "model": cfg.MODEL_OPT.DG, diff --git a/src/analysis/mutagenesis_plot.py b/src/analysis/mutagenesis_plot.py index 1edf2d9..13f5970 100644 --- a/src/analysis/mutagenesis_plot.py +++ b/src/analysis/mutagenesis_plot.py @@ -53,7 +53,6 @@ def plot_sequence(muta, pro_seq, pep_opts=ResInfo.amino_acids[:-1], delta=False) if __name__ == "__main__": - res_range = (0,215) muta = np.load(f'muta-{res_range[0]}_{res_range[1]}.npy') #plot_sequence(muta[:,res_range[0]:res_range[1]],