Skip to content

Commit

Permalink
Merge pull request #92 from jyaacoub/development
Browse files Browse the repository at this point in the history
GVPL x aflow and aflow_ring3 implementation and results
  • Loading branch information
jyaacoub authored Apr 12, 2024
2 parents bd15c9d + 99c5316 commit e4a6f8d
Show file tree
Hide file tree
Showing 24 changed files with 1,572 additions and 231 deletions.
10 changes: 5 additions & 5 deletions docs/requirements_versions.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
numpy==1.23.5
pandas==1.5.3
tqdm==4.65.0
rdkit==2023.3.1
scipy==1.10.1

# for generating figures:
Expand All @@ -12,7 +11,7 @@ statannotations==0.6.0
lifelines==0.27.7 # used for concordance index calc

# model building
torch==2.0.1
torch==1.12.1
torch-geometric==2.3.1
transformers==4.36.0 # huggingface needed for esm

Expand All @@ -25,7 +24,8 @@ requests==2.31.0
#ray[tune]

submitit==1.4.5
ProDy==2.4.1

# for chemgpt
selfies==1.0.4
# For protein/ligand processing:
rdkit==2023.3.1
ProDy==2.4.1
selfies==1.0.4 # ChemGPT uses this
45 changes: 35 additions & 10 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,42 @@
#%%
from src.analysis.figures import prepare_df
# %%
import logging
from typing import OrderedDict

df = prepare_df()
import seaborn as sns
from matplotlib import pyplot as plt
from statannotations.Annotator import Annotator

from src.analysis.figures import prepare_df, custom_fig, fig_combined

df = prepare_df()
sel_dataset = 'PDBbind'
exclude = []
sel_col = 'cindex'
# %%
df[df.edge.str.contains('ring3') | df.edge.str.contains('aflow')]

#%%
from src.analysis.figures import fig5_edge_feat_violin
# models to plot:
# - Original model with (nomsa, binary) and (original, binary) features for protein and ligand respectively
# - Aflow models with (nomsa, aflow*) and (original, binary) # x2 models here (aflow and aflow_ring3)
# - GVP protein model (gvp, binary) and (original, binary)
# - GVP ligand model (nomsa, binary) and (gvp, binary)

models = {
'DG': ('nomsa', 'binary', 'original', 'binary'),
'aflow': ('nomsa', 'aflow', 'original', 'binary'),
'aflow_ring3': ('nomsa', 'aflow_ring3', 'original', 'binary'),
# 'gvpP': ('gvp', 'binary', 'original', 'binary'),
'gvpL': ('nomsa', 'binary', 'gvp', 'binary'),
'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'),
'gvpL_aflow_rng3': ('nomsa', 'aflow_ring3', 'gvp', 'binary'),
}

# custom_fig(df, models, sel_dataset, sel_col)

# %%
fig, axes = fig_combined(df, datasets=['PDBbind'], fig_callable=custom_fig,
models=models, metrics=['pearson', 'cindex', 'mse', 'mae'],
fig_scale=(8,5))
plt.xticks(rotation=45)

fig5_edge_feat_violin(df, sel_dataset='PDBbind', exclude=['simple'],
sel_col='cindex', show=True, add_stats=False)
fig5_edge_feat_violin(df, sel_dataset='PDBbind', exclude=['simple'],
sel_col='mse', show=True, add_stats=False)

# %%
100 changes: 79 additions & 21 deletions rayTrain_Tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,19 @@ def train_func(config):
model = Loader.init_model(model=config["model"], pro_feature=config["feature_opt"],
pro_edge=config["edge_opt"],
# additional kwargs send to model class to handle
dropout=config["dropout"],
dropout_prot=config["dropout_prot"],
pro_emb_dim=config["pro_emb_dim"],
**config['architecture_kwargs']
)

# prepare model with rayTrain (moves it to correct device and wraps it in DDP)
model = ray.train.torch.prepare_model(model)
model = ray.train.torch.prepare_model(model, parallel_strategy='ddp',
parallel_strategy_kwargs={'find_unused_parameters':True})

