Skip to content

Commit

Permalink
Introduces multi-node training setup
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Adamov committed May 4, 2024
1 parent b0050b9 commit 896e9a5
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions train_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Standard library
import os
import random
import time
from argparse import ArgumentParser
Expand Down Expand Up @@ -230,13 +231,25 @@ def main():
)

# Instantiate model + trainer
if args.eval:
use_distributed_sampler = False
else:
use_distributed_sampler = True

devices = 1
num_nodes = 1
if torch.cuda.is_available():
device_name = "cuda"
torch.set_float32_matmul_precision(
"high"
) # Allows using Tensor Cores on A100s
accelerator = "cuda"
if "SLURM_JOB_ID" in os.environ:
devices = int(
os.environ.get(
"SLURM_GPUS_PER_NODE",
torch.cuda.device_count()))
num_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES", 1))
# Allows using Tensor Cores on A100s
torch.set_float32_matmul_precision("high")
else:
device_name = "cpu"
accelerator = "cpu"

# Load model parameters Use new args for model
model_class = MODELS[args.model]
Expand Down Expand Up @@ -269,8 +282,10 @@ def main():
trainer = pl.Trainer(
max_epochs=args.epochs,
deterministic=True,
strategy="ddp",
accelerator=device_name,
accelerator=accelerator,
devices=devices,
num_nodes=num_nodes,
use_distributed_sampler=use_distributed_sampler,
logger=logger,
log_every_n_steps=1,
callbacks=[checkpoint_callback],
Expand Down

0 comments on commit 896e9a5

Please sign in to comment.