Skip to content

Commit

Permalink
Merge branch 'develop' into axonn-cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored Jun 18, 2024
2 parents d24fa4c + b03b58e commit ad779c3
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 39 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/nvidia-rtx-3090-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- uses: actions/checkout@v3
- name: Install AxoNN
run: |
pip install -r requirements.txt
pip install -e .
- name: Download dataset
run: |
python -c "import torchvision; torchvision.datasets.MNIST(root=\"./axonn/tests\", download=True, train=True)"
Expand All @@ -44,7 +44,7 @@ jobs:
- uses: actions/checkout@v3
- name: Install AxoNN
run: |
pip install -r requirements.txt
pip install -e .
- name: Run intra-layer FC unit tests
run: |
mpirun -mca orte_allowed_exit_without_sync 1 -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_fc.py
Expand Down
60 changes: 32 additions & 28 deletions axonn/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
# Copyright 2021 Parallel Software and Systems Group, University of Maryland.
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch
from . import config
import os
from .axonn import model_params_fp16, model_params_fp32, model


def save_model_and_optimizer(model, optimizer, checkpoint_folder):
inter_rank = config.inter_layer_parallel_rank
intra_rank = config.intra_layer_parallel_rank
data_rank = config.data_parallel_rank

if intra_rank == 0 and data_rank == 0:
model_path = os.path.join(checkpoint_folder, f"model_{inter_rank}.pt")
optim_path = os.path.join(checkpoint_folder, f"optim_{inter_rank}.pt")
torch.save(model.state_dict(), model_path)
torch.save(optimizer.state_dict(), optim_path)


def load_model(model, checkpoint_folder):
inter_rank = config.inter_layer_parallel_rank

model_path = os.path.join(checkpoint_folder, f"model_{inter_rank}.pt")
model.load_state_dict(torch.load(model_path, map_location="cpu"))
return model


def load_optimizer(optimizer, checkpoint_folder):
inter_rank = config.inter_layer_parallel_rank
optim_path = os.path.join(checkpoint_folder, f"optim_{inter_rank}.pt")
optimizer.load_state_dict(torch.load(optim_path, map_location="cpu"))
if model is not None:
model_params_fp16.copy_(model_params_fp32)
return optimizer
def get_prefix_for_checkpoint():
row_tp_rank = config.intra_layer_row_parallel_rank
column_tp_rank = config.intra_layer_column_parallel_rank
depth_tp_rank = config.intra_layer_depth_parallel_rank
return f"tp_row_{row_tp_rank}_col_{column_tp_rank}_depth_{depth_tp_rank}"


def save(state, checkpoint_folder, checkpoint_name, overwrite=True):
if config.data_parallel_rank == 0:
checkpoint_folder = os.path.join(checkpoint_folder, get_prefix_for_checkpoint())
if not os.path.exists(checkpoint_folder):
os.makedirs(checkpoint_folder)
checkpoint_file = os.path.join(checkpoint_folder, f"{checkpoint_name}.pt")
if os.path.exists(checkpoint_file) and not overwrite:
raise ValueError(f"Checkpoint {checkpoint_file} already exists")
torch.save(state, checkpoint_file)


def load(state, checkpoint_folder, checkpoint_name):
assert os.path.isdir(
checkpoint_folder
), f"folder {checkpoint_folder} does not exist"
checkpoint_file = os.path.join(
checkpoint_folder, f"{get_prefix_for_checkpoint()}_{checkpoint_name}.pt"
)
torch.load(checkpoint_file)
return state
19 changes: 14 additions & 5 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def accumulate():
global pending_grad_accumulations
for param, grad in pending_grad_accumulations:
if param.grad is None:
param.grad = grad
param.grad = grad.to(param.dtype)
else:
param.grad.add_(grad)
param.grad.add_(grad.to(param.dtype))

pending_grad_accumulations = []

Expand Down Expand Up @@ -204,7 +204,9 @@ def optimize_communication(


@torch.no_grad()
def sync_gradients(model, gradient_attr_name="grad", mean=False, vectorize=False):
def sync_gradients(
model, gradient_attr_name="grad", mean=False, vectorize=False, mean_weight=None
):
grads_to_sync = []
for param in model.parameters():
if param.requires_grad:
Expand Down Expand Up @@ -239,6 +241,13 @@ def sync_gradients(model, gradient_attr_name="grad", mean=False, vectorize=False
old_tensor.data = new_tensor
else:
for grad in grads_to_sync:
dist.all_reduce(grad, group=ax.comm_handle.depth_intra_layer_parallel_group)
if mean:
grad.div_(world_size)
if mean_weight is None:
grad.div_(world_size)
else:
mean_weight_pt = torch.tensor(
[mean_weight], device="cuda", dtype=torch.float32
)
dist.all_reduce(mean_weight_pt)
grad.mul_(mean_weight).div_(mean_weight_pt)
dist.all_reduce(grad, group=ax.comm_handle.depth_intra_layer_parallel_group)
9 changes: 6 additions & 3 deletions axonn/intra_layer/automatic_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from axonn.intra_layer import Linear
from contextlib import contextmanager

reference_to_original_linear_class = nn.Linear


def is_parallelizable(in_features, out_features):
G_row = ax.config.G_intra_r
Expand All @@ -23,7 +25,9 @@ def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None):
parallel_layer = parallel_layer.to(dtype)
return parallel_layer
else:
sequential_layer = nn.Linear(in_features, out_features, bias=bias)
sequential_layer = reference_to_original_linear_class(
in_features, out_features, bias=bias
)
if device is not None:
sequential_layer = sequential_layer.to(device)
if dtype is not None:
Expand All @@ -33,9 +37,8 @@ def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None):

