From 39d689654f1eb3eaa95bb3a8e45be0f9712c393e Mon Sep 17 00:00:00 2001 From: shw Date: Thu, 10 Oct 2024 20:13:02 +0800 Subject: [PATCH 01/24] consolidate model and optimizer done --- torchacc/dist/consolidate_sharded_ckpts.py | 62 +++++ torchacc/dist/state_dict_utils.py | 302 +++++++++++++++++++++ 2 files changed, 364 insertions(+) create mode 100644 torchacc/dist/consolidate_sharded_ckpts.py create mode 100644 torchacc/dist/state_dict_utils.py diff --git a/torchacc/dist/consolidate_sharded_ckpts.py b/torchacc/dist/consolidate_sharded_ckpts.py new file mode 100644 index 0000000..f02ae06 --- /dev/null +++ b/torchacc/dist/consolidate_sharded_ckpts.py @@ -0,0 +1,62 @@ +from argparse import ArgumentParser + +from .state_dict_utils import consolidate_sharded_model_checkpoints, consolidate_sharded_optimizers_checkpoints + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--ckpt_prefix", + type=str, + required=True, + help=( + "The path prefix of the XLA FSDP checkpoint files to be consolidated. " + "Files matching the pattern ``ckpt_prefix + ckpt_suffix`` will be loaded." + ), + ) + parser.add_argument( + "--ckpt_suffix", + type=str, + default="*.pth", + help=( + "The path suffix of the XLA FSDP checkpoint files to be consolidated. " + "Files matching the pattern ``ckpt_prefix + ckpt_suffix`` will be loaded." + ), + ) + parser.add_argument( + "--ckpt_type", + type=str, + default="model", + help=( + "Consolidate model or optimizer. Please consolidate model fisrt, and then consolidate optimizer." + ), + ) + parser.add_argument( + "--reshard_num", + type=int, + default=1, + help=( + "We now support the reshard of FSDP model." + ) + ) + parser.add_argument( + "--save_path", + type=str, + default="", + help=("The save path of the output consolidated model state dict " + "(default is ``ckpt_prefix + '_consolidated.pth'``)"), + ) + args = parser.parse_args() + assert args.ckpt_type in ['model', 'optimizer'], ( + 'the ckpt_type should be model or optimizer' + ) + + if args.ckpt_type == "model": + consolidate_sharded_model_checkpoints(args.ckpt_prefix, args.ckpt_suffix, args.reshard_num, + args.save_path) + else: + consolidate_sharded_optimizer_checkpoints(args.ckpt_prefix, args.ckpt_suffix, args.reshard_num, + args.save_path) + +if __name__ == "__main__": + main() diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py new file mode 100644 index 0000000..43b0f43 --- /dev/null +++ b/torchacc/dist/state_dict_utils.py @@ -0,0 +1,302 @@ +from collections import OrderedDict +from glob import glob +from typing import Dict +import pickle + +import torch + +def _numel(shape): + numel = 1 + for d in shape: + numel *= d + return numel + +def unpad(tensor_buffer, layer_numels, world_size): + if tensor_buffer.dim() == 0: + return tensor_buffer + numel = 0 + for layer_numel in layer_numels: + numel += layer_numel + if numel % world_size != 0: + pad_size = world_size - numel % world_size + tensor_buffer = tensor_buffer[:-pad_size] + return tensor_buffer + + +def unflatten_params(params, param_names, param_shapes, param_numels): + if params.dim() == 0: + full_params = [params for _ in range(len(param_names))] + else: + full_params = [ + t.view(s) + for (t, s) in zip(params.split(param_numels), param_shapes) + ] + + return full_params + + +def get_layer_full_info(shard_metadata, model_state_dict): + """ + Get full name, shape and numel info of unflatten and unshard state_dict according + to shard_metadata and model's state_dict; + Args: + shard_metadata (dict): + ``model.get_shard_metadata()`` from an FSDP model of any rank + model_state_dict(dict): + The state_dict from an FSDP model. + + Returns: + For all ranks, we get the same shard_metadata and model_state_dict, and the return value is + same: + layer_name_list(list): 2-dimension list([[layer_name_group1], [layer_name_group2], ...]), contains the full name information. + if parameters if flatten, each layer may have mutiple orig name and parameter. + layer_size_list(list): 2-dimension list([[layer_size_group1], [layer_size_group2], ...]), contains the unflatten and unshard shape information of + each layer. + layer_numel_list(list): 2-dimension list([[layer_numel_group1], [layer_numel_group2], ...]), contains the unflatten and unshard numel information of + each layer. + sharded_list(list): 1-dimension list, contains whether the layer params is sharded. + """ + layer_name_list = [] + layer_size_list = [] + layer_numel_list = [] + sharded_list = [] + buffer_info = shard_metadata.get("buffer_info", {}) + + # consolidate the sharded parameters + for name, param in model_state_dict.items(): + if name in buffer_info: # cast buffer back to its original dtype + p = p.to(buffer_info[name]["_orig_dtype"]) + + is_sharded = False + name_splits = name.split(".") + model_num = 0 + # if start with 'model', we just skip the model + for name in name_splits: + if name != 'model': + break + else: + model_num = model_num + 1 + name_splits = name_splits[model_num:] + name = ".".join(name_splits) + + for idx, sep in enumerate(name_splits): + if sep.startswith("_fsdp_shard"): + is_sharded = True + prefix = ".".join(name_splits[:idx]) + suffix = ".".join(name_splits[idx:]) + break + + sharded_list.append(is_sharded) + if is_sharded: + p_info = shard_metadata["shard_info"][prefix][suffix] + orig_name = p_info["_orig_name"] + orig_size = p_info["_orig_size"] + full_name = orig_name + if prefix != "": + full_name = prefix + "." + orig_name + layer_name_list.append(full_name) + layer_size_list.append(orig_size) + layer_numel_list.append(_numel(orig_size)) + + else: + # unsharded buffers, we don't need the info in shard_metadata + layer_name_list.append(name) + layer_size_list.append(param.shape) + layer_numel_list.append(_numel(param.shape)) + + # flatten_parameters = True + flatten_info = shard_metadata["flatten_info"] + if flatten_info != {}: + layer_name_list_ = [] + layer_size_list_ = [] + layer_numel_list_ = [] + for name in layer_name_list: + if "_fsdp_wrapped_module.flat_param_" in name: + metadata = flatten_info[name] + prefix = ".".join(name.split(".")[:-1]) + param_names, param_shapes, param_numel = metadata + full_names = param_names + + if prefix != "": + full_names = [prefix + "." + n for n in full_names] + + full_names = [ + fn.replace("_fsdp_wrapped_module.", + "").replace("_fpw_module.", "") + for fn in full_names + ] + + layer_name_list_.append(full_names) + layer_size_list_.append(param_shapes) + layer_numel_list_.append(param_numel) + + return (layer_name_list_, layer_size_list_, layer_numel_list_, sharded_list) + + # return with lists + layer_name_list = [[ + fn.replace("_fsdp_wrapped_module.", "").replace("_fpw_module.", "") + ] for fn in layer_name_list] + layer_size_list = [[s] for s in layer_size_list] + layer_numel_list = [[n] for n in layer_numel_list] + + return (layer_name_list, layer_size_list, layer_numel_list, sharded_list) + +def load_checkpoints(ckpt_prefix, ckpt_suffix="*.pth"): + ckpt_path_pattern = ckpt_prefix + ckpt_suffix + ckpt_paths = glob(ckpt_path_pattern) + + checkpoints_and_paths = [] + for path in ckpt_paths: + ckpt = torch.load(path, map_location="cpu") + checkpoints_and_paths.append((ckpt, path)) + + checkpoints_and_paths.sort(key=lambda c: c[0]["shard_metadata"]["rank"]) + checkpoints = [c[0] for c in checkpoints_and_paths] + for rank, (ckpt, path) in enumerate(checkpoints_and_paths): + assert ckpt["shard_metadata"]["world_size"] == len(checkpoints), ( + f'Expecting {ckpt["shard_metadata"]["world_size"]} files ' + f"(based on metadata in {path}) but got {len(checkpoints)} files. " + f"Please check if you have missing or unexpected files in {ckpt_path_pattern}." + ) + assert ckpt["shard_metadata"]["rank"] == rank, ( + f'Expecting rank {ckpt["shard_metadata"]["rank"]} for {path} but it is ' + f"ranked {rank} (out of {len(checkpoints)} files). " + f"Please check if you have missing or unexpected files in {ckpt_path_pattern}." + ) + + return checkpoints + +def consolidate_sharded_model_checkpoints(ckpt_prefix, + ckpt_suffix="*.pth", + save_path="", + save_model=True): + """ + Consolidate the sharded FSDP checkpoints into a single model checkpoint. + + Args: + ckpt_prefix (str): + prefix to FSDP checkpoint files from all ranks + ckpt_suffix (str, Optional): + suffix to FSDP checkpoint files from all ranks. Files matching the + pattern ``ckpt_prefix + ckpt_suffix`` will be loaded. The each + checkpoint file is assumed to be a dict with a "model" key + containing the FSDP model's ``model.state_dict()`` and a + "shard_metadata" key containing the FSDP model's + ``model.get_shard_metadata()``. + save_path (str, Optional): + the save path to the consolidated model checkpoint file (if + ``save_model`` is ``True``). The checkpoint file is a dict with a + "model" key containing the consolidated model state dict. + save_model (str, Optional): + if ``True``, the consolidated model checkpoint will be saved to + ``save_path`` (or ``ckpt_prefix + "_consolidated.pth"`` if + ``save_path`` is empty). + + Returns: + full_state_dict: the consolidated model state dict + actual_save_path: the path to the consolidated model checkpoint file + (``None`` if ``save_model`` is ``False``) + """ + checkpoints = load_checkpoints(ckpt_prefix, ckpt_suffix) + state_dict_list = [ckpt["model"] for ckpt in checkpoints] + shard_metadata = checkpoints[0]["shard_metadata"] + layer_name_list, layer_size_list, layer_numel_list, sharded_list = get_layer_full_info(shard_metadata, state_dict_list[0]) + file_path = ckpt_prefix + "layer_info.pickle" + + with open(file_path, 'wb') as f: + pickle.dump([layer_name_list, layer_size_list, layer_numel_list, sharded_list], f) + + full_state_dict = OrderedDict() + + # consolidate and unflatten + for idx, (state_name, state_params) in enumerate(state_dict_list[0].items()): + layer_name = layer_name_list[idx] + layer_size = layer_size_list[idx] + layer_numel = layer_numel_list[idx] + is_sharded = sharded_list[idx] + + consolidate_params = state_params + if is_sharded: + p_shard_list = [] + for state_dict in state_dict_list: + p_shard_list.append(state_dict[state_name]) + consolidate_params = torch.cat(p_shard_list, dim=0) + orig_params = unflatten_params(consolidate_params, layer_name, layer_size, layer_numel) + + for fn, fp in zip(layer_name, orig_params): + full_state_dict[fn] = fp + + actual_save_path = None + if save_model: + actual_save_path = save_path if save_path else ckpt_prefix + "_consolidated.pth" + torch.save({"model": full_state_dict}, actual_save_path) + print(f"saved consolidated model to {actual_save_path}") + + return full_state_dict, actual_save_path + + +def consolidate_sharded_optimizer_checkpoints(ckpt_prefix, + ckpt_suffix="*.pth", + save_path="", + save_model=True): + ''' + Consolidate the sharded FSDP checkpoints into a single optimizer checkpoint. + we need first consolidate model checkpoint to reuse the layer_info + ''' + checkpoints = load_checkpoints(ckpt_prefix, ckpt_suffix) + optim_state_dict_list = [ckpt['optimizer'] for ckpt in checkpoints] + shard_metadata = checkpoints[0]["shard_metadata"] + file_path = ckpt_prefix + "layer_info.pickle" + layer_info = [] + try: + with open(file_path, 'rb') as f: + layer_info = pickle.load(f) + except FileNotFoundError: + print(f"please consolidate model first!") + + layer_name_list, layer_size_list, layer_numel_list, sharded_list = layer_info + flatten_name_list = [ + fn for layer_fn in layer_name_list for fn in layer_fn + ] + + full_optim_state_dict: Dict[str, Any] = { + 'state': {}, + 'param_groups': {} + } + + full_optim_state_dict['param_groups'] = optim_state_dict_list[0]['param_groups'] + + full_optim_state_dict['param_groups'][0]['params'].clear() + for fn in flatten_name_list: + full_optim_state_dict['param_groups'][0][ + 'params'].append(fn) + + unflat_state_dict = {fn: {} for fn in flatten_name_list} + + for idx, layer_state in enumerate(optim_state_dict_list[0]['state'].values()): + layer_name = layer_name_list[idx] + layer_size = layer_size_list[idx] + layer_numel = layer_numel_list[idx] + is_sharded = sharded_list[idx] + + for state_name, state_param in layer_state.items(): + consolidate_params = state_param + if is_sharded and state_param.dim() != 0: + p_shard_list = [] + for optim_state_dict in optim_state_dict_list: + p_shard_list.append(optim_state_dict['state'][idx][state_name]) + + consolidate_params = torch.cat(p_shard_list, dim=0) + orig_params = unflatten_params(consolidate_params, layer_name, layer_size, layer_numel) + + for fn, fp in zip(layer_name, orig_params): + unflat_state_dict[fn][state_name] = fp + full_optim_state_dict['state'] = unflat_state_dict + + actual_save_path = None + if save_model: + actual_save_path = save_path if save_path else ckpt_prefix + "_consolidated.pth" + torch.save({"optimizer": full_state_dict}, actual_save_path) + print(f"saved consolidated optimizer to {actual_save_path}") + + return full_optim_state_dict, actual_save_path From 5a3bd2d15cf1e412f3201357abdd55e9a31d62a7 Mon Sep 17 00:00:00 2001 From: shw Date: Fri, 11 Oct 2024 19:46:46 +0800 Subject: [PATCH 02/24] offonline_consolidate_utils_done --- torchacc/dist/consolidate_sharded_ckpts.py | 62 ---- torchacc/dist/state_dict_utils.py | 380 +++++++++++++++----- torchacc/utils/consolidate_sharded_ckpts.py | 73 ++++ 3 files changed, 369 insertions(+), 146 deletions(-) delete mode 100644 torchacc/dist/consolidate_sharded_ckpts.py create mode 100644 torchacc/utils/consolidate_sharded_ckpts.py diff --git a/torchacc/dist/consolidate_sharded_ckpts.py b/torchacc/dist/consolidate_sharded_ckpts.py deleted file mode 100644 index f02ae06..0000000 --- a/torchacc/dist/consolidate_sharded_ckpts.py +++ /dev/null @@ -1,62 +0,0 @@ -from argparse import ArgumentParser - -from .state_dict_utils import consolidate_sharded_model_checkpoints, consolidate_sharded_optimizers_checkpoints - - -def main(): - parser = ArgumentParser() - parser.add_argument( - "--ckpt_prefix", - type=str, - required=True, - help=( - "The path prefix of the XLA FSDP checkpoint files to be consolidated. " - "Files matching the pattern ``ckpt_prefix + ckpt_suffix`` will be loaded." - ), - ) - parser.add_argument( - "--ckpt_suffix", - type=str, - default="*.pth", - help=( - "The path suffix of the XLA FSDP checkpoint files to be consolidated. " - "Files matching the pattern ``ckpt_prefix + ckpt_suffix`` will be loaded." - ), - ) - parser.add_argument( - "--ckpt_type", - type=str, - default="model", - help=( - "Consolidate model or optimizer. Please consolidate model fisrt, and then consolidate optimizer." - ), - ) - parser.add_argument( - "--reshard_num", - type=int, - default=1, - help=( - "We now support the reshard of FSDP model." - ) - ) - parser.add_argument( - "--save_path", - type=str, - default="", - help=("The save path of the output consolidated model state dict " - "(default is ``ckpt_prefix + '_consolidated.pth'``)"), - ) - args = parser.parse_args() - assert args.ckpt_type in ['model', 'optimizer'], ( - 'the ckpt_type should be model or optimizer' - ) - - if args.ckpt_type == "model": - consolidate_sharded_model_checkpoints(args.ckpt_prefix, args.ckpt_suffix, args.reshard_num, - args.save_path) - else: - consolidate_sharded_optimizer_checkpoints(args.ckpt_prefix, args.ckpt_suffix, args.reshard_num, - args.save_path) - -if __name__ == "__main__": - main() diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index 43b0f43..e9216fd 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -1,9 +1,12 @@ from collections import OrderedDict +import copy from glob import glob -from typing import Dict import pickle +import threading +from typing import Dict import torch +import torch.nn.functional as F def _numel(shape): numel = 1 @@ -11,18 +14,6 @@ def _numel(shape): numel *= d return numel -def unpad(tensor_buffer, layer_numels, world_size): - if tensor_buffer.dim() == 0: - return tensor_buffer - numel = 0 - for layer_numel in layer_numels: - numel += layer_numel - if numel % world_size != 0: - pad_size = world_size - numel % world_size - tensor_buffer = tensor_buffer[:-pad_size] - return tensor_buffer - - def unflatten_params(params, param_names, param_shapes, param_numels): if params.dim() == 0: full_params = [params for _ in range(len(param_names))] @@ -37,17 +28,15 @@ def unflatten_params(params, param_names, param_shapes, param_numels): def get_layer_full_info(shard_metadata, model_state_dict): """ - Get full name, shape and numel info of unflatten and unshard state_dict according - to shard_metadata and model's state_dict; + Get full name, shape and numel info of unflatten and unshard model's state_dict according + to shard_metadata and shard model's state_dict; Args: shard_metadata (dict): ``model.get_shard_metadata()`` from an FSDP model of any rank model_state_dict(dict): - The state_dict from an FSDP model. + The FSDP model's state_dict. Returns: - For all ranks, we get the same shard_metadata and model_state_dict, and the return value is - same: layer_name_list(list): 2-dimension list([[layer_name_group1], [layer_name_group2], ...]), contains the full name information. if parameters if flatten, each layer may have mutiple orig name and parameter. layer_size_list(list): 2-dimension list([[layer_size_group1], [layer_size_group2], ...]), contains the unflatten and unshard shape information of @@ -141,14 +130,25 @@ def get_layer_full_info(shard_metadata, model_state_dict): return (layer_name_list, layer_size_list, layer_numel_list, sharded_list) -def load_checkpoints(ckpt_prefix, ckpt_suffix="*.pth"): - ckpt_path_pattern = ckpt_prefix + ckpt_suffix +def load_checkpoints(ckpt_dir, ckpt_name="*.pth"): + ckpt_path_pattern = ckpt_dir + ckpt_name ckpt_paths = glob(ckpt_path_pattern) checkpoints_and_paths = [] + + def load_ckpt(path): + ckpt = torch.load(path, map_location="cpu") + checkpoints_and_paths.append((ckpt, path)) + + threads = [] + for path in ckpt_paths: - ckpt = torch.load(path, map_location="cpu") - checkpoints_and_paths.append((ckpt, path)) + thread = threading.Thread(target=load_ckpt, args=(path,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() checkpoints_and_paths.sort(key=lambda c: c[0]["shard_metadata"]["rank"]) checkpoints = [c[0] for c in checkpoints_and_paths] @@ -166,49 +166,47 @@ def load_checkpoints(ckpt_prefix, ckpt_suffix="*.pth"): return checkpoints -def consolidate_sharded_model_checkpoints(ckpt_prefix, - ckpt_suffix="*.pth", - save_path="", - save_model=True): +def save_checkpoints(state_dict_list, shard_metadata_list, save_paths, save_type): + if not isinstance(state_dict_list, list): + torch.save(state_dict_list, save_paths) + return + + def save_checkpoint(state_dict,shard_metadata, save_path, save_type): + model = { + f"{save_type}": state_dict, + "shard_metadata": shard_metadata, + } + torch.save(model, save_path) + + threads = [] + #import pdb + #pdb.set_trace() + for state_dict, shard_metadata, save_path in zip(state_dict_list, shard_metadata_list, save_paths): + thread = threading.Thread(target=save_checkpoint, args=(state_dict, shard_metadata, save_path, save_type)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + +def consolidate_sharded_model_checkpoints(ckpt_dir, + checkpoints): """ - Consolidate the sharded FSDP checkpoints into a single model checkpoint. - - Args: - ckpt_prefix (str): - prefix to FSDP checkpoint files from all ranks - ckpt_suffix (str, Optional): - suffix to FSDP checkpoint files from all ranks. Files matching the - pattern ``ckpt_prefix + ckpt_suffix`` will be loaded. The each - checkpoint file is assumed to be a dict with a "model" key - containing the FSDP model's ``model.state_dict()`` and a - "shard_metadata" key containing the FSDP model's - ``model.get_shard_metadata()``. - save_path (str, Optional): - the save path to the consolidated model checkpoint file (if - ``save_model`` is ``True``). The checkpoint file is a dict with a - "model" key containing the consolidated model state dict. - save_model (str, Optional): - if ``True``, the consolidated model checkpoint will be saved to - ``save_path`` (or ``ckpt_prefix + "_consolidated.pth"`` if - ``save_path`` is empty). - - Returns: - full_state_dict: the consolidated model state dict - actual_save_path: the path to the consolidated model checkpoint file - (``None`` if ``save_model`` is ``False``) + Consolidate the sharded FSDP checkpoints into a single model checkpoint. """ - checkpoints = load_checkpoints(ckpt_prefix, ckpt_suffix) + state_dict_list = [ckpt["model"] for ckpt in checkpoints] shard_metadata = checkpoints[0]["shard_metadata"] layer_name_list, layer_size_list, layer_numel_list, sharded_list = get_layer_full_info(shard_metadata, state_dict_list[0]) - file_path = ckpt_prefix + "layer_info.pickle" + file_path = ckpt_dir + "layer_info.pickle" with open(file_path, 'wb') as f: pickle.dump([layer_name_list, layer_size_list, layer_numel_list, sharded_list], f) full_state_dict = OrderedDict() - # consolidate and unflatten + # consolidate and unflatten per layer for idx, (state_name, state_params) in enumerate(state_dict_list[0].items()): layer_name = layer_name_list[idx] layer_size = layer_size_list[idx] @@ -225,34 +223,17 @@ def consolidate_sharded_model_checkpoints(ckpt_prefix, for fn, fp in zip(layer_name, orig_params): full_state_dict[fn] = fp - - actual_save_path = None - if save_model: - actual_save_path = save_path if save_path else ckpt_prefix + "_consolidated.pth" - torch.save({"model": full_state_dict}, actual_save_path) - print(f"saved consolidated model to {actual_save_path}") - return full_state_dict, actual_save_path + return full_state_dict -def consolidate_sharded_optimizer_checkpoints(ckpt_prefix, - ckpt_suffix="*.pth", - save_path="", - save_model=True): - ''' +def consolidate_sharded_optimizer_checkpoints(ckpt_dir, + checkpoints, layer_info): + """ Consolidate the sharded FSDP checkpoints into a single optimizer checkpoint. - we need first consolidate model checkpoint to reuse the layer_info - ''' - checkpoints = load_checkpoints(ckpt_prefix, ckpt_suffix) + """ optim_state_dict_list = [ckpt['optimizer'] for ckpt in checkpoints] shard_metadata = checkpoints[0]["shard_metadata"] - file_path = ckpt_prefix + "layer_info.pickle" - layer_info = [] - try: - with open(file_path, 'rb') as f: - layer_info = pickle.load(f) - except FileNotFoundError: - print(f"please consolidate model first!") layer_name_list, layer_size_list, layer_numel_list, sharded_list = layer_info flatten_name_list = [ @@ -264,15 +245,16 @@ def consolidate_sharded_optimizer_checkpoints(ckpt_prefix, 'param_groups': {} } - full_optim_state_dict['param_groups'] = optim_state_dict_list[0]['param_groups'] - + full_optim_state_dict['param_groups'] = copy.deepcopy(optim_state_dict_list[0]['param_groups']) full_optim_state_dict['param_groups'][0]['params'].clear() + for fn in flatten_name_list: full_optim_state_dict['param_groups'][0][ 'params'].append(fn) unflat_state_dict = {fn: {} for fn in flatten_name_list} - + + # consolidate and unflatten per layer per state for idx, layer_state in enumerate(optim_state_dict_list[0]['state'].values()): layer_name = layer_name_list[idx] layer_size = layer_size_list[idx] @@ -292,11 +274,241 @@ def consolidate_sharded_optimizer_checkpoints(ckpt_prefix, for fn, fp in zip(layer_name, orig_params): unflat_state_dict[fn][state_name] = fp full_optim_state_dict['state'] = unflat_state_dict + + return full_optim_state_dict + +def _get_shard(tensor, shard_num): + """ + Return the shard tensor list of a full flatten tensor. + """ + if tensor.numel() % shard_num != 0: + pad_size = shard_num - tensor.numel() % shard_num + tensor = F.pad(tensor, [0, pad_size]) + + local_size = tensor.size(0) // shard_num + tensor_list = [] + for i in range(shard_num): + begin = i * local_size + end = (i + 1) * local_size + tensor_list.append(tensor[begin:end]) + + return tensor_list + +def flatten_tensor_list(param_list): + if len(param_list) == 0: + return param_list + + flat_tensors = [torch.flatten(param) for param in param_list] + + return torch.cat(flat_tensors, dim=0) + +def reshard_model_dict(consolidate_model_dict, shard_model, layer_name_lists, reshard_num): + """ + reshard the consolidate model into shard_model_state_dict_list according to the reshard_num. + Return the shard_model_state_dict_list and shard_metadata_list. + """ + shard_model_state_dict: Dict[str, Any] = {} + shard_model_state_dict_list = [copy.deepcopy(shard_model_state_dict) for _ in range(reshard_num)] + + # flatten and shard tensor per layer + for (shard_model_name, layer_names) in zip(shard_model['model'].keys(), layer_name_lists): + tensor_buffer_list = [] + for name in layer_names: + tensor_buffer = consolidate_model_dict[name] + tensor_buffer_list.append(tensor_buffer) + flat_tensor = flatten_tensor_list( + tensor_buffer_list) + shard_tensor_list = _get_shard(flat_tensor, reshard_num) + + for shard_tensor, shard_model_dict in zip(shard_tensor_list, shard_model_state_dict_list): + shard_model_dict[shard_model_name] = shard_tensor + + # get shardmeta_list + shard_metadata_list = [] + for idx in range(reshard_num): + shard_meta_data = copy.deepcopy(shard_model["shard_metadata"]) + shard_meta_data['world_size'] = reshard_num + shard_meta_data['rank'] = idx + shard_metadata_list.append(shard_meta_data) + + return shard_model_state_dict_list, shard_metadata_list + + +def reshard_optim_dict(consolidate_optim_dict, shard_optim, layer_name_lists, reshard_num): + """ + reshard the consolidate optim into shard_optim_state_dict_list according to the reshard_num. + Return the shard_optim_state_dict_list and shard_metadata_list. + """ + consolidate_optim_state = consolidate_optim_dict['state'] + + shard_optim_state_dict: Dict[str, Any] = { + 'state': {}, + 'param_groups': {} + } + shard_optim_state_dict_list = [copy.deepcopy(shard_optim_state_dict) for _ in range(reshard_num)] + + # flatten and shard tensor per layer per state_name + for idx, layer_names in enumerate(layer_name_lists): + shard_value: Dict[str, Any] = {} + shard_value_list = [copy.deepcopy(shard_value) for _ in range(reshard_num)] + for state_name in consolidate_optim_state[layer_names[0]].keys(): + tensor_buffer_list = [] + # we need the params of a whole layer state to be flatten and shard + for name in layer_names: + state_params = consolidate_optim_state[name][state_name] + # state name 'step' + if isinstance(state_params, + torch.Tensor) and state_params.dim() == 0: + for shard_value in shard_value_list: + shard_value[state_name] = state_params + break + + tensor_buffer_list.append(state_params) + + flat_tensor = flatten_tensor_list( + tensor_buffer_list) + + if state_params.dim() != 0: + shard_tensor_list = _get_shard(flat_tensor, reshard_num) + for (shard_value, shard_tensor) in zip(shard_value_list, shard_tensor_list): + shard_value[state_name] = shard_tensor + + for (shard_value, shard_optim_state_dict) in zip(shard_value_list, shard_optim_state_dict_list): + shard_optim_state_dict['state'][idx] = shard_value + + shard_metadata_list = [] + + # get the param_group of optim_state_dict and shard_meta_lists + for (idx, shard_optim_state_dict) in enumerate(shard_optim_state_dict_list): + shard_optim_state_dict['param_groups'] = shard_optim['optimizer']['param_groups'] + + shard_meta_data = copy.deepcopy(shard_optim["shard_metadata"]) + shard_meta_data['world_size'] = reshard_num + shard_meta_data['rank'] = idx + shard_metadata_list.append(shard_meta_data) + + return shard_optim_state_dict_list, shard_metadata_list + + +def consolidate_and_reshard_model_dict(ckpt_dir, + ckpt_name="", + reshard_num=1, + save_path="", + save_model=True): + """ + Consolidate the sharded FSDP checkpoints into a single model checkpoint. Then + reshard the FSDP model according to the reshard_num. + + Args: + ckpt_dir (str): + The dir to FSDP checkpoint files from all ranks + ckpt_name (str, Optional): + The name_pattern to FSDP checkpoint files from all ranks. Files matching the + pattern ``ckpt_dir + ckpt_name`` will be loaded. The each + checkpoint file is assumed to be a dict with a "model" key + containing the FSDP model's ``model.state_dict()`` and a + "shard_metadata" key containing the FSDP model's + ``model.get_shard_metadata()``. + reshard_num (int, Optional): + Reshard the fsdp model with reshard_num. If set to 1, we don't need to do + resharding. + save_path (str, Optional): + the save path to the consolidated model checkpoint file (if + ``save_model`` is ``True``). The checkpoint file is a dict with a + "model" key containing the consolidated model state dict. + save_model (str, Optional): + if ``True``, the model checkpoint will be saved to + ``save_path`` (or ``ckpt_dir + "consolidated_model.pth"`` if + ``save_path`` is empty). + + Returns: + model_state_dict: the consolidated model state dict or reshard model state dict list. + """ + + checkpoints = load_checkpoints(ckpt_dir, ckpt_name) + full_state_dict = consolidate_sharded_model_checkpoints(ckpt_dir, checkpoints) + + if reshard_num == 1: + if save_model: + actual_save_path = save_path if save_path else ckpt_dir + "model_consolidated.pth" + save_checkpoints(full_state_dict, checkpoints[0]['shard_metadata'], actual_save_path, 'model') + + return full_state_dict + # load layer_info + file_path = ckpt_dir + "layer_info.pickle" + layer_info = [] + try: + with open(file_path, 'rb') as f: + layer_info = pickle.load(f) + except FileNotFoundError: + print(f"please consolidate model first!") + + model_state_dict_list, shard_metadata_list = reshard_model_dict(full_state_dict, checkpoints[0], layer_info[0], reshard_num) + - actual_save_path = None if save_model: - actual_save_path = save_path if save_path else ckpt_prefix + "_consolidated.pth" - torch.save({"optimizer": full_state_dict}, actual_save_path) - print(f"saved consolidated optimizer to {actual_save_path}") + if save_path == "": + save_path = ckpt_dir + + actual_save_path = [ + f"rank-{rank}-of-{reshard_num}-model.pth" + for rank in range(reshard_num) + ] + + save_checkpoints(model_state_dict_list, shard_metadata_list, actual_save_path, 'model') + + + return model_state_dict_list + + +def consolidate_and_reshard_optim_dict(ckpt_dir, + ckpt_name="", + reshard_num=1, + save_path="", + save_optimizer=True): + """ + Consolidate the sharded FSDP checkpoints into a single optimizer checkpoint. Then + reshard the FSDP optimizer according to the reshard_num. - return full_optim_state_dict, actual_save_path + Returns: + optim_state_dict: the consolidated model state dict or reshard model state dict list + """ + # load checkpoints + checkpoints = load_checkpoints(ckpt_dir, ckpt_name) + + # load layer_info + file_path = ckpt_dir + "layer_info.pickle" + layer_info = [] + try: + with open(file_path, 'rb') as f: + layer_info = pickle.load(f) + except FileNotFoundError: + print(f"please consolidate model first!") + + full_optim_state_dict = consolidate_sharded_optimizer_checkpoints(ckpt_dir, checkpoints, layer_info) + + actual_save_path = None + + if reshard_num == 1: + if save_optimizer: + actual_save_path = save_path if save_path else ckpt_dir + "consolidated_optimizer.pth" + save_checkpoints(full_optim_state_dict, checkpoints[0]['shard_metadata'], actual_save_path, 'optimizer') + + return full_optim_state_dict + + optim_state_dict_list, shard_metadata_list = reshard_optim_dict(full_optim_state_dict, checkpoints[0], layer_info[0], reshard_num) + + + if save_optimizer: + if save_path == "": + save_path = ckpt_dir + + actual_save_path = [ + save_path + f"rank-{rank}-of-{reshard_num}-optimizer.pth" + for rank in range(reshard_num) + ] + import pdb + pdb.set_trace() + save_checkpoints(optim_state_dict_list, shard_metadata_list, actual_save_path, 'optimizer') + + return optim_state_dict_list \ No newline at end of file diff --git a/torchacc/utils/consolidate_sharded_ckpts.py b/torchacc/utils/consolidate_sharded_ckpts.py new file mode 100644 index 0000000..fcef1ed --- /dev/null +++ b/torchacc/utils/consolidate_sharded_ckpts.py @@ -0,0 +1,73 @@ +from argparse import ArgumentParser +from torchacc.dist.state_dict_utils import consolidate_and_reshard_model_dict, consolidate_and_reshard_optim_dict + +MODEL_NAME_PATTERN = "rank*-of-*-model.pth" +OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--ckpt_dir", + type=str, + required=True, + help=( + f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." + f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {OPTIM_NAME_PATTERN}" + ), + ) + parser.add_argument( + "--ckpt_name", + type=str, + default="", + help=( + f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." + f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {OPTIM_NAME_PATTERN}" + ), + ) + parser.add_argument( + "--ckpt_type", + type=str, + default="model", + help=( + "The type of checkpoint to consolidate, you can choose model or optimizer. Please consolidate model fisrt, and then consolidate optimizer." + ), + ) + parser.add_argument( + "--reshard_num", + type=int, + default=1, + help=( + "We now support the reshard of XLA FSDP checkpoint according to the reshard_num." + ) + ) + parser.add_argument( + "--save_path", + type=str, + default="", + help=(f"The save path of the output state dict " + f"(default consolidate_path is ``ckpt_dir + model/optimizer_consolidated.pth``)" + f"If you need to reshard the checkpoint, please only pass the save_dir(default is ckpt_dir)," + f"we will save the file in path ``save_path + {MODEL_NAME_PATTERN}/{OPTIM_NAME_PATTERN}``"), + ) + args = parser.parse_args() + assert args.ckpt_type in ['model', 'optimizer'], ( + 'the ckpt_type should be model or optimizer' + ) + + if args.ckpt_type == "model": + if args.ckpt_name == "": + args.ckpt_name = MODEL_NAME_PATTERN + consolidate_and_reshard_model_dict(args.ckpt_dir, args.ckpt_name, args.reshard_num, + args.save_path) + else: + if args.ckpt_name == "": + args.ckpt_name = OPTIM_NAME_PATTERN + consolidate_and_reshard_optim_dict(args.ckpt_dir, args.ckpt_name, args.reshard_num, + args.save_path) + +if __name__ == "__main__": + main() From b4473696f385dbfcf85434fec6ad38afefd139ce Mon Sep 17 00:00:00 2001 From: shw Date: Fri, 11 Oct 2024 19:52:05 +0800 Subject: [PATCH 03/24] format --- torchacc/dist/state_dict_utils.py | 364 +++++++++++--------- torchacc/utils/consolidate_sharded_ckpts.py | 124 +++---- 2 files changed, 259 insertions(+), 229 deletions(-) diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index e9216fd..f1d4275 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -8,12 +8,14 @@ import torch import torch.nn.functional as F + def _numel(shape): numel = 1 for d in shape: numel *= d return numel + def unflatten_params(params, param_names, param_shapes, param_numels): if params.dim() == 0: full_params = [params for _ in range(len(param_names))] @@ -74,7 +76,7 @@ def get_layer_full_info(shard_metadata, model_state_dict): prefix = ".".join(name_splits[:idx]) suffix = ".".join(name_splits[idx:]) break - + sharded_list.append(is_sharded) if is_sharded: p_info = shard_metadata["shard_info"][prefix][suffix] @@ -119,7 +121,8 @@ def get_layer_full_info(shard_metadata, model_state_dict): layer_size_list_.append(param_shapes) layer_numel_list_.append(param_numel) - return (layer_name_list_, layer_size_list_, layer_numel_list_, sharded_list) + return (layer_name_list_, layer_size_list_, layer_numel_list_, + sharded_list) # return with lists layer_name_list = [[ @@ -130,48 +133,51 @@ def get_layer_full_info(shard_metadata, model_state_dict): return (layer_name_list, layer_size_list, layer_numel_list, sharded_list) + def load_checkpoints(ckpt_dir, ckpt_name="*.pth"): - ckpt_path_pattern = ckpt_dir + ckpt_name - ckpt_paths = glob(ckpt_path_pattern) - - checkpoints_and_paths = [] - - def load_ckpt(path): - ckpt = torch.load(path, map_location="cpu") - checkpoints_and_paths.append((ckpt, path)) - - threads = [] - - for path in ckpt_paths: - thread = threading.Thread(target=load_ckpt, args=(path,)) - threads.append(thread) - thread.start() - - for thread in threads: - thread.join() - - checkpoints_and_paths.sort(key=lambda c: c[0]["shard_metadata"]["rank"]) - checkpoints = [c[0] for c in checkpoints_and_paths] - for rank, (ckpt, path) in enumerate(checkpoints_and_paths): - assert ckpt["shard_metadata"]["world_size"] == len(checkpoints), ( - f'Expecting {ckpt["shard_metadata"]["world_size"]} files ' - f"(based on metadata in {path}) but got {len(checkpoints)} files. " - f"Please check if you have missing or unexpected files in {ckpt_path_pattern}." - ) - assert ckpt["shard_metadata"]["rank"] == rank, ( - f'Expecting rank {ckpt["shard_metadata"]["rank"]} for {path} but it is ' - f"ranked {rank} (out of {len(checkpoints)} files). " - f"Please check if you have missing or unexpected files in {ckpt_path_pattern}." - ) - - return checkpoints - -def save_checkpoints(state_dict_list, shard_metadata_list, save_paths, save_type): + ckpt_path_pattern = ckpt_dir + ckpt_name + ckpt_paths = glob(ckpt_path_pattern) + + checkpoints_and_paths = [] + + def load_ckpt(path): + ckpt = torch.load(path, map_location="cpu") + checkpoints_and_paths.append((ckpt, path)) + + threads = [] + + for path in ckpt_paths: + thread = threading.Thread(target=load_ckpt, args=(path,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + checkpoints_and_paths.sort(key=lambda c: c[0]["shard_metadata"]["rank"]) + checkpoints = [c[0] for c in checkpoints_and_paths] + for rank, (ckpt, path) in enumerate(checkpoints_and_paths): + assert ckpt["shard_metadata"]["world_size"] == len(checkpoints), ( + f'Expecting {ckpt["shard_metadata"]["world_size"]} files ' + f"(based on metadata in {path}) but got {len(checkpoints)} files. " + f"Please check if you have missing or unexpected files in {ckpt_path_pattern}." + ) + assert ckpt["shard_metadata"]["rank"] == rank, ( + f'Expecting rank {ckpt["shard_metadata"]["rank"]} for {path} but it is ' + f"ranked {rank} (out of {len(checkpoints)} files). " + f"Please check if you have missing or unexpected files in {ckpt_path_pattern}." + ) + + return checkpoints + + +def save_checkpoints(state_dict_list, shard_metadata_list, save_paths, + save_type): if not isinstance(state_dict_list, list): torch.save(state_dict_list, save_paths) return - def save_checkpoint(state_dict,shard_metadata, save_path, save_type): + def save_checkpoint(state_dict, shard_metadata, save_path, save_type): model = { f"{save_type}": state_dict, "shard_metadata": shard_metadata, @@ -181,54 +187,62 @@ def save_checkpoint(state_dict,shard_metadata, save_path, save_type): threads = [] #import pdb #pdb.set_trace() - for state_dict, shard_metadata, save_path in zip(state_dict_list, shard_metadata_list, save_paths): - thread = threading.Thread(target=save_checkpoint, args=(state_dict, shard_metadata, save_path, save_type)) + for state_dict, shard_metadata, save_path in zip(state_dict_list, + shard_metadata_list, + save_paths): + thread = threading.Thread( + target=save_checkpoint, + args=(state_dict, shard_metadata, save_path, save_type)) threads.append(thread) thread.start() for thread in threads: thread.join() - - -def consolidate_sharded_model_checkpoints(ckpt_dir, - checkpoints): - """ + + +def consolidate_sharded_model_checkpoints(ckpt_dir, checkpoints): + """ Consolidate the sharded FSDP checkpoints into a single model checkpoint. """ - state_dict_list = [ckpt["model"] for ckpt in checkpoints] - shard_metadata = checkpoints[0]["shard_metadata"] - layer_name_list, layer_size_list, layer_numel_list, sharded_list = get_layer_full_info(shard_metadata, state_dict_list[0]) - file_path = ckpt_dir + "layer_info.pickle" - - with open(file_path, 'wb') as f: - pickle.dump([layer_name_list, layer_size_list, layer_numel_list, sharded_list], f) - - full_state_dict = OrderedDict() - - # consolidate and unflatten per layer - for idx, (state_name, state_params) in enumerate(state_dict_list[0].items()): - layer_name = layer_name_list[idx] - layer_size = layer_size_list[idx] - layer_numel = layer_numel_list[idx] - is_sharded = sharded_list[idx] - - consolidate_params = state_params - if is_sharded: - p_shard_list = [] - for state_dict in state_dict_list: - p_shard_list.append(state_dict[state_name]) - consolidate_params = torch.cat(p_shard_list, dim=0) - orig_params = unflatten_params(consolidate_params, layer_name, layer_size, layer_numel) - - for fn, fp in zip(layer_name, orig_params): - full_state_dict[fn] = fp - - return full_state_dict - - -def consolidate_sharded_optimizer_checkpoints(ckpt_dir, - checkpoints, layer_info): + state_dict_list = [ckpt["model"] for ckpt in checkpoints] + shard_metadata = checkpoints[0]["shard_metadata"] + layer_name_list, layer_size_list, layer_numel_list, sharded_list = get_layer_full_info( + shard_metadata, state_dict_list[0]) + file_path = ckpt_dir + "layer_info.pickle" + + with open(file_path, 'wb') as f: + pickle.dump( + [layer_name_list, layer_size_list, layer_numel_list, sharded_list], + f) + + full_state_dict = OrderedDict() + + # consolidate and unflatten per layer + for idx, (state_name, + state_params) in enumerate(state_dict_list[0].items()): + layer_name = layer_name_list[idx] + layer_size = layer_size_list[idx] + layer_numel = layer_numel_list[idx] + is_sharded = sharded_list[idx] + + consolidate_params = state_params + if is_sharded: + p_shard_list = [] + for state_dict in state_dict_list: + p_shard_list.append(state_dict[state_name]) + consolidate_params = torch.cat(p_shard_list, dim=0) + orig_params = unflatten_params(consolidate_params, layer_name, + layer_size, layer_numel) + + for fn, fp in zip(layer_name, orig_params): + full_state_dict[fn] = fp + + return full_state_dict + + +def consolidate_sharded_optimizer_checkpoints(ckpt_dir, checkpoints, + layer_info): """ Consolidate the sharded FSDP checkpoints into a single optimizer checkpoint. """ @@ -236,64 +250,64 @@ def consolidate_sharded_optimizer_checkpoints(ckpt_dir, shard_metadata = checkpoints[0]["shard_metadata"] layer_name_list, layer_size_list, layer_numel_list, sharded_list = layer_info - flatten_name_list = [ - fn for layer_fn in layer_name_list for fn in layer_fn - ] + flatten_name_list = [fn for layer_fn in layer_name_list for fn in layer_fn] - full_optim_state_dict: Dict[str, Any] = { - 'state': {}, - 'param_groups': {} - } - - full_optim_state_dict['param_groups'] = copy.deepcopy(optim_state_dict_list[0]['param_groups']) + full_optim_state_dict: Dict[str, Any] = {'state': {}, 'param_groups': {}} + + full_optim_state_dict['param_groups'] = copy.deepcopy( + optim_state_dict_list[0]['param_groups']) full_optim_state_dict['param_groups'][0]['params'].clear() - + for fn in flatten_name_list: - full_optim_state_dict['param_groups'][0][ - 'params'].append(fn) - + full_optim_state_dict['param_groups'][0]['params'].append(fn) + unflat_state_dict = {fn: {} for fn in flatten_name_list} - + # consolidate and unflatten per layer per state - for idx, layer_state in enumerate(optim_state_dict_list[0]['state'].values()): + for idx, layer_state in enumerate( + optim_state_dict_list[0]['state'].values()): layer_name = layer_name_list[idx] layer_size = layer_size_list[idx] layer_numel = layer_numel_list[idx] is_sharded = sharded_list[idx] - + for state_name, state_param in layer_state.items(): consolidate_params = state_param if is_sharded and state_param.dim() != 0: p_shard_list = [] for optim_state_dict in optim_state_dict_list: - p_shard_list.append(optim_state_dict['state'][idx][state_name]) - + p_shard_list.append( + optim_state_dict['state'][idx][state_name]) + consolidate_params = torch.cat(p_shard_list, dim=0) - orig_params = unflatten_params(consolidate_params, layer_name, layer_size, layer_numel) - + orig_params = unflatten_params(consolidate_params, layer_name, + layer_size, layer_numel) + for fn, fp in zip(layer_name, orig_params): - unflat_state_dict[fn][state_name] = fp + unflat_state_dict[fn][state_name] = fp full_optim_state_dict['state'] = unflat_state_dict return full_optim_state_dict + def _get_shard(tensor, shard_num): """ Return the shard tensor list of a full flatten tensor. """ if tensor.numel() % shard_num != 0: - pad_size = shard_num - tensor.numel() % shard_num - tensor = F.pad(tensor, [0, pad_size]) - + pad_size = shard_num - tensor.numel() % shard_num + tensor = F.pad(tensor, [0, pad_size]) + local_size = tensor.size(0) // shard_num tensor_list = [] for i in range(shard_num): begin = i * local_size - end = (i + 1) * local_size + end = (i + 1) * local_size tensor_list.append(tensor[begin:end]) - + return tensor_list + def flatten_tensor_list(param_list): if len(param_list) == 0: return param_list @@ -302,27 +316,32 @@ def flatten_tensor_list(param_list): return torch.cat(flat_tensors, dim=0) -def reshard_model_dict(consolidate_model_dict, shard_model, layer_name_lists, reshard_num): + +def reshard_model_dict(consolidate_model_dict, shard_model, layer_name_lists, + reshard_num): """ reshard the consolidate model into shard_model_state_dict_list according to the reshard_num. Return the shard_model_state_dict_list and shard_metadata_list. """ shard_model_state_dict: Dict[str, Any] = {} - shard_model_state_dict_list = [copy.deepcopy(shard_model_state_dict) for _ in range(reshard_num)] - + shard_model_state_dict_list = [ + copy.deepcopy(shard_model_state_dict) for _ in range(reshard_num) + ] + # flatten and shard tensor per layer - for (shard_model_name, layer_names) in zip(shard_model['model'].keys(), layer_name_lists): + for (shard_model_name, layer_names) in zip(shard_model['model'].keys(), + layer_name_lists): tensor_buffer_list = [] for name in layer_names: - tensor_buffer = consolidate_model_dict[name] - tensor_buffer_list.append(tensor_buffer) - flat_tensor = flatten_tensor_list( - tensor_buffer_list) + tensor_buffer = consolidate_model_dict[name] + tensor_buffer_list.append(tensor_buffer) + flat_tensor = flatten_tensor_list(tensor_buffer_list) shard_tensor_list = _get_shard(flat_tensor, reshard_num) - - for shard_tensor, shard_model_dict in zip(shard_tensor_list, shard_model_state_dict_list): + + for shard_tensor, shard_model_dict in zip(shard_tensor_list, + shard_model_state_dict_list): shard_model_dict[shard_model_name] = shard_tensor - + # get shardmeta_list shard_metadata_list = [] for idx in range(reshard_num): @@ -330,27 +349,29 @@ def reshard_model_dict(consolidate_model_dict, shard_model, layer_name_lists, re shard_meta_data['world_size'] = reshard_num shard_meta_data['rank'] = idx shard_metadata_list.append(shard_meta_data) - + return shard_model_state_dict_list, shard_metadata_list -def reshard_optim_dict(consolidate_optim_dict, shard_optim, layer_name_lists, reshard_num): +def reshard_optim_dict(consolidate_optim_dict, shard_optim, layer_name_lists, + reshard_num): """ reshard the consolidate optim into shard_optim_state_dict_list according to the reshard_num. Return the shard_optim_state_dict_list and shard_metadata_list. """ consolidate_optim_state = consolidate_optim_dict['state'] - - shard_optim_state_dict: Dict[str, Any] = { - 'state': {}, - 'param_groups': {} - } - shard_optim_state_dict_list = [copy.deepcopy(shard_optim_state_dict) for _ in range(reshard_num)] - + + shard_optim_state_dict: Dict[str, Any] = {'state': {}, 'param_groups': {}} + shard_optim_state_dict_list = [ + copy.deepcopy(shard_optim_state_dict) for _ in range(reshard_num) + ] + # flatten and shard tensor per layer per state_name for idx, layer_names in enumerate(layer_name_lists): shard_value: Dict[str, Any] = {} - shard_value_list = [copy.deepcopy(shard_value) for _ in range(reshard_num)] + shard_value_list = [ + copy.deepcopy(shard_value) for _ in range(reshard_num) + ] for state_name in consolidate_optim_state[layer_names[0]].keys(): tensor_buffer_list = [] # we need the params of a whole layer state to be flatten and shard @@ -358,37 +379,40 @@ def reshard_optim_dict(consolidate_optim_dict, shard_optim, layer_name_lists, re state_params = consolidate_optim_state[name][state_name] # state name 'step' if isinstance(state_params, - torch.Tensor) and state_params.dim() == 0: + torch.Tensor) and state_params.dim() == 0: for shard_value in shard_value_list: shard_value[state_name] = state_params break - + tensor_buffer_list.append(state_params) - flat_tensor = flatten_tensor_list( - tensor_buffer_list) + flat_tensor = flatten_tensor_list(tensor_buffer_list) if state_params.dim() != 0: shard_tensor_list = _get_shard(flat_tensor, reshard_num) - for (shard_value, shard_tensor) in zip(shard_value_list, shard_tensor_list): - shard_value[state_name] = shard_tensor + for (shard_value, shard_tensor) in zip(shard_value_list, + shard_tensor_list): + shard_value[state_name] = shard_tensor - for (shard_value, shard_optim_state_dict) in zip(shard_value_list, shard_optim_state_dict_list): + for (shard_value, + shard_optim_state_dict) in zip(shard_value_list, + shard_optim_state_dict_list): shard_optim_state_dict['state'][idx] = shard_value - + shard_metadata_list = [] # get the param_group of optim_state_dict and shard_meta_lists for (idx, shard_optim_state_dict) in enumerate(shard_optim_state_dict_list): - shard_optim_state_dict['param_groups'] = shard_optim['optimizer']['param_groups'] + shard_optim_state_dict['param_groups'] = shard_optim['optimizer'][ + 'param_groups'] shard_meta_data = copy.deepcopy(shard_optim["shard_metadata"]) shard_meta_data['world_size'] = reshard_num shard_meta_data['rank'] = idx shard_metadata_list.append(shard_meta_data) - + return shard_optim_state_dict_list, shard_metadata_list - + def consolidate_and_reshard_model_dict(ckpt_dir, ckpt_name="", @@ -424,14 +448,16 @@ def consolidate_and_reshard_model_dict(ckpt_dir, Returns: model_state_dict: the consolidated model state dict or reshard model state dict list. """ - + checkpoints = load_checkpoints(ckpt_dir, ckpt_name) - full_state_dict = consolidate_sharded_model_checkpoints(ckpt_dir, checkpoints) - + full_state_dict = consolidate_sharded_model_checkpoints( + ckpt_dir, checkpoints) + if reshard_num == 1: if save_model: actual_save_path = save_path if save_path else ckpt_dir + "model_consolidated.pth" - save_checkpoints(full_state_dict, checkpoints[0]['shard_metadata'], actual_save_path, 'model') + save_checkpoints(full_state_dict, checkpoints[0]['shard_metadata'], + actual_save_path, 'model') return full_state_dict # load layer_info @@ -442,25 +468,25 @@ def consolidate_and_reshard_model_dict(ckpt_dir, layer_info = pickle.load(f) except FileNotFoundError: print(f"please consolidate model first!") - - model_state_dict_list, shard_metadata_list = reshard_model_dict(full_state_dict, checkpoints[0], layer_info[0], reshard_num) - - + + model_state_dict_list, shard_metadata_list = reshard_model_dict( + full_state_dict, checkpoints[0], layer_info[0], reshard_num) + if save_model: - if save_path == "": + if save_path == "": save_path = ckpt_dir - - actual_save_path = [ + + actual_save_path = [ f"rank-{rank}-of-{reshard_num}-model.pth" for rank in range(reshard_num) ] - - save_checkpoints(model_state_dict_list, shard_metadata_list, actual_save_path, 'model') - - + + save_checkpoints(model_state_dict_list, shard_metadata_list, + actual_save_path, 'model') + return model_state_dict_list - - + + def consolidate_and_reshard_optim_dict(ckpt_dir, ckpt_name="", reshard_num=1, @@ -475,7 +501,7 @@ def consolidate_and_reshard_optim_dict(ckpt_dir, """ # load checkpoints checkpoints = load_checkpoints(ckpt_dir, ckpt_name) - + # load layer_info file_path = ckpt_dir + "layer_info.pickle" layer_info = [] @@ -484,31 +510,35 @@ def consolidate_and_reshard_optim_dict(ckpt_dir, layer_info = pickle.load(f) except FileNotFoundError: print(f"please consolidate model first!") - - full_optim_state_dict = consolidate_sharded_optimizer_checkpoints(ckpt_dir, checkpoints, layer_info) - + + full_optim_state_dict = consolidate_sharded_optimizer_checkpoints( + ckpt_dir, checkpoints, layer_info) + actual_save_path = None - + if reshard_num == 1: if save_optimizer: actual_save_path = save_path if save_path else ckpt_dir + "consolidated_optimizer.pth" - save_checkpoints(full_optim_state_dict, checkpoints[0]['shard_metadata'], actual_save_path, 'optimizer') + save_checkpoints(full_optim_state_dict, + checkpoints[0]['shard_metadata'], actual_save_path, + 'optimizer') return full_optim_state_dict - optim_state_dict_list, shard_metadata_list = reshard_optim_dict(full_optim_state_dict, checkpoints[0], layer_info[0], reshard_num) - - + optim_state_dict_list, shard_metadata_list = reshard_optim_dict( + full_optim_state_dict, checkpoints[0], layer_info[0], reshard_num) + if save_optimizer: if save_path == "": save_path = ckpt_dir - + actual_save_path = [ save_path + f"rank-{rank}-of-{reshard_num}-optimizer.pth" for rank in range(reshard_num) ] import pdb pdb.set_trace() - save_checkpoints(optim_state_dict_list, shard_metadata_list, actual_save_path, 'optimizer') - - return optim_state_dict_list \ No newline at end of file + save_checkpoints(optim_state_dict_list, shard_metadata_list, + actual_save_path, 'optimizer') + + return optim_state_dict_list diff --git a/torchacc/utils/consolidate_sharded_ckpts.py b/torchacc/utils/consolidate_sharded_ckpts.py index fcef1ed..7169d15 100644 --- a/torchacc/utils/consolidate_sharded_ckpts.py +++ b/torchacc/utils/consolidate_sharded_ckpts.py @@ -4,70 +4,70 @@ MODEL_NAME_PATTERN = "rank*-of-*-model.pth" OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" + def main(): - parser = ArgumentParser() - parser.add_argument( - "--ckpt_dir", - type=str, - required=True, - help=( - f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." - f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," - f"the default pattern is {OPTIM_NAME_PATTERN}" - ), - ) - parser.add_argument( - "--ckpt_name", - type=str, - default="", - help=( - f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." - f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," - f"the default pattern is {OPTIM_NAME_PATTERN}" - ), - ) - parser.add_argument( - "--ckpt_type", - type=str, - default="model", - help=( - "The type of checkpoint to consolidate, you can choose model or optimizer. Please consolidate model fisrt, and then consolidate optimizer." - ), - ) - parser.add_argument( - "--reshard_num", - type=int, - default=1, - help=( - "We now support the reshard of XLA FSDP checkpoint according to the reshard_num." - ) - ) - parser.add_argument( - "--save_path", - type=str, - default="", - help=(f"The save path of the output state dict " + parser = ArgumentParser() + parser.add_argument( + "--ckpt_dir", + type=str, + required=True, + help=( + f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." + f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {OPTIM_NAME_PATTERN}"), + ) + parser.add_argument( + "--ckpt_name", + type=str, + default="", + help=( + f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." + f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {OPTIM_NAME_PATTERN}"), + ) + parser.add_argument( + "--ckpt_type", + type=str, + default="model", + help=( + "The type of checkpoint to consolidate, you can choose model or optimizer. Please consolidate model fisrt, and then consolidate optimizer." + ), + ) + parser.add_argument( + "--reshard_num", + type=int, + default=1, + help=( + "We now support the reshard of XLA FSDP checkpoint according to the reshard_num." + )) + parser.add_argument( + "--save_path", + type=str, + default="", + help=( + f"The save path of the output state dict " f"(default consolidate_path is ``ckpt_dir + model/optimizer_consolidated.pth``)" f"If you need to reshard the checkpoint, please only pass the save_dir(default is ckpt_dir)," - f"we will save the file in path ``save_path + {MODEL_NAME_PATTERN}/{OPTIM_NAME_PATTERN}``"), - ) - args = parser.parse_args() - assert args.ckpt_type in ['model', 'optimizer'], ( - 'the ckpt_type should be model or optimizer' - ) - - if args.ckpt_type == "model": - if args.ckpt_name == "": - args.ckpt_name = MODEL_NAME_PATTERN - consolidate_and_reshard_model_dict(args.ckpt_dir, args.ckpt_name, args.reshard_num, - args.save_path) - else: - if args.ckpt_name == "": - args.ckpt_name = OPTIM_NAME_PATTERN - consolidate_and_reshard_optim_dict(args.ckpt_dir, args.ckpt_name, args.reshard_num, - args.save_path) + f"we will save the file in path ``save_path + {MODEL_NAME_PATTERN}/{OPTIM_NAME_PATTERN}``" + ), + ) + args = parser.parse_args() + assert args.ckpt_type in ['model', 'optimizer' + ], ('the ckpt_type should be model or optimizer') + + if args.ckpt_type == "model": + if args.ckpt_name == "": + args.ckpt_name = MODEL_NAME_PATTERN + consolidate_and_reshard_model_dict(args.ckpt_dir, args.ckpt_name, + args.reshard_num, args.save_path) + else: + if args.ckpt_name == "": + args.ckpt_name = OPTIM_NAME_PATTERN + consolidate_and_reshard_optim_dict(args.ckpt_dir, args.ckpt_name, + args.reshard_num, args.save_path) + if __name__ == "__main__": - main() + main() From 2dcee05f9d110d6e96c02fc27d93b22750e05410 Mon Sep 17 00:00:00 2001 From: shw Date: Sat, 12 Oct 2024 11:40:15 +0800 Subject: [PATCH 04/24] fix model save path --- torchacc/dist/state_dict_utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index f1d4275..0c330d1 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -182,11 +182,11 @@ def save_checkpoint(state_dict, shard_metadata, save_path, save_type): f"{save_type}": state_dict, "shard_metadata": shard_metadata, } + torch.save(model, save_path) threads = [] - #import pdb - #pdb.set_trace() + for state_dict, shard_metadata, save_path in zip(state_dict_list, shard_metadata_list, save_paths): @@ -477,7 +477,7 @@ def consolidate_and_reshard_model_dict(ckpt_dir, save_path = ckpt_dir actual_save_path = [ - f"rank-{rank}-of-{reshard_num}-model.pth" + save_path + f"rank-{rank}-of-{reshard_num}-model.pth" for rank in range(reshard_num) ] @@ -536,8 +536,7 @@ def consolidate_and_reshard_optim_dict(ckpt_dir, save_path + f"rank-{rank}-of-{reshard_num}-optimizer.pth" for rank in range(reshard_num) ] - import pdb - pdb.set_trace() + save_checkpoints(optim_state_dict_list, shard_metadata_list, actual_save_path, 'optimizer') From 622c40ddc476d475ff9c0590a271b58930006dbe Mon Sep 17 00:00:00 2001 From: shw Date: Sat, 12 Oct 2024 13:50:59 +0800 Subject: [PATCH 05/24] add clone for get shard --- torchacc/dist/state_dict_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index 0c330d1..bda1dbb 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -303,7 +303,7 @@ def _get_shard(tensor, shard_num): for i in range(shard_num): begin = i * local_size end = (i + 1) * local_size - tensor_list.append(tensor[begin:end]) + tensor_list.append(tensor[begin:end].clone()) return tensor_list From d5bf4825e1ec47b84c467885a86e79d8f7813f23 Mon Sep 17 00:00:00 2001 From: shw Date: Sat, 12 Oct 2024 17:19:21 +0800 Subject: [PATCH 06/24] fix ckpt_dir --- torchacc/dist/state_dict_utils.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index bda1dbb..1da6c96 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -2,6 +2,7 @@ import copy from glob import glob import pickle +import os import threading from typing import Dict @@ -135,9 +136,13 @@ def get_layer_full_info(shard_metadata, model_state_dict): def load_checkpoints(ckpt_dir, ckpt_name="*.pth"): - ckpt_path_pattern = ckpt_dir + ckpt_name + ckpt_path_pattern = os.path.join(ckpt_dir, "") + ckpt_name + #print(ckpt_dir) + #print(ckpt_name) + #print(ckpt_path_pattern) ckpt_paths = glob(ckpt_path_pattern) - + #import pdb + #pdb.set_trace() checkpoints_and_paths = [] def load_ckpt(path): @@ -455,13 +460,14 @@ def consolidate_and_reshard_model_dict(ckpt_dir, if reshard_num == 1: if save_model: - actual_save_path = save_path if save_path else ckpt_dir + "model_consolidated.pth" + actual_save_path = save_path if save_path else os.path.join( + ckpt_dir, "consolidated_optimizer.pth") save_checkpoints(full_state_dict, checkpoints[0]['shard_metadata'], actual_save_path, 'model') return full_state_dict # load layer_info - file_path = ckpt_dir + "layer_info.pickle" + file_path = cos.path.join(ckpt_dir, "layer_info.pickle") layer_info = [] try: with open(file_path, 'rb') as f: @@ -477,7 +483,7 @@ def consolidate_and_reshard_model_dict(ckpt_dir, save_path = ckpt_dir actual_save_path = [ - save_path + f"rank-{rank}-of-{reshard_num}-model.pth" + os.path.join(save_path, f"rank-{rank}-of-{reshard_num}-model.pth") for rank in range(reshard_num) ] @@ -503,7 +509,7 @@ def consolidate_and_reshard_optim_dict(ckpt_dir, checkpoints = load_checkpoints(ckpt_dir, ckpt_name) # load layer_info - file_path = ckpt_dir + "layer_info.pickle" + file_path = cos.path.join(ckpt_dir, "layer_info.pickle") layer_info = [] try: with open(file_path, 'rb') as f: @@ -518,7 +524,8 @@ def consolidate_and_reshard_optim_dict(ckpt_dir, if reshard_num == 1: if save_optimizer: - actual_save_path = save_path if save_path else ckpt_dir + "consolidated_optimizer.pth" + actual_save_path = save_path if save_path else os.path.join( + ckpt_dir, "consolidated_optimizer.pth") save_checkpoints(full_optim_state_dict, checkpoints[0]['shard_metadata'], actual_save_path, 'optimizer') @@ -533,7 +540,8 @@ def consolidate_and_reshard_optim_dict(ckpt_dir, save_path = ckpt_dir actual_save_path = [ - save_path + f"rank-{rank}-of-{reshard_num}-optimizer.pth" + os.path.join(save_path, + f"rank-{rank}-of-{reshard_num}-optimizer.pth") for rank in range(reshard_num) ] From c2365a2eeda2c235f747a11ea840b9e56d3ed483 Mon Sep 17 00:00:00 2001 From: shw Date: Mon, 14 Oct 2024 15:25:01 +0800 Subject: [PATCH 07/24] memory efficient --- torchacc/dist/state_dict_utils.py | 60 +++++++++++++++++++------------ 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index 1da6c96..4e662fe 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -135,24 +135,24 @@ def get_layer_full_info(shard_metadata, model_state_dict): return (layer_name_list, layer_size_list, layer_numel_list, sharded_list) -def load_checkpoints(ckpt_dir, ckpt_name="*.pth"): - ckpt_path_pattern = os.path.join(ckpt_dir, "") + ckpt_name - #print(ckpt_dir) - #print(ckpt_name) - #print(ckpt_path_pattern) +def load_checkpoints(ckpt_dir, ckpt_name): + """ + Load checkpoints that match the pattern of `ckpt_dir + ckpt_name`. + We use multiple thread to accelerate the loading progress, each thread + load one shard checkpoint. + """ + ckpt_path_pattern = os.path.join(ckpt_dir, ckpt_name) ckpt_paths = glob(ckpt_path_pattern) - #import pdb - #pdb.set_trace() - checkpoints_and_paths = [] + checkpoints_and_paths = [[] for _ in range(len(ckpt_paths))] - def load_ckpt(path): + def load_ckpt(path, idx): ckpt = torch.load(path, map_location="cpu") - checkpoints_and_paths.append((ckpt, path)) + checkpoints_and_paths[idx] = (ckpt, path) threads = [] - for path in ckpt_paths: - thread = threading.Thread(target=load_ckpt, args=(path,)) + for idx, path in enumerate(ckpt_paths): + thread = threading.Thread(target=load_ckpt, args=(path, idx)) threads.append(thread) thread.start() @@ -178,6 +178,11 @@ def load_ckpt(path): def save_checkpoints(state_dict_list, shard_metadata_list, save_paths, save_type): + """ + Save checkpoints to save_paths. + We use multiple thread to accelerate the saving progress, each thread + save one shard checkpoint. + """ if not isinstance(state_dict_list, list): torch.save(state_dict_list, save_paths) return @@ -208,8 +213,8 @@ def save_checkpoint(state_dict, shard_metadata, save_path, save_type): def consolidate_sharded_model_checkpoints(ckpt_dir, checkpoints): """ Consolidate the sharded FSDP checkpoints into a single model checkpoint. - """ - + Release the tensor in sharded FSDP checkpoints immediately to save memory. + """ state_dict_list = [ckpt["model"] for ckpt in checkpoints] shard_metadata = checkpoints[0]["shard_metadata"] layer_name_list, layer_size_list, layer_numel_list, sharded_list = get_layer_full_info( @@ -236,7 +241,10 @@ def consolidate_sharded_model_checkpoints(ckpt_dir, checkpoints): p_shard_list = [] for state_dict in state_dict_list: p_shard_list.append(state_dict[state_name]) + state_dict[state_name] = None + consolidate_params = torch.cat(p_shard_list, dim=0) + orig_params = unflatten_params(consolidate_params, layer_name, layer_size, layer_numel) @@ -249,7 +257,8 @@ def consolidate_sharded_model_checkpoints(ckpt_dir, checkpoints): def consolidate_sharded_optimizer_checkpoints(ckpt_dir, checkpoints, layer_info): """ - Consolidate the sharded FSDP checkpoints into a single optimizer checkpoint. + Consolidate the sharded FSDP checkpoints into a single optimizer checkpoint. + Release the tensor in sharded FSDP checkpoints immediately to save memory. """ optim_state_dict_list = [ckpt['optimizer'] for ckpt in checkpoints] shard_metadata = checkpoints[0]["shard_metadata"] @@ -283,8 +292,10 @@ def consolidate_sharded_optimizer_checkpoints(ckpt_dir, checkpoints, for optim_state_dict in optim_state_dict_list: p_shard_list.append( optim_state_dict['state'][idx][state_name]) + optim_state_dict['state'][idx][state_name] = None consolidate_params = torch.cat(p_shard_list, dim=0) + orig_params = unflatten_params(consolidate_params, layer_name, layer_size, layer_numel) @@ -297,7 +308,7 @@ def consolidate_sharded_optimizer_checkpoints(ckpt_dir, checkpoints, def _get_shard(tensor, shard_num): """ - Return the shard tensor list of a full flatten tensor. + Return the shard tensor list of a full flatten tensor. """ if tensor.numel() % shard_num != 0: pad_size = shard_num - tensor.numel() % shard_num @@ -325,8 +336,9 @@ def flatten_tensor_list(param_list): def reshard_model_dict(consolidate_model_dict, shard_model, layer_name_lists, reshard_num): """ - reshard the consolidate model into shard_model_state_dict_list according to the reshard_num. - Return the shard_model_state_dict_list and shard_metadata_list. + reshard the consolidate model into shard_model_state_dict_list according to the reshard_num. + Release tensor in consolidate_model_dict immediately to save tensor. + Return the shard_model_state_dict_list and shard_metadata_list. """ shard_model_state_dict: Dict[str, Any] = {} shard_model_state_dict_list = [ @@ -340,6 +352,7 @@ def reshard_model_dict(consolidate_model_dict, shard_model, layer_name_lists, for name in layer_names: tensor_buffer = consolidate_model_dict[name] tensor_buffer_list.append(tensor_buffer) + consolidate_model_dict[name] = None flat_tensor = flatten_tensor_list(tensor_buffer_list) shard_tensor_list = _get_shard(flat_tensor, reshard_num) @@ -361,8 +374,9 @@ def reshard_model_dict(consolidate_model_dict, shard_model, layer_name_lists, def reshard_optim_dict(consolidate_optim_dict, shard_optim, layer_name_lists, reshard_num): """ - reshard the consolidate optim into shard_optim_state_dict_list according to the reshard_num. - Return the shard_optim_state_dict_list and shard_metadata_list. + reshard the consolidate optim into shard_optim_state_dict_list according to the reshard_num. + Release tensor in consolidate_optim_dict immediately to save tensor. + Return the shard_optim_state_dict_list and shard_metadata_list. """ consolidate_optim_state = consolidate_optim_dict['state'] @@ -382,6 +396,8 @@ def reshard_optim_dict(consolidate_optim_dict, shard_optim, layer_name_lists, # we need the params of a whole layer state to be flatten and shard for name in layer_names: state_params = consolidate_optim_state[name][state_name] + consolidate_optim_state[name][state_name] = None + # state name 'step' if isinstance(state_params, torch.Tensor) and state_params.dim() == 0: @@ -467,7 +483,7 @@ def consolidate_and_reshard_model_dict(ckpt_dir, return full_state_dict # load layer_info - file_path = cos.path.join(ckpt_dir, "layer_info.pickle") + file_path = os.path.join(ckpt_dir, "layer_info.pickle") layer_info = [] try: with open(file_path, 'rb') as f: @@ -509,7 +525,7 @@ def consolidate_and_reshard_optim_dict(ckpt_dir, checkpoints = load_checkpoints(ckpt_dir, ckpt_name) # load layer_info - file_path = cos.path.join(ckpt_dir, "layer_info.pickle") + file_path = os.path.join(ckpt_dir, "layer_info.pickle") layer_info = [] try: with open(file_path, 'rb') as f: From bacb8acf6a04d3da01cbc4f01ef7c328b5e2ab4c Mon Sep 17 00:00:00 2001 From: shw Date: Mon, 14 Oct 2024 15:42:32 +0800 Subject: [PATCH 08/24] fix typo --- torchacc/dist/state_dict_utils.py | 2 +- torchacc/utils/consolidate_sharded_ckpts.py | 73 --------------------- 2 files changed, 1 insertion(+), 74 deletions(-) delete mode 100644 torchacc/utils/consolidate_sharded_ckpts.py diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index 4e662fe..f6ff226 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -519,7 +519,7 @@ def consolidate_and_reshard_optim_dict(ckpt_dir, reshard the FSDP optimizer according to the reshard_num. Returns: - optim_state_dict: the consolidated model state dict or reshard model state dict list + optim_state_dict: the consolidated optim state dict or reshard optim state dict list """ # load checkpoints checkpoints = load_checkpoints(ckpt_dir, ckpt_name) diff --git a/torchacc/utils/consolidate_sharded_ckpts.py b/torchacc/utils/consolidate_sharded_ckpts.py deleted file mode 100644 index 7169d15..0000000 --- a/torchacc/utils/consolidate_sharded_ckpts.py +++ /dev/null @@ -1,73 +0,0 @@ -from argparse import ArgumentParser -from torchacc.dist.state_dict_utils import consolidate_and_reshard_model_dict, consolidate_and_reshard_optim_dict - -MODEL_NAME_PATTERN = "rank*-of-*-model.pth" -OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" - - -def main(): - parser = ArgumentParser() - parser.add_argument( - "--ckpt_dir", - type=str, - required=True, - help=( - f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." - f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," - f"the default pattern is {OPTIM_NAME_PATTERN}"), - ) - parser.add_argument( - "--ckpt_name", - type=str, - default="", - help=( - f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." - f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," - f"the default pattern is {OPTIM_NAME_PATTERN}"), - ) - parser.add_argument( - "--ckpt_type", - type=str, - default="model", - help=( - "The type of checkpoint to consolidate, you can choose model or optimizer. Please consolidate model fisrt, and then consolidate optimizer." - ), - ) - parser.add_argument( - "--reshard_num", - type=int, - default=1, - help=( - "We now support the reshard of XLA FSDP checkpoint according to the reshard_num." - )) - parser.add_argument( - "--save_path", - type=str, - default="", - help=( - f"The save path of the output state dict " - f"(default consolidate_path is ``ckpt_dir + model/optimizer_consolidated.pth``)" - f"If you need to reshard the checkpoint, please only pass the save_dir(default is ckpt_dir)," - f"we will save the file in path ``save_path + {MODEL_NAME_PATTERN}/{OPTIM_NAME_PATTERN}``" - ), - ) - args = parser.parse_args() - assert args.ckpt_type in ['model', 'optimizer' - ], ('the ckpt_type should be model or optimizer') - - if args.ckpt_type == "model": - if args.ckpt_name == "": - args.ckpt_name = MODEL_NAME_PATTERN - consolidate_and_reshard_model_dict(args.ckpt_dir, args.ckpt_name, - args.reshard_num, args.save_path) - else: - if args.ckpt_name == "": - args.ckpt_name = OPTIM_NAME_PATTERN - consolidate_and_reshard_optim_dict(args.ckpt_dir, args.ckpt_name, - args.reshard_num, args.save_path) - - -if __name__ == "__main__": - main() From c6af54e98022c26704d593b713bade614e033bf1 Mon Sep 17 00:00:00 2001 From: shw Date: Mon, 14 Oct 2024 15:43:01 +0800 Subject: [PATCH 09/24] rename consolidate util4 --- .../utils/consolidate_and_reshard_ckpts.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 torchacc/utils/consolidate_and_reshard_ckpts.py diff --git a/torchacc/utils/consolidate_and_reshard_ckpts.py b/torchacc/utils/consolidate_and_reshard_ckpts.py new file mode 100644 index 0000000..7169d15 --- /dev/null +++ b/torchacc/utils/consolidate_and_reshard_ckpts.py @@ -0,0 +1,73 @@ +from argparse import ArgumentParser +from torchacc.dist.state_dict_utils import consolidate_and_reshard_model_dict, consolidate_and_reshard_optim_dict + +MODEL_NAME_PATTERN = "rank*-of-*-model.pth" +OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--ckpt_dir", + type=str, + required=True, + help=( + f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." + f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {OPTIM_NAME_PATTERN}"), + ) + parser.add_argument( + "--ckpt_name", + type=str, + default="", + help=( + f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." + f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {OPTIM_NAME_PATTERN}"), + ) + parser.add_argument( + "--ckpt_type", + type=str, + default="model", + help=( + "The type of checkpoint to consolidate, you can choose model or optimizer. Please consolidate model fisrt, and then consolidate optimizer." + ), + ) + parser.add_argument( + "--reshard_num", + type=int, + default=1, + help=( + "We now support the reshard of XLA FSDP checkpoint according to the reshard_num." + )) + parser.add_argument( + "--save_path", + type=str, + default="", + help=( + f"The save path of the output state dict " + f"(default consolidate_path is ``ckpt_dir + model/optimizer_consolidated.pth``)" + f"If you need to reshard the checkpoint, please only pass the save_dir(default is ckpt_dir)," + f"we will save the file in path ``save_path + {MODEL_NAME_PATTERN}/{OPTIM_NAME_PATTERN}``" + ), + ) + args = parser.parse_args() + assert args.ckpt_type in ['model', 'optimizer' + ], ('the ckpt_type should be model or optimizer') + + if args.ckpt_type == "model": + if args.ckpt_name == "": + args.ckpt_name = MODEL_NAME_PATTERN + consolidate_and_reshard_model_dict(args.ckpt_dir, args.ckpt_name, + args.reshard_num, args.save_path) + else: + if args.ckpt_name == "": + args.ckpt_name = OPTIM_NAME_PATTERN + consolidate_and_reshard_optim_dict(args.ckpt_dir, args.ckpt_name, + args.reshard_num, args.save_path) + + +if __name__ == "__main__": + main() From 128517cf11d2729e24b9e777e765e3b979e8c8f1 Mon Sep 17 00:00:00 2001 From: shw Date: Tue, 15 Oct 2024 11:52:52 +0800 Subject: [PATCH 10/24] ut done --- .../consolidate_and_reshard_ckpts.py | 294 ++++++++++++++++++ tests/standalone/offload.py | 3 +- torchacc/dist/state_dict_utils.py | 12 +- 3 files changed, 303 insertions(+), 6 deletions(-) create mode 100644 tests/standalone/consolidate_and_reshard_ckpts.py diff --git a/tests/standalone/consolidate_and_reshard_ckpts.py b/tests/standalone/consolidate_and_reshard_ckpts.py new file mode 100644 index 0000000..8e3c016 --- /dev/null +++ b/tests/standalone/consolidate_and_reshard_ckpts.py @@ -0,0 +1,294 @@ +import argparse +import os + +import torch +import torch_xla.core.xla_model as xm +import torchacc as ta +from torchacc.dist.state_dict_utils import consolidate_and_reshard_model_dict, consolidate_and_reshard_optim_dict, load_checkpoints + +from utils import EchoDataset, set_seed + + +class Net(torch.nn.Module): + + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(1024, 1024) + self.fc2 = torch.nn.Linear(1024, 1024) + self.fc3 = torch.nn.Linear(1024, 1024) + self.fc4 = torch.nn.Linear(1024, 1024) + self.fc5 = torch.nn.Linear(1024, 1024) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + x = self.fc4(x) + x = self.fc5(x) + return x + + +def train(args, model, device, train_loader, optimizer): + steps_per_print = args.steps_per_print + train_steps = args.train_steps * args.gradient_accumulation_steps + + scaler = ta.amp.GradScaler() if args.fp16 else None + + amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 + amp_enabled = args.fp16 or args.bf16 + gradient_accumulation_steps = args.gradient_accumulation_steps + + total_loss = torch.tensor(0.0).to(device) + global_step = 1 + for step, data in enumerate(train_loader): + with torch.cuda.amp.autocast( + enabled=amp_enabled, cache_enabled=True, dtype=amp_dtype): + loss = model(data[0]) + loss = torch.nn.functional.nll_loss(loss, data[1]) + if scaler is not None: + scaler.scale(loss).backward() + else: + loss.backward() + step += 1 + loss = loss.clone().detach() / gradient_accumulation_steps + total_loss += loss + if step % gradient_accumulation_steps == 0: + if scaler is not None: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + optimizer.zero_grad() + if ta.dist.local_rank() == 0: + if global_step % steps_per_print == 0: + ta.sync() + ta.utils.logger.info( + f"step: {global_step}, loss: {total_loss}") + if global_step == train_steps: + ta.sync() + return + global_step += 1 + total_loss.zero_() + + +def compare_model_dict(dict1, dict2, idx): + if dict1.keys() != dict2.keys(): + print("model dict keys are different") + return + + difference = False + + for key in dict2.keys(): + tensor1 = dict1[key] + tensor2 = dict2[key] + + if not torch.equal(tensor1, tensor2): + print(f"Difference found at key: {key}") + print(f"Tensor 1: {tensor1}") + print(f"Tensor 2: {tensor2}") + difference = True + + if not difference: + print(f"The model dict shard {idx} are same.") + + +def compare_optim_dict(state_dict1, state_dict2, idx): + state1 = state_dict1['state'] + state2 = state_dict2['state'] + if state1.keys() != state2.keys(): + print("optimizer state keys are different") + return + + difference = False + for key in state2.keys(): + dict1 = state1[key] + dict2 = state2[key] + for state_name in dict1.keys(): + tensor1 = dict1[state_name] + tensor2 = dict2[state_name] + + if not torch.equal(tensor1, tensor2): + print(f"Difference found at state key: {key}-{state_name}") + print(f"Tensor 1: {tensor1}") + print(f"Tensor 2: {tensor2}") + difference = True + + param_list1 = state_dict1['param_groups'] + param_list2 = state_dict2['param_groups'] + + for param1, param2 in zip(param_list1, param_list2): + if param1.keys() != param2.keys(): + print("optimizer param_groups keys are different") + return + + for key in param2.keys(): + if param2[key] != param1[key]: + print(f"Difference found at param_group key: {key}") + print(f"value 1: {param1[key]}") + print(f"value 2: {param2[key]}") + difference = True + + if not difference: + print(f"The optim dict shard {idx} are same.") + + +def main(args): + fsdp_num = args.fsdp_num + batch_size = args.batch_size + train_steps = args.train_steps * args.gradient_accumulation_steps + ckpt_dir = args.ckpt_dir + reshard_num = args.reshard_num + + model = Net() + + # set config + config = ta.Config() + config.backend = args.backend + config.compute.fp16 = args.fp16 + config.compute.bf16 = args.bf16 + + config.dist.fsdp.size = fsdp_num + config.dist.fsdp.wrap_layer_cls = {"Linear"} + config.dist.fsdp.flatten_parameters = True + + # accelerate + model = ta.accelerate(model, config=config) + device = model.device + + optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) + + scaler = ta.amp.GradScaler() if args.fp16 else None + + amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 + amp_enabled = args.fp16 or args.bf16 + + train_loader = EchoDataset( + data=[ + torch.zeros(batch_size, 1024), + torch.zeros(batch_size, dtype=torch.int64) + ], + repeat_count=train_steps) + + train_loader = ta.AsyncLoader(train_loader, device) + + # train model + train(args, model, device, train_loader, optimizer) + + # save shard model and optimizer + if not os.path.exists(ckpt_dir): + os.makedirs(ckpt_dir) + + model_ckpt = { + "model": model.state_dict(), + "shard_metadata": model.model.model.get_shard_metadata( + ), # we need first get the xla model + } + model_ckpt_path = os.path.join( + ckpt_dir, + f"rank{ta.dist.local_rank()}-of-{ta.dist.world_size()}-model.pth") + ta.save(model_ckpt, model_ckpt_path, master_only=False) + xm.rendezvous("saving_model") + + optim_ckpt = { + "optimizer": optimizer.state_dict(), + "shard_metadata": model.model.model.get_shard_metadata(), + } + optim_ckpt_path = os.path.join( + ckpt_dir, + f"rank{ta.dist.local_rank()}-of-{ta.dist.world_size()}-optim.pth") + ta.save(optim_ckpt, optim_ckpt_path, master_only=False) + xm.rendezvous("saving_optim") + + # rank 0 do consolidate and reshard: + if ta.dist.local_rank() == 0: + # consolidate and reshard model and optimizer + model_reshard_dicts, _ = consolidate_and_reshard_model_dict( + ckpt_dir=ckpt_dir, + ckpt_name=f"rank*-of-*-model.pth", + reshard_num=reshard_num, + save_model=True, + ) + print(f"model consolidate and reshard to path:{ckpt_dir}") + + optim_reshard_dicts, _ = consolidate_and_reshard_optim_dict( + ckpt_dir=ckpt_dir, + ckpt_name=f"rank*-of-*-optim.pth", + reshard_num=reshard_num, + save_optimizer=True, + ) + print(f"optimizer consolidate and reshard to path:{ckpt_dir}") + + # compare shard model and optimizer + if reshard_num == fsdp_num: + model_shard_dicts = load_checkpoints( + kpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-model.pth") + optim_shard_dicts = load_checkpoints( + ckpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-optim.pth") + + for idx, (dict1, dict2) in enumerate( + zip(model_shard_dicts, model_reshard_dicts)): + compare_model_dict(dict1['model'], dict2, idx) + + for idx, (dict1, dict2) in enumerate( + zip(optim_shard_dicts, optim_reshard_dicts)): + compare_optim_dict(dict1['optimizer'], dict2, idx) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='TorchAcc Consolidate And Reshard FSDP Checkpoints') + parser.add_argument('--fsdp_num', type=int, default=1) + parser.add_argument('--gradient_accumulation_steps', type=int, default=1) + parser.add_argument('--batch_size', type=int, default=2) + parser.add_argument('--steps_per_print', type=int, default=1) + parser.add_argument('--train_steps', type=int, default=10) + parser.add_argument("--fp16", action="store_true", default=False) + parser.add_argument("--bf16", action="store_true", default=False) + parser.add_argument("--backend", type=str, default="lazy") + + MODEL_NAME_PATTERN = "rank*-of-*-model.pth" + OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" + # ckpt arguments + parser.add_argument( + "--ckpt_dir", + type=str, + required=True, + help=( + f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." + f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {OPTIM_NAME_PATTERN}"), + ) + parser.add_argument( + "--ckpt_name", + type=str, + default="", + help=( + f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." + f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {OPTIM_NAME_PATTERN}"), + ) + parser.add_argument( + "--reshard_num", + type=int, + default=1, + help=( + "We now support the reshard of XLA FSDP checkpoint according to the reshard_num." + )) + parser.add_argument( + "--save_path", + type=str, + default="", + help=( + f"The save path of the output state dict " + f"(default consolidate_path is ``ckpt_dir + model/optimizer_consolidated.pth``)" + f"If you need to reshard the checkpoint, please only pass the save_dir(default is ckpt_dir)," + f"we will save the file in path ``save_path + {MODEL_NAME_PATTERN}/{OPTIM_NAME_PATTERN}``" + ), + ) + + args = parser.parse_args() + + set_seed() + main(args) diff --git a/tests/standalone/offload.py b/tests/standalone/offload.py index 85b2468..f789902 100644 --- a/tests/standalone/offload.py +++ b/tests/standalone/offload.py @@ -86,7 +86,7 @@ def train(args, model, device, train_loader, optimizer, scaler): def main(): - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser = argparse.ArgumentParser(description='Torchacc Offload Example') parser.add_argument( '--batch-size', type=int, @@ -115,6 +115,7 @@ def main(): device = dist.get_rank() model = Net() model.to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) scaler = torch.cuda.amp.GradScaler() diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index f6ff226..0f06072 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -219,8 +219,8 @@ def consolidate_sharded_model_checkpoints(ckpt_dir, checkpoints): shard_metadata = checkpoints[0]["shard_metadata"] layer_name_list, layer_size_list, layer_numel_list, sharded_list = get_layer_full_info( shard_metadata, state_dict_list[0]) - file_path = ckpt_dir + "layer_info.pickle" + file_path = os.path.join(ckpt_dir, "layer_info.pickle") with open(file_path, 'wb') as f: pickle.dump( [layer_name_list, layer_size_list, layer_numel_list, sharded_list], @@ -468,6 +468,7 @@ def consolidate_and_reshard_model_dict(ckpt_dir, Returns: model_state_dict: the consolidated model state dict or reshard model state dict list. + shard_meta_list: the reshard metadatalist. The consolidated model return None. """ checkpoints = load_checkpoints(ckpt_dir, ckpt_name) @@ -481,7 +482,7 @@ def consolidate_and_reshard_model_dict(ckpt_dir, save_checkpoints(full_state_dict, checkpoints[0]['shard_metadata'], actual_save_path, 'model') - return full_state_dict + return full_state_dict, None # load layer_info file_path = os.path.join(ckpt_dir, "layer_info.pickle") layer_info = [] @@ -506,7 +507,7 @@ def consolidate_and_reshard_model_dict(ckpt_dir, save_checkpoints(model_state_dict_list, shard_metadata_list, actual_save_path, 'model') - return model_state_dict_list + return model_state_dict_list, shard_metadata_list def consolidate_and_reshard_optim_dict(ckpt_dir, @@ -520,6 +521,7 @@ def consolidate_and_reshard_optim_dict(ckpt_dir, Returns: optim_state_dict: the consolidated optim state dict or reshard optim state dict list + shard_meta_list: the reshard metadatalist. The consolidated optim return None. """ # load checkpoints checkpoints = load_checkpoints(ckpt_dir, ckpt_name) @@ -546,7 +548,7 @@ def consolidate_and_reshard_optim_dict(ckpt_dir, checkpoints[0]['shard_metadata'], actual_save_path, 'optimizer') - return full_optim_state_dict + return full_optim_state_dict, None optim_state_dict_list, shard_metadata_list = reshard_optim_dict( full_optim_state_dict, checkpoints[0], layer_info[0], reshard_num) @@ -564,4 +566,4 @@ def consolidate_and_reshard_optim_dict(ckpt_dir, save_checkpoints(optim_state_dict_list, shard_metadata_list, actual_save_path, 'optimizer') - return optim_state_dict_list + return optim_state_dict_list, shard_metadata_list From 7260beb042df14d3d91d93caae4e52e8c9d6707b Mon Sep 17 00:00:00 2001 From: shw Date: Tue, 15 Oct 2024 11:54:30 +0800 Subject: [PATCH 11/24] add save comments --- torchacc/dist/state_dict_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index 0f06072..27c5500 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -183,6 +183,7 @@ def save_checkpoints(state_dict_list, shard_metadata_list, save_paths, We use multiple thread to accelerate the saving progress, each thread save one shard checkpoint. """ + # save consolidate checkpoint if not isinstance(state_dict_list, list): torch.save(state_dict_list, save_paths) return @@ -196,7 +197,7 @@ def save_checkpoint(state_dict, shard_metadata, save_path, save_type): torch.save(model, save_path) threads = [] - + # save reshard checkpoints for state_dict, shard_metadata, save_path in zip(state_dict_list, shard_metadata_list, save_paths): From 5504402233186a2e6519c7e94a1014426e862bb5 Mon Sep 17 00:00:00 2001 From: shw Date: Wed, 16 Oct 2024 14:14:37 +0800 Subject: [PATCH 12/24] isort --- tests/run_tests.sh | 1 + tests/standalone/consolidate_and_reshard_ckpts.py | 10 +++------- torchacc/dist/state_dict_utils.py | 6 +++--- torchacc/utils/consolidate_and_reshard_ckpts.py | 4 +++- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/run_tests.sh b/tests/run_tests.sh index daa98e5..0d957e1 100755 --- a/tests/run_tests.sh +++ b/tests/run_tests.sh @@ -8,6 +8,7 @@ function test_standalone() { torchrun --nproc_per_node=4 standalone/pipeline.py --pp_num 4 --gc --bf16 torchrun --nproc_per_node=4 standalone/pipeline.py --pp_num 4 --test_skip torchrun --nproc_per_node=4 standalone/ta_accelerate.py --gc + torchrun --nproc_per_node=4 standalone/consolidate_and_reshard_ckpts.py --fsdp_num 4 --ckpt_dir standalone/ckpt --reshard_num 4 # PyTorch DDP torchrun --nproc_per_node=4 standalone/ta_accelerate.py --backend eager # PyTorch FSDP diff --git a/tests/standalone/consolidate_and_reshard_ckpts.py b/tests/standalone/consolidate_and_reshard_ckpts.py index 8e3c016..cd39562 100644 --- a/tests/standalone/consolidate_and_reshard_ckpts.py +++ b/tests/standalone/consolidate_and_reshard_ckpts.py @@ -4,8 +4,9 @@ import torch import torch_xla.core.xla_model as xm import torchacc as ta -from torchacc.dist.state_dict_utils import consolidate_and_reshard_model_dict, consolidate_and_reshard_optim_dict, load_checkpoints - +from torchacc.dist.state_dict_utils import (consolidate_and_reshard_model_dict, + consolidate_and_reshard_optim_dict, + load_checkpoints) from utils import EchoDataset, set_seed @@ -157,11 +158,6 @@ def main(args): optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) - scaler = ta.amp.GradScaler() if args.fp16 else None - - amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 - amp_enabled = args.fp16 or args.bf16 - train_loader = EchoDataset( data=[ torch.zeros(batch_size, 1024), diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index 27c5500..060eca1 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -1,9 +1,9 @@ -from collections import OrderedDict import copy -from glob import glob -import pickle import os +import pickle import threading +from collections import OrderedDict +from glob import glob from typing import Dict import torch diff --git a/torchacc/utils/consolidate_and_reshard_ckpts.py b/torchacc/utils/consolidate_and_reshard_ckpts.py index 7169d15..de74333 100644 --- a/torchacc/utils/consolidate_and_reshard_ckpts.py +++ b/torchacc/utils/consolidate_and_reshard_ckpts.py @@ -1,5 +1,7 @@ from argparse import ArgumentParser -from torchacc.dist.state_dict_utils import consolidate_and_reshard_model_dict, consolidate_and_reshard_optim_dict + +from torchacc.dist.state_dict_utils import (consolidate_and_reshard_model_dict, + consolidate_and_reshard_optim_dict) MODEL_NAME_PATTERN = "rank*-of-*-model.pth" OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" From 920adbc587e065f34038c996969c53c6bff6398a Mon Sep 17 00:00:00 2001 From: shw Date: Wed, 16 Oct 2024 17:00:14 +0800 Subject: [PATCH 13/24] refine console scripts --- setup.py | 7 + .../consolidate_and_reshard_ckpts.py | 8 +- torchacc/dist/state_dict_utils.py | 127 +++++++++++------- .../utils/consolidate_and_reshard_ckpts.py | 80 +++++++---- 4 files changed, 146 insertions(+), 76 deletions(-) diff --git a/setup.py b/setup.py index aafecf4..c480f32 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,13 @@ def get_and_set_version(): packages=['torchacc'] + ['torchacc.' + \ pkg for pkg in find_packages('torchacc')], + # add console_scripts + entry_points={ + 'console_scripts': [ + 'consolidate_and_reshape_ckpts = torchacc.utils.consolidate_and_reshard_ckpts:main', + ], + }, + # Add _ prefix to the names of temporary build dirs options={'build': {'build_base': '_build'}, }, zip_safe=True, diff --git a/tests/standalone/consolidate_and_reshard_ckpts.py b/tests/standalone/consolidate_and_reshard_ckpts.py index cd39562..2e8a8d1 100644 --- a/tests/standalone/consolidate_and_reshard_ckpts.py +++ b/tests/standalone/consolidate_and_reshard_ckpts.py @@ -202,17 +202,17 @@ def main(args): ckpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-model.pth", reshard_num=reshard_num, - save_model=True, + save_model=False, ) - print(f"model consolidate and reshard to path:{ckpt_dir}") + print(f"model consolidate and reshard done.") optim_reshard_dicts, _ = consolidate_and_reshard_optim_dict( ckpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-optim.pth", reshard_num=reshard_num, - save_optimizer=True, + save_optimizer=False, ) - print(f"optimizer consolidate and reshard to path:{ckpt_dir}") + print(f"optimizer consolidate and reshard done.") # compare shard model and optimizer if reshard_num == fsdp_num: diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index 060eca1..f120a36 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -436,54 +436,57 @@ def reshard_optim_dict(consolidate_optim_dict, shard_optim, layer_name_lists, return shard_optim_state_dict_list, shard_metadata_list -def consolidate_and_reshard_model_dict(ckpt_dir, - ckpt_name="", - reshard_num=1, - save_path="", - save_model=True): +def consolidate_and_reshard_fsdp_model_dict(ckpt_dir, + ckpt_name, + save_dir="", + save_name="", + reshard_num=1, + save_model=True): """ Consolidate the sharded FSDP checkpoints into a single model checkpoint. Then reshard the FSDP model according to the reshard_num. Args: ckpt_dir (str): - The dir to FSDP checkpoint files from all ranks - ckpt_name (str, Optional): - The name_pattern to FSDP checkpoint files from all ranks. Files matching the - pattern ``ckpt_dir + ckpt_name`` will be loaded. The each + The dir to all FSDP shard model checkpoint files. + ckpt_name (str): + The name_pattern to all FSDP shard model checkpoint files. Files matching the + pattern ``ckpt_dir + ckpt_name`` will be loaded. Each checkpoint file is assumed to be a dict with a "model" key containing the FSDP model's ``model.state_dict()`` and a "shard_metadata" key containing the FSDP model's ``model.get_shard_metadata()``. + save_dir (str): + The save dir for consolidate or reshard model checkpoints. + save_name (str, Optional): + The name_pattern for consolidate or reshard model checkpoints. + For reshard checkpoints name pattern: ``rank*-of-*-model.pth`` + The final save_path is save_dir + save_name. reshard_num (int, Optional): - Reshard the fsdp model with reshard_num. If set to 1, we don't need to do + Reshard the fsdp model by reshard_num. If set to 1, we don't need to do resharding. - save_path (str, Optional): - the save path to the consolidated model checkpoint file (if - ``save_model`` is ``True``). The checkpoint file is a dict with a - "model" key containing the consolidated model state dict. save_model (str, Optional): - if ``True``, the model checkpoint will be saved to - ``save_path`` (or ``ckpt_dir + "consolidated_model.pth"`` if - ``save_path`` is empty). + if ``True``, the model checkpoint will be saved to ``save_dir + save_name``. Returns: model_state_dict: the consolidated model state dict or reshard model state dict list. - shard_meta_list: the reshard metadatalist. The consolidated model return None. + shard_meta_list: the reshard metadatalist. For consolidated model, return None. """ - checkpoints = load_checkpoints(ckpt_dir, ckpt_name) full_state_dict = consolidate_sharded_model_checkpoints( ckpt_dir, checkpoints) if reshard_num == 1: if save_model: - actual_save_path = save_path if save_path else os.path.join( - ckpt_dir, "consolidated_optimizer.pth") + if not save_dir or not save_name: + raise ValueError("save_dir and save_name should not be None!") + actual_save_path = os.path.join(save_dir, save_name) + save_checkpoints(full_state_dict, checkpoints[0]['shard_metadata'], actual_save_path, 'model') return full_state_dict, None + # load layer_info file_path = os.path.join(ckpt_dir, "layer_info.pickle") layer_info = [] @@ -491,19 +494,23 @@ def consolidate_and_reshard_model_dict(ckpt_dir, with open(file_path, 'rb') as f: layer_info = pickle.load(f) except FileNotFoundError: - print(f"please consolidate model first!") + raise NotImplementedError("please consolidate model first!") model_state_dict_list, shard_metadata_list = reshard_model_dict( full_state_dict, checkpoints[0], layer_info[0], reshard_num) if save_model: - if save_path == "": - save_path = ckpt_dir - - actual_save_path = [ - os.path.join(save_path, f"rank-{rank}-of-{reshard_num}-model.pth") - for rank in range(reshard_num) - ] + if not save_dir or not save_name: + raise ValueError("save_dir and save_name should not be None!") + + actual_save_path = [] + for idx in range(reshard_num): + save_name_ = re.sub( + r'\*', + lambda m: str(idx) if m.group(0) == '*' else str(reshard_num), + save_name, + count=2) + actual_save_path.append(os.path.join(save_dir, save_name_)) save_checkpoints(model_state_dict_list, shard_metadata_list, actual_save_path, 'model') @@ -511,18 +518,39 @@ def consolidate_and_reshard_model_dict(ckpt_dir, return model_state_dict_list, shard_metadata_list -def consolidate_and_reshard_optim_dict(ckpt_dir, - ckpt_name="", - reshard_num=1, - save_path="", - save_optimizer=True): +def consolidate_and_reshard_fsdp_optim_dict(ckpt_dir, + ckpt_name, + save_dir="", + save_name="", + reshard_num=1, + save_optimizer=True): """ Consolidate the sharded FSDP checkpoints into a single optimizer checkpoint. Then reshard the FSDP optimizer according to the reshard_num. - + Args: + ckpt_dir (str): + The dir to all FSDP shard optimizer checkpoint files. + ckpt_name (str): + The name_pattern to all FSDP shard optimizer checkpoint files. Files matching the + pattern ``ckpt_dir + ckpt_name`` will be loaded. Each + checkpoint file is assumed to be a dict with a "optimizer" key + containing the FSDP optimizer's ``optimizer.state_dict()`` and a + "shard_metadata" key containing the FSDP model's + ``model.get_shard_metadata()``. + save_dir (str, Optional): + The save dir for consolidate or reshard optimizer checkpoints. + save_name (str, Optional): + The name_pattern for consolidate or reshard optimizer checkpoints. + For reshard checkpoints name pattern:: `rank*-of-*-optimizer.pth` + The final save_path is save_dir + save_name. + reshard_num (int, Optional): + Reshard the fsdp optimizer by reshard_num. If set to 1, we don't need to do + resharding. + save_model (str, Optional): + if ``True``, the model checkpoint will be saved to ``save_dir + save_name``. Returns: optim_state_dict: the consolidated optim state dict or reshard optim state dict list - shard_meta_list: the reshard metadatalist. The consolidated optim return None. + shard_meta_list: the reshard metadatalist. For consolidated optim, return None. """ # load checkpoints checkpoints = load_checkpoints(ckpt_dir, ckpt_name) @@ -539,12 +567,12 @@ def consolidate_and_reshard_optim_dict(ckpt_dir, full_optim_state_dict = consolidate_sharded_optimizer_checkpoints( ckpt_dir, checkpoints, layer_info) - actual_save_path = None - if reshard_num == 1: if save_optimizer: - actual_save_path = save_path if save_path else os.path.join( - ckpt_dir, "consolidated_optimizer.pth") + if not save_dir or not save_name: + raise ValueError("save_dir and save_name should not be None!") + actual_save_path = os.path.join(save_dir, save_name) + save_checkpoints(full_optim_state_dict, checkpoints[0]['shard_metadata'], actual_save_path, 'optimizer') @@ -555,14 +583,17 @@ def consolidate_and_reshard_optim_dict(ckpt_dir, full_optim_state_dict, checkpoints[0], layer_info[0], reshard_num) if save_optimizer: - if save_path == "": - save_path = ckpt_dir - - actual_save_path = [ - os.path.join(save_path, - f"rank-{rank}-of-{reshard_num}-optimizer.pth") - for rank in range(reshard_num) - ] + if not save_dir or not save_name: + raise ValueError("save_dir and save_name should not be None!") + + actual_save_path = [] + for idx in range(reshard_num): + save_name_ = re.sub( + r'\*', + lambda m: str(idx) if m.group(0) == '*' else str(reshard_num), + save_name, + count=2) + actual_save_path.append(os.path.join(save_dir, save_name_)) save_checkpoints(optim_state_dict_list, shard_metadata_list, actual_save_path, 'optimizer') diff --git a/torchacc/utils/consolidate_and_reshard_ckpts.py b/torchacc/utils/consolidate_and_reshard_ckpts.py index de74333..8fcf9ae 100644 --- a/torchacc/utils/consolidate_and_reshard_ckpts.py +++ b/torchacc/utils/consolidate_and_reshard_ckpts.py @@ -1,10 +1,11 @@ from argparse import ArgumentParser -from torchacc.dist.state_dict_utils import (consolidate_and_reshard_model_dict, - consolidate_and_reshard_optim_dict) +from torchacc.dist.state_dict_utils import ( + consolidate_and_reshard_fsdp_model_dict, + consolidate_and_reshard_fsdp_optim_dict) -MODEL_NAME_PATTERN = "rank*-of-*-model.pth" -OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" +DEFAULT_MODEL_NAME_PATTERN = "rank*-of-*-model.pth" +DEFAULT_OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" def main(): @@ -14,24 +15,25 @@ def main(): type=str, required=True, help=( - f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " + f"The name dir of the XLA FSDP checkpoint files to be consolidated and reshard. " f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." - f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," - f"the default pattern is {OPTIM_NAME_PATTERN}"), + f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) parser.add_argument( "--ckpt_name", type=str, default="", help=( - f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " + f"The name pattern of the XLA FSDP checkpoint files to be consolidated and reshard. " f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." - f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," - f"the default pattern is {OPTIM_NAME_PATTERN}"), + f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) parser.add_argument( "--ckpt_type", type=str, + choices=["model", "optimizer"], default="model", help=( "The type of checkpoint to consolidate, you can choose model or optimizer. Please consolidate model fisrt, and then consolidate optimizer." @@ -43,32 +45,62 @@ def main(): default=1, help=( "We now support the reshard of XLA FSDP checkpoint according to the reshard_num." - )) + ), + ) parser.add_argument( - "--save_path", + "--save_dir", type=str, default="", help=( - f"The save path of the output state dict " - f"(default consolidate_path is ``ckpt_dir + model/optimizer_consolidated.pth``)" - f"If you need to reshard the checkpoint, please only pass the save_dir(default is ckpt_dir)," - f"we will save the file in path ``save_path + {MODEL_NAME_PATTERN}/{OPTIM_NAME_PATTERN}``" + f"The save dir of the output checkpoint files, the default value will be set to arg: ckpt_dir." + f"Files will be saved in path: ``save_dir + save_name``." + f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``." + f"For reshard checkpoints, the default path is: ``save_dir + {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." ), ) + parser.add_argument( + "--save_name", + type=str, + default="", + help=( + f"The save name pattern of the output checkpoint files, the default value is {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}." + f"Files will be saved in path: ``save_dir + save_name`.`" + f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``" + f"For reshard checkpoints, the default path is: ``save_dir + {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." + f"For reshard checkpoints, please use the same name patthern as {DEFULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}." + ), + ) + args = parser.parse_args() - assert args.ckpt_type in ['model', 'optimizer' - ], ('the ckpt_type should be model or optimizer') if args.ckpt_type == "model": if args.ckpt_name == "": - args.ckpt_name = MODEL_NAME_PATTERN - consolidate_and_reshard_model_dict(args.ckpt_dir, args.ckpt_name, - args.reshard_num, args.save_path) + args.ckpt_name = DEFULT_MODEL_NAME_PATTERN + if args.save_path == "": + args.save_path = args.ckpt_dir + if args.save_name == "": + if args.reshard_name == 1: + args.save_name = "model_consolidated.pth" + else: + args.save_name = DEFAULT_MODEL_NAME_PATTERN + + consolidate_and_reshard_fsdp_model_dict(args.ckpt_dir, args.ckpt_name, + args.save_dir, args.save_name, + args.reshard_num) else: if args.ckpt_name == "": - args.ckpt_name = OPTIM_NAME_PATTERN - consolidate_and_reshard_optim_dict(args.ckpt_dir, args.ckpt_name, - args.reshard_num, args.save_path) + args.ckpt_name = DEFULT_MODEL_NAME_PATTERN + if args.save_path == "": + args.save_path = args.ckpt_dir + if args.save_name == "": + if args.reshard_name == 1: + args.save_name = "optimizer_consolidated.pth" + else: + args.save_name = DEFAULT_OPTIM_NAME_PATTERN + + consolidate_and_reshard_fsdp_optim_dict(args.ckpt_dir, args.ckpt_name, + args.save_dir, args.save_name, + args.reshard_num) if __name__ == "__main__": From 175800853d143c765ffdf2bb161c268b8e6a4fdf Mon Sep 17 00:00:00 2001 From: shw Date: Wed, 16 Oct 2024 19:37:57 +0800 Subject: [PATCH 14/24] add name pattern matching --- .../consolidate_and_reshard_ckpts.py | 54 +++++++++++-------- torchacc/dist/state_dict_utils.py | 21 ++++---- .../utils/consolidate_and_reshard_ckpts.py | 34 ++++++------ 3 files changed, 60 insertions(+), 49 deletions(-) diff --git a/tests/standalone/consolidate_and_reshard_ckpts.py b/tests/standalone/consolidate_and_reshard_ckpts.py index 2e8a8d1..3d09360 100644 --- a/tests/standalone/consolidate_and_reshard_ckpts.py +++ b/tests/standalone/consolidate_and_reshard_ckpts.py @@ -4,9 +4,9 @@ import torch import torch_xla.core.xla_model as xm import torchacc as ta -from torchacc.dist.state_dict_utils import (consolidate_and_reshard_model_dict, - consolidate_and_reshard_optim_dict, - load_checkpoints) +from torchacc.dist.state_dict_utils import ( + consolidate_and_reshard_fsdp_model_dict, + consolidate_and_reshard_fsdp_optim_dict, load_checkpoints) from utils import EchoDataset, set_seed @@ -198,7 +198,7 @@ def main(args): # rank 0 do consolidate and reshard: if ta.dist.local_rank() == 0: # consolidate and reshard model and optimizer - model_reshard_dicts, _ = consolidate_and_reshard_model_dict( + model_reshard_dicts, _ = consolidate_and_reshard_fsdp_model_dict( ckpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-model.pth", reshard_num=reshard_num, @@ -206,7 +206,7 @@ def main(args): ) print(f"model consolidate and reshard done.") - optim_reshard_dicts, _ = consolidate_and_reshard_optim_dict( + optim_reshard_dicts, _ = consolidate_and_reshard_fsdp_optim_dict( ckpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-optim.pth", reshard_num=reshard_num, @@ -217,7 +217,7 @@ def main(args): # compare shard model and optimizer if reshard_num == fsdp_num: model_shard_dicts = load_checkpoints( - kpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-model.pth") + ckpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-model.pth") optim_shard_dicts = load_checkpoints( ckpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-optim.pth") @@ -242,28 +242,28 @@ def main(args): parser.add_argument("--bf16", action="store_true", default=False) parser.add_argument("--backend", type=str, default="lazy") - MODEL_NAME_PATTERN = "rank*-of-*-model.pth" - OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" + DEFAULT_MODEL_NAME_PATTERN = "rank*-of-*-model.pth" + DEFAULT_OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" # ckpt arguments parser.add_argument( "--ckpt_dir", type=str, required=True, help=( - f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." - f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," - f"the default pattern is {OPTIM_NAME_PATTERN}"), + f"The name dir of the XLA FSDP checkpoint files to be consolidated and reshard. " + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load." + f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) parser.add_argument( "--ckpt_name", type=str, default="", help=( - f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." - f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," - f"the default pattern is {OPTIM_NAME_PATTERN}"), + f"The name pattern of the XLA FSDP checkpoint files to be consolidated and reshard. " + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load." + f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) parser.add_argument( "--reshard_num", @@ -273,14 +273,26 @@ def main(args): "We now support the reshard of XLA FSDP checkpoint according to the reshard_num." )) parser.add_argument( - "--save_path", + "--save_dir", type=str, default="", help=( - f"The save path of the output state dict " - f"(default consolidate_path is ``ckpt_dir + model/optimizer_consolidated.pth``)" - f"If you need to reshard the checkpoint, please only pass the save_dir(default is ckpt_dir)," - f"we will save the file in path ``save_path + {MODEL_NAME_PATTERN}/{OPTIM_NAME_PATTERN}``" + f"The save dir of the output checkpoint files, the default value will be set to arg: ckpt_dir." + f"Files will be saved in path: ``save_dir + save_name``." + f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``." + f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." + ), + ) + parser.add_argument( + "--save_name", + type=str, + default="", + help=( + f"The save name pattern of the output checkpoint files, the default value is {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}." + f"Files will be saved in path: ``save_dir + save_name`.`" + f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``" + f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." + f"For reshard checkpoints, please use the same name patthern as {DEFAULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}." ), ) diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index f120a36..7090f09 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -2,6 +2,7 @@ import os import pickle import threading +import re from collections import OrderedDict from glob import glob from typing import Dict @@ -505,12 +506,10 @@ def consolidate_and_reshard_fsdp_model_dict(ckpt_dir, actual_save_path = [] for idx in range(reshard_num): - save_name_ = re.sub( - r'\*', - lambda m: str(idx) if m.group(0) == '*' else str(reshard_num), - save_name, - count=2) - actual_save_path.append(os.path.join(save_dir, save_name_)) + # replace the two '*' + save_name_temp = save_name.replace('*', str(idx), 1) + save_name_temp = save_name_temp.replace('*', str(reshard_num), 1) + actual_save_path.append(os.path.join(save_dir, save_name_temp)) save_checkpoints(model_state_dict_list, shard_metadata_list, actual_save_path, 'model') @@ -588,12 +587,10 @@ def consolidate_and_reshard_fsdp_optim_dict(ckpt_dir, actual_save_path = [] for idx in range(reshard_num): - save_name_ = re.sub( - r'\*', - lambda m: str(idx) if m.group(0) == '*' else str(reshard_num), - save_name, - count=2) - actual_save_path.append(os.path.join(save_dir, save_name_)) + # replace the two '*' + save_name_temp = save_name.replace('*', str(idx), 1) + save_name_temp = save_name_temp.replace('*', str(reshard_num), 1) + actual_save_path.append(os.path.join(save_dir, save_name_temp)) save_checkpoints(optim_state_dict_list, shard_metadata_list, actual_save_path, 'optimizer') diff --git a/torchacc/utils/consolidate_and_reshard_ckpts.py b/torchacc/utils/consolidate_and_reshard_ckpts.py index 8fcf9ae..112a49f 100644 --- a/torchacc/utils/consolidate_and_reshard_ckpts.py +++ b/torchacc/utils/consolidate_and_reshard_ckpts.py @@ -4,8 +4,8 @@ consolidate_and_reshard_fsdp_model_dict, consolidate_and_reshard_fsdp_optim_dict) -DEFAULT_MODEL_NAME_PATTERN = "rank*-of-*-model.pth" -DEFAULT_OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" +DEFAULT_MODEL_NAME_PATTERN = "rank-*-of-*-model.pth" +DEFAULT_OPTIM_NAME_PATTERN = "rank-*-of-*-optimizer.pth" def main(): @@ -16,7 +16,7 @@ def main(): required=True, help=( f"The name dir of the XLA FSDP checkpoint files to be consolidated and reshard. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load." f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) @@ -26,7 +26,7 @@ def main(): default="", help=( f"The name pattern of the XLA FSDP checkpoint files to be consolidated and reshard. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load." f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) @@ -55,7 +55,7 @@ def main(): f"The save dir of the output checkpoint files, the default value will be set to arg: ckpt_dir." f"Files will be saved in path: ``save_dir + save_name``." f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``." - f"For reshard checkpoints, the default path is: ``save_dir + {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." + f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." ), ) parser.add_argument( @@ -63,11 +63,11 @@ def main(): type=str, default="", help=( - f"The save name pattern of the output checkpoint files, the default value is {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}." + f"The save name pattern of the output checkpoint files, the default value is {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}." f"Files will be saved in path: ``save_dir + save_name`.`" f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``" - f"For reshard checkpoints, the default path is: ``save_dir + {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." - f"For reshard checkpoints, please use the same name patthern as {DEFULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}." + f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." + f"For reshard checkpoints, please use the same name patthern as {DEFAULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}." ), ) @@ -76,10 +76,10 @@ def main(): if args.ckpt_type == "model": if args.ckpt_name == "": args.ckpt_name = DEFULT_MODEL_NAME_PATTERN - if args.save_path == "": - args.save_path = args.ckpt_dir + if args.save_dir == "": + args.save_dir = args.ckpt_dir if args.save_name == "": - if args.reshard_name == 1: + if args.reshard_num == 1: args.save_name = "model_consolidated.pth" else: args.save_name = DEFAULT_MODEL_NAME_PATTERN @@ -89,15 +89,17 @@ def main(): args.reshard_num) else: if args.ckpt_name == "": - args.ckpt_name = DEFULT_MODEL_NAME_PATTERN - if args.save_path == "": - args.save_path = args.ckpt_dir + args.ckpt_name = DEFULT_OPTIM_NAME_PATTERN + if args.save_dir == "": + args.save_dir = args.ckpt_dir if args.save_name == "": - if args.reshard_name == 1: + if args.reshard_num == 1: args.save_name = "optimizer_consolidated.pth" else: args.save_name = DEFAULT_OPTIM_NAME_PATTERN - + print(args.ckpt_dir) + print(args.save_dir) + print(args.save_name) consolidate_and_reshard_fsdp_optim_dict(args.ckpt_dir, args.ckpt_name, args.save_dir, args.save_name, args.reshard_num) From 2edd0466601e3ba8108a474cd47cb5f849a0f066 Mon Sep 17 00:00:00 2001 From: shw Date: Thu, 17 Oct 2024 11:38:49 +0800 Subject: [PATCH 15/24] add model and optimizer args --- .../consolidate_and_reshard_ckpts.py | 53 +++++++-- torchacc/dist/state_dict_utils.py | 58 ++++++--- .../utils/consolidate_and_reshard_ckpts.py | 110 +++++++++++++----- 3 files changed, 163 insertions(+), 58 deletions(-) diff --git a/tests/standalone/consolidate_and_reshard_ckpts.py b/tests/standalone/consolidate_and_reshard_ckpts.py index 3d09360..2ab84bb 100644 --- a/tests/standalone/consolidate_and_reshard_ckpts.py +++ b/tests/standalone/consolidate_and_reshard_ckpts.py @@ -200,7 +200,7 @@ def main(args): # consolidate and reshard model and optimizer model_reshard_dicts, _ = consolidate_and_reshard_fsdp_model_dict( ckpt_dir=ckpt_dir, - ckpt_name=f"rank*-of-*-model.pth", + model_ckpt_name_pattern=f"rank*-of-*-model.pth", reshard_num=reshard_num, save_model=False, ) @@ -208,7 +208,7 @@ def main(args): optim_reshard_dicts, _ = consolidate_and_reshard_fsdp_optim_dict( ckpt_dir=ckpt_dir, - ckpt_name=f"rank*-of-*-optim.pth", + optimizer_ckpt_name_pattern=f"rank*-of-*-optim.pth", reshard_num=reshard_num, save_optimizer=False, ) @@ -244,6 +244,7 @@ def main(args): DEFAULT_MODEL_NAME_PATTERN = "rank*-of-*-model.pth" DEFAULT_OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" + # ckpt arguments parser.add_argument( "--ckpt_dir", @@ -251,45 +252,77 @@ def main(args): required=True, help=( f"The name dir of the XLA FSDP checkpoint files to be consolidated and reshard. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load." + f"Files matching the pattern ``ckpt_dir + ckpt_name_pattern`` will be load." f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) parser.add_argument( - "--ckpt_name", + "--model_ckpt_name_pattern", type=str, - default="", + default=DEFAULT_MODEL_NAME_PATTERN, help=( f"The name pattern of the XLA FSDP checkpoint files to be consolidated and reshard. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load." + f"Files matching the pattern ``ckpt_dir + ckpt_name_pattern`` will be load." f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) + parser.add_argument( + "--optimizer_ckpt_name_pattern", + type=str, + default=DEFAULT_OPTIM_NAME_PATTERN, + help=( + f"The name pattern of the XLA FSDP checkpoint files to be consolidated and reshard. " + f"Files matching the pattern ``ckpt_dir + ckpt_name_pattern`` will be load." + f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), + ) + parser.add_argument( + "--ckpt_type", + type=str, + choices=["all", "model", "optimizer"], + default="all", + help=( + f"The type of checkpoint to consolidate, you can choose to consolidate model and optimizer all or seperately." + f"Please consolidate model first and then optimizer."), + ) parser.add_argument( "--reshard_num", type=int, default=1, help=( "We now support the reshard of XLA FSDP checkpoint according to the reshard_num." - )) + ), + ) parser.add_argument( "--save_dir", type=str, default="", help=( f"The save dir of the output checkpoint files, the default value will be set to arg: ckpt_dir." - f"Files will be saved in path: ``save_dir + save_name``." + f"Files will be saved in path: ``save_dir + save_name_pattern``." f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``." f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." ), ) parser.add_argument( - "--save_name", + "--model_save_name_pattern", + type=str, + default="", + help=( + f"The save name pattern of the output checkpoint files, the default value is {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}." + f"Files will be saved in path: ``save_dir + save_name_pattern`.`" + f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``" + f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." + f"For reshard checkpoints, please use the same name patthern as {DEFAULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}." + ), + ) + parser.add_argument( + "--optimizer_save_name_pattern", type=str, default="", help=( f"The save name pattern of the output checkpoint files, the default value is {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}." - f"Files will be saved in path: ``save_dir + save_name`.`" + f"Files will be saved in path: ``save_dir + save_name_pattern`.`" f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``" f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." f"For reshard checkpoints, please use the same name patthern as {DEFAULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}." diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index 7090f09..3b3b821 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -438,9 +438,9 @@ def reshard_optim_dict(consolidate_optim_dict, shard_optim, layer_name_lists, def consolidate_and_reshard_fsdp_model_dict(ckpt_dir, - ckpt_name, + model_ckpt_name_pattern, save_dir="", - save_name="", + model_save_name_pattern="", reshard_num=1, save_model=True): """ @@ -450,7 +450,7 @@ def consolidate_and_reshard_fsdp_model_dict(ckpt_dir, Args: ckpt_dir (str): The dir to all FSDP shard model checkpoint files. - ckpt_name (str): + model_ckpt_name_pattern (str): The name_pattern to all FSDP shard model checkpoint files. Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded. Each checkpoint file is assumed to be a dict with a "model" key @@ -459,7 +459,7 @@ def consolidate_and_reshard_fsdp_model_dict(ckpt_dir, ``model.get_shard_metadata()``. save_dir (str): The save dir for consolidate or reshard model checkpoints. - save_name (str, Optional): + model_save_name_pattern (str, Optional): The name_pattern for consolidate or reshard model checkpoints. For reshard checkpoints name pattern: ``rank*-of-*-model.pth`` The final save_path is save_dir + save_name. @@ -473,13 +473,13 @@ def consolidate_and_reshard_fsdp_model_dict(ckpt_dir, model_state_dict: the consolidated model state dict or reshard model state dict list. shard_meta_list: the reshard metadatalist. For consolidated model, return None. """ - checkpoints = load_checkpoints(ckpt_dir, ckpt_name) + checkpoints = load_checkpoints(ckpt_dir, model_ckpt_name_pattern) full_state_dict = consolidate_sharded_model_checkpoints( ckpt_dir, checkpoints) if reshard_num == 1: if save_model: - if not save_dir or not save_name: + if not save_dir or not model_save_name_pattern: raise ValueError("save_dir and save_name should not be None!") actual_save_path = os.path.join(save_dir, save_name) @@ -501,13 +501,14 @@ def consolidate_and_reshard_fsdp_model_dict(ckpt_dir, full_state_dict, checkpoints[0], layer_info[0], reshard_num) if save_model: - if not save_dir or not save_name: - raise ValueError("save_dir and save_name should not be None!") + if not save_dir or not model_save_name_pattern: + raise ValueError( + "save_dir and save_name_pattern should not be None!") actual_save_path = [] for idx in range(reshard_num): # replace the two '*' - save_name_temp = save_name.replace('*', str(idx), 1) + save_name_temp = model_save_name_pattern.replace('*', str(idx), 1) save_name_temp = save_name_temp.replace('*', str(reshard_num), 1) actual_save_path.append(os.path.join(save_dir, save_name_temp)) @@ -518,9 +519,9 @@ def consolidate_and_reshard_fsdp_model_dict(ckpt_dir, def consolidate_and_reshard_fsdp_optim_dict(ckpt_dir, - ckpt_name, + optimizer_ckpt_name_pattern, save_dir="", - save_name="", + optimizer_save_name_pattern="", reshard_num=1, save_optimizer=True): """ @@ -529,7 +530,7 @@ def consolidate_and_reshard_fsdp_optim_dict(ckpt_dir, Args: ckpt_dir (str): The dir to all FSDP shard optimizer checkpoint files. - ckpt_name (str): + optimizer_ckpt_name_pattern (str): The name_pattern to all FSDP shard optimizer checkpoint files. Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded. Each checkpoint file is assumed to be a dict with a "optimizer" key @@ -538,7 +539,7 @@ def consolidate_and_reshard_fsdp_optim_dict(ckpt_dir, ``model.get_shard_metadata()``. save_dir (str, Optional): The save dir for consolidate or reshard optimizer checkpoints. - save_name (str, Optional): + optimizer_save_name_pattern (str, Optional): The name_pattern for consolidate or reshard optimizer checkpoints. For reshard checkpoints name pattern:: `rank*-of-*-optimizer.pth` The final save_path is save_dir + save_name. @@ -552,7 +553,7 @@ def consolidate_and_reshard_fsdp_optim_dict(ckpt_dir, shard_meta_list: the reshard metadatalist. For consolidated optim, return None. """ # load checkpoints - checkpoints = load_checkpoints(ckpt_dir, ckpt_name) + checkpoints = load_checkpoints(ckpt_dir, optimizer_ckpt_name_pattern) # load layer_info file_path = os.path.join(ckpt_dir, "layer_info.pickle") @@ -568,7 +569,7 @@ def consolidate_and_reshard_fsdp_optim_dict(ckpt_dir, if reshard_num == 1: if save_optimizer: - if not save_dir or not save_name: + if not save_dir or not optimizer_save_name_pattern: raise ValueError("save_dir and save_name should not be None!") actual_save_path = os.path.join(save_dir, save_name) @@ -582,13 +583,14 @@ def consolidate_and_reshard_fsdp_optim_dict(ckpt_dir, full_optim_state_dict, checkpoints[0], layer_info[0], reshard_num) if save_optimizer: - if not save_dir or not save_name: + if not save_dir or not optimizer_save_name_pattern: raise ValueError("save_dir and save_name should not be None!") actual_save_path = [] for idx in range(reshard_num): # replace the two '*' - save_name_temp = save_name.replace('*', str(idx), 1) + save_name_temp = optimizer_save_name_pattern.replace( + '*', str(idx), 1) save_name_temp = save_name_temp.replace('*', str(reshard_num), 1) actual_save_path.append(os.path.join(save_dir, save_name_temp)) @@ -596,3 +598,25 @@ def consolidate_and_reshard_fsdp_optim_dict(ckpt_dir, actual_save_path, 'optimizer') return optim_state_dict_list, shard_metadata_list + + +def consolidate_and_reshard_fsdp_checkpoint(ckpt_dir, + model_ckpt_name_pattern, + optimizer_ckpt_name_pattern, + save_dir="", + model_save_name_pattern="", + optimizer_save_name_pattern="", + reshard_num=1, + save_checkpoint=True): + """ + Consolidate the sharded FSDP model and optimizer checkpoints into a single checkpoint. Then + reshard the FSDP checkpoint according to the reshard_num. + """ + consolidate_and_reshard_fsdp_model_dict(ckpt_dir, model_ckpt_name_pattern, + save_dir, model_save_name_pattern, + reshard_num, save_checkpoint) + consolidate_and_reshard_fsdp_optim_dict(ckpt_dir, + optimizer_ckpt_name_pattern, + save_dir, + optimizer_save_name_pattern, + reshard_num, save_checkpoint) diff --git a/torchacc/utils/consolidate_and_reshard_ckpts.py b/torchacc/utils/consolidate_and_reshard_ckpts.py index 112a49f..6b11f21 100644 --- a/torchacc/utils/consolidate_and_reshard_ckpts.py +++ b/torchacc/utils/consolidate_and_reshard_ckpts.py @@ -1,6 +1,7 @@ from argparse import ArgumentParser from torchacc.dist.state_dict_utils import ( + consolidate_and_reshard_fsdp_checkpoint, consolidate_and_reshard_fsdp_model_dict, consolidate_and_reshard_fsdp_optim_dict) @@ -16,28 +17,38 @@ def main(): required=True, help=( f"The name dir of the XLA FSDP checkpoint files to be consolidated and reshard. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load." + f"Files matching the pattern ``ckpt_dir + ckpt_name_pattern`` will be load." f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) parser.add_argument( - "--ckpt_name", + "--model_ckpt_name_pattern", type=str, - default="", + default=DEFAULT_MODEL_NAME_PATTERN, help=( f"The name pattern of the XLA FSDP checkpoint files to be consolidated and reshard. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load." + f"Files matching the pattern ``ckpt_dir + ckpt_name_pattern`` will be load." + f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), + ) + parser.add_argument( + "--optimizer_ckpt_name_pattern", + type=str, + default=DEFAULT_OPTIM_NAME_PATTERN, + help=( + f"The name pattern of the XLA FSDP checkpoint files to be consolidated and reshard. " + f"Files matching the pattern ``ckpt_dir + ckpt_name_pattern`` will be load." f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) parser.add_argument( "--ckpt_type", type=str, - choices=["model", "optimizer"], - default="model", + choices=["all", "model", "optimizer"], + default="all", help=( - "The type of checkpoint to consolidate, you can choose model or optimizer. Please consolidate model fisrt, and then consolidate optimizer." - ), + f"The type of checkpoint to consolidate, you can choose to consolidate model and optimizer all or seperately." + f"Please consolidate model first and then optimizer."), ) parser.add_argument( "--reshard_num", @@ -53,18 +64,30 @@ def main(): default="", help=( f"The save dir of the output checkpoint files, the default value will be set to arg: ckpt_dir." - f"Files will be saved in path: ``save_dir + save_name``." + f"Files will be saved in path: ``save_dir + save_name_pattern``." f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``." f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." ), ) parser.add_argument( - "--save_name", + "--model_save_name_pattern", + type=str, + default="", + help=( + f"The save name pattern of the output checkpoint files, the default value is {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}." + f"Files will be saved in path: ``save_dir + save_name_pattern`.`" + f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``" + f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." + f"For reshard checkpoints, please use the same name patthern as {DEFAULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}." + ), + ) + parser.add_argument( + "--optimizer_save_name_pattern", type=str, default="", help=( f"The save name pattern of the output checkpoint files, the default value is {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}." - f"Files will be saved in path: ``save_dir + save_name`.`" + f"Files will be saved in path: ``save_dir + save_name_pattern`.`" f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``" f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." f"For reshard checkpoints, please use the same name patthern as {DEFAULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}." @@ -72,37 +95,62 @@ def main(): ) args = parser.parse_args() + if args.ckpt_type == "all": + if args.save_dir == "": + args.save_dir = args.ckpt_dir + if args.model_save_name_pattern == "": + if args.reshard_num == 1: + args.model_save_name_pattern = "model_consolidated.pth" + else: + args.model_save_name_pattern = DEFAULT_MODEL_NAME_PATTERN + if args.optimizer_save_name_pattern == "": + if args.reshard_num == 1: + args.optimizer_save_name_pattern = "optimizer_consolidated.pth" + else: + args.optimizer_save_name_pattern = DEFAULT_OPTIM_NAME_PATTERN - if args.ckpt_type == "model": - if args.ckpt_name == "": - args.ckpt_name = DEFULT_MODEL_NAME_PATTERN + consolidate_and_reshard_fsdp_checkpoint( + ckpt_dir=args.ckpt_dir, + model_ckpt_name_pattern=args.model_ckpt_name_pattern, + optimizer_ckpt_name_pattern=args.optimizer_ckpt_name_pattern, + save_dir=args.save_dir, + model_save_name_pattern=args.model_save_name_pattern, + optimizer_save_name_pattern=args.optimizer_save_name_pattern, + reshard_num=args.reshard_num) + elif args.ckpt_type == "model": if args.save_dir == "": args.save_dir = args.ckpt_dir - if args.save_name == "": + if args.model_save_name_pattern == "": if args.reshard_num == 1: - args.save_name = "model_consolidated.pth" + args.model_save_name_pattern = "model_consolidated.pth" else: - args.save_name = DEFAULT_MODEL_NAME_PATTERN + args.model_save_name_pattern = DEFAULT_MODEL_NAME_PATTERN - consolidate_and_reshard_fsdp_model_dict(args.ckpt_dir, args.ckpt_name, - args.save_dir, args.save_name, - args.reshard_num) + consolidate_and_reshard_fsdp_model_dict( + ckpt_dir=args.ckpt_dir, + model_ckpt_name_pattern=args.model_ckpt_name_pattern, + optimizer_ckpt_name_pattern=args.optimizer_ckpt_name_pattern, + save_dir=args.save_dir, + model_save_name_pattern=args.model_save_name_pattern, + optimizer_save_name_pattern=args.optimizer_save_name_pattern, + reshard_num=args.reshard_num) else: - if args.ckpt_name == "": - args.ckpt_name = DEFULT_OPTIM_NAME_PATTERN if args.save_dir == "": args.save_dir = args.ckpt_dir - if args.save_name == "": + if args.optimizer_save_name_pattern == "": if args.reshard_num == 1: - args.save_name = "optimizer_consolidated.pth" + args.optimizer_save_name_pattern = "optimizer_consolidated.pth" else: - args.save_name = DEFAULT_OPTIM_NAME_PATTERN - print(args.ckpt_dir) - print(args.save_dir) - print(args.save_name) - consolidate_and_reshard_fsdp_optim_dict(args.ckpt_dir, args.ckpt_name, - args.save_dir, args.save_name, - args.reshard_num) + args.optimizer_save_name_pattern = DEFAULT_OPTIM_NAME_PATTERN + + consolidate_and_reshard_fsdp_optim_dict( + ckpt_dir=args.ckpt_dir, + model_ckpt_name_pattern=args.model_ckpt_name_pattern, + optimizer_ckpt_name_pattern=args.optimizer_ckpt_name_pattern, + save_dir=args.save_dir, + model_save_name_pattern=args.model_save_name_pattern, + optimizer_save_name_pattern=args.optimizer_save_name_pattern, + reshard_num=args.reshard_num) if __name__ == "__main__": From 4f7922eece15fc1c92e2c22739f89be1a74f6602 Mon Sep 17 00:00:00 2001 From: shw Date: Fri, 18 Oct 2024 13:05:07 +0800 Subject: [PATCH 16/24] add unpad --- docs/source/dist/fsdp.md | 60 +++++++++++-------- setup.py | 2 +- .../consolidate_and_reshard_ckpts.py | 4 +- torchacc/dist/state_dict_utils.py | 36 ++++++++--- 4 files changed, 68 insertions(+), 34 deletions(-) diff --git a/docs/source/dist/fsdp.md b/docs/source/dist/fsdp.md index c7826b4..8334835 100644 --- a/docs/source/dist/fsdp.md +++ b/docs/source/dist/fsdp.md @@ -117,7 +117,7 @@ The main changes: The shell command for running FSDP tasks is the same as data parallelism: ```bash -$ torchrun --nproc_per_node=4 resnet_acc.py +$ torchrun --nproc_per_node=4 gpt2_acc.py ``` @@ -127,40 +127,52 @@ $ torchrun --nproc_per_node=4 resnet_acc.py ### Save Checkpoint -Save the model parameters for each FSDP shard, optimizer states, and LR scheduler. Note that you need to save `shard_metadata` to restore the correct shard information. +Save the model parameters and optimizer states for each FSDP shard and LR scheduler. Note that you need to save ``shard_metadata`` to restore the correct shard information. ```python +import torch_xla.core.xla_model as xm +shard_meta_data = model.model.model.get_shard_metadata() + # 1) Save model shards -torchacc.dist.rendezvous("saving_model") -torchacc.dist.mark_step() -ckpt = { +xm.rendezvous("saving_model") +model_ckpt = { 'model': model.state_dict(), - 'shard_metadata': model.get_shard_metadata(), + 'shard_metadata': shard_meta_data, } -torchacc.save(ckpt, CKPT_DIR, master_only=False) -# 2) Save optimizer states and LR scheduler -torchacc.dist.rendezvous("saving_optimizer_states") -torchacc.save(optimizer.state_dict(), OPTIMIZER_DIR) +torchacc.save(model_ckpt, CKPT_DIR + MODEL_NAME_PATTERN, master_only=False) + +# 2) Save optimizer shards +xm.rendezvous("saving_optimizer_states") +optim_ckpt = { + 'optimizer': optimizer.state_dict(), + 'shard_metadata': shard_meta_data, +} +torchacc.save(optim_ckpt, CKPT_DIR + OPTIM_NAME_PATTERN, master_only=False) + +# 3) Save lr_scheduler torchacc.save(lr_scheduler.state_dict(), LR_SCHEDULER_DIR) ``` -### Load from Checkpoint +### offline consolidation +We now support offline consolidate and reshard fsdp model and optimizer ckpts. You can run ``consolidate_and_reshard_fsdp_ckpts --help`` to refer to the instruction. +```shell +# consolidate model and optimizer +consolidate_and_reshard_fsdp_ckpts --ckpt_dir CKPT_DIR --model_ckpt_name_pattern MODEL_NAME_PATTERN --optimizer_ckpt_name_pattern OPTIM_NAME_PATTERN +# you can use --reshard_num to reshard the fsdp checkpoints +``` +### Load from Checkpoint ```python -# 1) Reorganize shards -if torchacc.dist.is_master_ordinal(local=False): - torchacc.dist.fsdp.consolidate_sharded_model_checkpoints( - CKPT_DIR, ckpt_suffix) -torchacc.dist.rendezvous("ckpt_consolidation") - -# 2) Load model -ckpt_consolidated = torch.load("consolidated.pth") -model.load_state_dict(ckpt_consolidated['model']) - -# 3) Load optimizer states and LR scheduler -optimizer_state = torch.load(OPTIMIZER_DIR) +# 1) Load model +model_consolidated = torch.load("model_consolidated.pth") +model.load_state_dict(model_consolidated) + +# 2) Load optimizer +optimizer_consolidated = torch.load("optimizer_consolidated.pth") +optimizer.load_state_dict(optimizer_consolidated) + +# 3) Load LR scheduler lr_scheduler_state = torch.load(LR_SCHEDULER_DIR) -optimizer.load_state_dict(optimizer_state) lr_scheduler.load_state_dict(lr_scheduler_state) ``` diff --git a/setup.py b/setup.py index c480f32..19c25fa 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ def get_and_set_version(): # add console_scripts entry_points={ 'console_scripts': [ - 'consolidate_and_reshape_ckpts = torchacc.utils.consolidate_and_reshard_ckpts:main', + 'consolidate_and_reshard_fsdp_ckpts = torchacc.utils.consolidate_and_reshard_ckpts:main', ], }, diff --git a/tests/standalone/consolidate_and_reshard_ckpts.py b/tests/standalone/consolidate_and_reshard_ckpts.py index 2ab84bb..58ae3a2 100644 --- a/tests/standalone/consolidate_and_reshard_ckpts.py +++ b/tests/standalone/consolidate_and_reshard_ckpts.py @@ -174,6 +174,7 @@ def main(args): if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) + xm.rendezvous("saving_model") model_ckpt = { "model": model.state_dict(), "shard_metadata": model.model.model.get_shard_metadata( @@ -183,8 +184,8 @@ def main(args): ckpt_dir, f"rank{ta.dist.local_rank()}-of-{ta.dist.world_size()}-model.pth") ta.save(model_ckpt, model_ckpt_path, master_only=False) - xm.rendezvous("saving_model") + xm.rendezvous("saving_optim") optim_ckpt = { "optimizer": optimizer.state_dict(), "shard_metadata": model.model.model.get_shard_metadata(), @@ -193,7 +194,6 @@ def main(args): ckpt_dir, f"rank{ta.dist.local_rank()}-of-{ta.dist.world_size()}-optim.pth") ta.save(optim_ckpt, optim_ckpt_path, master_only=False) - xm.rendezvous("saving_optim") # rank 0 do consolidate and reshard: if ta.dist.local_rank() == 0: diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index 3b3b821..ed4b2e7 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -30,6 +30,18 @@ def unflatten_params(params, param_names, param_shapes, param_numels): return full_params +def unpad(params, layer_numel, world_size): + if params.dim() == 0: + return params + numel = 0 + for layer_numel in layer_numel: + numel += layer_numel + if numel % world_size != 0: + pad_size = world_size - numel % world_size + params = params[:-pad_size] + return params + + def get_layer_full_info(shard_metadata, model_state_dict): """ Get full name, shape and numel info of unflatten and unshard model's state_dict according @@ -246,6 +258,9 @@ def consolidate_sharded_model_checkpoints(ckpt_dir, checkpoints): state_dict[state_name] = None consolidate_params = torch.cat(p_shard_list, dim=0) + consolidate_params = unpad(consolidate_params, layer_numel, + shard_metadata['world_size'] * + 128) # world_size * _shard_size_multiple orig_params = unflatten_params(consolidate_params, layer_name, layer_size, layer_numel) @@ -297,7 +312,9 @@ def consolidate_sharded_optimizer_checkpoints(ckpt_dir, checkpoints, optim_state_dict['state'][idx][state_name] = None consolidate_params = torch.cat(p_shard_list, dim=0) - + consolidate_params = unpad(consolidate_params, layer_numel, + shard_metadata['world_size'] * + 128) # world_size * _shard_size_multiple orig_params = unflatten_params(consolidate_params, layer_name, layer_size, layer_numel) @@ -312,8 +329,9 @@ def _get_shard(tensor, shard_num): """ Return the shard tensor list of a full flatten tensor. """ - if tensor.numel() % shard_num != 0: - pad_size = shard_num - tensor.numel() % shard_num + if tensor.numel() % (shard_num * + 128) != 0: # world_size * _shard_size_multiple + pad_size = (shard_num * 128) - tensor.numel() % (shard_num * 128) tensor = F.pad(tensor, [0, pad_size]) local_size = tensor.size(0) // shard_num @@ -480,8 +498,9 @@ def consolidate_and_reshard_fsdp_model_dict(ckpt_dir, if reshard_num == 1: if save_model: if not save_dir or not model_save_name_pattern: - raise ValueError("save_dir and save_name should not be None!") - actual_save_path = os.path.join(save_dir, save_name) + raise ValueError( + "save_dir and model_save_name_pattern should not be None!") + actual_save_path = os.path.join(save_dir, model_save_name_pattern) save_checkpoints(full_state_dict, checkpoints[0]['shard_metadata'], actual_save_path, 'model') @@ -570,8 +589,11 @@ def consolidate_and_reshard_fsdp_optim_dict(ckpt_dir, if reshard_num == 1: if save_optimizer: if not save_dir or not optimizer_save_name_pattern: - raise ValueError("save_dir and save_name should not be None!") - actual_save_path = os.path.join(save_dir, save_name) + raise ValueError( + "save_dir and optimizer_save_name_pattern should not be None!" + ) + actual_save_path = os.path.join(save_dir, + optimizer_save_name_pattern) save_checkpoints(full_optim_state_dict, checkpoints[0]['shard_metadata'], actual_save_path, From 9a7d16edfe68ed37b659299b59a9c554c49f5c47 Mon Sep 17 00:00:00 2001 From: shw Date: Fri, 18 Oct 2024 13:07:56 +0800 Subject: [PATCH 17/24] format --- tests/standalone/offload.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/standalone/offload.py b/tests/standalone/offload.py index f789902..b0a50d1 100644 --- a/tests/standalone/offload.py +++ b/tests/standalone/offload.py @@ -115,7 +115,6 @@ def main(): device = dist.get_rank() model = Net() model.to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) scaler = torch.cuda.amp.GradScaler() From 36d14d5d0e2404a19e8438fd9375aa5cade312bb Mon Sep 17 00:00:00 2001 From: shw Date: Fri, 18 Oct 2024 15:47:57 +0800 Subject: [PATCH 18/24] refine fsdp doc --- docs/source/dist/fsdp.md | 43 +++++++++++++++++++++++++++++---------- torchacc/dist/__init__.py | 18 ++++++++++++++++ 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/docs/source/dist/fsdp.md b/docs/source/dist/fsdp.md index 8334835..1cfa8d3 100644 --- a/docs/source/dist/fsdp.md +++ b/docs/source/dist/fsdp.md @@ -129,46 +129,67 @@ $ torchrun --nproc_per_node=4 gpt2_acc.py Save the model parameters and optimizer states for each FSDP shard and LR scheduler. Note that you need to save ``shard_metadata`` to restore the correct shard information. ```python -import torch_xla.core.xla_model as xm shard_meta_data = model.model.model.get_shard_metadata() +CKP_DIR="./ckpt_dir" +MODEL_NAME=f"rank{ta.dist.local_rank()}-of-{ta.dist.world_size()}-model.pth" +OPTIM_NAME=f"rank{ta.dist.local_rank()}-of-{ta.dist.world_size()}-optim.pth" # 1) Save model shards -xm.rendezvous("saving_model") +torchacc.dist.rendezvous("saving_model") model_ckpt = { 'model': model.state_dict(), 'shard_metadata': shard_meta_data, } -torchacc.save(model_ckpt, CKPT_DIR + MODEL_NAME_PATTERN, master_only=False) +torchacc.save(model_ckpt, os.path.join(CKPT_DIR, MODEL_NAME), master_only=False) # 2) Save optimizer shards -xm.rendezvous("saving_optimizer_states") +torchacc.dist.rendezvous("saving_optimizer_states") optim_ckpt = { 'optimizer': optimizer.state_dict(), 'shard_metadata': shard_meta_data, } -torchacc.save(optim_ckpt, CKPT_DIR + OPTIM_NAME_PATTERN, master_only=False) +torchacc.save(optim_ckpt, os.path.join(CKPT_DIR, OPTIM_NAME), master_only=False) # 3) Save lr_scheduler torchacc.save(lr_scheduler.state_dict(), LR_SCHEDULER_DIR) ``` -### offline consolidation -We now support offline consolidate and reshard fsdp model and optimizer ckpts. You can run ``consolidate_and_reshard_fsdp_ckpts --help`` to refer to the instruction. +### Load Checkpoint +We can load from the shard ckpts and continue training if the fsdp config do not change. +For example, we can save with fsdp_size = 4 and load with fsdp_size = 4. + +```python +CKPT_DIR="./ckpt_dir" +MODEL_NAME=f"rank{ta.dist.local_rank()}-of-{ta.dist.world_size()}-model.pth" +OPTIM_NAME=f"rank{ta.dist.local_rank()}-of-{ta.dist.world_size()}-optim.pth" + +model_ckpt = torch.load(os.path.join(CKPT_DIR, MODEL_NAME)) +model.load_state_dict(model_ckpt['model']) + +optim_ckpt = torch.load(os.path.join(CKPT_DIR, OPTIM_NAME)) +optimizer.load_state_dict(optim_ckpt['optimizer']) +``` + +### Offline Consolidation and Reshard +We now support offline consolidate and reshard fsdp checkpoints. For example, you can save shard ckpt with fsdp_size = 4, and offline consolidate the shard checkpoints to a full checkpoint and then load the full checkpoint. What's more, you can reshard the ckpts to 8, and then load the ckpts shardly with new fsdp config: fsdp_size=8. + +You can run ``consolidate_and_reshard_fsdp_ckpts --help`` for more instructions. ```shell # consolidate model and optimizer -consolidate_and_reshard_fsdp_ckpts --ckpt_dir CKPT_DIR --model_ckpt_name_pattern MODEL_NAME_PATTERN --optimizer_ckpt_name_pattern OPTIM_NAME_PATTERN +consolidate_and_reshard_fsdp_ckpts --ckpt_dir CKPT_DIR --model_ckpt_name_pattern "rank*-of-*-model.pth" --optimizer_ckpt_name_pattern "rank*-of-*-optim.pth" # you can use --reshard_num to reshard the fsdp checkpoints +consolidate_and_reshard_fsdp_ckpts --ckpt_dir CKPT_DIR --model_ckpt_name_pattern "rank*-of-*-model.pth" --optimizer_ckpt_name_pattern "rank*-of-*-optim.pth" --reshard_num 8 ``` -### Load from Checkpoint +### Load from Full Checkpoint ```python # 1) Load model -model_consolidated = torch.load("model_consolidated.pth") +model_consolidated = torch.load("model_consolidated.pth") # the default consolidate model name model.load_state_dict(model_consolidated) # 2) Load optimizer -optimizer_consolidated = torch.load("optimizer_consolidated.pth") +optimizer_consolidated = torch.load("optimizer_consolidated.pth") # the defualt consolidate optimizer name optimizer.load_state_dict(optimizer_consolidated) # 3) Load LR scheduler diff --git a/torchacc/dist/__init__.py b/torchacc/dist/__init__.py index 8cbbac3..7f9880f 100644 --- a/torchacc/dist/__init__.py +++ b/torchacc/dist/__init__.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist import torch_xla +import torch_xla.core.xla_model as xm import torchacc as ta @@ -92,3 +93,20 @@ def init_nccl_context(config) -> None: ta.sync() _NCCL_CONTEXT_INITED = True + +def rendezvous(tag, payload=b'', replicas=[]): + """Waits for all the mesh clients to reach the named rendezvous. + We use the rendezvous api of xla directly. + + Args: + tag (string): The name of the rendezvous to join. + payload (bytes, optional): The payload to be sent to the rendezvous. + replicas (list, int): The replica ordinals taking part of the rendezvous. + Empty means all replicas in the mesh. + Default: [] + + Returns: + The payloads exchanged by all the other cores, with the payload of core + ordinal `i` at position `i` in the returned tuple. + """ + return xm.rendezvous(payload, replicas or None, tag=tag) From 8ad04f7fd101c613db0640e54895d9d83142ed29 Mon Sep 17 00:00:00 2001 From: shw Date: Fri, 18 Oct 2024 15:52:08 +0800 Subject: [PATCH 19/24] add api --- docs/source/dist/fsdp.md | 4 ++-- tests/standalone/consolidate_and_reshard_ckpts.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/dist/fsdp.md b/docs/source/dist/fsdp.md index 1cfa8d3..fa3f0a3 100644 --- a/docs/source/dist/fsdp.md +++ b/docs/source/dist/fsdp.md @@ -134,7 +134,7 @@ CKP_DIR="./ckpt_dir" MODEL_NAME=f"rank{ta.dist.local_rank()}-of-{ta.dist.world_size()}-model.pth" OPTIM_NAME=f"rank{ta.dist.local_rank()}-of-{ta.dist.world_size()}-optim.pth" -# 1) Save model shards +# 1) Each rank save model shard torchacc.dist.rendezvous("saving_model") model_ckpt = { 'model': model.state_dict(), @@ -143,7 +143,7 @@ model_ckpt = { torchacc.save(model_ckpt, os.path.join(CKPT_DIR, MODEL_NAME), master_only=False) -# 2) Save optimizer shards +# 2) Each rank save optimizer shard torchacc.dist.rendezvous("saving_optimizer_states") optim_ckpt = { 'optimizer': optimizer.state_dict(), diff --git a/tests/standalone/consolidate_and_reshard_ckpts.py b/tests/standalone/consolidate_and_reshard_ckpts.py index 58ae3a2..e8061bc 100644 --- a/tests/standalone/consolidate_and_reshard_ckpts.py +++ b/tests/standalone/consolidate_and_reshard_ckpts.py @@ -174,7 +174,7 @@ def main(args): if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) - xm.rendezvous("saving_model") + ta.dist.rendezvous("saving_model") model_ckpt = { "model": model.state_dict(), "shard_metadata": model.model.model.get_shard_metadata( @@ -185,7 +185,7 @@ def main(args): f"rank{ta.dist.local_rank()}-of-{ta.dist.world_size()}-model.pth") ta.save(model_ckpt, model_ckpt_path, master_only=False) - xm.rendezvous("saving_optim") + ta.dist.rendezvous("saving_optim") optim_ckpt = { "optimizer": optimizer.state_dict(), "shard_metadata": model.model.model.get_shard_metadata(), From e4bec1bec9d58d67a3886586c1c82c2498c385fb Mon Sep 17 00:00:00 2001 From: shw Date: Fri, 18 Oct 2024 15:54:37 +0800 Subject: [PATCH 20/24] format --- torchacc/dist/__init__.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/torchacc/dist/__init__.py b/torchacc/dist/__init__.py index 7f9880f..a35de3a 100644 --- a/torchacc/dist/__init__.py +++ b/torchacc/dist/__init__.py @@ -5,22 +5,17 @@ import torch.distributed as dist import torch_xla import torch_xla.core.xla_model as xm - import torchacc as ta # register lazy backend -from . import backend - -from .mesh import Mesh - -from .parallel_module import ParallelModule +from . import backend, fsdp, pp, tp +from .distributed_parallel import DistributedParallel from .dp import DataParallel from .fsdp import FullyShardedDataParallel -from .spmd_fsdp import SpmdFullyShardedDataParallel +from .mesh import Mesh +from .parallel_module import ParallelModule from .pp import PipelineParallel -from .distributed_parallel import DistributedParallel - -from . import fsdp, pp, tp +from .spmd_fsdp import SpmdFullyShardedDataParallel BACKEND_NAME = backend._BACKEND_NAME EAGER_BACKEND_NAME = backend._EAGER_BACKEND_NAME From 33f000765fd15320244244ccf542730efe28b4dc Mon Sep 17 00:00:00 2001 From: shw Date: Fri, 18 Oct 2024 16:02:39 +0800 Subject: [PATCH 21/24] typo --- docs/source/dist/fsdp.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/dist/fsdp.md b/docs/source/dist/fsdp.md index fa3f0a3..e168a03 100644 --- a/docs/source/dist/fsdp.md +++ b/docs/source/dist/fsdp.md @@ -131,8 +131,8 @@ Save the model parameters and optimizer states for each FSDP shard and LR schedu ```python shard_meta_data = model.model.model.get_shard_metadata() CKP_DIR="./ckpt_dir" -MODEL_NAME=f"rank{ta.dist.local_rank()}-of-{ta.dist.world_size()}-model.pth" -OPTIM_NAME=f"rank{ta.dist.local_rank()}-of-{ta.dist.world_size()}-optim.pth" +MODEL_NAME=f"rank{torchacc.dist.local_rank()}-of-{torchacc.dist.world_size()}-model.pth" +OPTIM_NAME=f"rank{torchacc.dist.local_rank()}-of-{torchacc.dist.world_size()}-optim.pth" # 1) Each rank save model shard torchacc.dist.rendezvous("saving_model") @@ -161,8 +161,8 @@ For example, we can save with fsdp_size = 4 and load with fsdp_size = 4. ```python CKPT_DIR="./ckpt_dir" -MODEL_NAME=f"rank{ta.dist.local_rank()}-of-{ta.dist.world_size()}-model.pth" -OPTIM_NAME=f"rank{ta.dist.local_rank()}-of-{ta.dist.world_size()}-optim.pth" +MODEL_NAME=f"rank{torchacc.dist.local_rank()}-of-{torchacc.dist.world_size()}-model.pth" +OPTIM_NAME=f"rank{torchacc.dist.local_rank()}-of-{torchacc.dist.world_size()}-optim.pth" model_ckpt = torch.load(os.path.join(CKPT_DIR, MODEL_NAME)) model.load_state_dict(model_ckpt['model']) From 77fda58d8f6a165b79d90680f2eb61cdf1b31d77 Mon Sep 17 00:00:00 2001 From: shw Date: Fri, 18 Oct 2024 16:05:13 +0800 Subject: [PATCH 22/24] format --- torchacc/dist/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchacc/dist/__init__.py b/torchacc/dist/__init__.py index a35de3a..ccc6fac 100644 --- a/torchacc/dist/__init__.py +++ b/torchacc/dist/__init__.py @@ -89,8 +89,9 @@ def init_nccl_context(config) -> None: _NCCL_CONTEXT_INITED = True + def rendezvous(tag, payload=b'', replicas=[]): - """Waits for all the mesh clients to reach the named rendezvous. + """Waits for all the mesh clients to reach the named rendezvous. We use the rendezvous api of xla directly. Args: @@ -104,4 +105,4 @@ def rendezvous(tag, payload=b'', replicas=[]): The payloads exchanged by all the other cores, with the payload of core ordinal `i` at position `i` in the returned tuple. """ - return xm.rendezvous(payload, replicas or None, tag=tag) + return xm.rendezvous(payload, replicas or None, tag=tag) From ab3248028731aa7b5e2b04c0ffab5a3ca44f83b9 Mon Sep 17 00:00:00 2001 From: shw Date: Fri, 18 Oct 2024 16:16:32 +0800 Subject: [PATCH 23/24] revert dist init --- torchacc/dist/__init__.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/torchacc/dist/__init__.py b/torchacc/dist/__init__.py index ccc6fac..c158d61 100644 --- a/torchacc/dist/__init__.py +++ b/torchacc/dist/__init__.py @@ -5,17 +5,22 @@ import torch.distributed as dist import torch_xla import torch_xla.core.xla_model as xm + import torchacc as ta # register lazy backend -from . import backend, fsdp, pp, tp -from .distributed_parallel import DistributedParallel -from .dp import DataParallel -from .fsdp import FullyShardedDataParallel +from . import backend + from .mesh import Mesh + from .parallel_module import ParallelModule -from .pp import PipelineParallel +from .dp import DataParallel +from .fsdp import FullyShardedDataParallel from .spmd_fsdp import SpmdFullyShardedDataParallel +from .pp import PipelineParallel +from .distributed_parallel import DistributedParallel + +from . import fsdp, pp, tp BACKEND_NAME = backend._BACKEND_NAME EAGER_BACKEND_NAME = backend._EAGER_BACKEND_NAME From ed42f2ba3aed0ad352bf07e56cdd72111e13b354 Mon Sep 17 00:00:00 2001 From: shw Date: Fri, 18 Oct 2024 16:23:46 +0800 Subject: [PATCH 24/24] dist init isort skip --- torchacc/dist/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchacc/dist/__init__.py b/torchacc/dist/__init__.py index c158d61..9c29e1d 100644 --- a/torchacc/dist/__init__.py +++ b/torchacc/dist/__init__.py @@ -13,7 +13,7 @@ from .mesh import Mesh -from .parallel_module import ParallelModule +from .parallel_module import ParallelModule # isort: skip from .dp import DataParallel from .fsdp import FullyShardedDataParallel from .spmd_fsdp import SpmdFullyShardedDataParallel