diff --git a/torchacc/dist/fsdp.py b/torchacc/dist/fsdp.py index 3c45a0a..9d97fbb 100644 --- a/torchacc/dist/fsdp.py +++ b/torchacc/dist/fsdp.py @@ -289,7 +289,7 @@ def full_optim_state_dict(model: torch.nn.Module, xla_fsdp.XlaFullyShardedDataParallel) or isinstance( model, FullyShardedDataParallel) - if isinstance(model, xla_fsdp.XlaFullyShardedDataParallel): + if isinstance(model, FullyShardedDataParallel): model = model.model shard_meta_data = model.get_shard_metadata()