Skip to content

Commit

Permalink
u
Browse files Browse the repository at this point in the history
  • Loading branch information
thangckt committed Sep 10, 2024
1 parent a5b0b12 commit 1ace578
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
7 changes: 3 additions & 4 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def run(args: argparse.Namespace) -> None:
local_rank = distr_env.local_rank
rank = distr_env.rank
if rank == 0:
print(distr_env)

print("Using distributed Environment: ", distr_env)
torch.distributed.init_process_group(backend=args.distributed_backend)
else:
rank = int(0)
Expand Down Expand Up @@ -579,7 +578,7 @@ def run(args: argparse.Namespace) -> None:
if args.device == "cuda":
distributed_model = DDP(model, device_ids=[local_rank])
elif args.device == "cpu":
distributed_model = DDP(model, device_ids=[rank])
distributed_model = DDP(model)
else:
distributed_model = None

Expand Down Expand Up @@ -700,7 +699,7 @@ def run(args: argparse.Namespace) -> None:
if args.device == "cuda":
distributed_model = DDP(model, device_ids=[local_rank])
elif args.device == "cpu":
distributed_model = DDP(model, device_ids=[rank])
distributed_model = DDP(model)
model_to_evaluate = model if not args.distributed else distributed_model
if swa_eval:
logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation")
Expand Down
1 change: 0 additions & 1 deletion mace/tools/slurm_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,3 @@ def __init__(self):
self.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
self.local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
self.rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
return

0 comments on commit 1ace578

Please sign in to comment.