Skip to content

Commit

Permalink
feat(rayTune_DDP): init RayTune env for DDP #68
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed Dec 8, 2023
1 parent ac971ed commit 67370ae
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 15 deletions.
8 changes: 4 additions & 4 deletions raytune.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@


from src.utils.loader import Loader
from src.train_test.tune import simple_train, simple_test
from src.train_test.tune import simple_train, simple_eval
from src.utils import config as cfg

resources = {"cpu":6, "gpu": 4} # NOTE: must match SBATCH directives

def objective(config):
save_checkpoint = config.get("save_checkpoint", False)
Expand All @@ -35,7 +36,7 @@ def objective(config):
simple_train(model, optimizer, loaders['train'],
device=device,
epochs=1) # Train the model
loss = simple_test(model, loaders['val'],
loss = simple_eval(model, loaders['val'],
device=device) # Compute test accuracy

checkpoint = None
Expand All @@ -46,7 +47,6 @@ def objective(config):
session.report({"mean_loss": loss}, checkpoint=checkpoint)

algo = OptunaSearch()
# algo = ConcurrencyLimiter(algo, max_concurrent=4)
search_space = {
# constants:
"epochs": 15, # 15 epochs
Expand All @@ -64,7 +64,7 @@ def objective(config):
}

tuner = tune.Tuner(
tune.with_resources(objective, resources={"cpu": 6, "gpu": 1}), # NOTE: must match SBATCH directives
tune.with_resources(objective, resources=resources),
tune_config=tune.TuneConfig(
metric="mean_loss",
mode="min",
Expand Down
132 changes: 132 additions & 0 deletions raytune_DDP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# This is a simple tuning script for the raytune library.
# no support for distributed training in this file.

import random, os, socket
import torch

from torch import nn
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

import ray
from ray import tune, train
from ray.air import session # this session just comes from train._internal.session._session
from ray.train.torch import TorchCheckpoint
from ray.tune.search.optuna import OptunaSearch


from src.utils.loader import Loader
from src.train_test.tune import simple_train, simple_eval
from src.utils import config as cfg

def main(rank, world_size, config):
# ============= Set up DDP environment =====================
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = config['port']
init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)

device = torch.device(rank)

# ============ Load up distributed training data ==============
p_grp = session.get_trial_resources() # returns placement group object from TrialInfo
# first item is resource list, in our case resources are the same across all trials
# so it is safe to just take the first from the list to get our gpu count
trial_resources = p_grp._bound.args[0][0]
ncpus = trial_resources['cpu']

local_bs = config['global_batch_size']/world_size
if not local_bs.is_integer():
print(f'WARNING: batch size is not divisible by world size. Local batch size is {local_bs}.')

local_bs = int(local_bs)

loaders = Loader.load_distributed_DataLoaders(
num_replicas=world_size, rank=rank, seed=42, # DDP specific params

data=config['dataset'],
pro_feature=config['feature_opt'],
edge_opt=config['edge_opt'],
batch_train=local_bs, # global_bs/world_size
datasets=['train', 'val'],
training_fold=config['fold_selection'],
num_workers=ncpus, # number of subproc used for data loading
)

# ============ Init Model ==============
model = Loader.init_model(model=config["model"], pro_feature=config["feature_opt"],
pro_edge=config["edge_opt"], dropout=config["dropout"]
).to(device)

model = nn.SyncBatchNorm.convert_sync_batchnorm(model) # use if model contains batchnorm.
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])

torch.distributed.barrier() # Sync params across GPUs

# ============ Train Model for n epochs ============
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
save_checkpoint = config.get("save_checkpoint", False)

for _ in range(config["epochs"]):
torch.distributed.barrier()
simple_train(model, optimizer, loaders['train'],
device=device,
epochs=1) # one epoch

torch.distributed.barrier()
loss = simple_eval(model, loaders['val'], device) # Compute validation accuracy

checkpoint = None
if save_checkpoint and rank == 0:
checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())

# Report metrics (and possibly a checkpoint) to Tune
session.report({"mean_loss": loss}, checkpoint=checkpoint)

destroy_process_group()


def objective_DDP(config): # NO inter-node distribution due to communication difficulties
world_size = torch.cuda.device_count()
# device is also inserted as the first arg to main()
mp.spawn(main, args=(world_size, config,), nprocs=world_size)


algo = OptunaSearch()

search_space = {
# constants:
"epochs": 10, # 15 epochs
"model": "DG",
"dataset": "davis",
"feature_opt": "nomsa",
"edge_opt": "binary",
"fold_selection": 0,
"save_checkpoint": False,

# DDP specific constants:
"port": random.randint(49152,65535),

# hyperparameters to tune:
"lr": tune.loguniform(1e-4, 1e-2),
"dropout": tune.uniform(0, 0.5),
"embedding_dim": tune.choice([64, 128, 256]),

"global_batch_size": tune.choice([16, 32, 48]), # global batch size is divided by ngpus/world_size
}

tuner = tune.Tuner(
tune.with_resources(objective_DDP, resources={"cpu": 4, "gpu": 2}),
param_space=search_space,
tune_config=tune.TuneConfig(
metric="mean_loss",
mode="min",
search_alg=algo,
num_samples=50,
),
)

results = tuner.fit()
6 changes: 3 additions & 3 deletions src/train_test/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def simple_train(model: BaseModel, optimizer:torch.optim.Optimizer,
optimizer.step()


def simple_test(model:BaseModel, test_loader:DataLoader, device:torch.device,
def simple_eval(model:BaseModel, data_loader:DataLoader, device:torch.device,
CRITERION:torch.nn.Module=None) -> float:
"""
Run inference on the test set.
Expand All @@ -81,7 +81,7 @@ def simple_test(model:BaseModel, test_loader:DataLoader, device:torch.device,
CRITERION = CRITERION or torch.nn.MSELoss()

with torch.no_grad():
for data in test_loader:
for data in data_loader:
batch_pro = data['protein'].to(device)
batch_mol = data['ligand'].to(device)
labels = data['y'].reshape(-1,1).to(device)
Expand All @@ -94,6 +94,6 @@ def simple_test(model:BaseModel, test_loader:DataLoader, device:torch.device,
test_loss += loss.item()

# Compute average test loss
test_loss /= len(test_loader)
test_loss /= len(data_loader)

return test_loss
8 changes: 0 additions & 8 deletions src/train_test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,17 +421,9 @@ def debug(model: BaseModel, data_loader:DataLoader,
return train, eval

## ==================== Distributed Training ==================== ##
def handle_sigusr1(signum, frame):
# requeues the job
os.system(f'scontrol requeue {os.getenv("SLURM_JOB_ID")}')
exit()

def init_node(args):
args.ngpus_per_node = torch.cuda.device_count()

# requeue job on SLURM preemption
signal.signal(signal.SIGUSR1, handle_sigusr1)

# find the common host name on all nodes
cmd = 'scontrol show hostnames ' + os.getenv('SLURM_JOB_NODELIST')
stdout = subprocess.check_output(cmd.split())
Expand Down

0 comments on commit 67370ae

Please sign in to comment.