From 14c304e509e7c9c922258422f5040a87b522d138 Mon Sep 17 00:00:00 2001 From: shw Date: Wed, 4 Sep 2024 17:48:17 +0800 Subject: [PATCH] refine --- torchacc/dist/fsdp.py | 4 ++-- torchacc/utils/optim_utils.py | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/torchacc/dist/fsdp.py b/torchacc/dist/fsdp.py index c814a77..da9c7d9 100644 --- a/torchacc/dist/fsdp.py +++ b/torchacc/dist/fsdp.py @@ -255,7 +255,7 @@ def optim_state_dict( raise NotImplementedError( "we only support 'FULL_SATE_DICT' StateDictType now") if not self.model.flatten_parameters: - raise NotImplementedError("we only support flatten_parameters now") + raise NotImplementedError("we only support flatten_parameters=True now") shard_meta_data = self.model.get_shard_metadata() sharded_optim_state = optim.state_dict()['state'] @@ -342,7 +342,7 @@ def load_optim_state_dict( raise NotImplementedError( "we only support 'FULL_SATE_DICT' StateDictType now") if not self.model.flatten_parameters: - raise NotImplementedError("we only support flatten_parameters now") + raise NotImplementedError("we only support flatten_parameters=True now") shard_meta_data = self.model.get_shard_metadata() unflat_optim_state = optim_state_dict flat_optim_state: Dict[str, Any] = {'state': {}, 'param_groups': {}} diff --git a/torchacc/utils/optim_utils.py b/torchacc/utils/optim_utils.py index 43823a0..4e8292a 100644 --- a/torchacc/utils/optim_utils.py +++ b/torchacc/utils/optim_utils.py @@ -113,10 +113,6 @@ def _all_gather_state(state_params, model): if state_params.dim() == 0: return state_params - shape_list = list(state_params.size()) - shape_list[0] = shape_list[0] * model.world_size - buffer_size = tuple(shape_list) - tensor_buffer = state_params.new_zeros(*buffer_size) tensor_buffer = model.all_gather_op( state_params, groups=model.sharding_groups)