Skip to content

Commit

Permalink
fic(gvpl): unused parameters #113
Browse files Browse the repository at this point in the history
Unused parameters due to inheriting from DGraphDTA but not using the forward_pro method
  • Loading branch information
jyaacoub committed Jul 17, 2024
1 parent 60bee50 commit cff0bab
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 6 deletions.
46 changes: 46 additions & 0 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
#%%
# test gvpl model to make sure it worked correctly!
from src.utils.loader import Loader

m = Loader.init_model('GVPL', 'nomsa', 'binary')

#%%
import torch
smi, lig = next(iter(torch.load('../data/DavisKibaDataset/davis/nomsa_binary_gvp_binary/test/data_mol.pt').items()))
pid, pro = next(iter(torch.load('../data/DavisKibaDataset/davis/nomsa_binary_gvp_binary/test/data_pro.pt').items()))

m(pro,lig)


#%%
train_genes = ["CDC2L1", "ABL1(E255K)", "IRAK1", "AAK1", "RET(M918T)", "RSK1(KinDom.2-C-terminal)", "HIPK1", "CSNK1G3,"
"CHEK2", "PDPK1", "EGFR(S752I759del)", "CDC2L2", "ERN1", "CDK4-cyclinD1", "PLK3", "TIE1", "TIE2", "CDKL5,"
Expand Down Expand Up @@ -44,6 +58,38 @@
test_df = pd.read_csv('/cluster/home/t122995uhn/projects/MutDTA/splits/davis/test.csv')
train_df = train_df[~train_df.prot_id.isin(set(test_df.prot_id))]

# %%
import pandas as pd

davis_test_df = pd.read_csv(f"/home/jean/projects/MutDTA/splits/davis/test.csv")
davis_test_df['gene'] = davis_test_df['prot_id'].str.split('(').str[0]

#%% ONCO KB MERGE
onco_df = pd.read_csv("../data/oncoKB_DrugGenePairList.csv")
davis_join_onco = davis_test_df.merge(onco_df.drop_duplicates("gene"), on="gene", how="inner")

# %%
onco_df = pd.read_csv("../data/oncoKB_DrugGenePairList.csv")
onco_df.merge(davis_test_df.drop_duplicates("gene"), on="gene", how="inner").value_counts("gene")









# %%
from src.train_test.splitting import resplit
from src import cfg

db_p = lambda x: f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/nomsa_{x}_gvp_binary'

db = resplit(dataset=db_p('binary'), split_files=db_p('aflow'), use_train_set=True)



# %%
########################################################################
########################## VIOLIN PLOTTING #############################
Expand Down
52 changes: 47 additions & 5 deletions src/models/gvp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,43 @@
from src.models.ring3 import Ring3Branch


class GVPLigand_DGPro(DGraphDTA):
class GVPLigand_DGPro(BaseModel):
"""
DG model with GVP Ligand branch
model = GVPLigand_DGPro(num_features_pro=num_feat_pro,
dropout=dropout,
edge_weight_opt=pro_edge,
**kwargs)
"""
def __init__(self, num_features_pro=54,
num_features_mol=78,
output_dim=512,
dropout=0.2,
num_GVPLayers=3,
edge_weight_opt='binary', **kwargs):
output_dim = int(output_dim)
super(GVPLigand_DGPro, self).__init__(num_features_pro,
num_features_mol, output_dim,
dropout, edge_weight_opt)
super(GVPLigand_DGPro, self).__init__(pro_feat=None,
edge_weight_opt=edge_weight_opt)

self.gvp_ligand = GVPBranchLigand(num_layers=num_GVPLayers,
final_out=output_dim,
drop_rate=dropout)

# protein branch:
emb_feat= 54 # to ensure constant embedding size regardless of input size (for fair comparison)
self.pro_conv1 = GCNConv(num_features_pro, emb_feat)
self.pro_conv2 = GCNConv(emb_feat, emb_feat * 2)
self.pro_conv3 = GCNConv(emb_feat * 2, emb_feat * 4)
self.pro_fc = nn.Sequential(
nn.Linear(emb_feat * 4, 1024),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(1024, output_dim),
nn.Dropout(dropout)
)
self.relu = nn.ReLU()

# concat branch to feedforward network
self.dense_out = nn.Sequential(
nn.Linear(2*output_dim, 1024),
nn.Dropout(dropout),
Expand All @@ -51,6 +69,30 @@ def __init__(self, num_features_pro=54,
def forward_mol(self, data):
return self.gvp_ligand(data)

def forward_pro(self, data):
# get protein input
target_x, ei, target_batch = data.x, data.edge_index, data.batch
# if edge_weight doesnt exist no error is thrown it just passes it as None
ew = data.edge_weight if self.edge_weight else None

xt = self.pro_conv1(target_x, ei, ew)
xt = self.relu(xt)

# target_edge_index, _ = dropout_adj(target_edge_index, training=self.training)
xt = self.pro_conv2(xt, ei, ew)
xt = self.relu(xt)

# target_edge_index, _ = dropout_adj(target_edge_index, training=self.training)
xt = self.pro_conv3(xt, ei, ew)
xt = self.relu(xt)

# xt = self.pro_conv4(xt, target_edge_index)
# xt = self.relu(xt)
xt = gep(xt, target_batch) # global pooling

# FFNN
return self.pro_fc(xt)

def forward(self, data_pro, data_mol):
xm = self.forward_mol(data_mol)
xp = self.forward_pro(data_pro)
Expand Down
3 changes: 2 additions & 1 deletion src/models/prior_work.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def forward_pro(self, data):
xt = gep(xt, target_batch) # global pooling

# flatten
xt = self.relu(self.pro_fc_g1(xt))
xt = self.pro_fc_g1(xt)
xt = self.relu(xt)
xt = self.dropout(xt)
xt = self.pro_fc_g2(xt)
xt = self.dropout(xt)
Expand Down

0 comments on commit cff0bab

Please sign in to comment.