Skip to content

Commit

Permalink
Merge pull request #124 from jyaacoub/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
jyaacoub committed Jul 23, 2024
2 parents 46f4586 + b751abc commit 286baaf
Show file tree
Hide file tree
Showing 14 changed files with 293 additions and 79 deletions.
84 changes: 84 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# %%
from src.data_prep.feature_extraction.ligand import smile_to_graph
from src.data_prep.feature_extraction.protein import target_to_graph
from src.data_prep.feature_extraction.protein_edges import get_target_edge_weights
from src import cfg
import torch
import torch_geometric as torchg
import numpy as np

DATA = cfg.DATA_OPT.davis
lig_feature = cfg.LIG_FEAT_OPT.original
lig_edge = cfg.LIG_EDGE_OPT.binary
pro_feature = cfg.PRO_FEAT_OPT.nomsa
pro_edge = cfg.PRO_EDGE_OPT.binary

lig_seq = "CN(C)CC=CC(=O)NC1=C(C=C2C(=C1)C(=NC=N2)NC3=CC(=C(C=C3)F)Cl)OC4CCOC4"

#%% build ligand graph
mol_feat, mol_edge = smile_to_graph(lig_seq, lig_feature=lig_feature, lig_edge=lig_edge)
lig = torchg.data.Data(x=torch.Tensor(mol_feat), edge_index=torch.LongTensor(mol_edge),
lig_seq=lig_seq)

#%% build protein graph
# predicted using - https://zhanggroup.org/NeBcon/
prot_id = 'EGFR(L858R)'
pro_seq = 'MRPSGTAGAALLALLAALCPASRALEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCWGAGEENCQKLTKIICAQQCSGRCRGKSPSDCCHNQCAAGCTGPRESDCLVCRKFRDEATCKDTCPPLMLYNPTTYQMDVNPEGKYSFGATCVKKCPRNYVVTDHGSCVRACGADSYEMEEDGVRKCKKCEGPCRKVCNGIGIGEFKDSLSINATNIKHFKNCTSISGDLHILPVAFRGDSFTHTPPLDPQELDILKTVKEITGFLLIQAWPENRTDLHAFENLEIIRGRTKQHGQFSLAVVSLNITSLGLRSLKEISDGDVIISGNKNLCYANTINWKKLFGTSGQKTKIISNRGENSCKATGQVCHALCSPEGCWGPEPRDCVSCRNVSRGRECVDKCNLLEGEPREFVENSECIQCHPECLPQAMNITCTGRGPDNCIQCAHYIDGPHCVKTCPAGVMGENNTLVWKYADAGHVCHLCHPNCTYGCTGPGLEGCPTNGPKIPSIATGMVGALLLLLVVALGIGLFMRRRHIVRKRTLRRLLQERELVEPLTPSGEAPNQALLRILKETEFKKIKVLGSGAFGTVYKGLWIPEGEKVKIPVAIKELREATSPKANKEILDEAYVMASVDNPHVCRLLGICLTSTVQLITQLMPFGCLLDYVREHKDNIGSQYLLNWCVQIAKGMNYLEDRRLVHRDLAARNVLVKTPQHVKITDFGLAKLLGAEEKEYHAEGGKVPIKWMALESILHRIYTHQSDVWSYGVTVWELMTFGSKPYDGIPASEISSILEKGERLPQPPICTIDVYMIMVKCWMIDADSRPKFRELIIEFSKMARDPQRYLVIQGDERMHLPSPTDSNFYRALMDEEDMDDVVDADEYLIPQQGFFSSPSTSRTPLLSSLSATSNNSTVACIDRNGLQSCPIKEDSFLQRYSSDPTGALTEDSIDDTFLPVPEYINQSVPKRPAGSVQNPVYHNQPLNPAPSRDPHYQDPHSTAVGNPEYLNTVQPTCVNSTFDSPAHWAQKGSHQISLDNPDYQQDFFPKEAKPNGIFKGSTAENAEYLRVAPQSSEFIGA'
cmap_p = f'/cluster/home/t122995uhn/projects/data/davis/pconsc4/{prot_id}.npy'

