Skip to content

Commit

Permalink
fix(PDBbindDataset): rename lig_name -> lig_id on correct axis
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed Jul 18, 2024
1 parent a830819 commit 0c62001
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 95 deletions.
100 changes: 6 additions & 94 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,4 @@
#%%
# test gvpl model to make sure it worked correctly!
from src.utils.loader import Loader

m = Loader.init_model('GVPL', 'nomsa', 'binary')

#%%
import torch
smi, lig = next(iter(torch.load('../data/DavisKibaDataset/davis/nomsa_binary_gvp_binary/test/data_mol.pt').items()))
pid, pro = next(iter(torch.load('../data/DavisKibaDataset/davis/nomsa_binary_gvp_binary/test/data_pro.pt').items()))

#%%
from copy import deepcopy
s0 = deepcopy(m.state_dict())

# %% train with single sample
from torch import nn
criterion = nn.MSELoss()
optim = torch.optim.Adam(m.parameters(), lr=1)

m.train()
loss = criterion(m(pro, lig), torch.tensor([1.0]))

optim.zero_grad()
loss.backward()
optim.step()
#%%
for k in m.state_dict():
v1 = s0[k]
v2 = m.state_dict()[k]
if torch.allclose(v1, v2):
print(k)



#%%
train_genes = ["CDC2L1", "ABL1(E255K)", "IRAK1", "AAK1", "RET(M918T)", "RSK1(KinDom.2-C-terminal)", "HIPK1", "CSNK1G3,"
"CHEK2", "PDPK1", "EGFR(S752I759del)", "CDC2L2", "ERN1", "CDK4-cyclinD1", "PLK3", "TIE1", "TIE2", "CDKL5,"
"RSK4(KinDom.2-C-terminal)", "WEE1", "RIPK4", "TNK2", "SGK3", "MRCKA", "SLK", "MAK", "GCN2(KinDom2S808G),"
"p38-beta", "PIK4CB", "CLK2", "FLT3(N841I)", "TEC", "AMPK-alpha2", "MLK3", "DAPK1", "SYK", "GRK1", "MARK2,"
"PRKCD", "PRKCQ", "TBK1", "PRKCE", "ERK5", "CAMKK2", "JAK3(JH1domain-catalytic)", "INSR", "PKN2", "ADCK3,"
"ABL1(F317I)", "ADCK4", "NEK2", "S6K1", "EGFR(L747E749del)", "EPHA4", "EPHB3", "SgK110", "JAK1(JH1domain-catalytic),"
"PLK2", "PRKG2", "ERBB4", "PHKG2", "SRPK2", "CHEK1", "DCAMKL2", "NEK9", "ACVR2A", "AKT1", "CSK", "STK35,"
"IKK-alpha", "CSF1R", "MST1", "TYK2(JH2domain-pseudokinase)", "PCTK3", "YANK3", "ACVRL1", "DAPK3", "EGFR(L747S752del),"
"EPHA6", "LIMK1", "CDKL3", "HIPK2", "PKAC-alpha", "MST2", "CSNK1D", "OSR1", "EPHA5", "CDKL2", "MEK3,"
"PHKG1", "PIP5K2B", "TLK2", "CAMKK1", "MINK", "EGFR", "MKNK2", "PRKD3", "INSRR", "BRK", "EIF2AK1", "AURKA,"
"ERK3", "HCK", "JAK2(JH1domain-catalytic)", "TRKB", "TNK1", "MAP3K1", "MAP3K4", "LIMK2", "GAK", "ERBB3,"
"FLT3(R834Q)", "MYO3A", "MARK3", "LATS2", "IRAK4", "MEK4", "PRKR", "STK39", "YES", "RIPK1", "FLT3", "BIKE,"
"CSNK1E", "LYN", "PKN1", "RET", "ABL1(F317L)p", "RSK1(KinDom.1-N-terminal)", "ROS1", "CAMK1", "VEGFR2,"
"MERTK", "BRSK1", "IKK-beta", "RSK4(KinDom.1-N-terminal)", "p38-gamma", "YSK4", "PFPK5(Pfalciparum),"
"TTK", "MYO3B", "CLK4", "PRKG1", "MAP4K2", "LZK", "RIOK1", "EGFR(L858RT790M)", "LCK", "FRK", "PLK1,"
"DYRK1A", "TSSK1B", "MST3", "CSNK1A1L", "EGFR(L861Q)", "FAK", "ABL1(T315I)p", "MET(Y1235D)", "PIM2,"
"TRKC", "RPS6KA4(KinDom.2-C-terminal)", "TNIK", "FLT3(D835Y)", "PAK7", "PAK3", "QSK", "MKNK1", "VRK2,"
"PIM1", "MKK7", "CSNK1A1", "ROCK2", "RET(V804L)", "MEK5", "ARK5", "FER", "CDK5", "ERK8", "RIPK5", "NLK,"
"PIP5K1C", "PKAC-beta", "ABL1", "CAMK2G", "MEK6", "RIOK2", "ABL1(M351T)", "CSNK2A1", "ZAP70", "RSK2(KinDom.1-N-terminal),"
"TESK1", "STK36", "CDK9", "CAMK2B", "ABL1(F317I)p", "HUNK", "NEK1", "TAOK3", "MST1R", "YSK1", "CTK,"
"MYLK2", "PIM3", "PIK3CG", "FLT4", "HPK1", "AURKB", "PKNB(Mtuberculosis)", "SRMS", "ICK", "TLK1", "CSNK1G1,"
"FLT1", "PAK1", "NEK4", "RPS6KA4(KinDom.1-N-terminal)", "MYLK", "DYRK2", "CDK11", "GSK3B", "CDC2L5,"
"MAPKAPK5", "DAPK2", "MLK1", "WEE2", "DCAMKL3", "TRPM6", "FYN", "ROCK1", "MELK", "FGFR1", "ULK1", "SNARK,"
"FES", "PLK4", "TAOK2", "MAP3K15", "EPHB2", "CAMK1D", "RSK3(KinDom.2-C-terminal)", "EPHA8", "TYK2(JH1domain-catalytic),"
"TYRO3", "HIPK3", "BMPR1B", "CDK2", "ZAK", "LATS1", "ABL1(Q252H)", "RSK3(KinDom.1-N-terminal)", "FLT3(ITD),"
"ABL1(F317L)", "MAP4K4", "LTK", "PYK2", "TAOK1", "SIK", "RIPK2", "PAK4", "MTOR", "EPHB4", "ANKK1", "MAP3K3,"
"JAK1(JH2domain-pseudokinase)", "CSNK1G2", "MUSK", "ULK2", "ABL1(H396P)", "ABL1(Y253F)", "STK16", "ABL2,"
"FLT3(D835H)", "MAP4K5", "TGFBR2", "PRP4", "PIK3CB", "ALK", "ABL1(Q252H)p", "CDKL1", "EGFR(G719C)", "AKT2,"
"EPHA7", "ULK3", "TRKA", "ABL1(T315I)", "MEK2", "SBK1", "RET(V804M)", "HIPK4", "CAMK2A", "ASK1", "CLK1,"
"PFTK1", "JNK1", "YANK2", "DMPK", "MARK1", "p38-alpha", "MLCK", "PRKD1", "MARK4", "ASK2", "DYRK1B", "FGR,"
"EPHB6", "ITK", "PFTAIRE2", "SRPK1", "ABL1(H396P)p", "ERK1", "ABL1p", "DDR2", "DMPK2", "SRC", "JNK3,"
"YANK1", "CDK4-cyclinD3", "MET", "PIK3C2G", "GRK4", "PKMYT1", "NEK6", "STK33", "ERK4", "MRCKB", "CDK8,"
"NEK11", "ACVR1B", "TNNI3K", "DRAK2", "EPHA1", "EGFR(L747T751del)", "ERK2", "DLK", "PDGFRB", "TGFBR1,"
"CAMK2D", "EGFR(T790M)", "GSK3A", "PAK6", "BMX", "LKB1", "IGF1R", "MYLK4", "AKT3", "BLK", "EPHB1", "CDK7,"
"MAPKAPK2", "PCTK2", "FGFR4", "EGFR(L858R)", "NIM1", "DDR1", "PIK3CD", "CASK", "MAP3K2", "CDK3", "IRAK3,"
"MST4", "EGFR(G719S)", "SNRK", "BMPR1A", "AURKC", "PRKCI", "EGFR(E746A750del)", "CAMK4", "PFCDPK1(Pfalciparum),"
"PAK2", "AXL", "MAST1", "PRKCH", "CLK3", "NDR1", "GRK7", "MET(M1250T)", "DRAK1", "EPHA2", "PRKX", "AMPK-alpha1,"
"TXK", "SRPK3", "RIOK3", "FLT3(K663Q)", "CSNK2A2", "CIT", "DCAMKL1", "LRRK2(G2019S)", "PRKD2", "EPHA3,"
"BTK", "p38-delta", "ACVR1", "CAMK1G", "LRRK2", "PCTK1", "BRSK2", "JNK2", "MAP4K3"]
#%%
# %%
import pandas as pd
train_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/DavisKibaDataset/davis/nomsa_binary_original_binary/full/XY.csv')
test_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/davis/test.csv')
Expand Down Expand Up @@ -146,29 +71,16 @@
plt.xticks(rotation=45)