# ============ Load dataset ==============
print("Loading Dataset")
loaders = Loader.load_DataLoaders(data=config['dataset'], pro_feature=config['feature_opt'],
edge_opt=config['edge_opt'],
ligand_feature=config['lig_feat_opt'],
ligand_edge=config['lig_edge_opt'],
path=cfg.DATA_ROOT,
batch_train=config['batch_size'],
datasets=['train', 'val'],
Expand All @@ -49,11 +50,15 @@ def train_func(config):
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
save_checkpoint = config.get("save_checkpoint", False)

for _ in range(config['epochs']):

# NOTE: no need to pass in device, rayTrain will handle that for us
simple_train(model, optimizer, loaders['train'], epochs=1) # Train the model
loss = simple_eval(model, loaders['val']) # Compute test accuracy
for _ in range(config['epochs']):
try:
# NOTE: no need to pass in device, rayTrain will handle that for us
simple_train(model, optimizer, loaders['train'], epochs=1) # Train the model
loss = simple_eval(model, loaders['val']) # Compute test accuracy
except RuntimeError as e: # potential memory error
print("RuntimeError:", e)
ray.train.report({"loss": 100})
break


# Report metrics (and possibly a checkpoint) to ray
Expand All @@ -64,7 +69,7 @@ def train_func(config):
torch.save(model.state_dict(), checkpoint_path)
checkpoint = Checkpoint.from_directory(checkpoint_dir)

ray.train.report({"loss": loss}, checkpoint=checkpoint)
ray.train.report({"loss": loss}, checkpoint=checkpoint)


if __name__ == "__main__":
Expand All @@ -73,32 +78,85 @@ def train_func(config):
print("Cuda support:", torch.cuda.is_available(),":",
torch.cuda.device_count(), "devices")
print("CUDA VERSION:", torch.__version__)
# ray.init(num_gpus=1, num_cpus=8, ignore_reinit_error=True)

search_space = {
## constants:
"epochs": 20,
"model": "RNG",
"dataset": "PDBbind",
"feature_opt": "nomsa", # NOTE: SPD requires foldseek features!!!
"edge_opt": "ring3",
"model": cfg.MODEL_OPT.DG,
"dataset": cfg.DATA_OPT.PDBbind,
"feature_opt": cfg.PRO_FEAT_OPT.nomsa,
"edge_opt": cfg.PRO_EDGE_OPT.aflow,
"lig_feat_opt": cfg.LIG_FEAT_OPT.original,
"lig_edge_opt": cfg.LIG_EDGE_OPT.binary,

"fold_selection": 0,
"save_checkpoint": False,

## hyperparameters to tune:
"lr": ray.tune.loguniform(1e-5, 1e-3),
"batch_size": ray.tune.choice([8, 16, 24]), # batch size is per GPU! #NOTE: multiply this by num_workers
"batch_size": ray.tune.choice([32, 64, 128]), # local batch size

# model architecture hyperparams
"dropout": ray.tune.uniform(0.0, 0.5), # for fc layers
"dropout_prot": ray.tune.uniform(0.0, 0.5),
"pro_emb_dim": ray.tune.choice([128, 256, 512]), # input from SaProt is 480 dims
"architecture_kwargs":{
"dropout": ray.tune.uniform(0.0, 0.5),
"output_dim": ray.tune.choice([128, 256, 512]),
}
}
# 'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'):
# search_space = {
# ## constants:
# "epochs": 20,
# "model": cfg.MODEL_OPT.GVPL,
# "dataset": cfg.DATA_OPT.PDBbind,
# "feature_opt": cfg.PRO_FEAT_OPT.nomsa,
# "edge_opt": cfg.PRO_EDGE_OPT.aflow,
# "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp,
# "lig_edge_opt": cfg.LIG_EDGE_OPT.binary,

# "fold_selection": 0,
# "save_checkpoint": False,

# ## hyperparameters to tune:
# "lr": ray.tune.loguniform(1e-5, 1e-3),
# "batch_size": ray.tune.choice([32, 64, 128]), # local batch size

# # model architecture hyperparams
# "architecture_kwargs":{
# "dropout": ray.tune.uniform(0.0, 0.5),
# "output_dim": ray.tune.choice([128, 256, 512]),
# }
# }
# search space for GVPL_RNG MODEL:
# search_space = {
# ## constants:
# "epochs": 20,
# "model": cfg.MODEL_OPT.GVPL_RNG,
# "dataset": cfg.DATA_OPT.PDBbind,
# "feature_opt": cfg.PRO_FEAT_OPT.nomsa,
# "edge_opt": cfg.PRO_EDGE_OPT.aflow_ring3,
# "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp,
# "lig_edge_opt": cfg.LIG_EDGE_OPT.binary,
#
# "fold_selection": 0,
# "save_checkpoint": False,
#
# ## hyperparameters to tune:
# "lr": ray.tune.loguniform(1e-5, 1e-3),
# "batch_size": ray.tune.choice([16,32,64]), # local batch size
#
# # model architecture hyperparams
# "architecture_kwargs":{
# "dropout": ray.tune.uniform(0.0, 0.5),
# "pro_emb_dim": ray.tune.choice([64, 128, 256]),
# "output_dim": ray.tune.choice([128, 256, 512]),
# "nheads_pro": ray.tune.choice([3, 4, 5]),
# }
# }

# each worker is a node from the ray cluster.
# WARNING: SBATCH GPU directive should match num_workers*GPU_per_worker
# same for cpu-per-task directive
scaling_config = ScalingConfig(num_workers=2, # number of ray actors to launch to distribute compute across
scaling_config = ScalingConfig(num_workers=1, # number of ray actors to launch to distribute compute across
use_gpu=True, # default is for each worker to have 1 GPU (overrided by resources per worker)
resources_per_worker={"CPU": 2, "GPU": 1},
# trainer_resources={"CPU": 2, "GPU": 1},
Expand All @@ -118,7 +176,7 @@ def train_func(config):
search_alg=OptunaSearch(), # using ray.tune.search.Repeater() could be useful to get multiple trials per set of params
# would be even better if we could set trial-wise dependencies for a certain fold.
# https://github.com/ray-project/ray/issues/33677
num_samples=200,
num_samples=1000,
),
)

Expand Down
40 changes: 30 additions & 10 deletions results/model_media/model_stats.csv
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,33 @@ RNGM_PDBbind0D_nomsaF_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0
RNGM_PDBbind3D_nomsaF_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6845608518246677,0.5357292380549965,0.5262643639158158,2.8033910565957822,1.3118297254568652,1.674333018427273
RNGM_PDBbind1D_nomsaF_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6750296545462767,0.5185901583683697,0.5034912454720577,2.8088523469193527,1.2984564747465284,1.6759631102501489
RNGM_PDBbind4D_nomsaF_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6763045391457404,0.5057256652457955,0.5047101950511758,2.618698229612088,1.2698856009464514,1.6182392374467034
DGM_PDBbind0D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.6426988052884589,0.406588134840554,0.4227348856560998,3.120454147810393,1.3707334409001868,1.7664807238717304
DGM_PDBbind1D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.6337638489591875,0.3795519469266925,0.3967305500071522,3.1314639431203086,1.3655928022642796,1.7695942877169073
DGM_PDBbind2D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.6513570892279613,0.4305515739847962,0.4442416627474518,2.9943118695487394,1.3387308476755284,1.730408006670317
DGM_PDBbind3D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.6310054921700599,0.3776573127064433,0.3890187084575129,3.182500500059491,1.3841141298313269,1.7839564176457594
DGM_PDBbind4D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.6474083717941423,0.41884297350678834,0.4353919583788023,3.031050776147684,1.3427042543754748,1.7409913199518496
RNGM_PDBbind0D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6653959676484489,0.4594468871141732,0.4801002426988867,4.154945171290107,1.568700380900933,2.038368261941425
RNGM_PDBbind1D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6448596880768305,0.3998129142551612,0.4190505739131763,3.739009794564509,1.482601251936172,1.9336519321130443
RNGM_PDBbind2D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.679906458717887,0.4976193939048351,0.5214740426536342,3.8101620455639575,1.4884788046650663,1.951963638381606
RNGM_PDBbind3D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6433259320688763,0.4039176981504933,0.4161141374386168,3.6818319560241206,1.4657820598201496,1.918810036461171
RNGM_PDBbind4D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6795698349428855,0.5029467437374481,0.5170586141766542,3.798684555382926,1.4859081403686818,1.9490214353318247
DGM_PDBbind0D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.7138386406186538,0.632161676666746,0.5969379768993622,2.305938375785052,1.20129866702216,1.5185316512292564
DGM_PDBbind1D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.7071928697529236,0.6119312740868981,0.5825821216776425,2.227630691345381,1.1882162758282253,1.492524938265817
DGM_PDBbind2D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.7131792416345327,0.6300842144746225,0.5973837933267879,2.1487191174387785,1.156409247072916,1.4658509874604508
DGM_PDBbind3D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.660154663447613,0.477928410118703,0.4613848634220896,2.752867466543676,1.3047904886139765,1.6591767436122278
DGM_PDBbind4D_nomsaF_aflowE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.6849966099610509,0.5521829232190218,0.528410191518849,2.6060140340268965,1.2870582908675785,1.614315345286322
RNGM_PDBbind0D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.705837663251816,0.5925769047757962,0.5819610003983905,2.305359361226593,1.1903962990215846,1.518340989773573
RNGM_PDBbind1D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6960348680473861,0.5877518727647189,0.5567275097196549,2.3528305312600524,1.195004482080066,1.5338939113446055
RNGM_PDBbind2D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.7174010132899236,0.6351541961659758,0.6076186024009358,2.095580516039397,1.1347450122757563,1.447612004661262
RNGM_PDBbind3D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.6987873531320238,0.5978688522768106,0.5608647772334816,2.3233468636850785,1.201160574375637,1.5242528870515806
RNGM_PDBbind4D_nomsaF_aflow_ring3E_48B_0.0001071LR_0.3383D_2000E_originalLF_binaryLE,0.7026114627006717,0.6058195497309069,0.5730322590850676,2.262839648052509,1.184151166999151,1.5042737942450866
GVPM_PDBbind0D_gvpF_binaryE_48B_0.0001351LR_0.28157D_2000E_originalLF_binaryLE,0.6947350202078785,0.5498226711352023,0.5564602646332666,2.90614029644558,1.3244741636008794,1.7047405364000645
GVPM_PDBbind1D_gvpF_binaryE_48B_0.0001351LR_0.28157D_2000E_originalLF_binaryLE,0.6978109902923234,0.5434204593557842,0.5598089615551383,2.993091177534504,1.3343549799655197,1.730055252740358
GVPM_PDBbind2D_gvpF_binaryE_48B_0.0001351LR_0.28157D_2000E_originalLF_binaryLE,0.6936707284856692,0.5513294802303081,0.5552863808764578,2.7897382178391164,1.295975758580704,1.6702509445706404
GVPM_PDBbind3D_gvpF_binaryE_48B_0.0001351LR_0.28157D_2000E_originalLF_binaryLE,0.6974192823546201,0.5509073528693347,0.5676511207869644,3.0212262633577,1.365750774832816,1.738167501525011
GVPM_PDBbind4D_gvpF_binaryE_48B_0.0001351LR_0.28157D_2000E_originalLF_binaryLE,0.6949703486200646,0.5562085550041075,0.5596365374441437,2.736305140729789,1.281308484546666,1.6541780861593436
GVPLM_PDBbind2D_nomsaF_binaryE_128B_0.0002LR_0.2D_2000E_gvpLF_binaryLE,0.68008178345592,0.5242173198194416,0.517237612030569,2.76043080608376,1.3091822563868754,1.661454424919251
GVPLM_PDBbind0D_nomsaF_binaryE_128B_0.0002LR_0.2D_2000E_gvpLF_binaryLE,0.6910246829753094,0.5407443773715759,0.5474605079379407,2.712591784676441,1.2831902848911754,1.6469947737246895
GVPLM_PDBbind4D_nomsaF_binaryE_128B_0.0002LR_0.2D_2000E_gvpLF_binaryLE,0.6824302128418234,0.5299776653987239,0.5237527633262151,2.7502360319385275,1.2891588426970615,1.658383559957867
GVPLM_PDBbind3D_nomsaF_binaryE_128B_0.0002LR_0.2D_2000E_gvpLF_binaryLE,0.698694347610034,0.5674191707227977,0.5711326873928373,2.610268729668935,1.2641291124574372,1.6156326097442248
GVPLM_PDBbind1D_nomsaF_binaryE_128B_0.0002LR_0.2D_2000E_gvpLF_binaryLE,0.693415597273379,0.5451562579085044,0.5519121096465932,2.697724061805882,1.271342821263446,1.6424749805722711
GVPL_RNGM_PDBbind0D_nomsaF_aflow_ring3E_64B_0.000511787LR_0.27768D_2000E_gvpLF_binaryLE,0.7094741880607248,0.5879563811311338,0.5900623863392058,2.713032467763912,1.275113668311671,1.6471285522884704
GVPL_RNGM_PDBbind3D_nomsaF_aflow_ring3E_64B_0.000511787LR_0.27768D_2000E_gvpLF_binaryLE,0.7193853129709765,0.6170283320796139,0.6158246125125828,2.312149866619802,1.1534027675838034,1.5205755050703014
GVPL_RNGM_PDBbind4D_nomsaF_aflow_ring3E_64B_0.000511787LR_0.27768D_2000E_gvpLF_binaryLE,0.7083567810917276,0.6202003017018763,0.5935502039649166,2.535679235274617,1.2507571397245232,1.592381623630032
GVPL_RNGM_PDBbind1D_nomsaF_aflow_ring3E_64B_0.000511787LR_0.27768D_2000E_gvpLF_binaryLE,0.7242996397130641,0.6306614897395052,0.628684311009208,2.4235628205128066,1.202591815348809,1.5567796313264144
GVPL_RNGM_PDBbind2D_nomsaF_aflow_ring3E_64B_0.000511787LR_0.27768D_2000E_gvpLF_binaryLE,0.7268918298307233,0.6369287259575389,0.6288642313267695,2.528205215340215,1.217288659384408,1.590033086240728
GVPLM_PDBbind1D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE,0.7109736936213868,0.5959134582585407,0.5958003797210757,2.277362268673046,1.1636958009856089,1.509093194164312
GVPLM_PDBbind0D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE,0.7058942987473846,0.6291223362210724,0.5803380752439791,2.180345262170221,1.1609475984649054,1.4765992219184667
GVPLM_PDBbind2D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE,0.7194491955332394,0.6365982015372369,0.6153139418604331,2.098544323479561,1.1343965978092618,1.4486353314342295
GVPLM_PDBbind4D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE,0.7203456545202407,0.6299801928832721,0.6142370717682227,2.207161887712393,1.160070003592779,1.485652007608913
GVPLM_PDBbind3D_nomsaF_aflowE_128B_0.00022659LR_0.02414D_2000E_gvpLF_binaryLE,0.709231747593396,0.5876231909596512,0.5904267672441162,2.3008469727213514,1.1705255782036554,1.5168543017446836
Loading

0 comments on commit e4a6f8d

Please sign in to comment.