Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync latest update to model_context. #559

Merged
merged 1 commit into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 163 additions & 7 deletions atorch/atorch/auto/model_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import atorch
from atorch.auto.device_context import get_device_context
from atorch.auto.opt_lib.utils import _propose_leaf_modules, _propose_wrap_cls, to_module_class_by_name
from atorch.common.log_utils import default_logger as logger
from atorch.data import ShmDataloader, expand_batch_dim, get_sample_batch
from atorch.distributed.distributed import (
Expand All @@ -27,6 +28,7 @@
rank,
)
from atorch.utils.graph_transform_utils import map_aggregate
from atorch.utils.version import torch_version

try:
from pippy.IR import LossWrapper
Expand Down Expand Up @@ -450,24 +452,79 @@ def adjust_wrappers(self):
self.post_wrappers.pop("ddp")

if pipe_wrapper_exist:
# TODO: support pipe
pass
if "checkpoint" in self.post_wrappers:
ckpt_wrapper = self.post_wrappers.pop("checkpoint")
pipe_config = self.pre_wrappers["pipe"][1] or {}
ckpt_wrap_cls = ckpt_wrapper[1]
pipe_wrap_cls = set(to_module_class_by_name(self.model, ckpt_wrap_cls))
leaf_modules = _propose_leaf_modules(pipe_wrap_cls)
pipe_config["compiler_configs"]["checkpoint"] = True
pipe_config["compiler_configs"]["leaf_modules"] = leaf_modules
self.pre_wrappers["pipe"] = (
self.pre_wrappers["pipe"][0],
pipe_config,
)

if "amp_native" in self.post_wrappers:
amp_wrapper = self.post_wrappers.pop("amp_native")
amp_config = amp_wrapper[1] or None
pipe_config = self.pre_wrappers["pipe"][1] or {}
pipe_config["compiler_configs"]["amp_config"] = amp_config
self.pre_wrappers["pipe"] = (
self.pre_wrappers["pipe"][0],
pipe_config,
)

if "module_replace" in self.pre_wrappers:
self.pre_wrappers.pop("module_replace")
pipe_config = self.pre_wrappers["pipe"][1] or {}
pipe_config["compiler_configs"]["module_replace"] = True
self.pre_wrappers["pipe"] = (
self.pre_wrappers["pipe"][0],
pipe_config,
)

# FIXME Allow mixing of DDP/ZeRO with MP?
if mp_wrapper_exist:
# TODO: support mp
pass
mp_config = self.pre_wrappers["mp"][1] or {}
pipe_config = mp_config["pipe_config"]
if pipe_config is not None:
if "checkpoint" in self.post_wrappers:
ckpt_wrapper = self.post_wrappers.pop("checkpoint")
ckpt_wrap_cls = ckpt_wrapper[1]
pipe_wrap_cls = set(to_module_class_by_name(self.model, ckpt_wrap_cls))
leaf_modules = _propose_leaf_modules(pipe_wrap_cls)
pipe_config["compiler_configs"]["checkpoint"] = True
pipe_config["compiler_configs"]["leaf_modules"] = leaf_modules

if "amp_native" in self.post_wrappers:
amp_wrapper = self.post_wrappers.pop("amp_native")
amp_config = amp_wrapper[1] or None
pipe_config["compiler_configs"]["amp_config"] = amp_config

if "module_replace" in self.pre_wrappers:
self.pre_wrappers.pop("module_replace")
pipe_config["compiler_configs"]["module_replace"] = True

mp_config["pipe_config"] = pipe_config

self.pre_wrappers["mp"] = (
self.pre_wrappers["mp"][0],
mp_config,
)

ddp_wrapper_exist = "ddp" in self.post_wrappers
fairscale_zero2_wrapper_exist = "zero2" in self.post_wrappers
fsdp_wrapper_exist = "fsdp" in self.pre_wrappers or "zero2" in self.pre_wrappers
tensor_parallel_wrapper_exist = "tp" in self.pre_wrappers
# ckpt_wrapper_exist = "checkpoint" in self.post_wrappers
ckpt_wrapper_exist = "checkpoint" in self.post_wrappers
native_dynamo_wrapper_exist = "native_dynamo" in self.pre_wrappers

# remove ddp wrapper when using zero2
if ddp_wrapper_exist and (fairscale_zero2_wrapper_exist or fsdp_wrapper_exist):
logger.info("Found Zero and ddp wrapper or pipe wrapper, remove ddp wrapper")
self.post_wrappers.pop("ddp")
ddp_wrapper_exist = False
if fsdp_wrapper_exist and "amp_native" in self.post_wrappers:
logger.info("Found fsdp and amp_native wrapper, turn on mixed_precision in FSDP")
_, amp_native_config = self.post_wrappers["amp_native"]
Expand Down Expand Up @@ -502,9 +559,108 @@ def adjust_wrappers(self):
wrappers_list.pop(ddp_or_zero2_wrapper_index)
self.post_wrappers = dict(wrappers_list)

