From ffd4789d165978b4d358d8ab938fce02125b27c9 Mon Sep 17 00:00:00 2001 From: shw Date: Wed, 4 Sep 2024 17:51:31 +0800 Subject: [PATCH] format --- torchacc/dist/fsdp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchacc/dist/fsdp.py b/torchacc/dist/fsdp.py index da9c7d9..77ac844 100644 --- a/torchacc/dist/fsdp.py +++ b/torchacc/dist/fsdp.py @@ -255,7 +255,8 @@ def optim_state_dict( raise NotImplementedError( "we only support 'FULL_SATE_DICT' StateDictType now") if not self.model.flatten_parameters: - raise NotImplementedError("we only support flatten_parameters=True now") + raise NotImplementedError( + "we only support flatten_parameters=True now") shard_meta_data = self.model.get_shard_metadata() sharded_optim_state = optim.state_dict()['state'] @@ -342,7 +343,8 @@ def load_optim_state_dict( raise NotImplementedError( "we only support 'FULL_SATE_DICT' StateDictType now") if not self.model.flatten_parameters: - raise NotImplementedError("we only support flatten_parameters=True now") + raise NotImplementedError( + "we only support flatten_parameters=True now") shard_meta_data = self.model.get_shard_metadata() unflat_optim_state = optim_state_dict flat_optim_state: Dict[str, Any] = {'state': {}, 'param_groups': {}}