Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
hanwen-sun committed Sep 4, 2024
1 parent a729b08 commit 14c304e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 6 deletions.
4 changes: 2 additions & 2 deletions torchacc/dist/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def optim_state_dict(
raise NotImplementedError(
"we only support 'FULL_SATE_DICT' StateDictType now")
if not self.model.flatten_parameters:
raise NotImplementedError("we only support flatten_parameters now")
raise NotImplementedError("we only support flatten_parameters=True now")

shard_meta_data = self.model.get_shard_metadata()
sharded_optim_state = optim.state_dict()['state']
Expand Down Expand Up @@ -342,7 +342,7 @@ def load_optim_state_dict(
raise NotImplementedError(
"we only support 'FULL_SATE_DICT' StateDictType now")
if not self.model.flatten_parameters:
raise NotImplementedError("we only support flatten_parameters now")
raise NotImplementedError("we only support flatten_parameters=True now")
shard_meta_data = self.model.get_shard_metadata()
unflat_optim_state = optim_state_dict
flat_optim_state: Dict[str, Any] = {'state': {}, 'param_groups': {}}
Expand Down
4 changes: 0 additions & 4 deletions torchacc/utils/optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,6 @@ def _all_gather_state(state_params, model):
if state_params.dim() == 0:
return state_params

shape_list = list(state_params.size())
shape_list[0] = shape_list[0] * model.world_size
buffer_size = tuple(shape_list)
tensor_buffer = state_params.new_zeros(*buffer_size)
tensor_buffer = model.all_gather_op(
state_params, groups=model.sharding_groups)

Expand Down

0 comments on commit 14c304e

Please sign in to comment.