pro_cmap = np.load(cmap_p)
# updated_seq is for updated foldseek 3di combined seq
updated_seq, extra_feat, edge_idx = target_to_graph(target_sequence=pro_seq,
contact_map=pro_cmap,
threshold=8.0 if DATA is cfg.DATA_OPT.PDBbind else -0.5,
pro_feat=pro_feature)
pro_feat = torch.Tensor(extra_feat)



pro = torchg.data.Data(x=torch.Tensor(pro_feat),
edge_index=torch.LongTensor(edge_idx),
pro_seq=updated_seq, # Protein sequence for downstream esm model
prot_id=prot_id,
edge_weight=None)

#%% Loading the model
from src.utils.loader import Loader
from src import TUNED_MODEL_CONFIGS
import os

def reformat_kwargs(model_kwargs):
return {
'model': model_kwargs['model'],
'data': model_kwargs['dataset'],
'pro_feature': model_kwargs['feature_opt'],
'edge': model_kwargs['edge_opt'],
'batch_size': model_kwargs['batch_size'],
'lr': model_kwargs['lr'],
'dropout': model_kwargs['architecture_kwargs']['dropout'],
'n_epochs': model_kwargs.get('n_epochs', 2000), # Assuming a default value for n_epochs
'pro_overlap': model_kwargs.get('pro_overlap', False), # Assuming a default or None
'fold': model_kwargs.get('fold', 0), # Assuming a default or None
'ligand_feature': model_kwargs['lig_feat_opt'],
'ligand_edge': model_kwargs['lig_edge_opt']
}


model_kwargs = reformat_kwargs(TUNED_MODEL_CONFIGS['davis_esm'])

MODEL_KEY = Loader.get_model_key(**model_kwargs)

model_p_tmp = f'{cfg.MODEL_SAVE_DIR}/{MODEL_KEY}.model_tmp'
model_p = f'{cfg.MODEL_SAVE_DIR}/{MODEL_KEY}.model'

# MODEL_KEY = 'DDP-' + MODEL_KEY # distributed model
model_p = model_p if os.path.isfile(model_p) else model_p_tmp
assert os.path.isfile(model_p), f"MISSING MODEL CHECKPOINT {model_p}"

print(model_p)
# %%
args = model_kwargs
model = Loader.init_model(model=args['model'], pro_feature=args['pro_feature'],
pro_edge=args['edge'], **args['architecture_kwargs'])


69 changes: 37 additions & 32 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,47 @@
# %% oncokb proteins
# %%
import pandas as pd
train_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/DavisKibaDataset/davis/nomsa_binary_original_binary/full/XY.csv')
test_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/davis/test.csv')
train_df = train_df[~train_df.prot_id.isin(set(test_df.prot_id))]

# %%
import pandas as pd

kb_df = pd.read_csv('/cluster/home/t122995uhn/projects/downloads/oncoKB_DrugGenePairList.csv')
kb_prots = set(kb_df.gene)
davis_test_df = pd.read_csv(f"/home/jean/projects/MutDTA/splits/davis/test.csv")
davis_test_df['gene'] = davis_test_df['prot_id'].str.split('(').str[0]

#%% ONCO KB MERGE
onco_df = pd.read_csv("../data/oncoKB_DrugGenePairList.csv")
davis_join_onco = davis_test_df.merge(onco_df.drop_duplicates("gene"), on="gene", how="inner")

# %%
onco_df = pd.read_csv("../data/oncoKB_DrugGenePairList.csv")
onco_df.merge(davis_test_df.drop_duplicates("gene"), on="gene", how="inner").value_counts("gene")





davis_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/davis/test.csv')
davis_df['gene'] = davis_df.prot_id.str.split('(').str[0]
kiba_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/kiba/test.csv')
pdb_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind/test.csv')

davis_df['db'] = 'davis'
kiba_df['db'] = 'kiba'
pdb_df['db'] = 'pdbbind'