@contextmanager
def auto_parallelize():
old_linear = nn.Linear
nn.Linear = patched_linear
try:
yield None
finally:
nn.Linear = old_linear
nn.Linear = reference_to_original_linear_class
6 changes: 6 additions & 0 deletions axonn/lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright 2021 Parallel Software and Systems Group, University of Maryland.
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .axonn_strategy import AxonnStrategy # noqa: F401
228 changes: 228 additions & 0 deletions axonn/lightning/axonn_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# Copyright 2021 Parallel Software and Systems Group, University of Maryland.
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from datetime import timedelta
from typing import Any, Dict, List, Optional, Union

import torch
import torch.distributed
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
from torch import Tensor
from torch.nn import Module
from typing_extensions import override

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies.launchers.subprocess_script import (
_SubprocessScriptLauncher,
)
from lightning.fabric.strategies.parallel import ParallelStrategy
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import TBroadcast
from lightning.fabric.utilities.distributed import (
ReduceOp,
_distributed_is_initialized,
_get_default_process_group_backend_for_device,
_init_dist_connection,
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.rank_zero import rank_zero_only
from axonn import axonn as ax
from axonn.intra_layer import sync_gradients


class AxonnStrategy(ParallelStrategy):

def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
G_data: int = 1,
G_inter: int = 1,
G_intra_r: int = 1,
G_intra_c: int = 1,
G_intra_d: int = 1,
**kwargs: Any,
) -> None:
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision=precision,
)

assert G_data == 1, "Data Parallelism not Supported in AxoNNStrategy"
assert (
G_inter == 1
), "Inter-layer (or pipeline) Parallellism not Supported in AxoNNStrategy"
self._num_nodes = 1
self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
self.G_data = G_data
self.G_inter = G_inter
self.G_intra_r = G_intra_r
self.G_intra_c = G_intra_c
self.G_intra_d = G_intra_d
self._axonn_kwargs = kwargs

@property
@override
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]

@property
def num_nodes(self) -> int:
return self._num_nodes

@num_nodes.setter
def num_nodes(self, num_nodes: int) -> None:
self._num_nodes = num_nodes

@property
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0

@property
@override
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
return {
"num_replicas": ax.config.G_intra_d * ax.config.G_data,
"rank": ax.config.G_intra_d * ax.config.data_parallel_rank
+ ax.config.intra_layer_depth_parallel_rank,
}

@property
def process_group_backend(self) -> Optional[str]:
return self._process_group_backend

@override
def _configure_launcher(self) -> None:
assert self.cluster_environment is not None
self._launcher = _SubprocessScriptLauncher(
self.cluster_environment, self.num_processes, self.num_nodes
)

@override
def setup_environment(self) -> None:
super().setup_environment()
self._setup_distributed()

@override
def setup_module(self, module: Module):
return module # use autoparallelize later

@override
def module_to_device(self, module: Module) -> None:
module.to(self.root_device)

@override
def all_reduce(
self,
tensor: Tensor,
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
) -> Tensor:
if isinstance(tensor, Tensor):
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

@override
def barrier(self, *args: Any, **kwargs: Any) -> None:
if not _distributed_is_initialized():
return
if torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self._determine_device_ids())
else:
torch.distributed.barrier()

@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not _distributed_is_initialized():
return obj

obj = [obj]
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]

@classmethod
@override
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
pass

def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(
self.cluster_environment, self._process_group_backend, timeout=self._timeout
)

ax.init(
G_data=self.G_data,
G_inter=self.G_inter,
G_intra_r=self.G_intra_r,
G_intra_c=self.G_intra_c,
G_intra_d=self.G_intra_d,
)

def _get_process_group_backend(self) -> str:
return (
self._process_group_backend
or _get_default_process_group_backend_for_device(self.root_device)
)

def _set_world_ranks(self) -> None:
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(
self.node_rank * self.num_processes + self.local_rank
)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank

def _determine_device_ids(self) -> Optional[List[int]]:
return None if self.root_device.type == "cpu" else [self.root_device.index]

@override
def backward(
self, tensor: Tensor, module: Optional[Module] = None, *args: Any, **kwargs: Any
) -> None:
super().backward(tensor / self.G_intra_d, module, *args, **kwargs)
if self.G_intra_d > 1:
assert module is not None, (
"When using G_intra_d > 1 with AxoNN,"
" you need to pass the model in fabric.backward(model=..)"
)
sync_gradients(module)

def save_checkpoint(
self,
*args,
**kwargs,
) -> None:
assert False, (
"Current fabric.save(..) is not supported with the "
"AxoNN strategy. Use axonn.save instead."
)

def load_checkpoint(
self,
*args,
**kwargs,
) -> None:
assert False, (
"Current fabric.load(..) is not supported with the"
" AxoNN strategy. Use axonn.load instead."
)
Loading

0 comments on commit ad779c3

Please sign in to comment.