Skip to content

Commit

Permalink
Merge pull request #67 from jyaacoub/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
jyaacoub committed Dec 18, 2023
2 parents 3642588 + c1137a6 commit 94b3048
Show file tree
Hide file tree
Showing 64 changed files with 1,187,408 additions and 205 deletions.
42 changes: 34 additions & 8 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,37 @@
# %%
from matplotlib import pyplot as plt
from src.data_analysis.figures import fig0_dataPro_overlap
#%%
import torch

from src.utils.loader import Loader
from src.train_test.utils import debug
from src.utils import config as cfg
from torch_geometric.utils import dropout_edge, dropout_node

#%%
for data in ['kiba', 'davis']:
fig0_dataPro_overlap(data=data)
plt.savefig(f'results/figures/fig0_pro_overlap_{data}.png', dpi=300, bbox_inches='tight')
plt.clf()
device = torch.device('cuda:0')#'cuda:0' if torch.cuda.is_available() else 'cpu')

MODEL, DATA, = 'SPD', 'davis'
FEATURE, EDGEW = 'foldseek', 'binary'
ligand_feature, ligand_edge = None, None
BATCH_SIZE = 20
fold = 0
pro_overlap = False
DROPOUT = 0.2


# ==== LOAD DATA ====
loaders = Loader.load_DataLoaders(data=DATA, pro_feature=FEATURE, edge_opt=EDGEW, path=cfg.DATA_ROOT,
ligand_feature=ligand_feature, ligand_edge=ligand_edge,
batch_train=BATCH_SIZE,
datasets=['train'],
training_fold=fold,
protein_overlap=pro_overlap)


# ==== LOAD MODEL ====
print(f'#Device: {device}')
model = Loader.init_model(model=MODEL, pro_feature=FEATURE, pro_edge=EDGEW, dropout=DROPOUT,
ligand_feature=ligand_feature, ligand_edge=ligand_edge).to(device)


# %%
train, eval = debug(model, loaders['train'], device=device)
# %%
124 changes: 124 additions & 0 deletions rayTrain_Tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# This is a simple tuning script for the raytune library.
# no support for distributed training in this file.

import random
import os
import tempfile

import torch

import ray
from ray.air import session # this session just comes from train._internal.session._session
from ray.train import ScalingConfig, Checkpoint
from ray.train.torch import TorchCheckpoint, TorchTrainer
from ray.tune.search.optuna import OptunaSearch


from src.utils.loader import Loader
from src.train_test.simple import simple_train, simple_eval
from src.utils import config as cfg

def train_func(config):
# ============ Init Model ==============
model = Loader.init_model(model=config["model"], pro_feature=config["feature_opt"],
pro_edge=config["edge_opt"],
# additional kwargs send to model class to handle
dropout=config["dropout"],
dropout_prot=config["dropout_prot"], pro_emb_dim=config["pro_emb_dim"], extra_profc_layer=config["extra_profc_layer"]
)

# prepare model with rayTrain (moves it to correct device and wraps it in DDP)
model = ray.train.torch.prepare_model(model)

# ============ Load dataset ==============
print("Loading Dataset")
loaders = Loader.load_DataLoaders(data=config['dataset'], pro_feature=config['feature_opt'],
edge_opt=config['edge_opt'],
path=cfg.DATA_ROOT,
batch_train=config['batch_size'],
datasets=['train', 'val'],
training_fold=config['fold_selection'])

# prepare dataloaders with rayTrain (adds DistributedSampler and moves to correct device)
for k in loaders.keys():
loaders[k] = ray.train.torch.prepare_data_loader(loaders[k])


# ============= Simple training and eval loop =====================
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
save_checkpoint = config.get("save_checkpoint", False)

for _ in range(config['epochs']):

# NOTE: no need to pass in device, rayTrain will handle that for us
simple_train(model, optimizer, loaders['train'], epochs=1) # Train the model
loss = simple_eval(model, loaders['val']) # Compute test accuracy


# Report metrics (and possibly a checkpoint) to ray
checkpoint = None
if save_checkpoint:
checkpoint_dir = tempfile.gettempdir()
checkpoint_path = checkpoint_dir + "/model.checkpoint"
torch.save(model.state_dict(), checkpoint_path)
checkpoint = Checkpoint.from_directory(checkpoint_dir)

ray.train.report({"loss": loss}, checkpoint=checkpoint)


if __name__ == "__main__":
print("DATA_ROOT:", cfg.DATA_ROOT)
print("os.environ['TRANSFORMERS_CACHE']", os.environ['TRANSFORMERS_CACHE'])
print("Cuda support:", torch.cuda.is_available(),":",
torch.cuda.device_count(), "devices")
print("CUDA VERSION:", torch.__version__)
# ray.init(num_gpus=1, num_cpus=8, ignore_reinit_error=True)

