forked from atomicarchitects/equiformer
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdist_setup.py
executable file
·97 lines (86 loc) · 3.89 KB
/
dist_setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
'''
1. Copy distutils.setup from https://github.com/Open-Catalyst-Project/ocp/blob/89948582edfb8debb736406d54db9813a5f2c88d/ocpmodels/common/distutils.py#L16
2. Add OpenMPI multi-node training as Submitit is not supported.
'''
import logging
import os
import subprocess
import torch
import torch.distributed as dist
def setup(config):
if config["submit"]:
node_list = os.environ.get("SLURM_STEP_NODELIST")
if node_list is None:
node_list = os.environ.get("SLURM_JOB_NODELIST")
if node_list is not None:
try:
hostnames = subprocess.check_output(
["scontrol", "show", "hostnames", node_list]
)
config["init_method"] = "tcp://{host}:{port}".format(
host=hostnames.split()[0].decode("utf-8"),
port=config["distributed_port"],
)
nnodes = int(os.environ.get("SLURM_NNODES"))
ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE")
if ntasks_per_node is not None:
ntasks_per_node = int(ntasks_per_node)
else:
ntasks = int(os.environ.get("SLURM_NTASKS"))
nnodes = int(os.environ.get("SLURM_NNODES"))
assert ntasks % nnodes == 0
ntasks_per_node = int(ntasks / nnodes)
if ntasks_per_node == 1:
assert config["world_size"] % nnodes == 0
gpus_per_node = config["world_size"] // nnodes
node_id = int(os.environ.get("SLURM_NODEID"))
config["rank"] = node_id * gpus_per_node
config["local_rank"] = 0
else:
assert ntasks_per_node == config["world_size"] // nnodes
config["rank"] = int(os.environ.get("SLURM_PROCID"))
config["local_rank"] = int(os.environ.get("SLURM_LOCALID"))
logging.info(
f"Init: {config['init_method']}, {config['world_size']}, {config['rank']}"
)
# ensures GPU0 does not have extra context/higher peak memory
torch.cuda.set_device(config["local_rank"])
dist.init_process_group(
backend=config["distributed_backend"],
init_method=config["init_method"],
world_size=config["world_size"],
rank=config["rank"],
)
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError: # Slurm is not installed
pass
elif config["summit"]:
world_size = int(os.getenv('OMPI_COMM_WORLD_SIZE'))
world_rank = int(os.getenv('OMPI_COMM_WORLD_RANK'))
# Should be set already
#get_master = (
# "echo $(cat {} | sort | uniq | grep -v batch | grep -v login | head -1)"
#).format(os.environ["LSB_DJOB_HOSTFILE"])
#os.environ["MASTER_ADDR"] = str(
# subprocess.check_output(get_master, shell=True)
#)[2:-3]
#os.environ["MASTER_PORT"] = "23456"
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
config["local_rank"] = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'))
# NCCL and MPI initialization
dist.init_process_group(
backend="nccl",
rank=world_rank,
world_size=world_size,
init_method="env://",
)
else:
dist.init_process_group(
backend=config["distributed_backend"], init_method="env://",
rank=config['local_rank'],
world_size=config['world_size']
)
torch.cuda.set_device(config["local_rank"])
# TODO: SLURM