Skip to content

Commit

Permalink
Merge pull request #104 from jyaacoub/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
jyaacoub authored Jun 4, 2024
2 parents 23c6fe6 + 57dcb48 commit d23a8f5
Show file tree
Hide file tree
Showing 15 changed files with 453 additions and 260 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,4 @@ results/model_checkpoints/ours/*.model
results/model_media/*/train_log/*
results/model_media/*/train_set_pred/*
results/model_media/*/test_set_pred/*
results/model_media/test_set_pred
269 changes: 133 additions & 136 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,148 +1,145 @@
#%% 1.Gather data for davis,kiba and pdbbind datasets
import os
import pandas as pd
import matplotlib.pyplot as plt
from src.analysis.utils import combine_dataset_pids
from src import config as cfg
df_prots = combine_dataset_pids(dbs=[cfg.DATA_OPT.davis, cfg.DATA_OPT.PDBbind], # just davis and pdbbind for now
subset='test')


#%% 2. Load TCGA data
df_tcga = pd.read_csv('../downloads/TCGA_ALL.maf', sep='\t')

#%% 3. Pre filtering
df_tcga = df_tcga[df_tcga['Variant_Classification'] == 'Missense_Mutation']
df_tcga['seq_len'] = pd.to_numeric(df_tcga['Protein_position'].str.split('/').str[1])
df_tcga = df_tcga[df_tcga['seq_len'] < 5000]
df_tcga['seq_len'].plot.hist(bins=100, title="sequence length histogram capped at 5K")
plt.show()
df_tcga = df_tcga[df_tcga['seq_len'] < 1200]
df_tcga['seq_len'].plot.hist(bins=100, title="sequence length after capped at 1.2K")

#%% 4. Merging df_prots with TCGA
df_tcga['uniprot'] = df_tcga['SWISSPROT'].str.split('.').str[0]

dfm = df_tcga.merge(df_prots[df_prots.db != 'davis'],
left_on='uniprot', right_on='prot_id', how='inner')

# for davis we have to merge on HUGO_SYMBOLS
dfm_davis = df_tcga.merge(df_prots[df_prots.db == 'davis'],
left_on='Hugo_Symbol', right_on='prot_id', how='inner')

dfm = pd.concat([dfm,dfm_davis], axis=0)

del dfm_davis # to save mem

# %% 5. Post filtering step
# 5.1. Filter for only those sequences with matching sequence length (to get rid of nonmatched isoforms)
# seq_len_x is from tcga, seq_len_y is from our dataset
tmp = len(dfm)
# allow for some error due to missing amino acids from pdb file in PDBbind dataset
# - assumption here is that isoforms will differ by more than 50 amino acids
dfm = dfm[(dfm.seq_len_y <= dfm.seq_len_x) & (dfm.seq_len_x<= dfm.seq_len_y+50)]
print(f"Filter #1 (seq_len) : {tmp:5d} - {tmp-len(dfm):5d} = {len(dfm):5d}")

# 5.2. Filter out those that dont have the same reference seq according to the "Protein_position" and "Amino_acids" col

# Extract mutation location and reference amino acid from 'Protein_position' and 'Amino_acids' columns
dfm['mt_loc'] = pd.to_numeric(dfm['Protein_position'].str.split('/').str[0])
dfm = dfm[dfm['mt_loc'] < dfm['seq_len_y']]
dfm[['ref_AA', 'mt_AA']] = dfm['Amino_acids'].str.split('/', expand=True)

dfm['db_AA'] = dfm.apply(lambda row: row['prot_seq'][row['mt_loc']-1], axis=1)

# Filter #2: Match proteins with the same reference amino acid at the mutation location
tmp = len(dfm)
dfm = dfm[dfm['db_AA'] == dfm['ref_AA']]
print(f"Filter #2 (ref_AA match): {tmp:5d} - {tmp-len(dfm):5d} = {len(dfm):5d}")
print('\n',dfm.db.value_counts())


# %% final seq len distribution
n_bins = 25
lengths = dfm.seq_len_x
fig, ax = plt.subplots(1, 1, figsize=(10, 5))

# Plot histogram
n, bins, patches = ax.hist(lengths, bins=n_bins, color='blue', alpha=0.7)
ax.set_title('TCGA final filtering for db matches')

# Add counts to each bin
for count, x, patch in zip(n, bins, patches):
ax.text(x + 0.5, count, str(int(count)), ha='center', va='bottom')
#%%
# %%
import logging
from typing import OrderedDict

ax.set_xlabel('Sequence Length')
ax.set_ylabel('Frequency')
import seaborn as sns
from matplotlib import pyplot as plt
from statannotations.Annotator import Annotator

plt.tight_layout()
plt.show()
from src.analysis.figures import prepare_df, custom_fig, fig_combined

# %% Getting updated sequences
def apply_mut(row):
ref_seq = list(row['prot_seq'])
ref_seq[row['mt_loc']-1] = row['mt_AA']
return ''.join(ref_seq)
df = prepare_df()
# %%
models = {
'DG': ('nomsa', 'binary', 'original', 'binary'),
'aflow': ('nomsa', 'aflow', 'original', 'binary'),
# 'aflow_ring3': ('nomsa', 'aflow_ring3', 'original', 'binary'),
# 'gvpP': ('gvp', 'binary', 'original', 'binary'),
# 'gvpL': ('nomsa', 'binary', 'gvp', 'binary'),
'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'),
'gvpL': ('nomsa', 'binary', 'gvp', 'binary'),
# 'gvpL_aflow_rng3': ('nomsa', 'aflow_ring3', 'gvp', 'binary'),
}

dfm['mt_seq'] = dfm.apply(apply_mut, axis=1)
# %%
fig, axes = fig_combined(df, datasets=['davis','PDBbind'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(8,5))
plt.xticks(rotation=45)


# %%
dfm.to_csv("/cluster/home/t122995uhn/projects/data/tcga/tcga_maf_davis_pdbbind.csv")
# %%
from src.utils.seq_alignment import MSARunner
from tqdm import tqdm
########################################################################
########################## PLATINUM ANALYSIS ###########################
########################################################################
import torch, os
import pandas as pd
import os

DATA_DIR = '/cluster/home/t122995uhn/projects/data/tcga'
CSV = f'{DATA_DIR}/tcga_maf_davis_pdbbind.csv'
N_CPUS= 6
NUM_ARRAYS = 10
array_idx = 0#${SLURM_ARRAY_TASK_ID}

df = pd.read_csv(CSV, index_col=0)
df.sort_values(by='seq_len_y', inplace=True)


# %%
for DB in df.db.unique():
print('DB', DB)
RAW_DIR = f'{DATA_DIR}/{DB}'
# should already be unique if these are proteins mapped form tcga!
unique_df = df[df['db'] == DB]
########################## Get job partition
partition_size = len(unique_df) / NUM_ARRAYS
start, end = int(array_idx*partition_size), int((array_idx+1)*partition_size)

unique_df = unique_df[start:end]

#################################### create fastas
fa_dir = os.path.join(RAW_DIR, f'{DB}_fa')
fasta_fp = lambda idx,pid: os.path.join(fa_dir, f"{idx}-{pid}.fasta")
os.makedirs(fa_dir, exist_ok=True)
for idx, (prot_id, pro_seq) in tqdm(
unique_df[['prot_id', 'prot_seq']].iterrows(),
desc='Creating fastas',
total=len(unique_df)):
with open(fasta_fp(idx,prot_id), "w") as f:
f.write(f">{prot_id},{idx},{DB}\n{pro_seq}")

##################################### Run hhblits
aln_dir = os.path.join(RAW_DIR, f'{DB}_aln')
aln_fp = lambda idx,pid: os.path.join(aln_dir, f"{idx}-{pid}.a3m")
os.makedirs(aln_dir, exist_ok=True)

# finally running
for idx, (prot_id, pro_seq) in tqdm(
unique_df[['prot_id', 'mt_seq']].iterrows(),
desc='Running hhblits',
total=len(unique_df)):
in_fp = fasta_fp(idx,prot_id)
out_fp = aln_fp(idx,prot_id)
from src import cfg
from src import TUNED_MODEL_CONFIGS
from src.utils.loader import Loader
from src.train_test.training import test
from src.analysis.figures import predictive_performance, tbl_stratified_dpkd_metrics, tbl_dpkd_metrics_overlap, tbl_dpkd_metrics_in_binding

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

INFERENCE = True
VERBOSE = True
out_dir = f'{cfg.MEDIA_SAVE_DIR}/test_set_pred/'
os.makedirs(out_dir, exist_ok=True)
cp_dir = cfg.CHECKPOINT_SAVE_DIR
RAW_PLT_CSV=f"{cfg.DATA_ROOT}/PlatinumDataset/raw/platinum_flat_file.csv"

#%% load up model:
for KEY, CONFIG in TUNED_MODEL_CONFIGS.items():
MODEL_KEY = lambda fold: Loader.get_model_key(CONFIG['model'], CONFIG['dataset'], CONFIG['feature_opt'], CONFIG['edge_opt'],
CONFIG['batch_size'], CONFIG['lr'], CONFIG['architecture_kwargs']['dropout'],
n_epochs=2000, fold=fold,
ligand_feature=CONFIG['lig_feat_opt'], ligand_edge=CONFIG['lig_edge_opt'])
print('\n\n'+ '## ' + KEY)
OUT_PLT = lambda i: f'{out_dir}/{MODEL_KEY(i)}_PLATINUM.csv'
db_p = f"{CONFIG['feature_opt']}_{CONFIG['edge_opt']}_{CONFIG['lig_feat_opt']}_{CONFIG['lig_edge_opt']}"

if CONFIG['dataset'] in ['kiba', 'davis']:
db_p = f"DavisKibaDataset/{CONFIG['dataset']}/{db_p}"
else:
db_p = f"{CONFIG['dataset']}Dataset/{db_p}"

if not os.path.isfile(out_fp):
print(MSARunner.hhblits(in_fp, out_fp, n_cpus=N_CPUS, return_cmd=True))
break
train_p = lambda set: f"{cfg.DATA_ROOT}/{db_p}/{set}0/cleaned_XY.csv"

if not os.path.exists(OUT_PLT(0)) and INFERENCE:
print('running inference!')
cp = lambda fold: f"{cp_dir}/{MODEL_KEY(fold)}.model"

model = Loader.init_model(model=CONFIG["model"], pro_feature=CONFIG["feature_opt"],
pro_edge=CONFIG["edge_opt"],**CONFIG['architecture_kwargs'])

# load up platinum test db
loaders = Loader.load_DataLoaders(cfg.DATA_OPT.platinum,
pro_feature = CONFIG['feature_opt'],
edge_opt = CONFIG['edge_opt'],
ligand_feature = CONFIG['lig_feat_opt'],
ligand_edge = CONFIG['lig_edge_opt'],
datasets=['test'])

for i in range(5):
model.safe_load_state_dict(torch.load(cp(i), map_location=device))
model.to(device)
model.eval()

loss, pred, actual = test(model, loaders['test'], device, verbose=True)

# saving as csv with columns code, pred, actual
# get codes from test loader
codes, pid = [b['code'][0] for b in loaders['test']], [b['prot_id'][0] for b in loaders['test']]
df = pd.DataFrame({'prot_id': pid, 'pred': pred, 'actual': actual}, index=codes)
df.index.name = 'code'
df.to_csv(OUT_PLT(i))

# run platinum eval:
print('\n### 1. predictive performance')
mkdown = predictive_performance(OUT_PLT, train_p, verbose=VERBOSE, plot=False)
print('\n### 2 Mutation impact analysis')
print('\n#### 2.1 $\Delta pkd$ predictive performance')
mkdn = tbl_dpkd_metrics_overlap(OUT_PLT, train_p, verbose=VERBOSE, plot=False)
print('\n#### 2.2 Stratified by location of mutation (binding pocket vs not in binding pocket)')
m = tbl_dpkd_metrics_in_binding(OUT_PLT, RAW_PLT_CSV, verbose=VERBOSE, plot=False)

# %%
dfr = pd.read_csv(RAW_PLT_CSV, index_col=0)

# add in_binding info to df
def get_in_binding(df, dfr):
"""
df is the predicted csv with index as <raw_idx>_wt (or *_mt) where raw_idx
corresponds to an index in dfr which contains the raw data for platinum including
('mut.in_binding_site')
- 0: wildtype rows
- 1: close (<8 Ang)
- 2: Far (>8 Ang)
"""
pocket = dfr[dfr['mut.in_binding_site'] == 'YES'].index
pclass = []
for code in df.index:
if '_wt' in code:
pclass.append(0)
elif int(code.split('_')[0]) in pocket:
pclass.append(1)
else:
pclass.append(2)
df['pocket'] = pclass
return df

df = get_in_binding(pd.read_csv(OUT_PLT(0), index_col=0), dfr)
if VERBOSE:
cnts = df.pocket.value_counts()
cnts.index = ['wt', 'in pocket', 'not in pocket']
cnts.name = "counts"
print(cnts.to_markdown(), end="\n\n")

tbl_stratified_dpkd_metrics(OUT_PLT, NORMALIZE=True, n_models=5, df_transform=get_in_binding,
conditions=['(pocket == 0) | (pocket == 1)', '(pocket == 0) | (pocket == 2)'],
names=['in pocket', 'not in pocket'],
verbose=VERBOSE, plot=True, dfr=dfr)

65 changes: 11 additions & 54 deletions rayTrain_Tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,39 +78,16 @@ def train_func(config):
print("Cuda support:", torch.cuda.is_available(),":",
torch.cuda.device_count(), "devices")
print("CUDA VERSION:", torch.__version__)

# search_space = {
# ## constants:
# "epochs": 20,
# "model": cfg.MODEL_OPT.DG,
# "dataset": cfg.DATA_OPT.PDBbind,
# "feature_opt": cfg.PRO_FEAT_OPT.nomsa,
# "edge_opt": cfg.PRO_EDGE_OPT.aflow,
# "lig_feat_opt": cfg.LIG_FEAT_OPT.original,
# "lig_edge_opt": cfg.LIG_EDGE_OPT.binary,

# "fold_selection": 0,
# "save_checkpoint": False,

# ## hyperparameters to tune:
# "lr": ray.tune.loguniform(1e-5, 1e-3),
# "batch_size": ray.tune.choice([32, 64, 128]), # local batch size

# # model architecture hyperparams
# "architecture_kwargs":{
# "dropout": ray.tune.uniform(0.0, 0.5),
# "output_dim": ray.tune.choice([128, 256, 512]),
# }
# }
# 'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'):
search_space = {
## constants:
"epochs": 20,
"model": cfg.MODEL_OPT.GVPL,
"dataset": cfg.DATA_OPT.davis,

"dataset": cfg.DATA_OPT.kiba,
"feature_opt": cfg.PRO_FEAT_OPT.nomsa,
"edge_opt": cfg.PRO_EDGE_OPT.aflow,
"lig_feat_opt": cfg.LIG_FEAT_OPT.gvp,
"lig_feat_opt": cfg.LIG_FEAT_OPT.original,
"lig_edge_opt": cfg.LIG_EDGE_OPT.binary,

"fold_selection": 0,
Expand All @@ -123,35 +100,15 @@ def train_func(config):
# model architecture hyperparams
"architecture_kwargs":{
"dropout": ray.tune.uniform(0.0, 0.5),
"output_dim": ray.tune.choice([128, 256, 512]),
}
"output_dim": ray.tune.choice([128, 256, 512]),
},
}
# search space for GVPL_RNG MODEL:
# search_space = {
# ## constants:
# "epochs": 20,
# "model": cfg.MODEL_OPT.GVPL_RNG,
# "dataset": cfg.DATA_OPT.PDBbind,
# "feature_opt": cfg.PRO_FEAT_OPT.nomsa,
# "edge_opt": cfg.PRO_EDGE_OPT.aflow_ring3,
# "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp,
# "lig_edge_opt": cfg.LIG_EDGE_OPT.binary,
#
# "fold_selection": 0,
# "save_checkpoint": False,
#
# ## hyperparameters to tune:
# "lr": ray.tune.loguniform(1e-5, 1e-3),
# "batch_size": ray.tune.choice([16,32,64]), # local batch size
#
# # model architecture hyperparams
# "architecture_kwargs":{
# "dropout": ray.tune.uniform(0.0, 0.5),
# "pro_emb_dim": ray.tune.choice([64, 128, 256]),
# "output_dim": ray.tune.choice([128, 256, 512]),
# "nheads_pro": ray.tune.choice([3, 4, 5]),
# }
# }
arch_kwargs = search_space['architecture_kwargs']
if search_space['model'] == cfg.MODEL_OPT.GVPL:
arch_kwargs["num_GVPLayers"]= ray.tune.choice([2, 3, 4])
elif search_space['model'] == cfg.MODEL_OPT.GVPL_RNG:
arch_kwargs["pro_emb_dim"] = ray.tune.choice([64, 128, 256])
arch_kwargs["nheads_pro"] = ray.tune.choice([3, 4, 5])

# each worker is a node from the ray cluster.
# WARNING: SBATCH GPU directive should match num_workers*GPU_per_worker
Expand Down
Loading

0 comments on commit d23a8f5

Please sign in to comment.