search_space = {
## constants:
"epochs": 10,
"model": "EDI",
"dataset": "davis",
"feature_opt": "nomsa",
"edge_opt": "binary",
"fold_selection": 0,
"save_checkpoint": False,

## hyperparameters to tune:
"lr": ray.tune.loguniform(1e-4, 1e-2),
"batch_size": ray.tune.choice([16, 32, 48]), # batch size is per GPU!?

# model architecture hyperparams
"dropout": ray.tune.uniform(0, 0.5), # for fc layers
"dropout_prot": ray.tune.uniform(0, 0.5),
"pro_emb_dim": ray.tune.choice([480, 512, 1024]), # input from SaProt is 480 dims
"extra_profc_layer": ray.tune.choice([True, False])
}

# each worker is a node from the ray cluster.
# WARNING: SBATCH GPU directive should match num_workers*GPU_per_worker
scaling_config = ScalingConfig(num_workers=4, # number of ray actors to launch to distribute compute across
use_gpu=True, # default is for each worker to have 1 GPU (overrided by resources per worker)
resources_per_worker={"CPU": 2, "GPU": 1},
# trainer_resources={"CPU": 2, "GPU": 1},
# placement_strategy="PACK", # place workers on same node
)

print('init Tuner')
tuner = ray.tune.Tuner(
TorchTrainer(train_func),
param_space={
"train_loop_config": search_space,
"scaling_config": scaling_config
},
tune_config=ray.tune.TuneConfig(
metric="loss",
mode="min",
search_alg=OptunaSearch(), # using ray.tune.search.Repeater() could be useful to get multiple trials per set of params
# would be even better if we could set trial-wise dependencies for a certain fold.
# https://github.com/ray-project/ray/issues/33677
num_samples=50,
),
)

results = tuner.fit()
77 changes: 77 additions & 0 deletions raytune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# This is a simple tuning script for the raytune library.
# no support for distributed training in this file.

import torch
from ray.air import session
from ray.train.torch import TorchCheckpoint
from ray import tune
from ray.tune.search.optuna import OptunaSearch


from src.utils.loader import Loader
from src.train_test.simple import simple_train, simple_eval
from src.utils import config as cfg

resources = {"cpu":6, "gpu": 2} # NOTE: must match SBATCH directives

def objective(config):
save_checkpoint = config.get("save_checkpoint", False)
loaders = Loader.load_DataLoaders(data=config['dataset'], pro_feature=config['feature_opt'],
edge_opt=config['edge_opt'],
path=cfg.DATA_ROOT,
batch_train=config['batch_size'],
datasets=['train', 'val'],
training_fold=config['fold_selection'])

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = Loader.init_model(model=config["model"], pro_feature=config["feature_opt"],
pro_edge=config["edge_opt"], dropout=config["dropout"]
# WARNING: no ligand features for now
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])

for _ in range(config["epochs"]):
simple_train(model, optimizer, loaders['train'],
device=device,
epochs=1) # Train the model
loss = simple_eval(model, loaders['val'],
device=device) # Compute test accuracy

checkpoint = None
if save_checkpoint:
checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())

# Report metrics (and possibly a checkpoint) to Tune
session.report({"loss": loss}, checkpoint=checkpoint)

algo = OptunaSearch()
search_space = {
# constants:
"epochs": 15, # 15 epochs
"model": "DG",
"dataset": "davis",
"feature_opt": "nomsa",
"edge_opt": "binary",
"fold_selection": 0,
"save_checkpoint": False,

# hyperparameters to tune:
"lr": tune.loguniform(1e-4, 1e-2),
"dropout": tune.uniform(0, 0.5),
"batch_size": tune.choice([16, 32, 64, 128]),
}

tuner = tune.Tuner(
tune.with_resources(objective, resources=resources),
tune_config=tune.TuneConfig(
metric="loss",
mode="min",
search_alg=algo,
num_samples=50,
),
param_space=search_space,
)

results = tuner.fit()
134 changes: 134 additions & 0 deletions raytune_DDP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# This is a simple tuning script for the raytune library.
# no support for distributed training in this file.

import random, os, socket, time
import torch

from torch import nn
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

import ray
from ray import tune, train
from ray.air import session # this session just comes from train._internal.session._session
from ray.train.torch import TorchCheckpoint
from ray.tune.search.optuna import OptunaSearch


from src.utils.loader import Loader
from src.train_test.simple import simple_train, simple_eval
from src.utils import config as cfg

def main(rank, world_size, config): # define this inside objective??
# ============= Set up DDP environment =====================
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = config['port']
init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)

device = torch.device(rank)

# ============ Load up distributed training data ==============
p_grp = session.get_trial_resources() # returns placement group object from TrialInfo
# first item is resource list, in our case resources are the same across all trials
# so it is safe to just take the first from the list to get our gpu count
trial_resources = p_grp._bound.args[0][0]
ncpus = trial_resources['cpu']

local_bs = config['global_batch_size']/world_size
if not local_bs.is_integer():
print(f'WARNING: batch size is not divisible by world size. Local batch size is {local_bs}.')

local_bs = int(local_bs)

loaders = Loader.load_distributed_DataLoaders(
num_replicas=world_size, rank=rank, seed=42, # DDP specific params

data=config['dataset'],
pro_feature=config['feature_opt'],
edge_opt=config['edge_opt'],
batch_train=local_bs, # global_bs/world_size
datasets=['train', 'val'],
training_fold=config['fold_selection'],
num_workers=ncpus, # number of subproc used for data loading
)

