diff --git a/rayTrain_Tune.py b/rayTrain_Tune.py index 5d5cf21..e3274d3 100644 --- a/rayTrain_Tune.py +++ b/rayTrain_Tune.py @@ -104,7 +104,7 @@ def train_func(config): ## hyperparameters to tune: "lr": ray.tune.loguniform(1e-5, 1e-3), - "batch_size": ray.tune.choice([4, 8, 10, 12]), # local batch size + "batch_size": ray.tune.choice([16, 32, 64, 128]), # local batch size # model architecture hyperparams "architecture_kwargs":{ @@ -124,6 +124,10 @@ def train_func(config): arch_kwargs["pro_extra_fc_lyr"] = ray.tune.choice([True, False]) arch_kwargs["pro_emb_dim"] = ray.tune.choice([128, 256, 320]) + if 'esm' in search_space['model'].lower(): + search_space['batch_size'] = ray.tune.choice([4,8,12,16]) + + # each worker is a node from the ray cluster. # WARNING: SBATCH GPU directive should match num_workers*GPU_per_worker # same for cpu-per-task directive @@ -133,7 +137,7 @@ def train_func(config): # trainer_resources={"CPU": 2, "GPU": 1}, # placement_strategy="PACK", # place workers on same node ) - + print('init Tuner') tuner = ray.tune.Tuner( TorchTrainer(train_func), diff --git a/results/model_media/model_stats.csv b/results/model_media/model_stats.csv index fd71a75..017d191 100644 --- a/results/model_media/model_stats.csv +++ b/results/model_media/model_stats.csv @@ -21,11 +21,6 @@ DGM_kibaD_nomsaF_af2E_128B_0.0001LR_0.4D_2000E,0.7421544771099204,0.732030831449 DGM_davisD_nomsaF_af2E_64B_0.0001LR_0.4D_2000E,0.8371815910480507,0.7268638307639713,0.6356548295003425,0.4133190204335282,0.3667126292760789,0.642898919297216 EDIM_PDBbindD_nomsaF_anmE_32B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6929249918113332,0.5373233573074141,0.5551808770122657,2.6186518594530783,1.258185145017263,1.6182249100335462 EDIM_PDBbindD_nomsaF_anmE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6904623260666497,0.5184772560303198,0.5502737601079014,2.8358724195646863,1.322164628956769,1.6840048751606054 -DGM_davis4D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8372000580825498,0.7097254405448715,0.6167136172064404,0.4038543732328936,0.3597778618417262,0.6354953762482412 -DGM_davis1D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8259045903365158,0.6727098716351115,0.5969862332191551,0.4777166877928197,0.3703959785636244,0.6911705200547978 -DGM_davis3D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.842545467746034,0.7168413490003224,0.623085263864242,0.4159492493391408,0.3558861870119901,0.6449412758842008 -DGM_davis2D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8219248012487749,0.6776357016678246,0.5900163003204868,0.4347317921133404,0.368497045519577,0.6593419386883717 -DGM_davis0D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8118490670490435,0.681294197609124,0.5718484991508404,0.4293578322249251,0.3500587517846625,0.6552540211436517 DGM_davis0D_msaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8293327722196834,0.7023779705235912,0.6229347721106537,0.4493630702908394,0.3565351886078822,0.6703454857689723 DGM_davis1D_msaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8212757240173636,0.6902984939419878,0.6084680827736445,0.4717995428983164,0.3579402144362287,0.6868766577037805 DGM_davis2D_msaF_binaryE_64B_0.0001LR_0.4D_2000E,0.826504184715153,0.6936343607518232,0.6220887747340503,0.4539645759565525,0.3707979739652987,0.6737689336534837 @@ -142,11 +137,6 @@ RNGM_PDBbind0D_nomsaF_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0 RNGM_PDBbind3D_nomsaF_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6845608518246677,0.5357292380549965,0.5262643639158158,2.8033910565957822,1.3118297254568652,1.674333018427273 RNGM_PDBbind1D_nomsaF_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6750296545462767,0.5185901583683697,0.5034912454720577,2.8088523469193527,1.2984564747465284,1.6759631102501489 RNGM_PDBbind4D_nomsaF_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6763045391457404,0.5057256652457955,0.5047101950511758,2.618698229612088,1.2698856009464514,1.6182392374467034 -DGM_PDBbind0D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.7138386406186538,0.632161676666746,0.5969379768993622,2.305938375785052,1.20129866702216,1.5185316512292564 -DGM_PDBbind1D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.7071928697529236,0.6119312740868981,0.5825821216776425,2.227630691345381,1.1882162758282253,1.492524938265817 -DGM_PDBbind2D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.7131792416345327,0.6300842144746225,0.5973837933267879,2.1487191174387785,1.156409247072916,1.4658509874604508 -DGM_PDBbind3D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.660154663447613,0.477928410118703,0.4613848634220896,2.752867466543676,1.3047904886139765,1.6591767436122278 -DGM_PDBbind4D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.6849966099610509,0.5521829232190218,0.528410191518849,2.6060140340268965,1.2870582908675785,1.614315345286322 RNGM_PDBbind0D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.705837663251816,0.5925769047757962,0.5819610003983905,2.305359361226593,1.1903962990215846,1.518340989773573 RNGM_PDBbind1D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6960348680473861,0.5877518727647189,0.5567275097196549,2.3528305312600524,1.195004482080066,1.5338939113446055 RNGM_PDBbind2D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.7174010132899236,0.6351541961659758,0.6076186024009358,2.095580516039397,1.1347450122757563,1.447612004661262 @@ -190,7 +180,7 @@ GVPLM_davis2D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_ DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7618267865956071,0.4923371930328827,0.4613806752914631,0.4662829663578006,0.411145990216076,0.6828491534429846 DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7636053478309109,0.4921759568376547,0.4648783804454143,0.4664726196058547,0.4114446927293202,0.6829880083909634 DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7627598258454263,0.4909617783259909,0.4629240453488026,0.4675438018075736,0.4069470615043133,0.6837717468626308 -DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7586087337447441,0.48792801067181446,0.45565623472699257,0.4699824575652958,0.406830881759383,0.6855526657852742 +DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7586087337447441,0.4879280106718144,0.4556562347269925,0.4699824575652958,0.406830881759383,0.6855526657852742 DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7627751294107745,0.4910000034597989,0.463033422941668,0.4667885899388577,0.4158912707098973,0.6832192839336854 GVPLM_kiba0D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.6990933443140916,0.6227583767207777,0.5124656886492125,0.3874969197099424,0.4424602830299733,0.6224925057460069 GVPLM_kiba1D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.70978923821572,0.6245931929660367,0.5489066995277002,0.4011673269900183,0.407424328365187,0.6333777127354722 @@ -202,3 +192,23 @@ GVPLM_kiba2D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_200 GVPLM_kiba4D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.7013528905401658,0.6250248722716073,0.5171373465290922,0.3808855093784095,0.4271708875630601,0.6171592253044667 GVPLM_kiba3D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6975845475921519,0.6064399311172254,0.5135792108542759,0.39910064830136,0.4353593608266028,0.63174413198807 GVPLM_kiba4D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6935772666692046,0.6022132230956811,0.5044512666096735,0.4025857463348725,0.4328029182107182,0.6344964510025825 +DGM_kiba0D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.6833072849286168,0.6179585946798826,0.4727062867906341,0.4652485169636371,0.4847738638760487,0.6820912819877095 +DGM_kiba1D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.6319399207854125,0.5827772702783164,0.3484199826852496,0.4922941579440291,0.5148003016031965,0.7016367706613081 +DGM_kiba2D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.6957607612858523,0.6214073872153388,0.4986083842961703,0.452748185796174,0.4765085403346378,0.6728656521150221 +DGM_kiba3D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.6974294222044406,0.6012218059551159,0.5013974025588194,0.491169945360959,0.4609703423233422,0.7008351770287783 +DGM_kiba4D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7244322791516395,0.6458432916136384,0.5594966433719704,0.4308429350512986,0.4458192640807841,0.6563862697004704 +GVPLM_PDBbind0D_nomsaF_binaryE_128B_0.0001LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.6678798890840328,0.4896741551372414,0.4891211256936132,2.904874286920722,1.3519085603519854,1.7043691756543595 +GVPLM_PDBbind2D_nomsaF_binaryE_128B_0.0001LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.6709156636560543,0.4988289093016924,0.4953432777076078,2.864293386602017,1.339912495673216,1.69242234285713 +GVPLM_PDBbind3D_nomsaF_binaryE_128B_0.0001LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.6885412122614492,0.546594227557393,0.5441389649076723,2.7604737400170225,1.2944456345114177,1.661467345456125 +GVPLM_PDBbind1D_nomsaF_binaryE_128B_0.0001LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.6933219028688695,0.5538827419859854,0.5553824976404049,2.705459027185816,1.2695058672557165,1.6448279627930138 +GVPLM_PDBbind4D_nomsaF_binaryE_128B_0.0001LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.676430736316634,0.5174490877105464,0.5106368202466385,2.7894814673104245,1.3190480916775753,1.6701740829357952 +GVPLM_PDBbind2D_nomsaF_binaryE_128B_0.00020066831190641135LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.6574373252713155,0.4637746181496499,0.4607019256521232,2.989336503081218,1.3696321064622343,1.728969780846738 +GVPLM_PDBbind4D_nomsaF_binaryE_128B_0.00020066831190641135LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.6679713075111049,0.4962407370372961,0.4900979111492484,2.8746986390390545,1.3413662884434798,1.6954936269532406 +GVPLM_PDBbind3D_nomsaF_binaryE_128B_0.00020066831190641135LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.6815000967290411,0.5208901839051482,0.5303883169579683,2.8010927687600886,1.3162882675910494,1.6736465483369207 +GVPLM_PDBbind1D_nomsaF_binaryE_128B_0.00020066831190641135LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.6861127443356612,0.5310529816224904,0.5380334435709785,2.759075340929806,1.297114357627186,1.661046459594013 +GVPLM_PDBbind0D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.7015701786607161,0.5634102832749822,0.5691773211791266,2.442901347407971,1.2002801681321764,1.5629783579461267 +GVPLM_PDBbind0D_nomsaF_binaryE_128B_0.00020066831190641135LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.6850555528159531,0.5311038215844754,0.5335067223135304,2.820819380481298,1.315100869166162,1.6795295116434537 +GVPLM_PDBbind3D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.70023803089709,0.5589897695780572,0.5660036088465252,2.415406919150389,1.1958711257435026,1.5541579453679697 +GVPLM_PDBbind4D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.7127039080110099,0.5976033645560985,0.6006064125516688,2.282110245197861,1.156606004465194,1.5106654974539735 +GVPLM_PDBbind1D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.7104874374380044,0.594977799338086,0.5945328584235519,2.270213382939843,1.165092654266055,1.5067227292836074 +GVPLM_PDBbind2D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.7082159495264464,0.5882531427551811,0.5879428169952992,2.321420004071001,1.1716289097543746,1.5236206890400907 diff --git a/results/model_media/model_stats_val.csv b/results/model_media/model_stats_val.csv index 6d7ceab..5dd82b9 100644 --- a/results/model_media/model_stats_val.csv +++ b/results/model_media/model_stats_val.csv @@ -1,9 +1,4 @@ run,cindex,pearson,spearman,mse,mae,rmse -DGM_davis4D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8283315494528303,0.7069918172461128,0.6043805206997321,0.4413566391511215,0.3769707157421256,0.6643467762781133 -DGM_davis0D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8124934102025569,0.6909827817401493,0.5899715229081308,0.4697099884212956,0.3777247017505122,0.6853539147194649 -DGM_davis3D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8195586154273969,0.717774769130841,0.5720718176514452,0.3657891513377326,0.3382081135692242,0.6048050523414406 -DGM_davis2D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8233405467261639,0.7040322378359025,0.5687852322102779,0.3543421502907676,0.3433588875976263,0.5952664531877868 -DGM_davis1D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.82715800662027,0.7083018179748569,0.6132060396459778,0.4895027135441312,0.3843769813986386,0.699644705221251 DGM_davis0D_msaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8363462603536834,0.7300972139373315,0.6247030031563877,0.4115841439930461,0.3520789737327426,0.6415482398020013 DGM_davis1D_msaF_binaryE_64B_0.0001LR_0.4D_2000E,0.809188899866386,0.7272563430244341,0.5709758219318876,0.4267040368901719,0.3404274728645112,0.6532258697343301 DGM_davis2D_msaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8256781625398019,0.6981605259116868,0.5810792177230069,0.3963907833649285,0.3506354182128197,0.6295957301037933 @@ -121,11 +116,6 @@ RNGM_PDBbind0D_nomsaF_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0 RNGM_PDBbind3D_nomsaF_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6965391213545163,0.5625444709177245,0.5615638003459615,2.6078179393716696,1.2726609132908009,1.6148739701201669 RNGM_PDBbind1D_nomsaF_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6898928004339594,0.5523099218709815,0.5449872381776708,2.66497247713436,1.2915848421041296,1.6324743419528407 RNGM_PDBbind4D_nomsaF_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.7007950096666646,0.5899618659548683,0.5686640405834549,2.56015738655239,1.2682400034375922,1.6000491825417087 -DGM_PDBbind0D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.6834134775963704,0.5252423281589604,0.5224103809545283,2.6965886561729806,1.3100058635233471,1.642129305558177 -DGM_PDBbind1D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.6725081056502712,0.4958000718750098,0.4954655195111531,2.615718692551158,1.292999546388684,1.6173183646243423 -DGM_PDBbind2D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.697200369413798,0.5612627340712764,0.5573429695874568,2.5342934699448687,1.2302899472072593,1.5919464406646564 -DGM_PDBbind3D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.6486121701612797,0.4101926834257898,0.4379612695804781,2.729065093306028,1.3194110187149473,1.651988224324262 -DGM_PDBbind4D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.6882754004672126,0.5566029334560295,0.5405423083756331,2.7151125873136825,1.328552851976572,1.647759869432947 RNGM_PDBbind0D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6861028697545904,0.5155660146307575,0.5315176144803204,2.64558570968027,1.2594536337055984,1.626525656016612 RNGM_PDBbind1D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6688078142143364,0.4999913674973958,0.4853228941348466,2.578325113105034,1.2773636633218242,1.6057163862603616 RNGM_PDBbind2D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6936428830274479,0.5532477376519472,0.5506798228359694,2.5371871259092567,1.2411349264071958,1.5928550235063004 @@ -169,7 +159,7 @@ GVPLM_davis2D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_ DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7604358543507954,0.5258641768225495,0.4678115715796511,0.4397487084865327,0.4022678869222504,0.6631355129131095 DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7656357706050457,0.4589840305150596,0.4566457476087189,0.5370825057205953,0.431299532494959,0.7328591308843708 DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7739145158374762,0.5392158863378831,0.5089040317525807,0.5015169159434333,0.4169669235899112,0.7081785904300082 -DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.792242665227612,0.5533025236584668,0.49892823193027686,0.4181546354842427,0.3822523623535258,0.6466487728931701 +DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.792242665227612,0.5533025236584668,0.4989282319302768,0.4181546354842427,0.3822523623535258,0.6466487728931701 DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7677158296355124,0.5054133287375089,0.4922434839494996,0.5150349706238619,0.4282377090806979,0.717659369494931 GVPLM_kiba0D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.7092897627190399,0.6548681693449723,0.5396687042734233,0.3536005401673807,0.4179247613791516,0.5946432040874432 GVPLM_kiba1D_nomsaF_aflowE_32B_5.480618584919115e-05LR_0.0808130125360696D_2000E_gvpLF_binaryLE,0.6986045733427746,0.6183317995531191,0.5105718680744119,0.4168811810841643,0.4282363309861637,0.6456633651401977 @@ -180,4 +170,24 @@ GVPLM_kiba0D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_200 GVPLM_kiba1D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7313644516097428,0.6312817306753151,0.5827493870742838,0.3595024379826006,0.3891491395890177,0.599585221617912 GVPLM_kiba2D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7134694975801379,0.6146468608313317,0.5508657731405919,0.4262578886230455,0.4420421989218309,0.6528842842518462 GVPLM_kiba3D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.7265855364664442,0.6707244641287359,0.5849809016207204,0.4177953815885197,0.4526643843244999,0.6463709318870394 -GVPLM_kiba4D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6974728406520919,0.592039617802516,0.5169668490636758,0.457912664105794,0.47530745148151476,0.6766924442505576 +GVPLM_kiba4D_nomsaF_binaryE_32B_3.372637625954074e-05LR_0.09399264336737133D_2000E_gvpLF_binaryLE,0.6974728406520919,0.592039617802516,0.5169668490636758,0.457912664105794,0.4753074514815147,0.6766924442505576 +DGM_kiba0D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.6936711315000789,0.6005211932246391,0.504191230794894,0.4061498746095516,0.422564779250127,0.6372988895404977 +DGM_kiba1D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7511877727018614,0.6646873091581573,0.63999609502579,0.4314372442516314,0.4257131987866365,0.6568388266931481 +DGM_kiba2D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7041074849390627,0.6292608788699515,0.5243632657169943,0.4182002499105649,0.4551389928119791,0.6466840417936451 +DGM_kiba3D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7087689560411555,0.6109761826413282,0.5446445880040978,0.439896171251659,0.4333795149243592,0.6632466895896723 +DGM_kiba4D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7199849879779502,0.6508641886029578,0.553199309507323,0.4207193671639506,0.4372470265476958,0.6486288362106256 +GVPLM_PDBbind0D_nomsaF_binaryE_128B_0.0001LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.6705469945947667,0.4911956557148271,0.4947147850885711,2.7984128464034073,1.3298601096606495,1.6728457329961444 +GVPLM_PDBbind2D_nomsaF_binaryE_128B_0.0001LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.6902404232610884,0.5401274960633111,0.5451661801348844,2.5066734114402864,1.24434781868825,1.5832477416501456 +GVPLM_PDBbind3D_nomsaF_binaryE_128B_0.0001LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.7095752540257866,0.5941410557532136,0.5931963538385324,2.2809019280273777,1.1906128084150198,1.5102655157380034 +GVPLM_PDBbind1D_nomsaF_binaryE_128B_0.0001LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.7023499053264508,0.5717034353066683,0.5768404717948784,2.597026881763407,1.2668071200288409,1.61152936112359 +GVPLM_PDBbind4D_nomsaF_binaryE_128B_0.0001LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.682705190455417,0.5275119691618371,0.5285510910454421,2.578897436668115,1.2767680418003973,1.6058945907711737 +GVPLM_PDBbind2D_nomsaF_binaryE_128B_0.00020066831190641135LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.6780083044177326,0.5021233642551106,0.5131736615195965,2.6616380081093767,1.2854885146447852,1.631452729351781 +GVPLM_PDBbind4D_nomsaF_binaryE_128B_0.00020066831190641135LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.6714325212748902,0.4941398581831344,0.5009215560881669,2.7004517496469846,1.306688972264116,1.643305129806082 +GVPLM_PDBbind3D_nomsaF_binaryE_128B_0.00020066831190641135LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.7101387238933209,0.594200805778623,0.5958936006650243,2.183300254492003,1.160913492747904,1.4775994905562206 +GVPLM_PDBbind1D_nomsaF_binaryE_128B_0.00020066831190641135LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.7001322643778998,0.5728546426839862,0.5694176286636995,2.4404130942819298,1.2330857080714064,1.5621821578426538 +GVPLM_PDBbind0D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.6599455744201389,0.472484336171776,0.4634475416394563,2.792344446508325,1.3312219531005862,1.671030953186782 +GVPLM_PDBbind0D_nomsaF_binaryE_128B_0.00020066831190641135LR_0.4661593536060576D_2000E_gvpLF_binaryLE,0.6977940825243933,0.568541518084829,0.5625221606996651,2.6271055990872454,1.2775717433364508,1.6208348463329771 +GVPLM_PDBbind3D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.6867643360605489,0.5204019111247571,0.5329889074281522,2.386227898319556,1.2097184585601857,1.5447420167521682 +GVPLM_PDBbind4D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.7191498977429507,0.6310691763028511,0.6220020295790686,2.3079838932615666,1.2161203560016527,1.5192050201541485 +GVPLM_PDBbind1D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.692730133231894,0.5465182719144834,0.5539502171380625,2.372344757420169,1.224056827991747,1.5402417853766237 +GVPLM_PDBbind2D_nomsaF_aflowE_128B_0.00020048122460779208LR_0.042268679447260635D_2000E_gvpLF_binaryLE,0.682824682368369,0.5237913264455828,0.5242914903942202,2.6452677106060514,1.270501777071001,1.626427898987856 diff --git a/src/__init__.py b/src/__init__.py index 1bb1995..dfab147 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -13,13 +13,14 @@ "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, - "lr": 0.0001360163557088453, - "batch_size": 128, # local batch size - - "architecture_kwargs":{ - "dropout": 0.027175922988649594, - "output_dim": 128, - } + 'lr': 0.00014968791626986144, + 'batch_size': 128, + + 'architecture_kwargs': { + 'dropout': 0.00039427600918916277, + 'output_dim': 256, + 'num_GVPLayers': 3 + } }, #GVPLM_davis3D_nomsaF_binaryE_128B_0.00020535607176845963LR_0.08845592454543601D_2000E_gvpLF_binaryLE 'davis_gvpl': { @@ -38,8 +39,7 @@ 'output_dim': 512 } }, - - 'davis_aflow':{ # not trained yet... + 'davis_aflow':{ "model": cfg.MODEL_OPT.DG, "dataset": cfg.DATA_OPT.davis, @@ -60,6 +60,25 @@ ##################################################### ############## kiba ################################# ##################################################### + 'kiba_gvpl_aflow': { + "model": cfg.MODEL_OPT.GVPL, + + "dataset": cfg.DATA_OPT.kiba, + "feature_opt": cfg.PRO_FEAT_OPT.nomsa, + "edge_opt": cfg.PRO_EDGE_OPT.aflow, + "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, + "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, + + + 'lr': 0.00005480618584919115, + 'batch_size': 32, + + 'architecture_kwargs': { + 'dropout': 0.0808130125360696, + 'output_dim': 512, + 'num_GVPLayers': 4 + } + }, 'kiba_gvpl': { "model": cfg.MODEL_OPT.GVPL, @@ -79,8 +98,6 @@ 'num_GVPLayers': 4 } }, - - ##################################################### ########### PDBbind ################################# ##################################################### @@ -94,11 +111,46 @@ "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, - 'lr': 0.00022659, + 'lr': 0.00020048122460779208, + 'batch_size': 128, + + 'architecture_kwargs': { + 'dropout': 0.042268679447260635, + 'output_dim': 512, + 'num_GVPLayers': 3, + } + }, + 'PDBbind_gvpl':{ + "model": cfg.MODEL_OPT.GVPL, + + "dataset": cfg.DATA_OPT.PDBbind, + "feature_opt": cfg.PRO_FEAT_OPT.nomsa, + "edge_opt": cfg.PRO_EDGE_OPT.binary, + "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, + "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, + + 'lr': 0.00020066831190641135, + 'batch_size': 128, + + 'architecture_kwargs': { + 'dropout': 0.4661593536060576, + 'output_dim': 512 + } + }, + 'PDBbind_aflow':{ + "model": cfg.MODEL_OPT.DG, + + "dataset": cfg.DATA_OPT.PDBbind, + "feature_opt": cfg.PRO_FEAT_OPT.nomsa, + "edge_opt": cfg.PRO_EDGE_OPT.aflow, + "lig_feat_opt": cfg.LIG_FEAT_OPT.original, + "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, + + 'lr': 0.0009185598967356679, 'batch_size': 128, 'architecture_kwargs': { - 'dropout': 0.02414, + 'dropout': 0.22880989869337157, 'output_dim': 256 } }, diff --git a/src/data_prep/init_dataset.py b/src/data_prep/init_dataset.py index d961304..ff2ce16 100644 --- a/src/data_prep/init_dataset.py +++ b/src/data_prep/init_dataset.py @@ -1,12 +1,13 @@ import os import sys import itertools -from src.utils import config as cfg +import pandas as pd # Add the project root directory to Python path so imports work if file is run PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')) sys.path.append(PROJECT_ROOT) +from src.utils import config as cfg from src.data_prep.feature_extraction.protein_nodes import create_pfm_np_files from src.data_prep.datasets import DavisKibaDataset, PDBbindDataset, PlatinumDataset from src.train_test.splitting import train_val_test_split, balanced_kfold_split @@ -19,7 +20,8 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis random_seed:int=0, train_split:float=0.8, val_split:float=0.1, - overwrite=True, + overwrite=True, + test_prots_csv:str=None, **kwargs) -> None: """ Creates the datasets for the given data, feature, and edge options. @@ -44,6 +46,10 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis `k_folds` : int, optional If not None, the number of folds to split the final training set into for cross validation, by default None + `test_prots_csv` : str, optional + If not None, the path to a csv file containing the test proteins to use, + by default None. The csv file should have a 'prot_id' column. + """ if isinstance(data_opt, str): data_opt = [data_opt] if isinstance(feat_opt, str): feat_opt = [feat_opt] @@ -124,8 +130,14 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis else: assert test_split > 0, f"Invalid train/val/test split: {train_split}/{val_split}/{test_split}" assert not pro_overlap, f"No support for overlapping proteins with k-folds rn." + if test_prots_csv is not None: + df = pd.read_csv(test_prots_csv) + test_prots = set(df['prot_id'].tolist()) + else: + test_prots = None + train_loader, val_loader, test_loader = balanced_kfold_split(dataset, - k_folds=k_folds, test_split=test_split, + k_folds=k_folds, test_split=test_split, test_prots=test_prots, random_seed=random_seed) # only non-overlapping splits for k-folds subset_names = ['train', 'val', 'test'] @@ -140,7 +152,7 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis # loops through all k folds and saves as train1, train2, etc. dataset.save_subset_folds(train_loader, subset_names[0]) dataset.save_subset_folds(val_loader, subset_names[1]) - + dataset.save_subset(test_loader, subset_names[2]) - + del dataset # free up memory diff --git a/src/models/branches.py b/src/models/branches.py index 65c147e..865c781 100644 --- a/src/models/branches.py +++ b/src/models/branches.py @@ -86,22 +86,22 @@ def forward(self, data): ew = data.edge_weight if (self.edge_weight is not None and self.edge_weight != 'binary') else None - target_x = self.relu(target_x) - ei_drp, _, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], - training=self.training) + xt = self.relu(target_x) + ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], + training=self.training) # conv1 - xt = self.conv1(target_x, ei_drp, ew) + xt = self.conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) - ei_drp, _, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], + ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], training=self.training) # conv2 - xt = self.conv2(xt, ei_drp, ew) + xt = self.conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) - ei_drp, _, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], + ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], training=self.training) # conv3 - xt = self.conv3(xt, ei_drp, ew) + xt = self.conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) # flatten/pool diff --git a/src/models/esm_models.py b/src/models/esm_models.py index a88f38a..372b12c 100644 --- a/src/models/esm_models.py +++ b/src/models/esm_models.py @@ -95,22 +95,22 @@ def forward_pro(self, data): ew = data.edge_weight if (self.edge_weight is not None and self.edge_weight != 'binary') else None - target_x = self.relu(target_x) - ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0], - training=self.training) + xt = self.relu(target_x) + ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], + training=self.training) # conv1 - xt = self.pro_conv1(target_x, ei_drp, ew) + xt = self.conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) - ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0], + ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], training=self.training) # conv2 - xt = self.pro_conv2(xt, ei_drp, ew) + xt = self.conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) - ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0], + ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], training=self.training) # conv3 - xt = self.pro_conv3(xt, ei_drp, ew) + xt = self.conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) # flatten/pool @@ -257,24 +257,24 @@ def forward_pro(self, data): ew = data.edge_weight if (self.edge_weight is not None and self.edge_weight != 'binary') else None - target_x = self.relu(target_x) - ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0], - training=self.training) + xt = self.relu(target_x) + ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], + training=self.training) # conv1 - xt = self.pro_conv1(target_x, ei_drp, ew) + xt = self.conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) - ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0], + ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], training=self.training) # conv2 - xt = self.pro_conv2(xt, ei_drp, ew) + xt = self.conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) - ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0], + ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], training=self.training) # conv3 - xt = self.pro_conv3(xt, ei_drp, ew) + xt = self.conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) - + # flatten/pool xt = gep(xt, data.batch) # global pooling xt = self.relu(xt) diff --git a/src/models/gvp_models.py b/src/models/gvp_models.py index f11754f..76b40cc 100644 --- a/src/models/gvp_models.py +++ b/src/models/gvp_models.py @@ -77,7 +77,7 @@ def __init__(self, dropout_gnn=pro_dropout_gnn, extra_fc_lyr=pro_extra_fc_lyr, output_dim=output_dim, dropout=dropout, - edge_weight_opt=edge_weight_opt) + edge_weight=edge_weight_opt) self.dense_out = nn.Sequential( nn.Linear(2*output_dim, 1024), diff --git a/src/models/ring3.py b/src/models/ring3.py index bab8944..7e0ff4e 100644 --- a/src/models/ring3.py +++ b/src/models/ring3.py @@ -211,20 +211,22 @@ def forward_pro(self, data): self.edge_weight != 'binary') else None #### Graph NN #### - target_x = self.relu(target_x) - # WARNING: dropout_node doesnt work if `ew` isnt also dropped out - # ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0], - # training=self.training) - # GNN layers: - xt = self.pro_conv1(target_x, ei, ew) + xt = self.relu(target_x) + ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], + training=self.training) + + # conv1 + xt = self.conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) - # ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0], - # training=self.training) - xt = self.pro_conv2(xt, ei, ew) + ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], + training=self.training) + # conv2 + xt = self.conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) - # ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0], - # training=self.training) - xt = self.pro_conv3(xt, ei, ew) + ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], + training=self.training) + # conv3 + xt = self.conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) # flatten/pool diff --git a/src/train_test/splitting.py b/src/train_test/splitting.py index 918b1cd..91114d2 100644 --- a/src/train_test/splitting.py +++ b/src/train_test/splitting.py @@ -136,7 +136,7 @@ def balanced_kfold_split(dataset: BaseDataset, k_folds:int=5, test_split=.1, shuffle_dataset=True, random_seed=None, batch_train=128, - verbose=False) -> tuple[DataLoader]: + verbose=False, test_prots:set=None) -> tuple[DataLoader]: """ Same as train_val_test_split_kfold but we make considerations for the fact that each protein might not show up in equal proportions (e.g.: @@ -157,12 +157,20 @@ def balanced_kfold_split(dataset: BaseDataset, seed for shuffle, by default None `batch_train` : int, optional size of batch, by default 128 + `verbose` : bool, optional + If true, will print out the number of proteins in each fold, by default False + `test_prots` : set, optional + If not None, will use this set of proteins ("prot_ids" only!) for the test set, by default + None. Returns ------- tuple[DataLoader] Train, val, and test loaders """ + # throwing error so that we make sure to always use the same test_prots + assert test_prots is None or isinstance(test_prots, set), 'test_prots must be set for consistency' + if random_seed is not None: np.random.seed(random_seed) torch.manual_seed(random_seed) @@ -180,13 +188,17 @@ def balanced_kfold_split(dataset: BaseDataset, prots = list(prot_counts.keys()) np.random.shuffle(prots) - #### Getting test set - count = 0 - test_prots = {} + #### Add manually selected proteins here + test_prots = test_prots if test_prots is not None else set() + # increment count by number of samples in test_prots + count = sum([prot_counts[p] for p in test_prots]) + + #### Sampling remaining proteins for test set (if we are under the test_size) for p in prots: # O(k); k = number of proteins - if count + prot_counts[p] <= test_size: - test_prots[p] = True - count += prot_counts[p] + if count + prot_counts[p] > test_size: + break + test_prots.add(p) + count += prot_counts[p] # looping through dataset to get indices for test test_indices = [i for i in range(dataset_size) if dataset[i]['prot_id'] in test_prots] @@ -198,6 +210,7 @@ def balanced_kfold_split(dataset: BaseDataset, prots = [p for p in prots if p not in test_prots] print(f'Number of unique proteins in test set: {len(test_prots)} == {count} samples') + ########## split remaining proteins into k_folds ########## # Steps for this basically follow Greedy Number Partitioning # tuple of (list of proteins, total weight, current-score): diff --git a/train_test.py b/train_test.py index a8c02d9..b662b52 100644 --- a/train_test.py +++ b/train_test.py @@ -46,7 +46,7 @@ cp_saver = CheckpointSaver(model=None, save_path=None, train_all=False, # forces full training - patience=50, min_delta=0.5) + patience=150, min_delta=0.5) # %% Training loop metrics = {}