#%%
all_df = pd.concat([davis_df, kiba_df, pdb_df], axis=0)
new_order = ['db'] + [x for x in all_df.columns if x != 'db']
all_df = all_df[new_order].drop(['seq_len',
'gene_matched_on_pdb_id',
'gene_matched_on_uniprot_id'], axis=1)

all_df.to_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/all_tests.csv')

kb_overlap_test = all_df[all_df.gene.isin(kb_prots)]
# %%
from src.train_test.splitting import resplit
from src import cfg

db_p = lambda x: f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/nomsa_{x}_gvp_binary'

db = resplit(dataset=db_p('binary'), split_files=db_p('aflow'), use_train_set=True)

kb_overlap_test.to_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/all_tests_oncokb.csv')

['BRAF', 'ERBB2', 'FGFR2', 'FGFR3', 'KIT', 'PDGFRA', 'PIK3CA',
'RAF1', 'CHEK1', 'CHEK2', 'FGFR1', 'MAP2K1', 'MAP2K2', 'MTOR',
'EZH2', 'KDM6A', 'HRAS', 'KRAS', 'IDH1', 'PTEN', 'ESR1', 'BRIP1']

# %%
########################################################################
########################## VIOLIN PLOTTING #############################
########################################################################
import logging
from typing import OrderedDict

import seaborn as sns
from matplotlib import pyplot as plt
from statannotations.Annotator import Annotator

from src.analysis.figures import prepare_df, fig_combined, custom_fig

Expand All @@ -49,7 +50,7 @@

