Skip to content

Commit

Permalink
Merge pull request #140 from jyaacoub/v94-platinum-analysis
Browse files Browse the repository at this point in the history
v94 platinum analysis
  • Loading branch information
jyaacoub authored Oct 16, 2024
2 parents 9168fa3 + 91b0d78 commit 5d511fd
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 64 deletions.
24 changes: 24 additions & 0 deletions SBATCH/run_platinum.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash
#SBATCH -t 10:00
#SBATCH --job-name=run_platinum
#SBATCH --mem=10G

#SBATCH --gpus-per-node=a100:1
#SBATCH --cpus-per-task=4

#SBATCH --output=./%x_%a.out
#SBATCH --array=0

# runs across all folds for a model
# should produce a matrix for each fold

# Then to get most accurate mutagenesis you can average these matrices
# and visualize them with src.analysis.mutagenesis_plot.plot_sequence

cd /home/jyaacoub/projects/def-sushant/jyaacoub/MutDTA
source .venv/bin/activate

python -u run_platinum.py \
--model_opt davis_DG \
--fold ${SLURM_ARRAY_TASK_ID} \
--out_dir ./
24 changes: 11 additions & 13 deletions run_mutagenesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,20 @@
RES_START = args.res_start
RES_END = args.res_end

# Your code logic here
print("#"*50)
print(f"LIGAND_SMILE: {LIGAND_SMILE}")
print(f"LIGAND_SMILE_NAME: {LIGAND_SMILE_NAME}")
print(f"PDB_FILE: {PDB_FILE}")
print(f"OUT_PATH: {OUT_PATH}")
print(f"MODEL_OPT: {MODEL_OPT}")
print(f"FOLD: {FOLD}")
print(f"RES_START: {RES_START}")
print(f"RES_END: {RES_END}")
print("#"*50)

import os
import logging
logging.getLogger().setLevel(logging.DEBUG)
logging.debug("#"*50)
logging.debug(f"LIGAND_SMILE: {LIGAND_SMILE}")
logging.debug(f"LIGAND_SMILE_NAME: {LIGAND_SMILE_NAME}")
logging.debug(f"PDB_FILE: {PDB_FILE}")
logging.debug(f"OUT_PATH: {OUT_PATH}")
logging.debug(f"MODEL_OPT: {MODEL_OPT}")
logging.debug(f"FOLD: {FOLD}")
logging.debug(f"RES_START: {RES_START}")
logging.debug(f"RES_END: {RES_END}")
logging.debug("#"*50)

import os
import numpy as np
import torch
import torch_geometric as torchg
Expand Down
66 changes: 66 additions & 0 deletions run_platinum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import argparse
parser = argparse.ArgumentParser(description='Runs model on platinum dataset to evaluate it.')
parser.add_argument('--model_opt', type=str, default='davis_DG',
choices=['davis_DG', 'davis_gvpl', 'davis_esm',
'kiba_DG', 'kiba_esm', 'kiba_gvpl',
'PDBbind_DG', 'PDBbind_esm', 'PDBbind_gvpl'],
help='Model option. See MutDTA/src/__init__.py for details.')
parser.add_argument('--fold', type=int, default=1,
help='Which model fold to use (there are 5 models for each option due to 5-fold CV).')
parser.add_argument('--out_dir', type=str, default='./',
help='Output directory path to save csv file for prediction results.')

args = parser.parse_args()
MODEL_OPT = args.model_opt
FOLD = args.fold
OUT_DIR = args.out_dir

import logging
logging.getLogger().setLevel(logging.DEBUG)
logging.debug("#"*50)
logging.debug(f"MODEL_OPT: {MODEL_OPT}")
logging.debug(f"FOLD: {FOLD}")
logging.debug(f"OUT_DIR: {OUT_DIR}")
logging.debug("#"*50)

import torch, os
import pandas as pd

from src import cfg
from src import TUNED_MODEL_CONFIGS
from src.utils.loader import Loader
from src.train_test.training import test

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_PARAMS = TUNED_MODEL_CONFIGS[MODEL_OPT]

### Loading the model
logging.debug(f"Loading the model {MODEL_OPT}")
model, model_kwargs = Loader.load_tuned_model(MODEL_OPT, fold=FOLD)
MODEL_KEY = Loader.get_model_key(**model_kwargs)
model.to(DEVICE)
model.eval()

### Loading the data and Test
logging.debug("Loading platinum test dataloader for model")
loaders = Loader.load_DataLoaders(cfg.DATA_OPT.platinum,
pro_feature = MODEL_PARAMS['feature_opt'],
edge_opt = MODEL_PARAMS['edge_opt'],
ligand_feature = MODEL_PARAMS['lig_feat_opt'],
ligand_edge = MODEL_PARAMS['lig_edge_opt'],
datasets=['test'])

logging.debug("Running inference on test loader")
loss, pred, actual = test(model, loaders['test'], DEVICE, verbose=True)

# save as a CSV with cols: code, prot_id, pred, actual
logging.debug(f"Saving output to '{OUT_DIR}/{MODEL_KEY}_PLATINUM.csv'")
df = pd.DataFrame({
'prot_id': [b['prot_id'][0] for b in loaders['test']],
'pred': pred,
'actual': actual
},
index=[b['code'][0] for b in loaders['test']])

df.index.name = 'code'
df.to_csv(f'{OUT_DIR}/{MODEL_KEY}_PLATINUM.csv')
4 changes: 4 additions & 0 deletions src/analysis/mutagenesis_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def plot_sequence(muta, pro_seq, pep_opts=ResInfo.amino_acids[:-1], delta=False)
if __name__ == "__main__":
res_range = (0,215)
muta = np.load(f'muta-{res_range[0]}_{res_range[1]}.npy')
# to plot full sequence
# plot_sequence(muta, pro_seq=Chain(pdb_file).sequence)

# to plot a specific residue range (e.g.: a pocket)
#plot_sequence(muta[:,res_range[0]:res_range[1]],
# pro_seq=Chain(pdb_file).sequence[res_range[0]:res_range[1]])

Expand Down
47 changes: 0 additions & 47 deletions src/analysis/platinum.py

This file was deleted.

13 changes: 9 additions & 4 deletions src/utils/mutate_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import sys
import os

from modeller import *
from modeller.optimizers import MolecularDynamics, ConjugateGradients
from modeller.automodel import autosched
import logging

# try catch around modeller since it is only really needed for run_mutagenesis.py
try:
from modeller import *
from modeller.optimizers import MolecularDynamics, ConjugateGradients
from modeller.automodel import autosched
except ImportError:
logging.warning("Modeller failed to import - will not able to run mutagenesis scripts.")


def optimize(atmsel, sched):
Expand Down

0 comments on commit 5d511fd

Please sign in to comment.