#%%
########################################################################
########################## BUILD DATASETS ##############################
########################################################################
from src.data_prep.init_dataset import create_datasets
from src import cfg
import logging
cfg.logger.setLevel(logging.DEBUG)

splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/kiba/'
create_datasets(cfg.DATA_OPT.kiba,
feat_opt=cfg.PRO_FEAT_OPT.nomsa,
edge_opt=cfg.PRO_EDGE_OPT.aflow,
ligand_features=cfg.LIG_FEAT_OPT.gvp,
ligand_edges=cfg.LIG_EDGE_OPT.binary,
k_folds=5,
test_prots_csv=f'{splits}/test.csv',
val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)])

# %%
from src.data_prep.init_dataset import create_datasets
from src import cfg
import logging
cfg.logger.setLevel(logging.DEBUG)

splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/kiba/'
create_datasets(cfg.DATA_OPT.kiba,
splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/pdbbind/'
create_datasets(cfg.DATA_OPT.PDBbind,
feat_opt=cfg.PRO_FEAT_OPT.nomsa,
edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow],
ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp],
Expand Down
2 changes: 1 addition & 1 deletion src/data_prep/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def pre_process(self):
# Get binding data:
df_binding = PDBbindProcessor.get_binding_data(self.raw_paths[0]) # _data.2020
df_binding.drop(columns=['resolution', 'release_year'], inplace=True)
df_binding.rename({'lig_name':'lig_id'}, inplace=True)
df_binding.rename({'lig_name':'lig_id'}, inplace=True, axis=1)
pdb_codes = df_binding.index # pdbcodes

############## validating codes #############
Expand Down

0 comments on commit 0c62001

Please sign in to comment.