Skip to content

Commit

Permalink
[FSDP][optim_state_dict][8/N] Enable fully_shard optim state_dict sav…
Browse files Browse the repository at this point in the history
…e and load (pytorch#91234)

**What does this PR do?**
This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`.
Pull Request resolved: pytorch#91234
Approved by: https://github.com/rohan-varma
  • Loading branch information
fegin authored and pytorchmergebot committed Dec 30, 2022
1 parent f8740db commit 0e8565d
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 114 deletions.
57 changes: 57 additions & 0 deletions test/distributed/_composable/test_fully_shard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["oncall: distributed"]

import unittest
import contextlib
import copy
import functools
Expand Down Expand Up @@ -712,5 +713,61 @@ def _check_model_parity(self, m1: nn.Module, m2: nn.Module):
self.assertEqual(p1, p2)


class TestFSDPOptimStateDict(FSDPTest):
"""Composable FSDP optimizer state dict tests."""

@property
def world_size(self) -> int:
return 2

def _test_optim_state_save_load(self, model1, optim1, model2, optim2) -> None:
batch = torch.randn(2, 100, device="cuda")
for model, optim in (
(model1, optim1),
(model2, optim2),
):
optim.zero_grad(set_to_none=True)
model(batch).sum().backward()
optim.step()

optim_state_dict1 = FSDP._optim_state_dict(model1, optim1)
optim_state_dict2 = FSDP._optim_state_dict(model2, optim2)

self.assertEqual(len(optim_state_dict1["state"]), len(optim_state_dict2["state"]))
for fqn, state in optim_state_dict1["state"].items():
self.assertEqual(state, optim_state_dict2["state"][fqn], fqn)

for group1, group2 in itertools.zip_longest(
optim_state_dict1["param_groups"], optim_state_dict2["param_groups"]
):
for key, value in group1.items():
self.assertEqual(value, group2[key])

@unittest.skip("The test currently fails on CI.")
@skip_if_lt_x_gpu(2)
def test_optim_state_dict_save_load(self):
orig_model = CompositeParamModel(device=torch.device("cuda"))
composable_model = copy.deepcopy(orig_model)
fully_shard(composable_model, policy=ModuleWrapPolicy({UnitModule}))
composable_optim = torch.optim.Adam(composable_model.parameters(), lr=1e-2)
orig_model = FSDP(orig_model)
orig_optim = torch.optim.Adam(orig_model.parameters(), lr=1e-2)

self._test_optim_state_save_load(orig_model, orig_optim, composable_model, composable_optim)

@unittest.skip("The test currently fails on CI.")
@skip_if_lt_x_gpu(2)
def test_optim_state_dict_submodule_fully_shard(self):
orig_model = CompositeParamModel(device=torch.device("cuda"))
composable_model = copy.deepcopy(orig_model)
fully_shard(composable_model.u1)
fully_shard(composable_model.u2)
composable_optim = torch.optim.Adam(composable_model.parameters(), lr=1e-2)
orig_model = FSDP(orig_model)
orig_optim = torch.optim.Adam(orig_model.parameters(), lr=1e-2)

self._test_optim_state_save_load(orig_model, orig_optim, composable_model, composable_optim)


if __name__ == "__main__":
run_tests()
18 changes: 13 additions & 5 deletions test/distributed/fsdp/test_fsdp_optim_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import bisect
import sys
import unittest
from enum import auto, Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

Expand Down Expand Up @@ -782,8 +783,9 @@ def test_flatten_sharded_optim_state_dict_transformer(self) -> None:
num_iters=3,
)

@unittest.skip("The test currently fails on CI.")
@skip_if_lt_x_gpu(2)
def _test_use_orig_params(self) -> None:
def test_use_orig_params(self) -> None:
"""Tests :meth:`optim_state_dict` for an FSDP-root nested model."""
self._test_load_optim_state(
_ModelClass.NESTED,
Expand Down Expand Up @@ -928,8 +930,8 @@ def _test_load_optim_state(
optim=optim2,
)
elif osd_comm_method == _OSDCommMethod.OPTIM_STATE_DICT:
sharded_osd1 = FSDP._load_optim_state_dict(fsdp_osd1, model2, optim2)
sharded_osd2 = FSDP._load_optim_state_dict(fsdp_osd2, model2, optim2)
sharded_osd1 = FSDP._optim_state_dict_to_load(fsdp_osd1, model2, optim2)
sharded_osd2 = FSDP._optim_state_dict_to_load(fsdp_osd2, model2, optim2)

# As a sanity check, check that sharding the second model's full/sharded
# optimizer state dict according to itself is equivalent to its local
Expand Down Expand Up @@ -1440,8 +1442,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
loss.backward()
optim.step()

@unittest.skip("The test currently fails on CI.")
@skip_if_lt_x_gpu(2)
def _test_compatible_with_named_optimizer(self):
def test_compatible_with_named_optimizer(self):
class TestDummyModel(torch.nn.Module):
def __init__(self):
super(TestDummyModel, self).__init__()
Expand Down Expand Up @@ -1475,7 +1478,12 @@ def forward(self, x):
loss = model(batch).sum()
loss.backward()
optim.step()
state_dicts.append(FSDP._optim_state_dict(model, optim))
if isinstance(optim, _NamedOptimizer):
state_dict = optim.state_dict()
state_dict = FSDP._optim_state_dict_post_hook(model, optim, state_dict)
state_dicts.append(state_dict)
else:
state_dicts.append(FSDP._optim_state_dict(model, optim))

self._check_same_param_groups(
state_dicts[0], state_dicts[1], check_same_param_keys=False
Expand Down
21 changes: 20 additions & 1 deletion torch/distributed/fsdp/_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
)

import torch
import torch.distributed as dist
import torch.distributed.fsdp.flat_param as flat_param_file
import torch.nn as nn
from torch.distributed._composable_state import _get_module_state, _State
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_PREFIX,
)

from .api import FullStateDictConfig, StateDictConfig, StateDictType
from .api import FullStateDictConfig, ShardingStrategy, StateDictConfig, StateDictType

FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
FSDP_PREFIX = FSDP_WRAPPED_MODULE + "."
Expand All @@ -41,7 +42,14 @@ def __init__(self) -> None:
self._is_root: Optional[bool] = None
self._handles: List[flat_param_file.FlatParamHandle] = []
self._ignored_modules: Set[nn.Module] = set()
self._fully_sharded_module_to_handles: Dict[
nn.Module, flat_param_file.FlatParamHandle
] = {}
self.rank: int = -1
self.world_size: int = -1
self.sharding_strategy = ShardingStrategy.FULL_SHARD
self.compute_device = torch.device("cuda", torch.cuda.current_device())
self.process_group: Optional[dist.ProcessGroup] = None


def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]:
Expand All @@ -51,6 +59,17 @@ def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]:
return state


def _get_module_fsdp_state_if_comm_module(module: nn.Module) -> Optional[_FSDPState]:
state = _get_module_fsdp_state(module)
if state is None:
return None
if state == module: # FullyShardedDataParallel module case.
return state
if module in state._fully_sharded_module_to_handles: # fully_shard case.
return state
return None


class TrainingState(Enum):
"""
An enum that indicates the state of a ``FullyShardedDataParallel` instance.
Expand Down
Loading

0 comments on commit 0e8565d

Please sign in to comment.