diff --git a/raytune.py b/raytune.py index fb69870d..ba7f145d 100644 --- a/raytune.py +++ b/raytune.py @@ -9,9 +9,10 @@ from src.utils.loader import Loader -from src.train_test.tune import simple_train, simple_test +from src.train_test.tune import simple_train, simple_eval from src.utils import config as cfg +resources = {"cpu":6, "gpu": 4} # NOTE: must match SBATCH directives def objective(config): save_checkpoint = config.get("save_checkpoint", False) @@ -35,7 +36,7 @@ def objective(config): simple_train(model, optimizer, loaders['train'], device=device, epochs=1) # Train the model - loss = simple_test(model, loaders['val'], + loss = simple_eval(model, loaders['val'], device=device) # Compute test accuracy checkpoint = None @@ -46,7 +47,6 @@ def objective(config): session.report({"mean_loss": loss}, checkpoint=checkpoint) algo = OptunaSearch() -# algo = ConcurrencyLimiter(algo, max_concurrent=4) search_space = { # constants: "epochs": 15, # 15 epochs @@ -64,7 +64,7 @@ def objective(config): } tuner = tune.Tuner( - tune.with_resources(objective, resources={"cpu": 6, "gpu": 1}), # NOTE: must match SBATCH directives + tune.with_resources(objective, resources=resources), tune_config=tune.TuneConfig( metric="mean_loss", mode="min", diff --git a/raytune_DDP.py b/raytune_DDP.py new file mode 100644 index 00000000..8b109153 --- /dev/null +++ b/raytune_DDP.py @@ -0,0 +1,132 @@ +# This is a simple tuning script for the raytune library. +# no support for distributed training in this file. + +import random, os, socket +import torch + +from torch import nn +import torch.multiprocessing as mp +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed import init_process_group, destroy_process_group +import os + +import ray +from ray import tune, train +from ray.air import session # this session just comes from train._internal.session._session +from ray.train.torch import TorchCheckpoint +from ray.tune.search.optuna import OptunaSearch + + +from src.utils.loader import Loader +from src.train_test.tune import simple_train, simple_eval +from src.utils import config as cfg + +def main(rank, world_size, config): + # ============= Set up DDP environment ===================== + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = config['port'] + init_process_group(backend="nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + device = torch.device(rank) + + # ============ Load up distributed training data ============== + p_grp = session.get_trial_resources() # returns placement group object from TrialInfo + # first item is resource list, in our case resources are the same across all trials + # so it is safe to just take the first from the list to get our gpu count + trial_resources = p_grp._bound.args[0][0] + ncpus = trial_resources['cpu'] + + local_bs = config['global_batch_size']/world_size + if not local_bs.is_integer(): + print(f'WARNING: batch size is not divisible by world size. Local batch size is {local_bs}.') + + local_bs = int(local_bs) + + loaders = Loader.load_distributed_DataLoaders( + num_replicas=world_size, rank=rank, seed=42, # DDP specific params + + data=config['dataset'], + pro_feature=config['feature_opt'], + edge_opt=config['edge_opt'], + batch_train=local_bs, # global_bs/world_size + datasets=['train', 'val'], + training_fold=config['fold_selection'], + num_workers=ncpus, # number of subproc used for data loading + ) + + # ============ Init Model ============== + model = Loader.init_model(model=config["model"], pro_feature=config["feature_opt"], + pro_edge=config["edge_opt"], dropout=config["dropout"] + ).to(device) + + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) # use if model contains batchnorm. + model = nn.parallel.DistributedDataParallel(model, device_ids=[rank]) + + torch.distributed.barrier() # Sync params across GPUs + + # ============ Train Model for n epochs ============ + optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"]) + save_checkpoint = config.get("save_checkpoint", False) + + for _ in range(config["epochs"]): + torch.distributed.barrier() + simple_train(model, optimizer, loaders['train'], + device=device, + epochs=1) # one epoch + + torch.distributed.barrier() + loss = simple_eval(model, loaders['val'], device) # Compute validation accuracy + + checkpoint = None + if save_checkpoint and rank == 0: + checkpoint = TorchCheckpoint.from_state_dict(model.state_dict()) + + # Report metrics (and possibly a checkpoint) to Tune + session.report({"mean_loss": loss}, checkpoint=checkpoint) + + destroy_process_group() + + +def objective_DDP(config): # NO inter-node distribution due to communication difficulties + world_size = torch.cuda.device_count() + # device is also inserted as the first arg to main() + mp.spawn(main, args=(world_size, config,), nprocs=world_size) + + +algo = OptunaSearch() + +search_space = { + # constants: + "epochs": 10, # 15 epochs + "model": "DG", + "dataset": "davis", + "feature_opt": "nomsa", + "edge_opt": "binary", + "fold_selection": 0, + "save_checkpoint": False, + + # DDP specific constants: + "port": random.randint(49152,65535), + + # hyperparameters to tune: + "lr": tune.loguniform(1e-4, 1e-2), + "dropout": tune.uniform(0, 0.5), + "embedding_dim": tune.choice([64, 128, 256]), + + "global_batch_size": tune.choice([16, 32, 48]), # global batch size is divided by ngpus/world_size +} + +tuner = tune.Tuner( + tune.with_resources(objective_DDP, resources={"cpu": 4, "gpu": 2}), + param_space=search_space, + tune_config=tune.TuneConfig( + metric="mean_loss", + mode="min", + search_alg=algo, + num_samples=50, + ), +) + +results = tuner.fit() \ No newline at end of file diff --git a/src/train_test/tune.py b/src/train_test/tune.py index 10caf821..d9d5cb4f 100644 --- a/src/train_test/tune.py +++ b/src/train_test/tune.py @@ -54,7 +54,7 @@ def simple_train(model: BaseModel, optimizer:torch.optim.Optimizer, optimizer.step() -def simple_test(model:BaseModel, test_loader:DataLoader, device:torch.device, +def simple_eval(model:BaseModel, data_loader:DataLoader, device:torch.device, CRITERION:torch.nn.Module=None) -> float: """ Run inference on the test set. @@ -81,7 +81,7 @@ def simple_test(model:BaseModel, test_loader:DataLoader, device:torch.device, CRITERION = CRITERION or torch.nn.MSELoss() with torch.no_grad(): - for data in test_loader: + for data in data_loader: batch_pro = data['protein'].to(device) batch_mol = data['ligand'].to(device) labels = data['y'].reshape(-1,1).to(device) @@ -94,6 +94,6 @@ def simple_test(model:BaseModel, test_loader:DataLoader, device:torch.device, test_loss += loss.item() # Compute average test loss - test_loss /= len(test_loader) + test_loss /= len(data_loader) return test_loss \ No newline at end of file diff --git a/src/train_test/utils.py b/src/train_test/utils.py index 81349070..b082e6db 100644 --- a/src/train_test/utils.py +++ b/src/train_test/utils.py @@ -421,17 +421,9 @@ def debug(model: BaseModel, data_loader:DataLoader, return train, eval ## ==================== Distributed Training ==================== ## -def handle_sigusr1(signum, frame): - # requeues the job - os.system(f'scontrol requeue {os.getenv("SLURM_JOB_ID")}') - exit() - def init_node(args): args.ngpus_per_node = torch.cuda.device_count() - # requeue job on SLURM preemption - signal.signal(signal.SIGUSR1, handle_sigusr1) - # find the common host name on all nodes cmd = 'scontrol show hostnames ' + os.getenv('SLURM_JOB_NODELIST') stdout = subprocess.check_output(cmd.split())