Skip to content

Commit

Permalink
done
Browse files Browse the repository at this point in the history
  • Loading branch information
hanwen-sun committed Sep 14, 2024
1 parent 380074e commit fa2137e
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 34 deletions.
30 changes: 30 additions & 0 deletions torchacc/dist/distributed_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from torchacc.config import Config
from torchacc.dist import ParallelModule, DataParallel, FullyShardedDataParallel, PipelineParallel, SpmdFullyShardedDataParallel
from typing import Any, Dict


class DistributedParallel(ParallelModule):
Expand Down Expand Up @@ -71,3 +72,32 @@ def forward_backward(self, *args, output_fn=None, **kwargs):
"forward_backward is only supported for pipeline parallel.")
assert isinstance(self.model, PipelineParallel)
return self.model.forward_backward(*args, output_fn=output_fn, **kwargs)

def sharded_optim_state_dict(self, optim: torch.optim.Optimizer):
if not self.has_fsdp:
raise NotImplementedError(
"sharded_optim_state_dict is only support for FullyShardedDataParallel"
)
assert isinstance(self.model, FullyShardedDataParallel)
return self.model.sharded_optim_state_dict(self.model, optim)

def full_optim_state_dict(self,
optim: torch.optim.Optimizer,
rank0_only: bool = True,
cpu_offload: bool = True):
if not self.has_fsdp:
raise NotImplementedError(
"full_optim_state_dict is only support for FullyShardedDataParallel"
)
assert isinstance(self.model, FullyShardedDataParallel)
return self.model.full_optim_state_dict(self.model, optim)

def load_optim_state_dict(self,
optim_state_dict: Dict[str, Any],
rank0_only: bool = True):
if not self.has_fsdp:
raise NotImplementedError(
"load_optim_state_dict is only support for FullyShardedDataParallel"
)
assert isinstance(self.model, FullyShardedDataParallel)
return self.model.load_optim_state_dict(self.model, optim)
62 changes: 36 additions & 26 deletions torchacc/dist/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,29 +207,33 @@ def clip_grad_norm_(self, max_grad_norm):
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

@staticmethod
def sharded_optim_state_dict(
self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
):
"""
Return the optimizer state-dict in its sharded form.
Args:
optim (torch.optim.Optimizer): Optimizer for self.model's
model (torch.nn.Module): FSDP model whose parameters were
passed into the optimizer ``optim``.
optim (torch.optim.Optimizer): Optimizer for model's
parameters.
Returns:
Dict[str, Any]: A :class:`dict` containing the optimizer state for
self.model. Each rank get the sharded optim state added with shard_metadata.
fsdp model. Each rank get the sharded optim state added with shard_metadata.
"""
optimizer = {
"optimizer": optim.state_dict(),
"shard_metadata": self.model.get_shard_metadata(),
"shard_metadata": model.model.get_shard_metadata(),
}

return optimizer

