-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
154 lines (119 loc) · 4.37 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import logging
import os
import pickle
import random
import numpy as np
import torch
import torch.distributed as dist
logger = logging.getLogger(__name__)
def set_seed(seed: int = 42) -> None:
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(seed)
logger.info(f"Random seed set as {seed}")
def is_global_master(args):
return args.rank == 0
def is_local_master(args):
return args.local_rank == 0
def is_master(args, local=False):
return is_local_master(args) if local else is_global_master(args)
def is_distributed():
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"]) > 1
if "SLURM_NTASKS" in os.environ:
return int(os.environ["SLURM_NTASKS"]) > 1
return False
def world_info_from_env():
local_rank = 0
for v in (
"LOCAL_RANK",
"MPI_LOCALRANKID",
"SLURM_LOCALID",
"OMPI_COMM_WORLD_LOCAL_RANK",
):
if v in os.environ:
local_rank = int(os.environ[v])
break
global_rank = 0
for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
if v in os.environ:
global_rank = int(os.environ[v])
break
world_size = 1
for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
if v in os.environ:
world_size = int(os.environ[v])
break
return local_rank, global_rank, world_size
def init_distributed_device(args):
args.distributed = False
args.world_size = 1
args.rank = 0
args.local_rank = 0
if is_distributed():
if "SLURM_PROCID" in os.environ:
args.local_rank, args.rank, args.world_size = world_info_from_env()
os.environ["LOCAL_RANK"] = str(args.local_rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
dist.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
else:
args.local_rank, _, _ = world_info_from_env()
dist.init_process_group(
backend=args.dist_backend, init_method=args.dist_url
)
args.world_size = dist.get_world_size()
args.rank = dist.get_rank()
args.distributed = True
def assign_learning_rate(param_group, new_lr):
param_group["lr"] = new_lr
def _warmup_lr(base_lr, warmup_length, step):
return base_lr * (step + 1) / warmup_length
def cosine_lr(optimizer, base_lrs, warmup_length, steps):
if not isinstance(base_lrs, list):
base_lrs = [base_lrs for _ in optimizer.param_groups]
assert len(base_lrs) == len(optimizer.param_groups)
def _lr_adjuster(step):
for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
e = step - warmup_length
es = steps - warmup_length
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
assign_learning_rate(param_group, lr)
return _lr_adjuster
def pickle_save(classifier, save_path):
with open(save_path, "wb") as f:
pickle.dump(classifier.cpu(), f)
def pickle_load(save_path, device=None):
with open(save_path, "rb") as f:
classifier = pickle.load(f)
if device is not None:
classifier = classifier.to(device)
return classifier
def instantiate(config, *args, **kwargs):
if isinstance(config, dict):
cls = config.pop("_target_", None)
if cls is None:
raise ValueError("'_target_' key is required in the config dictionary")
if isinstance(cls, str):
module_name, class_name = cls.rsplit(".", 1)
module = __import__(module_name, fromlist=[class_name])
cls = getattr(module, class_name)
for key, value in config.items():
if isinstance(value, dict) and "_target_" in value:
config[key] = instantiate(value)
merged_kwargs = {**config, **kwargs}
return cls(*args, **merged_kwargs)
else:
return config