Skip to content

Commit

Permalink
u
Browse files Browse the repository at this point in the history
  • Loading branch information
thangckt committed Sep 10, 2024
1 parent f25be29 commit a5b0b12
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
7 changes: 5 additions & 2 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,11 @@ def run(args: argparse.Namespace) -> None:
device=device,
)
model.to(device)
if args.distributed and args.device == "cuda":
distributed_model = DDP(model, device_ids=[local_rank])
if args.distributed:
if args.device == "cuda":
distributed_model = DDP(model, device_ids=[local_rank])
elif args.device == "cpu":
distributed_model = DDP(model, device_ids=[rank])
model_to_evaluate = model if not args.distributed else distributed_model
if swa_eval:
logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation")
Expand Down
1 change: 1 addition & 0 deletions mace/tools/slurm_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ def __init__(self):
self.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
self.local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
self.rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
return

0 comments on commit a5b0b12

Please sign in to comment.