# move dynamo_native wrapper behind ddp or fsdp
# Note that dynamo_native wrapper and fsdp wrapper are pre-wrappers while ddp wrapper is a post-wrapper.
if native_dynamo_wrapper_exist:
if fsdp_wrapper_exist:
# both dynamo_native wrapper and fsdp wrapper are pre-wrappers
native_dynamo_wrapper_index, fsdp_wrapper_index = -1, -1
pre_wrappers_list = []
for i, (wrapper_name, v) in enumerate(self.pre_wrappers.items()):
pre_wrappers_list.append((wrapper_name, v))
if wrapper_name == "fsdp":
fsdp_wrapper_index = i
elif wrapper_name == "native_dynamo":
native_dynamo_wrapper_index = i
if native_dynamo_wrapper_index < fsdp_wrapper_index:
native_dynamo_wrapper = pre_wrappers_list[native_dynamo_wrapper_index]
pre_wrappers_list.insert(fsdp_wrapper_index + 1, native_dynamo_wrapper)
pre_wrappers_list.pop(native_dynamo_wrapper_index)
self.pre_wrappers = dict(pre_wrappers_list)
elif ddp_wrapper_exist:
# ddp wrapper is a post-wrapper. Popping dynamo_native wrapper from pre-wrappers
# then insert it after ddp wrapper.
post_wrappers_list = []
ddp_wrapper_index = -1
for i, (wrapper_name, v) in enumerate(self.post_wrappers.items()):
post_wrappers_list.append((wrapper_name, v))
if wrapper_name == "ddp":
ddp_wrapper_index = i
native_dynamo_wrapper = self.pre_wrappers.pop("native_dynamo")
post_wrappers_list.insert(ddp_wrapper_index + 1, ("native_dynamo", native_dynamo_wrapper))
self.post_wrappers = dict(post_wrappers_list)

if tensor_parallel_wrapper_exist:
# todo: support tp
pass
wrap_cls = None
if fsdp_wrapper_exist and torch_version() >= (1, 12, 0):
fsdp_wrapper = self.pre_wrappers["fsdp"]
fsdp_wrapper = list(fsdp_wrapper)
if fsdp_wrapper[1] is None:
fsdp_wrapper[1] = dict()

fsdp_config = fsdp_wrapper[1]
if wrap_cls is None:
wrap_cls = set(to_module_class_by_name(self.model, fsdp_config.get("atorch_wrap_cls", set())))
else:
wrap_cls = wrap_cls & set(
to_module_class_by_name(self.model, fsdp_config.get("atorch_wrap_cls", set()))
)

if ckpt_wrapper_exist:
ckpt_wrapper = self.post_wrappers["checkpoint"]
ckpt_wrapper = list(ckpt_wrapper)
if ckpt_wrapper[1] is None:
ckpt_wrapper[1] = tuple()
ckpt_wrap_cls = ckpt_wrapper[1]
if wrap_cls is None:
wrap_cls = set(to_module_class_by_name(self.model, ckpt_wrap_cls))
else:
wrap_cls = set(to_module_class_by_name(self.model, ckpt_wrap_cls)) & wrap_cls

leaf_modules = _propose_leaf_modules(wrap_cls)
auto_wrap_cls = _propose_wrap_cls(leaf_modules)

if fsdp_wrapper_exist and torch_version() >= (1, 12, 0):
if "atorch_wrap_cls" in fsdp_config:
if auto_wrap_cls is not None:
fsdp_config["atorch_wrap_cls"] = auto_wrap_cls
else:
fsdp_config.pop("atorch_wrap_cls")

fsdp_wrapper[1] = fsdp_config
self.pre_wrappers["fsdp"] = tuple(fsdp_wrapper)

if ckpt_wrapper_exist:
if auto_wrap_cls is not None:
ckpt_wrapper[1] = tuple(auto_wrap_cls)
self.post_wrappers["checkpoint"] = tuple(ckpt_wrapper)
else:
# in this case module structure should have been broken, nothing to checkpoint on
self.post_wrappers.pop("checkpoint")

# let tensor parallel wrapper be the first, make sure meta models fully reloaded
tensor_parallel_wrapper_item = ("tp", self.pre_wrappers["tp"])
wrappers_list = list(self.pre_wrappers.items())
tensor_parallel_idx = wrappers_list.index(tensor_parallel_wrapper_item)
wrappers_list.pop(tensor_parallel_idx)
# wrapper item and wrapper are all tuples
tensor_parallel_wrapper_item = list(tensor_parallel_wrapper_item)
tensor_parallel_wrapper_item[1] = list(tensor_parallel_wrapper_item[1])
tensor_parallel_wrapper_item[1][1]["leaf_modules"] = leaf_modules
if fsdp_wrapper_exist or pipe_wrapper_exist:
tensor_parallel_wrapper_item[1][1]["defer_init"] = True
tensor_parallel_wrapper_item[1] = tuple(tensor_parallel_wrapper_item[1])
tensor_parallel_wrapper_item = tuple(tensor_parallel_wrapper_item)
wrappers_list.insert(0, tensor_parallel_wrapper_item)
self.pre_wrappers = dict(wrappers_list)

