Skip to content

Commit

Permalink
Add utility functions to handle distributed groups
Browse files Browse the repository at this point in the history
  • Loading branch information
ducksoup committed Jul 16, 2019
1 parent ffabb74 commit 5c3a083
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ To install PyTorch, please refer to https://github.com/pytorch/pytorch#installat

To install the package containing the iABN layers:
```bash
pip install git+https://github.com/mapillary/[email protected].2
pip install git+https://github.com/mapillary/[email protected].3
```
Note that some parts of InPlace-ABN have native C++/CUDA implementations, meaning that the command above will need to
compile them.
Expand Down
1 change: 1 addition & 0 deletions inplace_abn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from ._version import version as __version__
from .abn import ABN, InPlaceABN, InPlaceABNSync
from .group import active_group, set_active_group
49 changes: 49 additions & 0 deletions inplace_abn/group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
import torch.distributed as distributed
import torch.nn as nn


def active_group(active):
"""Initialize a distributed group where each process can independently decide whether to participate or not
Parameters
----------
active : bool
Whether this process will be active in the group or not
Returns
-------
A distributed group containing all processes that passed `active=True`, or `None` if all passed `False`
"""
world_size = distributed.get_world_size()
rank = distributed.get_rank()

# Check if cache is initialized, add WORLD and None to it
if not hasattr(active_group, "__cache__"):
active_group.__cache__ = {
frozenset(range(world_size)): distributed.group.WORLD,
frozenset(): None
}

# Gather active status from all workers
active = torch.tensor(rank if active else -1, dtype=torch.long, device=torch.cuda.current_device())
active_workers = torch.empty(world_size, dtype=torch.long, device=torch.cuda.current_device())
distributed.all_gather(list(active_workers.unbind(0)), active)

# Create and cache group if it doesn't exist yet
active_workers = frozenset(int(i) for i in active_workers.tolist() if i != -1)
if active_workers not in active_group.__cache__:
group = distributed.new_group(list(active_workers))
active_group.__cache__[active_workers] = group

return active_group.__cache__[active_workers]


def set_active_group(module: nn.Module, group):
"""Scan all submodules, passing a distributed group to all those that implement `set_group`"""

def _set_group(m):
if hasattr(m, "set_group"):
m.set_group(group)

module.apply(_set_group)

0 comments on commit 5c3a083

Please sign in to comment.