Skip to content

Commit

Permalink
Merge pull request #106 from jyaacoub/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
jyaacoub committed Jun 14, 2024
2 parents a42a242 + 29d8093 commit 543a2a0
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 20 deletions.
17 changes: 11 additions & 6 deletions rayTrain_Tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,20 @@ def train_func(config):
search_space = {
## constants:
"epochs": 20,
"model": cfg.MODEL_OPT.DG,
"model": cfg.MODEL_OPT.GVPL_ESM,

"dataset": cfg.DATA_OPT.kiba,
"dataset": cfg.DATA_OPT.davis,
"feature_opt": cfg.PRO_FEAT_OPT.nomsa,
"edge_opt": cfg.PRO_EDGE_OPT.aflow,
"lig_feat_opt": cfg.LIG_FEAT_OPT.original,
"lig_feat_opt": cfg.LIG_FEAT_OPT.gvp,
"lig_edge_opt": cfg.LIG_EDGE_OPT.binary,

"fold_selection": 0,
"save_checkpoint": False,

## hyperparameters to tune:
"lr": ray.tune.loguniform(1e-5, 1e-3),
"batch_size": ray.tune.choice([32, 64, 128]), # local batch size
"batch_size": ray.tune.choice([4, 8, 10, 12]), # local batch size

# model architecture hyperparams
"architecture_kwargs":{
Expand All @@ -118,11 +118,16 @@ def train_func(config):
elif search_space['model'] == cfg.MODEL_OPT.GVPL_RNG:
arch_kwargs["pro_emb_dim"] = ray.tune.choice([64, 128, 256])
arch_kwargs["nheads_pro"] = ray.tune.choice([3, 4, 5])
elif search_space['model'] == cfg.MODEL_OPT.GVPL_ESM:
arch_kwargs["num_GVPLayers"] = ray.tune.choice([2, 3, 4])
arch_kwargs["pro_dropout_gnn"] = ray.tune.uniform(0.0, 0.5)
arch_kwargs["pro_extra_fc_lyr"] = ray.tune.choice([True, False])
arch_kwargs["pro_emb_dim"] = ray.tune.choice([128, 256, 320])

# each worker is a node from the ray cluster.
# WARNING: SBATCH GPU directive should match num_workers*GPU_per_worker
# same for cpu-per-task directive
scaling_config = ScalingConfig(num_workers=1, # number of ray actors to launch to distribute compute across
scaling_config = ScalingConfig(num_workers=4, # number of ray actors to launch to distribute compute across
use_gpu=True, # default is for each worker to have 1 GPU (overrided by resources per worker)
resources_per_worker={"CPU": 2, "GPU": 1},
# trainer_resources={"CPU": 2, "GPU": 1},
Expand All @@ -142,7 +147,7 @@ def train_func(config):
search_alg=OptunaSearch(), # using ray.tune.search.Repeater() could be useful to get multiple trials per set of params
# would be even better if we could set trial-wise dependencies for a certain fold.
# https://github.com/ray-project/ray/issues/33677
num_samples=1000,
num_samples=500,
),
)

Expand Down
10 changes: 10 additions & 0 deletions results/model_media/model_stats.csv
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,13 @@ DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E
DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7627598258454263,0.4909617783259909,0.4629240453488026,0.4675438018075736,0.4069470615043133,0.6837717468626308
DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7586087337447441,0.48792801067181446,0.45565623472699257,0.4699824575652958,0.406830881759383,0.6855526657852742
DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7627751294107745,0.4910000034597989,0.463033422941668,0.4667885899388577,0.4158912707098973,0.6832192839336854
GVPLM_kiba0D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.6990933443140916,0.6227583767207777,0.5124656886492125,0.3874969197099424,0.4424602830299733,0.6224925057460069
GVPLM_kiba1D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.70978923821572,0.6245931929660367,0.5489066995277002,0.4011673269900183,0.407424328365187,0.6333777127354722
GVPLM_kiba2D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.7209881878996922,0.6223837050134309,0.5681747928654506,0.4084023193612667,0.4323828255794715,0.6390636270053763
GVPLM_kiba3D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.6981566847135613,0.6245976640510174,0.5107333557005451,0.3822818250482888,0.438698144767069,0.6182894346891986
GVPLM_kiba0D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6953137059634658,0.6022650380739804,0.5077108516657277,0.4026754666211966,0.4342038435039596,0.6345671490245903
GVPLM_kiba1D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7214611958633622,0.5829207378175565,0.5585350054296212,0.444391720517441,0.4188053511364005,0.6666271225486111
GVPLM_kiba2D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7270670616197747,0.6261480674852774,0.5808679477063539,0.3849585394612216,0.4063130463450911,0.6204502715457715
GVPLM_kiba4D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.7013528905401658,0.6250248722716073,0.5171373465290922,0.3808855093784095,0.4271708875630601,0.6171592253044667
GVPLM_kiba3D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6975845475921519,0.6064399311172254,0.5135792108542759,0.39910064830136,0.4353593608266028,0.63174413198807
GVPLM_kiba4D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6935772666692046,0.6022132230956811,0.5044512666096735,0.4025857463348725,0.4328029182107182,0.6344964510025825
10 changes: 10 additions & 0 deletions results/model_media/model_stats_val.csv
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,13 @@ DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E
DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7739145158374762,0.5392158863378831,0.5089040317525807,0.5015169159434333,0.4169669235899112,0.7081785904300082
DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.792242665227612,0.5533025236584668,0.49892823193027686,0.4181546354842427,0.3822523623535258,0.6466487728931701
DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7677158296355124,0.5054133287375089,0.4922434839494996,0.5150349706238619,0.4282377090806979,0.717659369494931
GVPLM_kiba0D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.7092897627190399,0.6548681693449723,0.5396687042734233,0.3536005401673807,0.4179247613791516,0.5946432040874432
GVPLM_kiba1D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.6986045733427746,0.6183317995531191,0.5105718680744119,0.4168811810841643,0.4282363309861637,0.6456633651401977
GVPLM_kiba2D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.7383284840154883,0.652738529680564,0.6084551659361673,0.4666099733018358,0.4453689996572655,0.683088554509469
GVPLM_kiba3D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.69778020457795,0.5833997763573209,0.5187109375347536,0.4580786225022705,0.4709254114953695,0.6768150578276687
GVPLM_kiba4D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.709715591723624,0.6463041421148916,0.5398907787474324,0.3748993404970558,0.4213049353496972,0.6122902420397175
GVPLM_kiba0D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7129487277274348,0.6254086540231212,0.5523443855595185,0.4778276435668844,0.465151212219583,0.6912507819647544
GVPLM_kiba1D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7313644516097428,0.6312817306753151,0.5827493870742838,0.3595024379826006,0.3891491395890177,0.599585221617912
GVPLM_kiba2D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7134694975801379,0.6146468608313317,0.5508657731405919,0.4262578886230455,0.4420421989218309,0.6528842842518462
GVPLM_kiba3D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7265855364664442,0.6707244641287359,0.5849809016207204,0.4177953815885197,0.4526643843244999,0.6463709318870394
GVPLM_kiba4D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6974728406520919,0.592039617802516,0.5169668490636758,0.457912664105794,0.47530745148151476,0.6766924442505576
6 changes: 5 additions & 1 deletion src/models/branches.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
class ESMBranch(nn.Module):
def __init__(self, esm_head:str='facebook/esm2_t6_8M_UR50D',
num_feat=320, emb_dim=512, output_dim=128, dropout=0.2,
dropout_gnn=0.0, extra_fc_lyr=False):
dropout_gnn=0.0, extra_fc_lyr=False, esm_only=True,
edge_weight='binary'):

super(ESMBranch, self).__init__()

self.esm_only = esm_only
self.edge_weight = edge_weight

# Protein graph:
self.conv1 = GCNConv(num_feat, emb_dim)
Expand Down
1 change: 0 additions & 1 deletion src/models/esm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def forward(self, data_pro, data_mol):
xc = self.dropout(xc)
xc = self.fc2(xc)
xc = self.relu(xc)
xc = self.dropout(xc)
out = self.out(xc)
return out

Expand Down
10 changes: 4 additions & 6 deletions src/models/gvp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(self, num_features_pro=54,
nn.ReLU(),

nn.Linear(512, 128),
nn.Dropout(dropout),
nn.ReLU(),

nn.Linear(128, 1),
Expand All @@ -65,9 +64,10 @@ def __init__(self,
num_GVPLayers=3,
dropout=0.2,
output_dim=512,
edge_weight_opt='binary',
**kwargs):
output_dim = int(output_dim)
super(GVPLigand_DGPro, self).__init__()
super(GVPL_ESM, self).__init__()

self.gvp_ligand = GVPBranchLigand(num_layers=num_GVPLayers,
final_out=output_dim,
Expand All @@ -76,7 +76,8 @@ def __init__(self,
self.esm_branch = ESMBranch(num_feat=pro_num_feat, emb_dim=pro_emb_dim,
dropout_gnn=pro_dropout_gnn,
extra_fc_lyr=pro_extra_fc_lyr,
output_dim=output_dim, dropout=dropout)
output_dim=output_dim, dropout=dropout,
edge_weight_opt=edge_weight_opt)

self.dense_out = nn.Sequential(
nn.Linear(2*output_dim, 1024),
Expand All @@ -88,7 +89,6 @@ def __init__(self,
nn.ReLU(),

nn.Linear(512, 128),
nn.Dropout(dropout),
nn.ReLU(),

nn.Linear(128, 1),
Expand Down Expand Up @@ -137,7 +137,6 @@ def __init__(self, dropout=0.2, pro_emb_dim=128, output_dim=250,
nn.Dropout(dropout),
nn.ReLU(),
nn.Linear(512, 128),
nn.Dropout(dropout),
nn.ReLU(),
nn.Linear(128, 1),
)
Expand Down Expand Up @@ -228,6 +227,5 @@ def forward(self, data_pro, data_mol):
xc = self.dropout(xc)
xc = self.fc2(xc)
xc = self.relu(xc)
xc = self.dropout(xc)
out = self.out(xc)
return out
2 changes: 0 additions & 2 deletions src/models/prior_work.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def forward(self, data_pro, data_mol):
xc = self.dropout(xc)
xc = self.fc2(xc)
xc = self.relu(xc)
xc = self.dropout(xc)
out = self.out(xc)
return out

Expand Down Expand Up @@ -278,7 +277,6 @@ def forward(self, data):
xc = self.dropout(xc)
xc = self.fc2(xc)
xc = self.relu(xc)
xc = self.dropout(xc)
out = self.out(xc)
return out

Expand Down
2 changes: 0 additions & 2 deletions src/models/ring3.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(self, pro_emb_dim=128, output_dim=250,
nn.Dropout(dropout),
nn.ReLU(),
nn.Linear(1024, output_dim),
nn.Dropout(dropout),
nn.ReLU(),
)

Expand Down Expand Up @@ -127,7 +126,6 @@ def __init__(self, pro_emb_dim=128, output_dim=250,
nn.Dropout(dropout),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(512, 1),
)

Expand Down
7 changes: 5 additions & 2 deletions src/utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,15 @@ def init_model(model:str, pro_feature:str, pro_edge:str, dropout:float=0.2, **kw

elif model == cfg.MODEL_OPT.GVPL:
model = GVPLigand_DGPro(num_features_pro=num_feat_pro,
dropout=dropout,
dropout=dropout,
edge_weight_opt=pro_edge,
**kwargs)
elif model == cfg.MODEL_OPT.GVPL_RNG:
model = GVPLigand_RNG3(dropout=dropout, **kwargs)
elif model == cfg.MODEL_OPT.GVPL_ESM:
model = GVPL_ESM(pro_num_feat=320+num_feat_pro, **kwargs)
model = GVPL_ESM(pro_num_feat=320,
edge_weight_opt=pro_edge,
**kwargs)
return model

@staticmethod
Expand Down

0 comments on commit 543a2a0

Please sign in to comment.