Skip to content

Commit

Permalink
Merge branch 'develop' into fix-model-kwargs-in-parallelize
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored Oct 25, 2024
2 parents 630f0e5 + 660a224 commit c8959eb
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 34 deletions.
71 changes: 49 additions & 22 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MPI4PY = False
import torch
import numpy as np
from typing import Sequence, Optional


class communication_handle:
Expand Down Expand Up @@ -135,7 +136,6 @@ def __init__(
self.coll_nccl_comm = ith_jth_data_parallel_group
self.data_parallel_group = ith_jth_data_parallel_group

# create communicators for intra-layer parallelism
for i_ in range(G_data):
for j_ in range(G_inter):
ranks_in_ith_jth_intra_layer_group = [
Expand All @@ -152,13 +152,46 @@ def __init__(
== G_intra_r * G_intra_c * G_intra_d
)

# store intra-layer groups here
# the keys of this dictionary are
# tuples of (G_intra_r, G_intra_c, G_intra_d)

self.intra_layer_group_cache = {}
(
self.inner_intra_layer_parallel_group,
self.outer_intra_layer_parallel_group,
self.depth_intra_layer_parallel_group,
) = self.get_intra_layer_groups()

def get_intra_layer_groups(
self, tensor_parallel_dims: Optional[Sequence[int]] = None
):
G_inter, G_data, G_intra = self.G_inter, self.G_data, self.G_intra
if tensor_parallel_dims is None:
G_intra_r, G_intra_c, G_intra_d = (
self.G_intra_r,
self.G_intra_c,
self.G_intra_d,
)
else:
G_intra_r, G_intra_c, G_intra_d = tensor_parallel_dims
# first check if these communicators have already
# been created
group_key = (G_intra_r, G_intra_c, G_intra_d)
if group_key in self.intra_layer_group_cache:
return self.intra_layer_group_cache[group_key]

# create communicators for intra-layer parallelism
for i_ in range(G_data):
for j_ in range(G_inter):
ranks_in_ith_jth_intra_layer_group = [
i_ * G_inter * G_intra + j_ * G_intra + k for k in range(G_intra)
]
ranks_in_ith_jth_intra_layer_group = np.array(
ranks_in_ith_jth_intra_layer_group
).reshape(G_intra_d, G_intra_r, G_intra_c)
# form row and column tensor parallel groups
# G_intra_d x G_intra_r x G_intra_c

# inner
# inner/column
for i in range(G_intra_d):
for j in range(G_intra_r):
group_members = list(
Expand All @@ -168,9 +201,9 @@ def __init__(
ranks=group_members, backend="nccl"
)
if self.world_rank in group_members:
self.inner_intra_layer_parallel_group = group
inner_intra_layer_parallel_group = group

# outer
# outer/row
for i in range(G_intra_d):
for j in range(G_intra_c):
group_members = list(
Expand All @@ -180,9 +213,9 @@ def __init__(
ranks=group_members, backend="nccl"
)
if self.world_rank in group_members:
self.outer_intra_layer_parallel_group = group
outer_intra_layer_parallel_group = group

# depth
# depth/fsdp
for i in range(G_intra_r):
for j in range(G_intra_c):
group_members = list(
Expand All @@ -192,21 +225,15 @@ def __init__(
ranks=group_members, backend="nccl"
)
if self.world_rank in group_members:
self.depth_intra_layer_parallel_group = group
depth_intra_layer_parallel_group = group

# combined inner+outer
for i in range(G_intra_d):
group_members = list(
ranks_in_ith_jth_intra_layer_group[i, :, :].flatten()
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
)
if self.world_rank in group_members:
self.outer_inner_intra_layer_parallel_group = group
self.outer_inner_intra_layer_parallel_group_root = (
group_members[0]
)
self.intra_layer_group_cache[group_key] = (
inner_intra_layer_parallel_group,
outer_intra_layer_parallel_group,
depth_intra_layer_parallel_group,
)

return self.intra_layer_group_cache[group_key]

def _torch_to_mpi(self, tensor: torch.Tensor):
"""Converts a PyTorch tensor into an mpi4py compatible array using its
Expand Down
17 changes: 14 additions & 3 deletions axonn/intra_layer/automatic_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,28 @@ def is_parallelizable_embedding(num_embeddings, embedding_dim):


class patched_linear:
def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None):
def __new__(
cls,
in_features,
out_features,
*args,
bias=True,
device=None,
dtype=None,
**kwargs,
):
if is_parallelizable_linear(in_features, out_features):
parallel_layer = Linear(in_features, out_features, bias=bias)
parallel_layer = Linear(
in_features, out_features, bias=bias, *args, **kwargs
)
if device is not None:
parallel_layer = parallel_layer.to(device)
if dtype is not None:
parallel_layer = parallel_layer.to(dtype)
return parallel_layer
else:
sequential_layer = reference_to_original_linear_class(
in_features, out_features, bias=bias
in_features, out_features, bias=bias, *args, **kwargs
)
if device is not None:
sequential_layer = sequential_layer.to(device)
Expand Down
23 changes: 14 additions & 9 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
GatherChannelsScatterBatch,
gather_batch_sizes,
)
from typing import Optional, Sequence


# Wrapper for custom_fwd to handle different versions of PyTorch
Expand Down Expand Up @@ -197,6 +198,7 @@ def __init__(
skip_bias_add=False,
init_method=None,
expert_mode=False,
tensor_parallel_dims: Optional[Sequence[int]] = None,
**kwargs,
):
super(Linear, self).__init__()
Expand All @@ -205,13 +207,16 @@ def __init__(
# in_features are distributed across self.inner_group (X tensor parallel group)
# out_features are distributed across self.inner_group (Y tensor parallel group)
# if transpose is true then X and Y are swapped

if not transpose:
self.inner_group = ax.comm_handle.inner_intra_layer_parallel_group
self.outer_group = ax.comm_handle.outer_intra_layer_parallel_group
else:
self.inner_group = ax.comm_handle.outer_intra_layer_parallel_group
self.outer_group = ax.comm_handle.inner_intra_layer_parallel_group
if tensor_parallel_dims is not None and torch.distributed.get_rank() == 0:
print(
"Manually setting TP dims for a layer with shape",
f" - {(in_features, out_features)} | tp-dims = {tensor_parallel_dims}",
)
self.inner_group, self.outer_group, self.depth_group = (
ax.comm_handle.get_intra_layer_groups(tensor_parallel_dims)
)
if transpose:
self.inner_group, self.outer_group = self.outer_group, self.inner_group

# depth_group is the Z tensor parallel group (akin to FSDP)
self.depth_group = ax.comm_handle.depth_intra_layer_parallel_group
Expand Down Expand Up @@ -303,7 +308,7 @@ def forward(
original_shape_x = x.shape
x = x.reshape(-1, x.shape[-1])
weight = self.weight
if not self.expert_mode:
if not self.expert_mode and (self.inner_group_size * self.outer_group_size > 1):
# extra communication to transition from pure data parallelism
# to 4D hybrid parallelism
inner_group_batch_sizes = gather_batch_sizes(x.shape[0], self.inner_group)
Expand All @@ -321,7 +326,7 @@ def forward(
(self.local_out_features, self.local_in_features),
cache_weights_in_all_gather,
)
if not self.expert_mode:
if not self.expert_mode and (self.inner_group_size * self.outer_group_size > 1):
# extra communication to transition from 4D hybrid parallelism
# to pure data parallelism
x = GatherChannelsScatterBatch.apply(
Expand Down

0 comments on commit c8959eb

Please sign in to comment.