From 96e3e03f6f303d8a19e329dc7c3fa28dd722699e Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Tue, 11 Jun 2024 16:04:05 -0400 Subject: [PATCH 1/4] feat(raytune): ESM_GVPL search space --- rayTrain_Tune.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/rayTrain_Tune.py b/rayTrain_Tune.py index 666d16f..5d5cf21 100644 --- a/rayTrain_Tune.py +++ b/rayTrain_Tune.py @@ -91,12 +91,12 @@ def train_func(config): search_space = { ## constants: "epochs": 20, - "model": cfg.MODEL_OPT.DG, + "model": cfg.MODEL_OPT.GVPL_ESM, - "dataset": cfg.DATA_OPT.kiba, + "dataset": cfg.DATA_OPT.davis, "feature_opt": cfg.PRO_FEAT_OPT.nomsa, "edge_opt": cfg.PRO_EDGE_OPT.aflow, - "lig_feat_opt": cfg.LIG_FEAT_OPT.original, + "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, "fold_selection": 0, @@ -104,7 +104,7 @@ def train_func(config): ## hyperparameters to tune: "lr": ray.tune.loguniform(1e-5, 1e-3), - "batch_size": ray.tune.choice([32, 64, 128]), # local batch size + "batch_size": ray.tune.choice([4, 8, 10, 12]), # local batch size # model architecture hyperparams "architecture_kwargs":{ @@ -118,11 +118,16 @@ def train_func(config): elif search_space['model'] == cfg.MODEL_OPT.GVPL_RNG: arch_kwargs["pro_emb_dim"] = ray.tune.choice([64, 128, 256]) arch_kwargs["nheads_pro"] = ray.tune.choice([3, 4, 5]) + elif search_space['model'] == cfg.MODEL_OPT.GVPL_ESM: + arch_kwargs["num_GVPLayers"] = ray.tune.choice([2, 3, 4]) + arch_kwargs["pro_dropout_gnn"] = ray.tune.uniform(0.0, 0.5) + arch_kwargs["pro_extra_fc_lyr"] = ray.tune.choice([True, False]) + arch_kwargs["pro_emb_dim"] = ray.tune.choice([128, 256, 320]) # each worker is a node from the ray cluster. # WARNING: SBATCH GPU directive should match num_workers*GPU_per_worker # same for cpu-per-task directive - scaling_config = ScalingConfig(num_workers=1, # number of ray actors to launch to distribute compute across + scaling_config = ScalingConfig(num_workers=4, # number of ray actors to launch to distribute compute across use_gpu=True, # default is for each worker to have 1 GPU (overrided by resources per worker) resources_per_worker={"CPU": 2, "GPU": 1}, # trainer_resources={"CPU": 2, "GPU": 1}, @@ -142,7 +147,7 @@ def train_func(config): search_alg=OptunaSearch(), # using ray.tune.search.Repeater() could be useful to get multiple trials per set of params # would be even better if we could set trial-wise dependencies for a certain fold. # https://github.com/ray-project/ray/issues/33677 - num_samples=1000, + num_samples=500, ), ) From de2a9e5b50cabf5bea820d8645a7822d007521cc Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Wed, 12 Jun 2024 14:51:17 -0400 Subject: [PATCH 2/4] fix(dropout): dropout right before last layer is a bad idea Since model cannot compensate for dropped nodes with only 1 layer. "no ability to "correct" errors induced by dropout before the classification happens." - https://stats.stackexchange.com/questions/299292/dropout-makes-performance-worse --- src/models/esm_models.py | 1 - src/models/gvp_models.py | 4 ---- src/models/prior_work.py | 2 -- src/models/ring3.py | 2 -- 4 files changed, 9 deletions(-) diff --git a/src/models/esm_models.py b/src/models/esm_models.py index fa1b154..a88f38a 100644 --- a/src/models/esm_models.py +++ b/src/models/esm_models.py @@ -181,7 +181,6 @@ def forward(self, data_pro, data_mol): xc = self.dropout(xc) xc = self.fc2(xc) xc = self.relu(xc) - xc = self.dropout(xc) out = self.out(xc) return out diff --git a/src/models/gvp_models.py b/src/models/gvp_models.py index 92024f6..a9d98d7 100644 --- a/src/models/gvp_models.py +++ b/src/models/gvp_models.py @@ -43,7 +43,6 @@ def __init__(self, num_features_pro=54, nn.ReLU(), nn.Linear(512, 128), - nn.Dropout(dropout), nn.ReLU(), nn.Linear(128, 1), @@ -88,7 +87,6 @@ def __init__(self, nn.ReLU(), nn.Linear(512, 128), - nn.Dropout(dropout), nn.ReLU(), nn.Linear(128, 1), @@ -137,7 +135,6 @@ def __init__(self, dropout=0.2, pro_emb_dim=128, output_dim=250, nn.Dropout(dropout), nn.ReLU(), nn.Linear(512, 128), - nn.Dropout(dropout), nn.ReLU(), nn.Linear(128, 1), ) @@ -228,6 +225,5 @@ def forward(self, data_pro, data_mol): xc = self.dropout(xc) xc = self.fc2(xc) xc = self.relu(xc) - xc = self.dropout(xc) out = self.out(xc) return out diff --git a/src/models/prior_work.py b/src/models/prior_work.py index 420bba6..4ce32d1 100644 --- a/src/models/prior_work.py +++ b/src/models/prior_work.py @@ -125,7 +125,6 @@ def forward(self, data_pro, data_mol): xc = self.dropout(xc) xc = self.fc2(xc) xc = self.relu(xc) - xc = self.dropout(xc) out = self.out(xc) return out @@ -278,7 +277,6 @@ def forward(self, data): xc = self.dropout(xc) xc = self.fc2(xc) xc = self.relu(xc) - xc = self.dropout(xc) out = self.out(xc) return out diff --git a/src/models/ring3.py b/src/models/ring3.py index c60fc62..bab8944 100644 --- a/src/models/ring3.py +++ b/src/models/ring3.py @@ -44,7 +44,6 @@ def __init__(self, pro_emb_dim=128, output_dim=250, nn.Dropout(dropout), nn.ReLU(), nn.Linear(1024, output_dim), - nn.Dropout(dropout), nn.ReLU(), ) @@ -127,7 +126,6 @@ def __init__(self, pro_emb_dim=128, output_dim=250, nn.Dropout(dropout), nn.Linear(1024, 512), nn.ReLU(), - nn.Dropout(dropout), nn.Linear(512, 1), ) From 8c89a24318b41c66e9007c65ff55e16cb5ddb059 Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Fri, 14 Jun 2024 09:49:46 -0400 Subject: [PATCH 3/4] fix(GVPL_ESM): edge_weight_opt specified in loader --- src/models/branches.py | 6 +++++- src/models/gvp_models.py | 6 ++++-- src/utils/loader.py | 7 +++++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/models/branches.py b/src/models/branches.py index ace5611..65c147e 100644 --- a/src/models/branches.py +++ b/src/models/branches.py @@ -16,9 +16,13 @@ class ESMBranch(nn.Module): def __init__(self, esm_head:str='facebook/esm2_t6_8M_UR50D', num_feat=320, emb_dim=512, output_dim=128, dropout=0.2, - dropout_gnn=0.0, extra_fc_lyr=False): + dropout_gnn=0.0, extra_fc_lyr=False, esm_only=True, + edge_weight='binary'): super(ESMBranch, self).__init__() + + self.esm_only = esm_only + self.edge_weight = edge_weight # Protein graph: self.conv1 = GCNConv(num_feat, emb_dim) diff --git a/src/models/gvp_models.py b/src/models/gvp_models.py index a9d98d7..f11754f 100644 --- a/src/models/gvp_models.py +++ b/src/models/gvp_models.py @@ -64,9 +64,10 @@ def __init__(self, num_GVPLayers=3, dropout=0.2, output_dim=512, + edge_weight_opt='binary', **kwargs): output_dim = int(output_dim) - super(GVPLigand_DGPro, self).__init__() + super(GVPL_ESM, self).__init__() self.gvp_ligand = GVPBranchLigand(num_layers=num_GVPLayers, final_out=output_dim, @@ -75,7 +76,8 @@ def __init__(self, self.esm_branch = ESMBranch(num_feat=pro_num_feat, emb_dim=pro_emb_dim, dropout_gnn=pro_dropout_gnn, extra_fc_lyr=pro_extra_fc_lyr, - output_dim=output_dim, dropout=dropout) + output_dim=output_dim, dropout=dropout, + edge_weight_opt=edge_weight_opt) self.dense_out = nn.Sequential( nn.Linear(2*output_dim, 1024), diff --git a/src/utils/loader.py b/src/utils/loader.py index d46f43b..7d149d4 100644 --- a/src/utils/loader.py +++ b/src/utils/loader.py @@ -152,12 +152,15 @@ def init_model(model:str, pro_feature:str, pro_edge:str, dropout:float=0.2, **kw elif model == cfg.MODEL_OPT.GVPL: model = GVPLigand_DGPro(num_features_pro=num_feat_pro, - dropout=dropout, + dropout=dropout, + edge_weight_opt=pro_edge, **kwargs) elif model == cfg.MODEL_OPT.GVPL_RNG: model = GVPLigand_RNG3(dropout=dropout, **kwargs) elif model == cfg.MODEL_OPT.GVPL_ESM: - model = GVPL_ESM(pro_num_feat=320+num_feat_pro, **kwargs) + model = GVPL_ESM(pro_num_feat=320, + edge_weight_opt=pro_edge, + **kwargs) return model @staticmethod From 29d8093d68fd49f44c680e5bd4cfc403b126ebcb Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Fri, 14 Jun 2024 09:59:11 -0400 Subject: [PATCH 4/4] results(kiba): gvpl and gvpl_aflow #narval #90 --- results/model_media/model_stats.csv | 10 ++++++++++ results/model_media/model_stats_val.csv | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/results/model_media/model_stats.csv b/results/model_media/model_stats.csv index e72df15..fd71a75 100644 --- a/results/model_media/model_stats.csv +++ b/results/model_media/model_stats.csv @@ -192,3 +192,13 @@ DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7627598258454263,0.4909617783259909,0.4629240453488026,0.4675438018075736,0.4069470615043133,0.6837717468626308 DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7586087337447441,0.48792801067181446,0.45565623472699257,0.4699824575652958,0.406830881759383,0.6855526657852742 DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7627751294107745,0.4910000034597989,0.463033422941668,0.4667885899388577,0.4158912707098973,0.6832192839336854 +GVPLM_kiba0D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.6990933443140916,0.6227583767207777,0.5124656886492125,0.3874969197099424,0.4424602830299733,0.6224925057460069 +GVPLM_kiba1D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.70978923821572,0.6245931929660367,0.5489066995277002,0.4011673269900183,0.407424328365187,0.6333777127354722 +GVPLM_kiba2D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.7209881878996922,0.6223837050134309,0.5681747928654506,0.4084023193612667,0.4323828255794715,0.6390636270053763 +GVPLM_kiba3D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.6981566847135613,0.6245976640510174,0.5107333557005451,0.3822818250482888,0.438698144767069,0.6182894346891986 +GVPLM_kiba0D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6953137059634658,0.6022650380739804,0.5077108516657277,0.4026754666211966,0.4342038435039596,0.6345671490245903 +GVPLM_kiba1D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7214611958633622,0.5829207378175565,0.5585350054296212,0.444391720517441,0.4188053511364005,0.6666271225486111 +GVPLM_kiba2D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7270670616197747,0.6261480674852774,0.5808679477063539,0.3849585394612216,0.4063130463450911,0.6204502715457715 +GVPLM_kiba4D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.7013528905401658,0.6250248722716073,0.5171373465290922,0.3808855093784095,0.4271708875630601,0.6171592253044667 +GVPLM_kiba3D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6975845475921519,0.6064399311172254,0.5135792108542759,0.39910064830136,0.4353593608266028,0.63174413198807 +GVPLM_kiba4D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6935772666692046,0.6022132230956811,0.5044512666096735,0.4025857463348725,0.4328029182107182,0.6344964510025825 diff --git a/results/model_media/model_stats_val.csv b/results/model_media/model_stats_val.csv index 95dd6c1..6d7ceab 100644 --- a/results/model_media/model_stats_val.csv +++ b/results/model_media/model_stats_val.csv @@ -171,3 +171,13 @@ DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7739145158374762,0.5392158863378831,0.5089040317525807,0.5015169159434333,0.4169669235899112,0.7081785904300082 DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.792242665227612,0.5533025236584668,0.49892823193027686,0.4181546354842427,0.3822523623535258,0.6466487728931701 DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7677158296355124,0.5054133287375089,0.4922434839494996,0.5150349706238619,0.4282377090806979,0.717659369494931 +GVPLM_kiba0D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.7092897627190399,0.6548681693449723,0.5396687042734233,0.3536005401673807,0.4179247613791516,0.5946432040874432 +GVPLM_kiba1D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.6986045733427746,0.6183317995531191,0.5105718680744119,0.4168811810841643,0.4282363309861637,0.6456633651401977 +GVPLM_kiba2D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.7383284840154883,0.652738529680564,0.6084551659361673,0.4666099733018358,0.4453689996572655,0.683088554509469 +GVPLM_kiba3D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.69778020457795,0.5833997763573209,0.5187109375347536,0.4580786225022705,0.4709254114953695,0.6768150578276687 +GVPLM_kiba4D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.709715591723624,0.6463041421148916,0.5398907787474324,0.3748993404970558,0.4213049353496972,0.6122902420397175 +GVPLM_kiba0D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7129487277274348,0.6254086540231212,0.5523443855595185,0.4778276435668844,0.465151212219583,0.6912507819647544 +GVPLM_kiba1D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7313644516097428,0.6312817306753151,0.5827493870742838,0.3595024379826006,0.3891491395890177,0.599585221617912 +GVPLM_kiba2D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7134694975801379,0.6146468608313317,0.5508657731405919,0.4262578886230455,0.4420421989218309,0.6528842842518462 +GVPLM_kiba3D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7265855364664442,0.6707244641287359,0.5849809016207204,0.4177953815885197,0.4526643843244999,0.6463709318870394 +GVPLM_kiba4D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6974728406520919,0.592039617802516,0.5169668490636758,0.457912664105794,0.47530745148151476,0.6766924442505576