models = {
'DG': ('nomsa', 'binary', 'original', 'binary'),
# 'esm': ('ESM', 'binary', 'original', 'binary'), # esm model
'esm': ('ESM', 'binary', 'original', 'binary'), # esm model
'aflow': ('nomsa', 'aflow', 'original', 'binary'),
# 'gvpP': ('gvp', 'binary', 'original', 'binary'),
'gvpL': ('nomsa', 'binary', 'gvp', 'binary'),
Expand All @@ -70,15 +71,19 @@
fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance", box=True, fold_labels=True)
plt.xticks(rotation=45)


# %%
#%%
########################################################################
########################## BUILD DATASETS ##############################
########################################################################
from src.data_prep.init_dataset import create_datasets
from src import cfg
import logging
cfg.logger.setLevel(logging.DEBUG)

splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/davis/'
create_datasets(cfg.DATA_OPT.davis,
splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind/'
create_datasets(cfg.DATA_OPT.PDBbind,
feat_opt=cfg.PRO_FEAT_OPT.nomsa,
edge_opt=cfg.PRO_EDGE_OPT.aflow,
edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow],
ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp],
ligand_edges=cfg.LIG_EDGE_OPT.binary,
k_folds=5,
Expand Down
27 changes: 24 additions & 3 deletions results/v113/model_media/model_stats.csv
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,28 @@ EDIM_davis1D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.821300
EDIM_davis2D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.8225446485053062,0.6603442335278025,0.5949532834745241,0.6383847882308629,0.4154985330988705,0.7989898548985856
EDIM_davis3D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.8257508664241492,0.7310753856354779,0.5889781203307795,0.4852039936128347,0.3546071329782175,0.6965658573407361
EDIM_davis4D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.8404380957025455,0.7536920423101249,0.6146193037809486,0.4475845677765498,0.344083449582884,0.6690176139508958
DGM_kiba0D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.5790729147254525,0.3237567466038684,0.220483622988913,0.5872366004827765,0.5416789341999021,0.7663136436752098
DGM_kiba1D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6991041045925556,0.5827312646091369,0.5234334020017861,0.4359633916482334,0.4614902562500025,0.6602752393117838
DGM_kiba3D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6703138545001528,0.5130147763454005,0.4537929112353873,0.4829267623645413,0.484057142425249,0.6949293218483023
DGM_kiba0D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7098487012268593,0.5707564567872483,0.5408755032739221,0.4702066637233398,0.456037305701769,0.68571616848616
DGM_kiba1D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6991040727775232,0.5827312636226913,0.5234333190865307,0.4359633921504585,0.4614902574862928,0.6602752396920988
DGM_kiba3D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.673975414774567,0.5176426001573622,0.4632216123889441,0.4814432352161062,0.4857926130830525,0.6938611065740076
DGM_kiba2D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.672607400192465,0.5094659678690476,0.4582769804170869,0.5474786138565632,0.5343923583458045,0.7399179777898109
DGM_kiba4D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.725334039945205,0.5904997735444316,0.5827307698737493,0.4852863028425619,0.452352737630416,0.6966249369944791
DGM_kiba3D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.6779600069482259,0.5366735812473975,0.4684844875359095,0.4911992003209447,0.4879214073204387,0.7008560482159976
DGM_kiba1D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.6572547696170707,0.4910169881324348,0.4223030308489566,0.5056608650145961,0.4887140404038343,0.7110983511544631
DGM_kiba0D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7034265325632763,0.5925355722791278,0.5277625252545413,0.4342528236949791,0.4414575484107895,0.6589786215765874
DGM_kiba2D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7048145943290888,0.5949726563519002,0.5309972890766975,0.4319786589236721,0.4397173516202035,0.657250834098879
DGM_kiba4D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.701385350951551,0.5879330884340543,0.5238489684293576,0.43711720760821254,0.44305434527103205,0.6611484005941575
DGM_PDBbind0D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6655853516013849,0.4884783169365537,0.4768266531852068,2.716448143336347,1.29524747639392,1.648165083763258
DGM_PDBbind1D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6616134720685233,0.4760514005042608,0.465349083935074,2.7772597147478297,1.309049959958126,1.6665112405104952
DGM_PDBbind2D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6547720831157129,0.4674109621848075,0.4473761042353416,2.911396055546887,1.350302681084595,1.7062813529857517
DGM_PDBbind3D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6590676461310532,0.4730640252598409,0.4579028540331872,2.813666484139417,1.3159578356701942,1.6773987254494436
DGM_PDBbind4D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6560204986341623,0.4656808967835565,0.4491845247551415,2.830198214336796,1.3234003991911518,1.6823192961910638
GVPLM_kiba0D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6916246228601947,0.5614767553441977,0.4996457822085293,1.395078889806946,1.0094453810491737,1.181134577347961
GVPLM_kiba3D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6563609963973004,0.5357178296100711,0.4197069130416901,1.1612746598503312,0.8684839913064618,1.0776245449368398
GVPLM_kiba1D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6676856814770858,0.5583291036072489,0.443468456589727,1.3486743771266174,0.9835523785170164,1.161324406497434
GVPLM_kiba4D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.706312738166959,0.618586606318348,0.5346567955843519,1.2316886588884002,0.945708981855823,1.1098146957435733
GVPLM_kiba2D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6999996243009051,0.5844167249718527,0.5172933123278608,1.1060115689114345,0.8743866055307609,1.051670846278166
GVPLM_kiba4D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.7028775926753292,0.5511820968815552,0.5274587057439738,0.5191856014604727,0.464434987594827,0.7205453500373676
GVPLM_kiba3D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.6974533040075589,0.5366909353583093,0.5092679820627712,0.5267452909346243,0.4756001094118965,0.7257722031978245
GVPLM_kiba2D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.6914933128392398,0.5535320315157304,0.49916906416324947,0.5167043183823063,0.46841519355949945,0.7188214787986696
GVPLM_kiba0D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.667835957156033,0.5118338017508158,0.4512406693094591,0.529732104745498,0.4987863530176246,0.7278269744558098
GVPLM_kiba1D_nomsaF_aflowE_16B_0.00010990897170411903LR_0.03599877069828837D_2000E_gvpLF_binaryLE,0.673244203521782,0.5584520336614822,0.4601543459719329,0.4634166140494674,0.4859944296833685,0.6807470999199831
Loading

0 comments on commit 286baaf

Please sign in to comment.