Skip to content

Commit

Permalink
Tensor parallel embedding (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored Jul 9, 2024
1 parent ade0c1a commit 459a443
Show file tree
Hide file tree
Showing 9 changed files with 391 additions and 11 deletions.
10 changes: 7 additions & 3 deletions .github/workflows/nvidia-rtx-3090-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
export G_inter=${{ matrix.ginter }}
export G_data=$(( 2 / G_inter ))
echo "training with G_inter = ${G_inter}, G_data = $(( 2 / G_inter )) ${{ matrix.memopt }}"
mpirun -mca orte_allowed_exit_without_sync 1 -n 2 pytest --with-mpi ./axonn/tests/test_vit.py
mpirun -n 2 pytest --with-mpi ./axonn/tests/test_vit.py
- name: Uninstall AxoNN
run: |
pip uninstall --yes axonn
Expand All @@ -45,10 +45,14 @@ jobs:
pip install -e .
- name: Run intra-layer FC unit tests
run: |
mpirun -mca orte_allowed_exit_without_sync 1 -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_fc.py
torchrun --nproc_per_node 2 --no_python python -m pytest ./axonn/tests/test_intra_layer_fc.py
- name: Run intra-layer Conv unit tests
run: |
mpirun -mca orte_allowed_exit_without_sync 1 -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_conv.py
torchrun --nproc_per_node 2 --no_python python -m pytest ./axonn/tests/test_intra_layer_conv.py
- name: Run intra-layer Embedding unit tests
run: |
torchrun --nproc_per_node 2 --no_python python -m pytest ./axonn/tests/test_intra_layer_emb.py -k bw_pass
torchrun --nproc_per_node 2 --no_python python -m pytest ./axonn/tests/test_intra_layer_emb.py -k fw_pass
- name: Uninstall AxoNN
run: |
pip uninstall --yes axonn
4 changes: 2 additions & 2 deletions axonn/axonn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@


def init(
G_inter: int,
G_data: int,
G_inter: int = 1,
G_data: int = 1,
G_intra_r: int = 1,
G_intra_c: int = 1,
G_intra_d: int = 1,
Expand Down
6 changes: 6 additions & 0 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Copyright 2021 Parallel Software and Systems Group, University of Maryland.
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from contextlib import contextmanager
from .fully_connected import Linear # noqa: F401
from .conv import Conv2d # noqa: F401
from .embedding import Embedding # noqa: F401

from .communication import Drop, Gather
from .gradient_normalization import clip_grad_norm_ # noqa: F401
Expand Down
222 changes: 222 additions & 0 deletions axonn/intra_layer/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright 2021 Parallel Software and Systems Group, University of Maryland.
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch.distributed as dist
import torch
import torch.nn.functional as F

from axonn import axonn as ax
from .communication import (
Drop,
Gather,
ForwardGather_BackwardReduceScatter,
)


def divide(a, b):
assert a % b == 0
return a // b


@torch.no_grad()
def extract_local_params_from_full_params(
params, out_features_group, in_features_group, depth_group
):
params = Drop.apply(params, out_features_group)
params = Drop.apply(torch.t(params).contiguous(), in_features_group)
params = torch.t(params).contiguous()
params = Drop.apply(params.reshape(-1), depth_group) # create 1D view
return params


@torch.no_grad()
def initialize_params(
out_features,
in_features,
out_features_group,
in_features_group,
depth_group,
init_method,
init_device="cuda",
):
params = torch.empty((in_features, out_features), device=init_device)
init_method(params)
params = extract_local_params_from_full_params(
params, out_features_group, in_features_group, depth_group
).cpu()
return params


@torch.no_grad()
def default_init_method(weight, padding_idx=None):
return torch.nn.init.normal_(weight)


class Embedding(torch.nn.Module):
def __init__(
self,
num_embeddings,
embedding_dim,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
_freeze=False,
*args,
transpose=False,
init_method=None,
expert_mode=False,
**kwargs,
):
assert not _weight, "_weight argument is not supported."
super(Embedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert (
padding_idx < self.num_embeddings
), "Padding_idx must be within num_embeddings"
elif padding_idx < 0:
assert (
padding_idx >= -self.num_embeddings
), "Padding_idx must be within num_embeddings"
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse

self.inner_group = ax.comm_handle.inner_intra_layer_parallel_group
self.outer_group = ax.comm_handle.outer_intra_layer_parallel_group
self.depth_group = ax.comm_handle.depth_intra_layer_parallel_group

self.inner_group_size = dist.get_world_size(self.inner_group)
self.outer_group_size = dist.get_world_size(self.outer_group)
self.depth_group_size = dist.get_world_size(self.depth_group)

self.out_features = self.embedding_dim
self.in_features = self.num_embeddings

if init_method is None:
init_method = default_init_method

if not transpose:
assert self.inner_group_size == 1
assert self.embedding_dim % self.outer_group_size == 0
self.local_in_features = self.num_embeddings
self.local_out_features = divide(self.embedding_dim, self.outer_group_size)
initial_params = initialize_params(
self.out_features,
self.in_features,
self.outer_group,
self.inner_group,
self.depth_group,
init_method,
)
else:
assert self.outer_group_size == 1
assert embedding_dim % self.inner_group_size == 0
self.local_in_features = self.num_embeddings
self.local_out_features = divide(self.embedding_dim, self.inner_group_size)
initial_params = initialize_params(
self.out_features,
self.in_features,
self.inner_group,
self.outer_group,
self.depth_group,
init_method,
)

if self.padding_idx is not None:
initial_params[padding_idx].fill_(0)

self.weight = torch.nn.Parameter(initial_params, requires_grad=not _freeze)

setattr(self.weight, "is_tensor_parallel", True)
setattr(self.weight, "needs_gradient_sync", False)
setattr(
self.weight,
"process_group_for_norm_reduction",
ax.comm_handle.intra_layer_group,
)

self.expert_mode = expert_mode
self.transpose = transpose
self._old_load_from_state_dict = self._load_from_state_dict
self._load_from_state_dict = self._modified_load_from_state_dict

def get_output_feature_size(self):
return self.local_out_features

def forward(self, x):
# gather weights from depth parallel group
# reduce scatter in the backward pass
weight = self.weight
weight = ForwardGather_BackwardReduceScatter.apply(
weight, self.depth_group
).reshape(self.local_in_features, self.local_out_features)
x = F.embedding(
x,
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
if not self.expert_mode:
x = Gather.apply(
x, self.outer_group if not self.transpose else self.inner_group
)

return x

def _is_full_weight_matrix(self, weight):
return (
weight.ndim == 2
and weight.size(0) == self.in_features
and weight.size(1) == self.out_features
)

def _is_sharded_weight_matrix(self, weight):
return weight.ndim == 1 and weight.size(0) == divide(
self.local_out_features * self.local_in_features, self.depth_group_size
)

@torch.no_grad()
def _modified_load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
weight = (
state_dict[prefix + "weight"] if prefix + "weight" in state_dict else None
)

if weight is not None:
is_full_weight_matrix = self._is_full_weight_matrix(weight)
is_sharded_weight_matrix = self._is_sharded_weight_matrix(weight)

assert (
is_full_weight_matrix or is_sharded_weight_matrix
), "This is neither a full checkpoint nor a sharded checkpoint"

if is_full_weight_matrix:
out_features_group, in_features_group = (
self.outer_group,
self.inner_group,
)
if self.transpose:
out_features_group, in_features_group = (
self.inner_group,
self.outer_group,
)
weight = extract_local_params_from_full_params(
weight, out_features_group, in_features_group, self.depth_group
)

state_dict[prefix + "weight"] = weight

self._old_load_from_state_dict(state_dict, prefix, *args, **kwargs)
5 changes: 5 additions & 0 deletions axonn/intra_layer/gradient_normalization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright 2021 Parallel Software and Systems Group, University of Maryland.
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch

# for backwards compatibility with pytorch 1.13
Expand Down
10 changes: 9 additions & 1 deletion axonn/tests/test_intra_layer_conv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright 2021 Parallel Software and Systems Group, University of Maryland.
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch
import pytest
from mpi4py import MPI # noqa: F401
Expand Down Expand Up @@ -34,7 +39,6 @@ def norm_allclose(X, Y):
return False


@pytest.mark.mpi
@pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)])
@pytest.mark.parametrize("B", [2, 4, 16])
@pytest.mark.parametrize(
Expand All @@ -47,6 +51,8 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
# These tests are in fp-32
torch.manual_seed(42)
torch.cuda.manual_seed(42)
if not torch.distributed.is_initialized():
dist.init_process_group(backend="nccl")
# Need to remove all non-determinism from convolutions
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
Expand Down Expand Up @@ -135,6 +141,8 @@ def test_bw_pass(
# Need to remove all non-determinism from convolutions
torch.manual_seed(42)
torch.cuda.manual_seed(42)
if not torch.distributed.is_initialized():
dist.init_process_group(backend="nccl")
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
Expand Down
Loading

0 comments on commit 459a443

Please sign in to comment.