# TP checkpointing needs amp_config explicitly to take effect, HACK here
if "amp_native" in self.post_wrappers and "checkpoint" in self.post_wrappers:
amp_wrapper = self.post_wrappers.pop("amp_native")
amp_config = amp_wrapper[1] or None
from atorch.modules.distributed_modules.activation_checkpointing import _insert_amp_config_for_tp_ckpt

_insert_amp_config_for_tp_ckpt(amp_config)

def add_wrapper(self, wrapper_name, wrapper_func, wrapper_config=None, is_pre_wrapper=True):
"""
Expand Down
56 changes: 56 additions & 0 deletions atorch/atorch/auto/opt_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,65 @@
except ImportError:
pipe_split = None

from atorch.modules.distributed_modules.modules_registry import _SHARD_MAP, _SHARDABLE_OPERATORS, _SHARDED_OPERATORS
from atorch.utils.graph_transform_utils import map_aggregate


# FIXME This is a temporary fix for FSDP and TP compatibility,
# TP will only trace the module to the level of FSDP wrapping, with constraint of being shardable.
# The wrapping will be broken if the wrapped op is shardable, causing FSDP to be slow (in most cases).
def _propose_leaf_modules(atorch_wrap_cls=None):
leaf_modules = None
if atorch_wrap_cls is not None:
leaf_modules = list(set(_SHARDABLE_OPERATORS.values()) & set(atorch_wrap_cls))
if leaf_modules is None or len(leaf_modules) == 0:
leaf_modules = list(_SHARDABLE_OPERATORS.values())
return leaf_modules


def _propose_wrap_cls(leaf_modules=set()):
leaf_module_names = [name for name, cls in _SHARDABLE_OPERATORS.items() if cls in leaf_modules]
if len(leaf_modules) != 0:
atorch_wrap_cls = {
op
for name in leaf_module_names
for op in [_SHARDABLE_OPERATORS[name]] + [_SHARDED_OPERATORS[shard] for shard in _SHARD_MAP[name]]
}
else:
atorch_wrap_cls = None
if len(leaf_modules) == len(list(_SHARDABLE_OPERATORS.items())):
# in this case, auto wrap is meaning less
atorch_wrap_cls = None
return atorch_wrap_cls


def propose_leaf_modules_by_strategy(model, strategy=None):
wrap_cls = None
if strategy is None:
return _propose_leaf_modules(wrap_cls)
for opt in strategy:
opt_name = opt[0]
if opt_name == "fsdp":
if len(opt) > 1:
opt_config = opt[1]
atorch_wrap_cls = set(to_module_class_by_name(model, opt_config.get("atorch_wrap_cls", set())))
if wrap_cls is None:
wrap_cls = atorch_wrap_cls
else:
wrap_cls = wrap_cls & atorch_wrap_cls
if opt_name == "checkpoint":
if len(opt) > 1:
opt_config = opt[1]
ckpt_wrap_cls = set(to_module_class_by_name(model, opt_config))
if wrap_cls is None:
wrap_cls = ckpt_wrap_cls
else:
wrap_cls = wrap_cls & ckpt_wrap_cls

leaf_modules = _propose_leaf_modules(wrap_cls)
return leaf_modules


def find_modules(model, m_list):
if isinstance(model, tuple(m_list)):
return [model]
Expand Down
16 changes: 16 additions & 0 deletions atorch/atorch/tests/model_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,22 @@ def test_adjust_amp_apex_zero2_ddp_wrapper(self):
self.assertLess(amp_apex_o2_index, zero_index)
self.assertEqual(wrappers_order, [0, 1, 2, 3, 4, 5])

def test_adjust_dynamo_and_fsdp_wrapper(self):
self.context.add_wrapper("native_dynamo", None, None, is_pre_wrapper=True)
self.context.add_wrapper("fsdp", None, None, is_pre_wrapper=True)
self.context.adjust_wrappers()
wrapper_names = [name for name, _ in self.context.pre_wrappers.items()]
wrappers_order = [wrapper_names.index("fsdp"), wrapper_names.index("native_dynamo")]
self.assertListEqual(wrappers_order, [0, 1])

def test_adjust_dynamo_and_ddp_wrapper(self):
self.context.add_wrapper("native_dynamo", None, None, is_pre_wrapper=True)
self.context.add_wrapper("ddp", None, None, is_pre_wrapper=False)
self.context.adjust_wrappers()
wrapper_names = [name for name, _ in self.context.post_wrappers.items()]
wrappers_order = [wrapper_names.index("ddp"), wrapper_names.index("native_dynamo")]
self.assertListEqual(wrappers_order, [0, 1])


def use_shm_dataloader_func():
if torch.cuda.is_available():
Expand Down