def full_optim_state_dict(self,
@staticmethod
def full_optim_state_dict(model: torch.nn.Module,
optim: torch.optim.Optimizer,
rank0_only: bool = True,
cpu_offload: bool = True) -> Dict[str, Any]:
Expand All @@ -239,15 +243,17 @@ def full_optim_state_dict(self,
as a :class:`dict` following the convention of
:meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"``
and ``"param_groups"``. The flattened parameters in ``FSDP`` modules
contained in self.model are mapped back to their unflattened parameters.
contained in model are mapped back to their unflattened parameters.
.. warning:: This needs to be called on all ranks since it uses
collective communications. However, if ``rank0_only=True``, then
the state dict is only populated on rank 0, and all other ranks
return an empty :class:`dict`.
Args:
optim (torch.optim.Optimizer): Optimizer for self.model 's
model (torch.nn.Module): FSDP model whose parameters were
passed into the optimizer ``optim``.
optim (torch.optim.Optimizer): Optimizer for model 's
parameters.
rank0_only (bool): If ``True``, return the populated :class:`dict`
only on rank 0; if ``False``, return it on all ranks. (Default:
Expand All @@ -262,7 +268,7 @@ def full_optim_state_dict(self,
:meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=True``,
then nonzero ranks return an :class:`dict` with keys but empty value.
"""
shard_meta_data = self.model.get_shard_metadata()
shard_meta_data = model.model.get_shard_metadata()
sharded_optim_state = optim.state_dict()['state']
optim_state_param_groups = optim.state_dict()['param_groups']
# unflattened and consolidated state_dict
Expand All @@ -273,14 +279,14 @@ def full_optim_state_dict(self,

# param_names(2-dim list), param_shapes(2-dim list), param_numel(2-dim list)
layer_name_lists, layer_size_lists, layer_numel_lists = optim_utils.get_layer_full_info(
shard_meta_data, self.state_dict())
shard_meta_data, model.state_dict())

# transform 2-dim list name to 1-dim list name
flatten_name_list = [
fn for layer_fn in layer_name_lists for fn in layer_fn
]
# (rank0_only and self.model.rank == 0) or (not rank0_only)
if not rank0_only or self.model.rank == 0:
if not rank0_only or model.model.rank == 0:
consolidate_optim_state_dict['param_groups'] = copy.deepcopy(
optim_state_param_groups)
consolidate_optim_state_dict['param_groups'][0]['params'].clear()
Expand All @@ -296,18 +302,18 @@ def full_optim_state_dict(self,
layer_numels = layer_numel_lists[idx]
for state_name, state_params in layer_state.items():
tensor_buffer = optim_utils.all_gather_state(
state_params, self.model.sharding_groups,
self.model.all_gather_op)
state_params, model.model.sharding_groups,
model.model.all_gather_op)
tensor_buffer = optim_utils.unpad(
tensor_buffer, layer_numels,
self.model.world_size * self.model._shard_size_multiple)
model.model.world_size * model.model._shard_size_multiple)
orig_params = optim_utils.unflatten_optim_params(
tensor_buffer, layer_names, layer_shapes, layer_numels)

if not rank0_only or self.model.rank == 0:
if not rank0_only or model.model.rank == 0:
for fn, fp in zip(layer_names, orig_params):
if cpu_offload:
ta.mark_step()
ta.mark_step() # tensor evaluation
unflat_state_dict[fn][state_name] = fp.cpu()
else:
unflat_state_dict[fn][state_name] = fp
Expand All @@ -316,25 +322,28 @@ def full_optim_state_dict(self,

return consolidate_optim_state_dict

def load_optim_state_dict(self,
@staticmethod
def load_optim_state_dict(model: torch.nn.Module,
optim_state_dict: Dict[str, Any],
rank0_only: bool = True) -> Dict[str, Any]:
"""
Convert an optimizer state-dict so that it can be loaded into the
optimizer associated with the FSDP model.
We check whether the optim_state_dict is sharded automatically
We check whether the optim_state_dict is sharded automatically.
Args:
model (torch.nn.Module): FSDP model whose parameters were
passed into the optimizer whose state_dict is ``optim_state_dict``.
optim_state_dict (Dict[str, Any]): The optimizer states to be loaded.
rank0_only: (bool): control whether load state_dict only from
rank0 at the begining.(Default: ``True``) If set to True,
then nonzero ranks load an :class:`dict` with keys but empty value.
Returns:
Dict[str, Any]: A :class:`dict` containing the optimizer state for
self.model which is sharded.
model which is sharded.
"""
shard_meta_data = self.model.get_shard_metadata()
shard_meta_data = model.model.get_shard_metadata()

# for sharded optim_state, we return directly
if 'shard_metadata' in optim_state_dict.keys():
Expand All @@ -349,16 +358,17 @@ def load_optim_state_dict(self,
flat_optim_state: Dict[str, Any] = {'state': {}, 'param_groups': {}}

layer_name_lists, layer_size_lists, layer_numel_lists = optim_utils.get_layer_full_info(
shard_meta_data, self.state_dict())
shard_meta_data, model.state_dict())

if rank0_only:
unflat_optim_state = optim_utils.broadcast_processed_state(
unflat_optim_state, self.model.rank, self.model.sharding_groups)
unflat_optim_state, model.model.rank,
model.model.sharding_groups)
unflat_state = unflat_optim_state['state']

flat_optim_state['param_groups'] = copy.deepcopy(
unflat_optim_state['param_groups'])
# flatten and sharded state_dict

for idx, layer_names in enumerate(layer_name_lists):
flat_value: Dict[str, Any] = {}
# broadcast tensor to other ranks per layer per state
Expand All @@ -376,16 +386,16 @@ def load_optim_state_dict(self,
tensor_buffer = unflat_state[name][state_name]
if rank0_only:
tensor_buffer = optim_utils.broadcast_state(
state_params, self.model.xla_device,
self.model.rank, self.model.sharding_groups,
self.model.collective_broadcast_op)
state_params, model.model.xla_device,
model.model.rank, model.model.sharding_groups,
model.model.collective_broadcast_op)
tensor_buffer_list.append(tensor_buffer)

flat_tensor = optim_utils.flatten_optim_state(
tensor_buffer_list)

if len(flat_tensor):
flat_value[state_name] = self.model._get_shard(flat_tensor)
flat_value[state_name] = model.model._get_shard(flat_tensor)
ta.mark_step()

flat_optim_state['state'][idx] = flat_value
Expand Down
18 changes: 10 additions & 8 deletions torchacc/utils/optim_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.distributed as dist
from typing import NamedTuple, Optional
from typing import Any, Dict, NamedTuple, Optional
from torch.utils._pytree import tree_map_only
import torch_xla.core.xla_model as xm

Expand All @@ -23,13 +23,13 @@ def get_layer_full_info(shard_metadata, model_state_dict):
The state_dict from an FSDP model.
Returns:
For all ranks, we get the same shard_metadata and model_state_dict, so the return value is
same.
layer_name_list: 2-dimension list, contains the full name information.
if parameters if flattened, each layer may have mutiple names.
layer_size_list: 2-dimension list, contains the unflatten and unshard shape information of
For all ranks, we get the same shard_metadata and model_state_dict, and the return value is
same:
layer_name_list(list): 2-dimension list([[layer_name_group1], [layer_name_group2], ...]), contains the full name information.
if parameters if flatten, each layer may have mutiple orig name and parameter.
layer_size_list(list): 2-dimension list([[layer_size_group1], [layer_size_group2], ...]), contains the unflatten and unshard shape information of
each layer.
layer_numel_list: 2-dimension list, contains the unflatten and unshard numel information of
layer_numel_list(list): 2-dimension list([[layer_numel_group1], [layer_numel_group2], ...]), contains the unflatten and unshard numel information of
each layer.
"""
layer_name_list = []
Expand Down Expand Up @@ -156,7 +156,7 @@ def _cleanup_gloo_distributed(pg):
dist.destroy_process_group(pg)


def broadcast_processed_state(optim_state: dict[str, any], rank,
def broadcast_processed_state(optim_state: dict[str, Any], rank,
sharding_groups):
objects: list[Any] = [None]
if rank == 0:
Expand All @@ -170,12 +170,14 @@ def broadcast_processed_state(optim_state: dict[str, any], rank,
ordinal = xm.get_ordinal()
new_group = []

# broadcast within each sharding_group
for group in sharding_groups:
if ordinal in group:
new_group = group
break

pg_group = _setup_gloo_distributed(new_group)
# the src is the global rank of each sharding group's rank0
dist.broadcast_object_list(
objects, src=dist.get_global_rank(pg_group, 0), group=pg_group)
_cleanup_gloo_distributed(pg_group)
Expand Down

0 comments on commit fa2137e

Please sign in to comment.