# ============ Init Model ==============
model = Loader.init_model(model=config["model"], pro_feature=config["feature_opt"],
pro_edge=config["edge_opt"], dropout=config["dropout"]
).to(device)

model = nn.SyncBatchNorm.convert_sync_batchnorm(model) # use if model contains batchnorm.
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])

torch.distributed.barrier() # Sync params across GPUs

# ============ Train Model for n epochs ============
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
save_checkpoint = config.get("save_checkpoint", False)

for _ in range(config["epochs"]):
torch.distributed.barrier()
simple_train(model, optimizer, loaders['train'],
device=device,
epochs=1) # one epoch

torch.distributed.barrier()
loss = simple_eval(model, loaders['val'], device) # Compute validation accuracy

checkpoint = None
if save_checkpoint and rank == 0:
checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())

# Report metrics (and possibly a checkpoint) to Tune
session.report({"mean_loss": loss}, checkpoint=checkpoint)

destroy_process_group()


def objective_DDP(config): # NO inter-node distribution due to communication difficulties
world_size = torch.cuda.device_count()
# device is also inserted as the first arg to main()
print(f'World size: {world_size}')
mp.spawn(main, args=(world_size, config,), nprocs=world_size)


if __name__ == "__main__":
search_space = {
# constants:
"epochs": 10, # 15 epochs
"model": "DG",
"dataset": "davis",
"feature_opt": "nomsa",
"edge_opt": "binary",
"fold_selection": 0,
"save_checkpoint": False,

# DDP specific constants:
"port": random.randint(49152,65535),

# hyperparameters to tune:
"lr": tune.loguniform(1e-4, 1e-2),
"dropout": tune.uniform(0, 0.5),
"embedding_dim": tune.choice([64, 128, 256]),

"global_batch_size": tune.choice([16, 32, 48]), # global batch size is divided by ngpus/world_size
}

ray.init(num_gpus=1, num_cpus=8, ignore_reinit_error=True)

tuner = tune.Tuner(
tune.with_resources(objective_DDP, resources={"cpu": 6, "gpu": 2}),
param_space=search_space,
tune_config=tune.TuneConfig(
metric="mean_loss",
mode="min",
search_alg=OptunaSearch(),
num_samples=50,
),
)

results = tuner.fit()
Binary file added results/figures/davis_kinaseFamilies.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 10 additions & 1 deletion results/model_media/model_stats.csv
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,13 @@ EDIM_kiba4D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.7554049750038857,0.72048816
EDIM_kiba0D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.7470462628129559,0.719261877780631,0.6123189591738586,0.3581329716211184,0.4034757009810871,0.5984421205272223
EDIM_kiba1D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.7245467124961595,0.6561356476619974,0.562435629479892,0.4278950398199864,0.4566482743879321,0.6541368662749306
EDIM_kiba2D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.7652894580828917,0.7295933180985272,0.6555127509611349,0.3505783094362673,0.3984159894469308,0.5920965372608316
EDIM_kiba3D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.7524012037760042,0.7053345660980792,0.6236428370262223,0.3743329274493446,0.41874673207439383,0.6118275308036936
EDIM_kiba3D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.7524012037760042,0.7053345660980792,0.6236428370262223,0.3743329274493446,0.4187467320743938,0.6118275308036936
SPDM_davis0D_foldseekF_binaryE_48B_0.0001LR_0.4D_2000E,0.8525156503393888,0.7659447551090673,0.6581153879764675,0.3708225026029861,0.338606736402342,0.6089519706865116
SPDM_davis1D_foldseekF_binaryE_48B_0.0001LR_0.4D_2000E,0.8481255372327631,0.7711495477512904,0.6503928717525675,0.3643582525326579,0.3295057302623702,0.6036209510385288
SPDM_davis2D_foldseekF_binaryE_48B_0.0001LR_0.4D_2000E,0.8580614874482597,0.7742503824348432,0.6651237962212082,0.3632413843209577,0.3079463243484497,0.602695100627969
SPDM_davis4D_foldseekF_binaryE_48B_0.0001LR_0.4D_2000E,0.861350839182363,0.7729691181533276,0.6731323529056853,0.3606378936494832,0.3228092759480718,0.6005313427702864
DGM_davis0D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E,0.822423721276364,0.6836410200968875,0.5905873651981737,0.4237553412219658,0.3582645132831933,0.6509649308695252
DGM_davis3D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E,0.8227123189458018,0.6860915819553239,0.5912681095280419,0.4464863570217145,0.3584857518630543,0.6681963461601047
DGM_davis4D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E,0.8408179656586924,0.7432233299406013,0.619193221562032,0.3618264547363151,0.3197961523823144,0.6015201199763106
DGM_davis2D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E,0.8289943097251969,0.6969761208616752,0.6016205091333046,0.4075082262256328,0.3545933397621854,0.6383637099848587
DGM_davis1D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E,0.8267878534867681,0.69467191933228,0.5980002891045212,0.4307114257343926,0.3479064531756817,0.6562860852817105
Loading

0 comments on commit 94b3048

Please sign in to comment.