diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 28c3dd26..c0953956 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -696,8 +696,11 @@ def run(args: argparse.Namespace) -> None: device=device, ) model.to(device) - if args.distributed and args.device == "cuda": - distributed_model = DDP(model, device_ids=[local_rank]) + if args.distributed: + if args.device == "cuda": + distributed_model = DDP(model, device_ids=[local_rank]) + elif args.device == "cpu": + distributed_model = DDP(model, device_ids=[rank]) model_to_evaluate = model if not args.distributed else distributed_model if swa_eval: logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation") diff --git a/mace/tools/slurm_distributed.py b/mace/tools/slurm_distributed.py index 866cbab3..f58ac675 100644 --- a/mace/tools/slurm_distributed.py +++ b/mace/tools/slurm_distributed.py @@ -39,3 +39,4 @@ def __init__(self): self.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) self.local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) self.rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + return