Skip to content

Commit

Permalink
sync latest update to model_context. (#559)
Browse files Browse the repository at this point in the history
  • Loading branch information
nash635 authored Aug 1, 2023
1 parent fe945e5 commit 61f07e1
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 7 deletions.
170 changes: 163 additions & 7 deletions atorch/atorch/auto/model_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import atorch
from atorch.auto.device_context import get_device_context
from atorch.auto.opt_lib.utils import _propose_leaf_modules, _propose_wrap_cls, to_module_class_by_name
from atorch.common.log_utils import default_logger as logger
from atorch.data import ShmDataloader, expand_batch_dim, get_sample_batch
from atorch.distributed.distributed import (
Expand All @@ -27,6 +28,7 @@
rank,
)
from atorch.utils.graph_transform_utils import map_aggregate
from atorch.utils.version import torch_version

try:
from pippy.IR import LossWrapper
Expand Down Expand Up @@ -450,24 +452,79 @@ def adjust_wrappers(self):
self.post_wrappers.pop("ddp")

if pipe_wrapper_exist:
# TODO: support pipe
pass
if "checkpoint" in self.post_wrappers:
ckpt_wrapper = self.post_wrappers.pop("checkpoint")
pipe_config = self.pre_wrappers["pipe"][1] or {}
ckpt_wrap_cls = ckpt_wrapper[1]
pipe_wrap_cls = set(to_module_class_by_name(self.model, ckpt_wrap_cls))
leaf_modules = _propose_leaf_modules(pipe_wrap_cls)
pipe_config["compiler_configs"]["checkpoint"] = True
pipe_config["compiler_configs"]["leaf_modules"] = leaf_modules
self.pre_wrappers["pipe"] = (
self.pre_wrappers["pipe"][0],
pipe_config,
)

if "amp_native" in self.post_wrappers:
amp_wrapper = self.post_wrappers.pop("amp_native")
amp_config = amp_wrapper[1] or None
pipe_config = self.pre_wrappers["pipe"][1] or {}
pipe_config["compiler_configs"]["amp_config"] = amp_config
self.pre_wrappers["pipe"] = (
self.pre_wrappers["pipe"][0],
pipe_config,
)

if "module_replace" in self.pre_wrappers:
self.pre_wrappers.pop("module_replace")
pipe_config = self.pre_wrappers["pipe"][1] or {}
pipe_config["compiler_configs"]["module_replace"] = True
self.pre_wrappers["pipe"] = (
self.pre_wrappers["pipe"][0],
pipe_config,
)

# FIXME Allow mixing of DDP/ZeRO with MP?
if mp_wrapper_exist:
# TODO: support mp
pass
mp_config = self.pre_wrappers["mp"][1] or {}
pipe_config = mp_config["pipe_config"]
if pipe_config is not None:
if "checkpoint" in self.post_wrappers:
ckpt_wrapper = self.post_wrappers.pop("checkpoint")
ckpt_wrap_cls = ckpt_wrapper[1]
pipe_wrap_cls = set(to_module_class_by_name(self.model, ckpt_wrap_cls))
leaf_modules = _propose_leaf_modules(pipe_wrap_cls)
pipe_config["compiler_configs"]["checkpoint"] = True
pipe_config["compiler_configs"]["leaf_modules"] = leaf_modules

if "amp_native" in self.post_wrappers:
amp_wrapper = self.post_wrappers.pop("amp_native")
amp_config = amp_wrapper[1] or None
pipe_config["compiler_configs"]["amp_config"] = amp_config

if "module_replace" in self.pre_wrappers:
self.pre_wrappers.pop("module_replace")
pipe_config["compiler_configs"]["module_replace"] = True

mp_config["pipe_config"] = pipe_config

self.pre_wrappers["mp"] = (
self.pre_wrappers["mp"][0],
mp_config,
)

ddp_wrapper_exist = "ddp" in self.post_wrappers
fairscale_zero2_wrapper_exist = "zero2" in self.post_wrappers
fsdp_wrapper_exist = "fsdp" in self.pre_wrappers or "zero2" in self.pre_wrappers
tensor_parallel_wrapper_exist = "tp" in self.pre_wrappers
# ckpt_wrapper_exist = "checkpoint" in self.post_wrappers
ckpt_wrapper_exist = "checkpoint" in self.post_wrappers
native_dynamo_wrapper_exist = "native_dynamo" in self.pre_wrappers

# remove ddp wrapper when using zero2
if ddp_wrapper_exist and (fairscale_zero2_wrapper_exist or fsdp_wrapper_exist):
logger.info("Found Zero and ddp wrapper or pipe wrapper, remove ddp wrapper")
self.post_wrappers.pop("ddp")
ddp_wrapper_exist = False
if fsdp_wrapper_exist and "amp_native" in self.post_wrappers:
logger.info("Found fsdp and amp_native wrapper, turn on mixed_precision in FSDP")
_, amp_native_config = self.post_wrappers["amp_native"]
Expand Down Expand Up @@ -502,9 +559,108 @@ def adjust_wrappers(self):
wrappers_list.pop(ddp_or_zero2_wrapper_index)
self.post_wrappers = dict(wrappers_list)

# move dynamo_native wrapper behind ddp or fsdp
# Note that dynamo_native wrapper and fsdp wrapper are pre-wrappers while ddp wrapper is a post-wrapper.
if native_dynamo_wrapper_exist:
if fsdp_wrapper_exist:
# both dynamo_native wrapper and fsdp wrapper are pre-wrappers
native_dynamo_wrapper_index, fsdp_wrapper_index = -1, -1
pre_wrappers_list = []
for i, (wrapper_name, v) in enumerate(self.pre_wrappers.items()):
pre_wrappers_list.append((wrapper_name, v))
if wrapper_name == "fsdp":
fsdp_wrapper_index = i
elif wrapper_name == "native_dynamo":
native_dynamo_wrapper_index = i
if native_dynamo_wrapper_index < fsdp_wrapper_index:
native_dynamo_wrapper = pre_wrappers_list[native_dynamo_wrapper_index]
pre_wrappers_list.insert(fsdp_wrapper_index + 1, native_dynamo_wrapper)
pre_wrappers_list.pop(native_dynamo_wrapper_index)
self.pre_wrappers = dict(pre_wrappers_list)
elif ddp_wrapper_exist:
# ddp wrapper is a post-wrapper. Popping dynamo_native wrapper from pre-wrappers
# then insert it after ddp wrapper.
post_wrappers_list = []
ddp_wrapper_index = -1
for i, (wrapper_name, v) in enumerate(self.post_wrappers.items()):
post_wrappers_list.append((wrapper_name, v))
if wrapper_name == "ddp":
ddp_wrapper_index = i
native_dynamo_wrapper = self.pre_wrappers.pop("native_dynamo")
post_wrappers_list.insert(ddp_wrapper_index + 1, ("native_dynamo", native_dynamo_wrapper))
self.post_wrappers = dict(post_wrappers_list)

if tensor_parallel_wrapper_exist:
# todo: support tp
pass
wrap_cls = None
if fsdp_wrapper_exist and torch_version() >= (1, 12, 0):
fsdp_wrapper = self.pre_wrappers["fsdp"]
fsdp_wrapper = list(fsdp_wrapper)
if fsdp_wrapper[1] is None:
fsdp_wrapper[1] = dict()

fsdp_config = fsdp_wrapper[1]
if wrap_cls is None:
wrap_cls = set(to_module_class_by_name(self.model, fsdp_config.get("atorch_wrap_cls", set())))
else:
wrap_cls = wrap_cls & set(
to_module_class_by_name(self.model, fsdp_config.get("atorch_wrap_cls", set()))
)

if ckpt_wrapper_exist:
ckpt_wrapper = self.post_wrappers["checkpoint"]
ckpt_wrapper = list(ckpt_wrapper)
if ckpt_wrapper[1] is None:
ckpt_wrapper[1] = tuple()
ckpt_wrap_cls = ckpt_wrapper[1]
if wrap_cls is None:
wrap_cls = set(to_module_class_by_name(self.model, ckpt_wrap_cls))
else:
wrap_cls = set(to_module_class_by_name(self.model, ckpt_wrap_cls)) & wrap_cls

leaf_modules = _propose_leaf_modules(wrap_cls)
auto_wrap_cls = _propose_wrap_cls(leaf_modules)

if fsdp_wrapper_exist and torch_version() >= (1, 12, 0):
if "atorch_wrap_cls" in fsdp_config:
if auto_wrap_cls is not None:
fsdp_config["atorch_wrap_cls"] = auto_wrap_cls
else:
fsdp_config.pop("atorch_wrap_cls")

fsdp_wrapper[1] = fsdp_config
self.pre_wrappers["fsdp"] = tuple(fsdp_wrapper)

if ckpt_wrapper_exist:
if auto_wrap_cls is not None:
ckpt_wrapper[1] = tuple(auto_wrap_cls)
self.post_wrappers["checkpoint"] = tuple(ckpt_wrapper)
else:
# in this case module structure should have been broken, nothing to checkpoint on
self.post_wrappers.pop("checkpoint")

# let tensor parallel wrapper be the first, make sure meta models fully reloaded
tensor_parallel_wrapper_item = ("tp", self.pre_wrappers["tp"])
wrappers_list = list(self.pre_wrappers.items())
tensor_parallel_idx = wrappers_list.index(tensor_parallel_wrapper_item)
wrappers_list.pop(tensor_parallel_idx)
# wrapper item and wrapper are all tuples
tensor_parallel_wrapper_item = list(tensor_parallel_wrapper_item)
tensor_parallel_wrapper_item[1] = list(tensor_parallel_wrapper_item[1])
tensor_parallel_wrapper_item[1][1]["leaf_modules"] = leaf_modules
if fsdp_wrapper_exist or pipe_wrapper_exist:
tensor_parallel_wrapper_item[1][1]["defer_init"] = True
tensor_parallel_wrapper_item[1] = tuple(tensor_parallel_wrapper_item[1])
tensor_parallel_wrapper_item = tuple(tensor_parallel_wrapper_item)
wrappers_list.insert(0, tensor_parallel_wrapper_item)
self.pre_wrappers = dict(wrappers_list)

# TP checkpointing needs amp_config explicitly to take effect, HACK here
if "amp_native" in self.post_wrappers and "checkpoint" in self.post_wrappers:
amp_wrapper = self.post_wrappers.pop("amp_native")
amp_config = amp_wrapper[1] or None
from atorch.modules.distributed_modules.activation_checkpointing import _insert_amp_config_for_tp_ckpt

_insert_amp_config_for_tp_ckpt(amp_config)

def add_wrapper(self, wrapper_name, wrapper_func, wrapper_config=None, is_pre_wrapper=True):
"""
Expand Down
56 changes: 56 additions & 0 deletions atorch/atorch/auto/opt_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,65 @@
except ImportError:
pipe_split = None

from atorch.modules.distributed_modules.modules_registry import _SHARD_MAP, _SHARDABLE_OPERATORS, _SHARDED_OPERATORS
from atorch.utils.graph_transform_utils import map_aggregate


# FIXME This is a temporary fix for FSDP and TP compatibility,
# TP will only trace the module to the level of FSDP wrapping, with constraint of being shardable.
# The wrapping will be broken if the wrapped op is shardable, causing FSDP to be slow (in most cases).
def _propose_leaf_modules(atorch_wrap_cls=None):
leaf_modules = None
if atorch_wrap_cls is not None:
leaf_modules = list(set(_SHARDABLE_OPERATORS.values()) & set(atorch_wrap_cls))
if leaf_modules is None or len(leaf_modules) == 0:
leaf_modules = list(_SHARDABLE_OPERATORS.values())
return leaf_modules


def _propose_wrap_cls(leaf_modules=set()):
leaf_module_names = [name for name, cls in _SHARDABLE_OPERATORS.items() if cls in leaf_modules]
if len(leaf_modules) != 0:
atorch_wrap_cls = {
op
for name in leaf_module_names
for op in [_SHARDABLE_OPERATORS[name]] + [_SHARDED_OPERATORS[shard] for shard in _SHARD_MAP[name]]
}
else:
atorch_wrap_cls = None
if len(leaf_modules) == len(list(_SHARDABLE_OPERATORS.items())):
# in this case, auto wrap is meaning less
atorch_wrap_cls = None
return atorch_wrap_cls


def propose_leaf_modules_by_strategy(model, strategy=None):
wrap_cls = None
if strategy is None:
return _propose_leaf_modules(wrap_cls)
for opt in strategy:
opt_name = opt[0]
if opt_name == "fsdp":
if len(opt) > 1:
opt_config = opt[1]
atorch_wrap_cls = set(to_module_class_by_name(model, opt_config.get("atorch_wrap_cls", set())))
if wrap_cls is None:
wrap_cls = atorch_wrap_cls
else:
wrap_cls = wrap_cls & atorch_wrap_cls
if opt_name == "checkpoint":
if len(opt) > 1:
opt_config = opt[1]
ckpt_wrap_cls = set(to_module_class_by_name(model, opt_config))
if wrap_cls is None:
wrap_cls = ckpt_wrap_cls
else:
wrap_cls = wrap_cls & ckpt_wrap_cls

leaf_modules = _propose_leaf_modules(wrap_cls)
return leaf_modules


def find_modules(model, m_list):
if isinstance(model, tuple(m_list)):
return [model]
Expand Down
16 changes: 16 additions & 0 deletions atorch/atorch/tests/model_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,22 @@ def test_adjust_amp_apex_zero2_ddp_wrapper(self):
self.assertLess(amp_apex_o2_index, zero_index)
self.assertEqual(wrappers_order, [0, 1, 2, 3, 4, 5])

def test_adjust_dynamo_and_fsdp_wrapper(self):
self.context.add_wrapper("native_dynamo", None, None, is_pre_wrapper=True)
self.context.add_wrapper("fsdp", None, None, is_pre_wrapper=True)
self.context.adjust_wrappers()
wrapper_names = [name for name, _ in self.context.pre_wrappers.items()]
wrappers_order = [wrapper_names.index("fsdp"), wrapper_names.index("native_dynamo")]
self.assertListEqual(wrappers_order, [0, 1])

def test_adjust_dynamo_and_ddp_wrapper(self):
self.context.add_wrapper("native_dynamo", None, None, is_pre_wrapper=True)
self.context.add_wrapper("ddp", None, None, is_pre_wrapper=False)
self.context.adjust_wrappers()
wrapper_names = [name for name, _ in self.context.post_wrappers.items()]
wrappers_order = [wrapper_names.index("ddp"), wrapper_names.index("native_dynamo")]
self.assertListEqual(wrappers_order, [0, 1])


def use_shm_dataloader_func():
if torch.cuda.is_available():
Expand Down

0 comments on commit 61f07e1

Please sign in to comment.