-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdistributed.py
34 lines (30 loc) · 1.02 KB
/
distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import torch.distributed as dist
warn_on_half = True
def broadcast_params(params):
for p in params:
if torch.is_tensor(p):
torch.distributed.broadcast(p, 0)
def reduce_gradients(module):
buckets = {}
for name, param in module.named_parameters():
if param.requires_grad and param.grad is not None:
tp = type(param.data)
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
if warn_on_half:
if torch.cuda.HalfTensor in buckets:
print("WARNING: gloo dist backend for half parameters may be slow." +
" It is recommended to use the NCCL backend in this case.")
warn_on_half = False
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced)
torch.cuda.synchronize()
coalesced /= dist.get_world_size()
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)