From e4a3a3942207f5266d424ac6830842007c3f4c35 Mon Sep 17 00:00:00 2001 From: Susan Zhang Date: Sat, 5 Nov 2022 20:02:17 +0100 Subject: [PATCH 01/24] split out a get_checkpoint_path_to_load --- metaseq/checkpoint_utils.py | 87 ++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 44 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 5dacdcaba..948467d98 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -228,30 +228,10 @@ def _delete_old_checkpoint_files( os.remove(old_chk) -def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): - """ - Load a checkpoint and restore the training iterator. - - *passthrough_args* will be passed through to - ``trainer.get_train_iterator``. - """ - - reset_optimizer = cfg.reset_optimizer - reset_lr_scheduler = cfg.reset_lr_scheduler - optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides) - reset_meters = cfg.reset_meters - reset_dataloader = cfg.reset_dataloader - - if cfg.finetune_from_model is not None and ( - reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader - ): - raise ValueError( - "--finetune-from-model can not be set together with either --reset-optimizer" - " or reset_lr_scheduler or reset_meters or reset_dataloader" - ) - +def get_checkpoint_path_to_load(cfg: CheckpointConfig, trainer) -> str: suffix = trainer.checkpoint_suffix default_restore_file = "checkpoint_last.pt" + # default to loading from restore file. if cfg.restore_file == default_restore_file: checkpoint_path_to_load = os.path.join( @@ -261,10 +241,10 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): if cfg.finetune_from_model is not None and first_launch: # if there is no last checkpoint to restore, start the finetune from pretrained model # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. - reset_optimizer = True - reset_lr_scheduler = True - reset_meters = True - reset_dataloader = True + cfg.reset_optimizer = True + cfg.reset_lr_scheduler = True + cfg.reset_meters = True + cfg.reset_dataloader = True checkpoint_path_to_load = None if PathManager.exists(cfg.finetune_from_model): checkpoint_path_to_load = cfg.finetune_from_model @@ -316,7 +296,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): specific_restore_file_provided = cfg.restore_file != default_restore_file slurm_was_restarted = int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 restart_from_latest = slurm_was_restarted or ( - cfg.finetune_from_model is None and not specific_restore_file_provided + cfg.finetune_from_model is None and not specific_restore_file_provided ) if restart_from_latest and os.path.exists(nfs_path): max_checkpoint = None @@ -343,13 +323,13 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): elif cfg.cluster_env == ComputeEnvs.AZURE.value and has_metaseq_internal: if ( - # --restore-file was not passed, always download latest checkpoint - ( - cfg.restore_file == default_restore_file - and cfg.finetune_from_model is None - ) - # --restore-file was passed, but we requeued, so download latest checkpoint - or int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 + # --restore-file was not passed, always download latest checkpoint + ( + cfg.restore_file == default_restore_file + and cfg.finetune_from_model is None + ) + # --restore-file was passed, but we requeued, so download latest checkpoint + or int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 ): # download checkpoint into local save_dir checkpoint_path_to_load = os.path.join( @@ -359,9 +339,9 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): cfg.cloud_upload_path, checkpoint_path_to_load, suffix + ".pt" ) elif ( - # --restore-file was passed and is a blob URL, download that checkpoint - cfg.restore_file != default_restore_file - and "windows.net" in cfg.restore_file + # --restore-file was passed and is a blob URL, download that checkpoint + cfg.restore_file != default_restore_file + and "windows.net" in cfg.restore_file ): blob_url = cfg.restore_file.replace(".pt", suffix + ".pt") # download checkpoint into local save_dir @@ -376,8 +356,8 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): # RSC logic: --restore-file was passed, and we requeued elif ( - cfg.restore_file != default_restore_file - and int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 + cfg.restore_file != default_restore_file + and int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 ): # point checkpoint_path to the current checkpoint directory for loading, if it exists. save_dir_last = os.path.join( @@ -385,6 +365,25 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): ) if PathManager.isfile(save_dir_last): checkpoint_path_to_load = save_dir_last + return checkpoint_path_to_load + + +def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): + """ + Load a checkpoint and restore the training iterator. + + *passthrough_args* will be passed through to + ``trainer.get_train_iterator``. + """ + if cfg.finetune_from_model is not None and ( + cfg.reset_optimizer or cfg.reset_lr_scheduler or cfg.reset_meters or cfg.reset_dataloader + ): + raise ValueError( + "--finetune-from-model can not be set together with either --reset-optimizer" + " or reset_lr_scheduler or reset_meters or reset_dataloader" + ) + + checkpoint_path_to_load = get_checkpoint_path_to_load(cfg, trainer) logger.info(f"attempting to load checkpoint from: {checkpoint_path_to_load}") @@ -393,13 +392,13 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): extra_state = trainer.load_checkpoint( checkpoint_path_to_load, - reset_optimizer, - reset_lr_scheduler, - optimizer_overrides, - reset_meters=reset_meters, + cfg.reset_optimizer, + cfg.reset_lr_scheduler, + ast.literal_eval(cfg.optimizer_overrides), + reset_meters=cfg.reset_meters, ) - if extra_state is not None and not reset_dataloader: + if extra_state is not None and not cfg.reset_dataloader: # restore iterator from checkpoint itr_state = extra_state["train_iterator"] epoch_itr = trainer.get_train_iterator( From 1f6071bec3717bfc15bcd6e2d36f43bc86ef16e0 Mon Sep 17 00:00:00 2001 From: Susan Zhang Date: Sat, 5 Nov 2022 20:14:34 +0100 Subject: [PATCH 02/24] make restore_file optional, indent in trainer.load_checkpoint path exists check --- metaseq/checkpoint_utils.py | 6 ++ metaseq/dataclass/configs.py | 7 +-- metaseq/trainer.py | 111 +++++++++++++++++------------------ 3 files changed, 64 insertions(+), 60 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 948467d98..15185178d 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -232,6 +232,12 @@ def get_checkpoint_path_to_load(cfg: CheckpointConfig, trainer) -> str: suffix = trainer.checkpoint_suffix default_restore_file = "checkpoint_last.pt" + # Logic flow: + # - restore_file is defined, will load from this first + # - if no restore_file: < try to grab latest from checkpoint / cloud > + # - if cloud upload is defined, pull from cloud upload + # - if no cloud upload, pull from checkpoint + # default to loading from restore file. if cfg.restore_file == default_restore_file: checkpoint_path_to_load = os.path.join( diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index 8a2bd49df..7503c29e0 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -452,11 +452,10 @@ class CheckpointConfig(MetaseqDataclass): save_dir: str = field( default="checkpoints", metadata={"help": "path to save checkpoints"} ) - restore_file: str = field( - default="checkpoint_last.pt", + restore_file: Optional[str] = field( + default=None, # Used to be: "checkpoint_last.pt" metadata={ - "help": "filename from which to load checkpoint " - "(default: /checkpoint_last.pt" + "help": "filename from which to load checkpoint" }, ) finetune_from_model: Optional[str] = field( diff --git a/metaseq/trainer.py b/metaseq/trainer.py index 784674aab..0acb1e034 100644 --- a/metaseq/trainer.py +++ b/metaseq/trainer.py @@ -413,10 +413,9 @@ def load_checkpoint( other ranks. """ extra_state, self._optim_history, last_optim_state = None, [], None - is_distributed = self.data_parallel_world_size > 1 - bexists = PathManager.isfile(filename) - if bexists: + + if PathManager.isfile(filename): logger.info(f"Preparing to load checkpoint {filename}") load_on_all_ranks = ( self.cfg.checkpoint.load_checkpoint_on_all_dp_ranks @@ -481,70 +480,70 @@ def load_checkpoint( extra_state = state["extra_state"] self._optim_history = state["optimizer_history"] - if last_optim_state is not None and not reset_optimizer: - # rebuild optimizer after loading model, since params may have changed - self._build_optimizer() + if last_optim_state is not None and not reset_optimizer: + # rebuild optimizer after loading model, since params may have changed + self._build_optimizer() - # only reload optimizer and lr_scheduler if they match - last_optim = self._optim_history[-1] - assert ( - last_optim["criterion_name"] == self.get_criterion().__class__.__name__ - ), ( - f"Criterion does not match; please reset the optimizer " - f"(--reset-optimizer). {last_optim['criterion_name']} vs " - f"{self.get_criterion().__class__.__name__}" - ) - assert last_optim["optimizer_name"] == self.optimizer.__class__.__name__, ( - f"Optimizer does not match; please reset the optimizer " - f"(--reset-optimizer). {last_optim['optimizer_name']} vs " - f"{self.optimizer.__class__.__name__}" - ) - - if not reset_lr_scheduler: - self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) - - if not load_on_all_ranks and is_distributed: - last_optim_state = self.optimizer.broadcast_global_state_dict( - last_optim_state + # only reload optimizer and lr_scheduler if they match + last_optim = self._optim_history[-1] + assert ( + last_optim["criterion_name"] == self.get_criterion().__class__.__name__ + ), ( + f"Criterion does not match; please reset the optimizer " + f"(--reset-optimizer). {last_optim['criterion_name']} vs " + f"{self.get_criterion().__class__.__name__}" ) - elif self.is_fsdp and not self.use_sharded_state: - last_optim_state = self.model.get_shard_from_optim_state_dict( - last_optim_state + assert last_optim["optimizer_name"] == self.optimizer.__class__.__name__, ( + f"Optimizer does not match; please reset the optimizer " + f"(--reset-optimizer). {last_optim['optimizer_name']} vs " + f"{self.optimizer.__class__.__name__}" ) - logger.info(f"FSDP got shard from optim_state for {filename}") - self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) - logger.info(f"Loaded optim_state for {filename}") - self.set_num_updates(last_optim["num_updates"]) + if not reset_lr_scheduler: + self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) - if extra_state is not None: - itr_state = extra_state["train_iterator"] - epoch = itr_state["epoch"] + if not load_on_all_ranks and is_distributed: + last_optim_state = self.optimizer.broadcast_global_state_dict( + last_optim_state + ) + elif self.is_fsdp and not self.use_sharded_state: + last_optim_state = self.model.get_shard_from_optim_state_dict( + last_optim_state + ) + logger.info(f"FSDP got shard from optim_state for {filename}") - if "previous_training_time" in extra_state: - self._previous_training_time = extra_state["previous_training_time"] - self._start_time = time.time() + self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) + logger.info(f"Loaded optim_state for {filename}") + self.set_num_updates(last_optim["num_updates"]) - self.lr_step(epoch) + if extra_state is not None: + itr_state = extra_state["train_iterator"] + epoch = itr_state["epoch"] - if ( - itr_state.get("version", 1) >= 2 - and itr_state["iterations_in_epoch"] == 0 - ): - # reset meters at start of epoch - reset_meters = True + if "previous_training_time" in extra_state: + self._previous_training_time = extra_state["previous_training_time"] + self._start_time = time.time() - if "metrics" in extra_state and not reset_meters: - metrics.load_state_dict(extra_state["metrics"]) + self.lr_step(epoch) - # reset TimeMeters, since their start times don't make sense anymore - for meter in metrics.get_meters("default"): - if isinstance(meter, meters.TimeMeter): - meter.reset() + if ( + itr_state.get("version", 1) >= 2 + and itr_state["iterations_in_epoch"] == 0 + ): + # reset meters at start of epoch + reset_meters = True - logger.info( - f"Loaded checkpoint {filename} (epoch {epoch} @ {self.get_num_updates()} updates)" - ) + if "metrics" in extra_state and not reset_meters: + metrics.load_state_dict(extra_state["metrics"]) + + # reset TimeMeters, since their start times don't make sense anymore + for meter in metrics.get_meters("default"): + if isinstance(meter, meters.TimeMeter): + meter.reset() + + logger.info( + f"Loaded checkpoint {filename} (epoch {epoch} @ {self.get_num_updates()} updates)" + ) else: logger.info("No existing checkpoint found {}".format(filename)) From 6ee9c4de332065540d9dc42232819644ff18bfac Mon Sep 17 00:00:00 2001 From: Susan Zhang Date: Sat, 5 Nov 2022 20:17:51 +0100 Subject: [PATCH 03/24] switch from default_restore_file string matching to is/not None checks --- metaseq/checkpoint_utils.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 15185178d..052a0aafb 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -228,18 +228,16 @@ def _delete_old_checkpoint_files( os.remove(old_chk) -def get_checkpoint_path_to_load(cfg: CheckpointConfig, trainer) -> str: - suffix = trainer.checkpoint_suffix - default_restore_file = "checkpoint_last.pt" - +def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: # Logic flow: # - restore_file is defined, will load from this first # - if no restore_file: < try to grab latest from checkpoint / cloud > # - if cloud upload is defined, pull from cloud upload # - if no cloud upload, pull from checkpoint - # default to loading from restore file. - if cfg.restore_file == default_restore_file: + suffix = trainer.checkpoint_suffix + + if cfg.restore_file is None: checkpoint_path_to_load = os.path.join( cfg.save_dir, "checkpoint_last{}.pt".format(suffix) ) @@ -272,7 +270,7 @@ def get_checkpoint_path_to_load(cfg: CheckpointConfig, trainer) -> str: else: checkpoint_path_to_load = cfg.restore_file - if cfg.restore_file != default_restore_file and cfg.finetune_from_model: + if cfg.restore_file is not None and cfg.finetune_from_model: raise ValueError( "--finetune-from-model and --restore-file (non-default value) " "can not be specified together: " + str(cfg) @@ -299,7 +297,7 @@ def get_checkpoint_path_to_load(cfg: CheckpointConfig, trainer) -> str: ) nfs_path = cfg.cloud_upload_path[4:] filename = None - specific_restore_file_provided = cfg.restore_file != default_restore_file + specific_restore_file_provided = cfg.restore_file is not None slurm_was_restarted = int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 restart_from_latest = slurm_was_restarted or ( cfg.finetune_from_model is None and not specific_restore_file_provided @@ -331,7 +329,7 @@ def get_checkpoint_path_to_load(cfg: CheckpointConfig, trainer) -> str: if ( # --restore-file was not passed, always download latest checkpoint ( - cfg.restore_file == default_restore_file + cfg.restore_file is None and cfg.finetune_from_model is None ) # --restore-file was passed, but we requeued, so download latest checkpoint @@ -346,7 +344,7 @@ def get_checkpoint_path_to_load(cfg: CheckpointConfig, trainer) -> str: ) elif ( # --restore-file was passed and is a blob URL, download that checkpoint - cfg.restore_file != default_restore_file + cfg.restore_file is not None and "windows.net" in cfg.restore_file ): blob_url = cfg.restore_file.replace(".pt", suffix + ".pt") @@ -362,7 +360,7 @@ def get_checkpoint_path_to_load(cfg: CheckpointConfig, trainer) -> str: # RSC logic: --restore-file was passed, and we requeued elif ( - cfg.restore_file != default_restore_file + cfg.restore_file is not None and int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 ): # point checkpoint_path to the current checkpoint directory for loading, if it exists. @@ -389,7 +387,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): " or reset_lr_scheduler or reset_meters or reset_dataloader" ) - checkpoint_path_to_load = get_checkpoint_path_to_load(cfg, trainer) + checkpoint_path_to_load = get_and_prep_checkpoint_path(cfg, trainer) logger.info(f"attempting to load checkpoint from: {checkpoint_path_to_load}") From c68d5780a3c4a0c14c84ff8d9dc59db8c1f68d01 Mon Sep 17 00:00:00 2001 From: Susan Zhang Date: Sat, 5 Nov 2022 20:19:18 +0100 Subject: [PATCH 04/24] none guard --- metaseq/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 052a0aafb..80a925277 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -316,7 +316,7 @@ def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: nfs_path, candidate, f"checkpoint{suffix}.pt" ) else: - filename = cfg.restore_file.replace(".pt", suffix + ".pt") + filename = cfg.restore_file.replace(".pt", suffix + ".pt") if cfg.restore_file is not None else None if filename is not None and os.path.exists(checkpoint_path_to_load): logger.info( f"Copying checkpoint from nfs {filename} -> {checkpoint_path_to_load}" From a6fd478f2f4b40779756c458b7525f77fb021515 Mon Sep 17 00:00:00 2001 From: Susan Zhang Date: Sat, 5 Nov 2022 20:27:00 +0100 Subject: [PATCH 05/24] sorting out restore_file logic --- metaseq/checkpoint_utils.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 80a925277..bf2cef5a9 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -230,7 +230,6 @@ def _delete_old_checkpoint_files( def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: # Logic flow: - # - restore_file is defined, will load from this first # - if no restore_file: < try to grab latest from checkpoint / cloud > # - if cloud upload is defined, pull from cloud upload # - if no cloud upload, pull from checkpoint @@ -250,31 +249,30 @@ def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: cfg.reset_meters = True cfg.reset_dataloader = True checkpoint_path_to_load = None + logger.warning( + "Resetting optimizer, lr scheduler, meters, and dataloader for fine-tuning!" + ) if PathManager.exists(cfg.finetune_from_model): - checkpoint_path_to_load = cfg.finetune_from_model + return cfg.finetune_from_model elif suffix is not None: # check for sharded version sharded_path = cfg.finetune_from_model.replace(".pt", suffix + ".pt") if PathManager.exists(sharded_path): - checkpoint_path_to_load = sharded_path - if checkpoint_path_to_load is None: + return sharded_path + else: raise ValueError( f"--finetune-from-model {cfg.finetune_from_model} does not exist either as is or sharded" ) + else: # restore_file specified + if suffix is not None: + checkpoint_path_to_load = cfg.restore_file.replace(".pt", suffix + ".pt") + else: + checkpoint_path_to_load = cfg.restore_file - logger.info( - f"loading pretrained model from {checkpoint_path_to_load}: " - "optimizer, lr scheduler, meters, dataloader will be reset" + if cfg.restore_file is not None and cfg.finetune_from_model: + raise ValueError( + "--finetune-from-model and --restore-file (non-default value) " + "can not be specified together: " + str(cfg) ) - elif suffix is not None: - checkpoint_path_to_load = cfg.restore_file.replace(".pt", suffix + ".pt") - else: - checkpoint_path_to_load = cfg.restore_file - - if cfg.restore_file is not None and cfg.finetune_from_model: - raise ValueError( - "--finetune-from-model and --restore-file (non-default value) " - "can not be specified together: " + str(cfg) - ) # Azure logic try: From 18c6d166a369ce7a1982718046068ed5b3973723 Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Thu, 24 Nov 2022 05:29:41 -0800 Subject: [PATCH 06/24] make save_async a single command --- metaseq/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 509ffb80a..4b76726f3 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -97,7 +97,7 @@ def is_better(a, b): ) def _copy_if_not_async(src, dest): - if cfg.write_checkpoints_asynchronously: + if cfg.save_async: pass # TODO[file_io]: Need to implement a delayed asynchronous file copying/moving feature. else: assert PathManager.copy( From 570fcedc11f1d4d1e502678f224cd686dea896b3 Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Thu, 24 Nov 2022 06:27:53 -0800 Subject: [PATCH 07/24] remove duplicate flags --- metaseq/cli/train.py | 2 +- metaseq/dataclass/configs.py | 28 +++++++++++++--------------- metaseq/trainer.py | 2 +- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index c7a6e478f..74d71cc96 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -185,7 +185,7 @@ def main(cfg: DictConfig) -> None: logger.info("done training in {:.1f} seconds".format(train_meter.sum)) # Wait for all asynchronous file writes to complete. - if cfg.checkpoint.write_checkpoints_asynchronously: + if cfg.checkpoint.save_async: logger.info( "PathManager waiting for all asynchronous checkpoint writes to finish." ) diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index faaa6a812..69ea8ec89 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -437,9 +437,9 @@ class OptimizationConfig(MetaseqDataclass): @dataclass class CheckpointConfig(MetaseqDataclass): - save_dir: str = field( - default="checkpoints", metadata={"help": "path to save checkpoints"} - ) + # save_dir: str = field( + # default="None", metadata={"help": "path to save checkpoints"} + # ) restore_file: Optional[str] = field( default=None, # Used to be: "checkpoint_last.pt" metadata={ @@ -537,26 +537,24 @@ class CheckpointConfig(MetaseqDataclass): "(default: only load on rank 0 and broadcast to other devices)" }, ) - write_checkpoints_asynchronously: bool = field( + save_async: bool = field( default=False, metadata={ "help": ( "Write checkpoints asynchronously in a separate " "thread. NOTE: This feature is currently being tested." ), - "argparse_alias": "--save-async", - }, - ) - cloud_upload_path: Optional[str] = field( - default=None, - metadata={ - "help": ( - "Upload checkpoints asynchronously in a separate " - "thread to blob store. NOTE: This feature is currently being tested." - ), - "argparse_alias": "--cloud-dir", }, ) + # cloud_upload_path: Optional[str] = field( + # default=None, + # metadata={ + # "help": ( + # "Upload checkpoints asynchronously in a separate " + # "thread to blob store. NOTE: This feature is currently being tested." + # ), + # }, + # ) # TODO(susanz): After https://github.com/fairinternal/fairseq-big-internal/issues/22 is tackled, modify this # to use ComputeEnvs constant cluster_env: str = field( diff --git a/metaseq/trainer.py b/metaseq/trainer.py index b1e777d4d..02c3c46df 100644 --- a/metaseq/trainer.py +++ b/metaseq/trainer.py @@ -386,7 +386,7 @@ def save_checkpoint( ) state_dict["extra_state"].update(extra_state) if self.should_save_checkpoint_on_current_rank: - if self.cfg.checkpoint.write_checkpoints_asynchronously: + if self.cfg.checkpoint.save_async: if not hasattr(self, "async_checkpoint"): self.async_checkpoint = ThreadPoolExecutor(max_workers=1) From a9989aa3b34677a5847b4dc39d7bff1a2a655f66 Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Fri, 25 Nov 2022 05:39:31 -0800 Subject: [PATCH 08/24] change naming to local_checkpoints_dir --- metaseq/checkpoint_utils.py | 18 +++++++++--------- metaseq/cli/train.py | 20 ++++++++++---------- metaseq/dataclass/configs.py | 14 +++++++------- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 4b76726f3..88568caad 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -40,7 +40,7 @@ def save_checkpoint( # only one worker should attempt to create the required dir if trainer.data_parallel_rank == 0: - os.makedirs(cfg.save_dir, exist_ok=True) + os.makedirs(cfg.local_checkpoints_dir, exist_ok=True) trainer.consolidate_optimizer() @@ -83,7 +83,7 @@ def is_better(a, b): extra_state = {"train_iterator": epoch_itr.state_dict()} checkpoints = [ - os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond + os.path.join(cfg.local_checkpoints_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: if PathManager.islink(checkpoints[0]): @@ -129,7 +129,7 @@ def _delete_old_checkpoint_files( # remove old checkpoints; checkpoints are sorted in descending order for one_suffix in suffixes: checkpoints = _checkpoint_paths( - cfg.save_dir, pattern=r"checkpoint_(\d+){}\.pt".format(one_suffix) + cfg.local_checkpoints_dir, pattern=r"checkpoint_(\d+){}\.pt".format(one_suffix) ) for old_chk in checkpoints[cfg.keep_last_updates :]: if os.path.lexists(old_chk): @@ -137,7 +137,7 @@ def _delete_old_checkpoint_files( if cfg.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = _checkpoint_paths( - cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) + cfg.local_checkpoints_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) ) for old_chk in checkpoints[cfg.keep_last_epochs :]: if os.path.lexists(old_chk): @@ -154,7 +154,7 @@ def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: if cfg.restore_file is None: checkpoint_path_to_load = os.path.join( - cfg.save_dir, "checkpoint_last{}.pt".format(suffix) + cfg.local_checkpoints_dir, "checkpoint_last{}.pt".format(suffix) ) first_launch = not PathManager.exists(checkpoint_path_to_load) if cfg.finetune_from_model is not None and first_launch: @@ -207,7 +207,7 @@ def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: if cfg.cloud_upload_path: if cfg.cloud_upload_path.startswith("nfs:"): checkpoint_path_to_load = os.path.join( - cfg.save_dir, "checkpoint_last{}.pt".format(suffix) + cfg.local_checkpoints_dir, "checkpoint_last{}.pt".format(suffix) ) nfs_path = cfg.cloud_upload_path[4:] filename = None @@ -266,7 +266,7 @@ def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: ): # download checkpoint into local save_dir checkpoint_path_to_load = os.path.join( - cfg.save_dir, "checkpoint_last{}.pt".format(suffix) + cfg.local_checkpoints_dir, "checkpoint_last{}.pt".format(suffix) ) azure_utils.download_recent_ckpt( cfg.cloud_upload_path, checkpoint_path_to_load, suffix + ".pt" @@ -279,7 +279,7 @@ def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: blob_url = cfg.restore_file.replace(".pt", suffix + ".pt") # download checkpoint into local save_dir checkpoint_path_to_load = os.path.join( - cfg.save_dir, "checkpoint_last{}.pt".format(suffix) + cfg.local_checkpoints_dir, "checkpoint_last{}.pt".format(suffix) ) azure_utils.download_specific_ckpt(blob_url, checkpoint_path_to_load) else: @@ -294,7 +294,7 @@ def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: ): # point checkpoint_path to the current checkpoint directory for loading, if it exists. save_dir_last = os.path.join( - cfg.save_dir, "checkpoint_last{}.pt".format(suffix) + cfg.local_checkpoints_dir, "checkpoint_last{}.pt".format(suffix) ) if PathManager.isfile(save_dir_last): checkpoint_path_to_load = save_dir_last diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index 74d71cc96..8c7cf1936 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -55,16 +55,16 @@ def main(cfg: DictConfig) -> None: # replace with actual job id slurm_jobid = os.environ.get("SLURM_JOBID", None) - if "%jobid" in cfg.checkpoint.save_dir and slurm_jobid is not None: - cfg.checkpoint.save_dir = cfg.checkpoint.save_dir.replace("%jobid", slurm_jobid) + if "%jobid" in cfg.checkpoint.local_checkpoints_dir and slurm_jobid is not None: + cfg.checkpoint.local_checkpoints_dir = cfg.checkpoint.local_checkpoints_dir.replace("%jobid", slurm_jobid) - checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) + checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.local_checkpoints_dir) if distributed_utils.is_master(cfg.distributed_training): # save a (vaguely human readable) copy of the training config OmegaConf.save( config=_flatten_config(cfg), - f=os.path.join(cfg.checkpoint.save_dir, "config.yml"), + f=os.path.join(cfg.checkpoint.local_checkpoints_dir, "config.yml"), ) if ( @@ -238,14 +238,14 @@ def train( if distributed_utils.is_master(cfg.distributed_training) else None ), - aim_param_checkpoint_dir=cfg.checkpoint.save_dir, + aim_param_checkpoint_dir=cfg.checkpoint.local_checkpoints_dir, wandb_project=( cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None ), wandb_run_name=os.environ.get( - "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) + "WANDB_NAME", os.path.basename(cfg.checkpoint.local_checkpoints_dir) ), ) progress.update_config(_flatten_config(cfg)) @@ -297,7 +297,7 @@ def train( valid_losses, should_stop = train(i, samples) torch.cuda.synchronize() with open( - os.path.join(cfg.checkpoint.save_dir, "memory_usage.txt"), "a" + os.path.join(cfg.checkpoint.local_checkpoints_dir, "memory_usage.txt"), "a" ) as sourceFile: print( prof.key_averages(group_by_stack_n=5).table( @@ -306,7 +306,7 @@ def train( file=sourceFile, ) prof.export_chrome_trace( - os.path.join(cfg.checkpoint.save_dir, "profiler_trace.json") + os.path.join(cfg.checkpoint.local_checkpoints_dir, "profiler_trace.json") ) else: valid_losses, should_stop = train(i, samples) @@ -599,14 +599,14 @@ def validate( if distributed_utils.is_master(cfg.distributed_training) else None ), - aim_param_checkpoint_dir=cfg.checkpoint.save_dir, + aim_param_checkpoint_dir=cfg.checkpoint.local_checkpoints_dir, wandb_project=( cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None ), wandb_run_name=os.environ.get( - "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) + "WANDB_NAME", os.path.basename(cfg.checkpoint.local_checkpoints_dir) ), ) diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index 69ea8ec89..3cfd725b1 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -652,13 +652,13 @@ class CommonEvalConfig(MetaseqDataclass): @dataclass class ReshardConfig(MetaseqDataclass): - save_dir: Optional[str] = field( - default=None, - metadata={ - "help": "where to save the resharded checkpoints", - "argparse_alias": "--dest-dir", - }, - ) + # save_dir: Optional[str] = field( + # default=None, + # metadata={ + # "help": "where to save the resharded checkpoints", + # "argparse_alias": "--dest-dir", + # }, + # ) save_prefix: Optional[str] = field( default="reshard", metadata={"help": "save to dest-dir/save-prefix-shard{i}.pt"} ) From 6eefb065e10a47571eaf323b0b7597ac12a77c4c Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Tue, 29 Nov 2022 04:25:03 -0800 Subject: [PATCH 09/24] finish core prepare_local_checkpoint_path function --- metaseq/checkpoint_utils.py | 257 +++++++++++++++++++++++++++++------- 1 file changed, 211 insertions(+), 46 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 88568caad..683d6ef5b 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -83,7 +83,9 @@ def is_better(a, b): extra_state = {"train_iterator": epoch_itr.state_dict()} checkpoints = [ - os.path.join(cfg.local_checkpoints_dir, fn) for fn, cond in checkpoint_conds.items() if cond + os.path.join(cfg.local_checkpoints_dir, fn) + for fn, cond in checkpoint_conds.items() + if cond ] if len(checkpoints) > 0: if PathManager.islink(checkpoints[0]): @@ -129,7 +131,8 @@ def _delete_old_checkpoint_files( # remove old checkpoints; checkpoints are sorted in descending order for one_suffix in suffixes: checkpoints = _checkpoint_paths( - cfg.local_checkpoints_dir, pattern=r"checkpoint_(\d+){}\.pt".format(one_suffix) + cfg.local_checkpoints_dir, + pattern=r"checkpoint_(\d+){}\.pt".format(one_suffix), ) for old_chk in checkpoints[cfg.keep_last_updates :]: if os.path.lexists(old_chk): @@ -144,6 +147,166 @@ def _delete_old_checkpoint_files( os.remove(old_chk) +def get_storage_type(path): + if path.startswith("nfs:"): + return "nfs" + elif "windows.net" in path: + return "azure_blob" + else: + return "local" + + +def get_checkpoint_steps(path): + match = re.search(r"checkpoint_(\d+)", path) + if match[1] is None: + return 0 + return int(match[1]) + + +def get_all_checkpoints_from_directory(directory, suffix, increased_priority, storage_type): + checkpoints = [] + for candidate in os.listdir(directory): + steps = get_checkpoint_steps(candidate) + if steps == 0: + continue + # TODO needs to be adapated for local dir + expected_file_count = distributed_utils.get_global_world_size() + present_files = len( + [ + f + for f in os.listdir(os.path.join(directory, candidate)) + if not f.startswith("_") + ] + ) + if present_files != expected_file_count: + logger.info( + f"skipping checkpoint {candidate} because it only has" + f" {present_files} files (expected {expected_file_count})" + ) + continue + + # TODO validate this + checkpoints.append( + { + "path": os.path.join(directory, candidate, f"checkpoint{suffix}.pt"), + "priority": steps + increased_priority, + "storage_type": storage_type, + } + ) + return checkpoints + + +def get_recent_checkpoint_from_azure_blob(blob_url, suffix, increased_priority): + file_to_load = azure_utils.get_most_recent_ckpt(blob_url, suffix) + if file_to_load is None: + return [] + steps = get_checkpoint_steps(file_to_load) + return [ + { + "path": blob_url + "/" + file_to_load, + "priority": steps + increased_priority, + "storage_type": "azure_blob", + } + ] +def get_checkpoint_to_finetune(finetune_path, suffix, priority): + if PathManager.exists(finetune_path): + validated_path = finetune_path + else: # check for sharded version + sharded_path = finetune_path.replace(".pt", suffix + ".pt") + if PathManager.exists(sharded_path): + validated_path = sharded_path + if validated_path is None: + raise ValueError( + f"--finetune-from-model {cfg.finetune_from_model} does not exist either as is or sharded" + ) + return { + "path": validated_path, + "priority": priority, + "storage_type": get_storage_type(validated_path), + "run_before_loading": reset_for_finetuning + } + +def reset_for_finetuning(cfg, checkpoint): + cfg.reset_optimizer = True + cfg.reset_lr_scheduler = True + cfg.reset_meters = True + cfg.reset_dataloader = True + logger.warning( + "Resetting optimizer, lr scheduler, meters, and dataloader for fine-tuning!" + ) + +def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer): + suffix = trainer.checkpoint_suffix + checkpoints = [] + + if cfg.finetune_from_model: + checkpoints.append( + get_checkpoint_to_finetune(cfg.finetune_from_model, suffix, 0) + ) + + if cfg.restore_file: + checkpoints.append( + { + "path": cfg.restore_file.replace(".pt", suffix + ".pt"), + "priority": get_checkpoint_steps(cfg.restore_file) + 0.1, + "storage_type": get_storage_type(cfg.restore_file), + } + ) + + if cfg.cloud_upload_path: + cloud_storage_type = get_storage_type(cfg.cloud_upload_path) + if cloud_storage_type == "nfs": + checkpoints.extend( + get_all_checkpoints_from_directory( + cfg.cloud_upload_path[4:], + suffix, + increased_priority=0.2, + storage_type="nfs", + ) + ) + elif cloud_storage_type == "azure_blob": + checkpoints.extend( + get_recent_checkpoint_from_azure_blob( + cfg.cloud_upload_path, suffix, increased_priority=0.2 + ) + ) + + + checkpoints.extend( + get_all_checkpoints_from_directory( + cfg.local_checkpoints_dir, increased_priority=0.3, storage_type="local" + ) + ) + + checkpoints.sort(key = lambda checkpoint: checkpoint["priority"]) + if len(checkpoints) == 0: + return "" + logger.info(f"The following checkpoints were found to be ready to load: {str(checkpoints)}") + + + selected = checkpoints[-1] + if "run_before_loading" in selected: + selected["before_loading"](cfg, selected) + + if selected["storage_type"] == "local": + return selected["path"] + + local_tmp_dir = os.path.join( + cfg.local_checkpoints_dir, "checkpoint_tmp{}.pt".format(suffix) + ) + logger.info( + f"Copying checkpoint from {selected["path"]} -> {local_tmp_dir}" + ) + if selected["storage_type"] == "nfs": + shutil.copyfile(selected["path"], local_tmp_dir) + elif selected["storage_type"] == "azure_blob": + azure_utils.download_specific_ckpt(selected["path"], local_tmp_dir) + + return local_tmp_dir + + # based on the storage type, process and return final path + + def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: # Logic flow: # - if no restore_file: < try to grab latest from checkpoint / cloud > @@ -156,28 +319,29 @@ def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: checkpoint_path_to_load = os.path.join( cfg.local_checkpoints_dir, "checkpoint_last{}.pt".format(suffix) ) - first_launch = not PathManager.exists(checkpoint_path_to_load) - if cfg.finetune_from_model is not None and first_launch: - # if there is no last checkpoint to restore, start the finetune from pretrained model - # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. - cfg.reset_optimizer = True - cfg.reset_lr_scheduler = True - cfg.reset_meters = True - cfg.reset_dataloader = True - checkpoint_path_to_load = None - logger.warning( - "Resetting optimizer, lr scheduler, meters, and dataloader for fine-tuning!" - ) - if PathManager.exists(cfg.finetune_from_model): - return cfg.finetune_from_model - elif suffix is not None: # check for sharded version - sharded_path = cfg.finetune_from_model.replace(".pt", suffix + ".pt") - if PathManager.exists(sharded_path): - return sharded_path - else: - raise ValueError( - f"--finetune-from-model {cfg.finetune_from_model} does not exist either as is or sharded" - ) + # move this out + # first_launch = not PathManager.exists(checkpoint_path_to_load) + # if cfg.finetune_from_model is not None and first_launch: + # # if there is no last checkpoint to restore, start the finetune from pretrained model + # # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. + # cfg.reset_optimizer = True + # cfg.reset_lr_scheduler = True + # cfg.reset_meters = True + # cfg.reset_dataloader = True + # checkpoint_path_to_load = None + # logger.warning( + # "Resetting optimizer, lr scheduler, meters, and dataloader for fine-tuning!" + # ) + # if PathManager.exists(cfg.finetune_from_model): + # return cfg.finetune_from_model + # elif suffix is not None: # check for sharded version + # sharded_path = cfg.finetune_from_model.replace(".pt", suffix + ".pt") + # if PathManager.exists(sharded_path): + # return sharded_path + # else: + # raise ValueError( + # f"--finetune-from-model {cfg.finetune_from_model} does not exist either as is or sharded" + # ) else: # restore_file specified if suffix is not None: checkpoint_path_to_load = cfg.restore_file.replace(".pt", suffix + ".pt") @@ -205,6 +369,8 @@ def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: # Note that we compare by value since ComputeEnvs may be imported from metaseq_internal if cfg.cloud_upload_path: + + # RSC NFS LOGIC if cfg.cloud_upload_path.startswith("nfs:"): checkpoint_path_to_load = os.path.join( cfg.local_checkpoints_dir, "checkpoint_last{}.pt".format(suffix) @@ -213,8 +379,9 @@ def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: filename = None specific_restore_file_provided = cfg.restore_file is not None slurm_was_restarted = int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 + # --cloud_upload_path, slurm restart or no (finetune and restart file) restart_from_latest = slurm_was_restarted or ( - cfg.finetune_from_model is None and not specific_restore_file_provided + cfg.finetune_from_model is None and not specific_restore_file_provided ) if restart_from_latest: checkpoints = [] @@ -245,8 +412,14 @@ def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: f" {present_files} files (expected {expected_file_count})" ) else: - filename = cfg.restore_file.replace(".pt", suffix + ".pt") if cfg.restore_file is not None else None + # --cloud_upload_path, no slurm restart, or a finetune or restart file) -> use restore_file + filename = ( + cfg.restore_file.replace(".pt", suffix + ".pt") + if cfg.restore_file is not None + else None + ) if filename is not None: + # rsc nfs copying logger.info( f"Copying checkpoint from nfs {filename} -> {checkpoint_path_to_load}" ) @@ -254,15 +427,13 @@ def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: else: logger.info(f"No NFS checkpoints found") + # AZURE LOGIC elif cfg.cluster_env == ComputeEnvs.AZURE.value and has_metaseq_internal: if ( - # --restore-file was not passed, always download latest checkpoint - ( - cfg.restore_file is None - and cfg.finetune_from_model is None - ) - # --restore-file was passed, but we requeued, so download latest checkpoint - or int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 + # --restore-file was not passed, always download latest checkpoint + (cfg.restore_file is None and cfg.finetune_from_model is None) + # --restore-file was passed, but we requeued, so download latest checkpoint + or int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 ): # download checkpoint into local save_dir checkpoint_path_to_load = os.path.join( @@ -272,9 +443,9 @@ def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: cfg.cloud_upload_path, checkpoint_path_to_load, suffix + ".pt" ) elif ( - # --restore-file was passed and is a blob URL, download that checkpoint - cfg.restore_file is not None - and "windows.net" in cfg.restore_file + # --restore-file was passed and is a blob URL, download that checkpoint + cfg.restore_file is not None + and "windows.net" in cfg.restore_file ): blob_url = cfg.restore_file.replace(".pt", suffix + ".pt") # download checkpoint into local save_dir @@ -288,9 +459,10 @@ def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: ) # RSC logic: --restore-file was passed, and we requeued + # no cloud upload path specified elif ( - cfg.restore_file is not None - and int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 + cfg.restore_file is not None + and int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 ): # point checkpoint_path to the current checkpoint directory for loading, if it exists. save_dir_last = os.path.join( @@ -308,15 +480,8 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): *passthrough_args* will be passed through to ``trainer.get_train_iterator``. """ - if cfg.finetune_from_model is not None and ( - cfg.reset_optimizer or cfg.reset_lr_scheduler or cfg.reset_meters or cfg.reset_dataloader - ): - raise ValueError( - "--finetune-from-model can not be set together with either --reset-optimizer" - " or reset_lr_scheduler or reset_meters or reset_dataloader" - ) - checkpoint_path_to_load = get_and_prep_checkpoint_path(cfg, trainer) + checkpoint_path_to_load = prepare_local_checkpoint_path(cfg, trainer) logger.info(f"attempting to load checkpoint from: {checkpoint_path_to_load}") From a8b06134653c62e1880f8be0926b26140c5d26c8 Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Tue, 29 Nov 2022 04:28:38 -0800 Subject: [PATCH 10/24] black lint --- metaseq/checkpoint_utils.py | 27 ++++++++++++++++----------- metaseq/cli/train.py | 11 ++++++++--- metaseq/dataclass/configs.py | 4 +--- metaseq/trainer.py | 7 +++++-- 4 files changed, 30 insertions(+), 19 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 683d6ef5b..37460d586 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -163,7 +163,9 @@ def get_checkpoint_steps(path): return int(match[1]) -def get_all_checkpoints_from_directory(directory, suffix, increased_priority, storage_type): +def get_all_checkpoints_from_directory( + directory, suffix, increased_priority, storage_type +): checkpoints = [] for candidate in os.listdir(directory): steps = get_checkpoint_steps(candidate) @@ -208,6 +210,8 @@ def get_recent_checkpoint_from_azure_blob(blob_url, suffix, increased_priority): "storage_type": "azure_blob", } ] + + def get_checkpoint_to_finetune(finetune_path, suffix, priority): if PathManager.exists(finetune_path): validated_path = finetune_path @@ -223,8 +227,9 @@ def get_checkpoint_to_finetune(finetune_path, suffix, priority): "path": validated_path, "priority": priority, "storage_type": get_storage_type(validated_path), - "run_before_loading": reset_for_finetuning - } + "run_before_loading": reset_for_finetuning, + } + def reset_for_finetuning(cfg, checkpoint): cfg.reset_optimizer = True @@ -235,6 +240,7 @@ def reset_for_finetuning(cfg, checkpoint): "Resetting optimizer, lr scheduler, meters, and dataloader for fine-tuning!" ) + def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer): suffix = trainer.checkpoint_suffix checkpoints = [] @@ -271,18 +277,18 @@ def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer): ) ) - checkpoints.extend( get_all_checkpoints_from_directory( cfg.local_checkpoints_dir, increased_priority=0.3, storage_type="local" ) ) - checkpoints.sort(key = lambda checkpoint: checkpoint["priority"]) + checkpoints.sort(key=lambda checkpoint: checkpoint["priority"]) if len(checkpoints) == 0: return "" - logger.info(f"The following checkpoints were found to be ready to load: {str(checkpoints)}") - + logger.info( + f"The following checkpoints were found to be ready to load: {str(checkpoints)}" + ) selected = checkpoints[-1] if "run_before_loading" in selected: @@ -292,11 +298,10 @@ def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer): return selected["path"] local_tmp_dir = os.path.join( - cfg.local_checkpoints_dir, "checkpoint_tmp{}.pt".format(suffix) - ) - logger.info( - f"Copying checkpoint from {selected["path"]} -> {local_tmp_dir}" + cfg.local_checkpoints_dir, f"checkpoint_tmp{suffix}.pt" ) + + logger.info(f"Copying checkpoint from {selected['path']} -> {local_tmp_dir}") if selected["storage_type"] == "nfs": shutil.copyfile(selected["path"], local_tmp_dir) elif selected["storage_type"] == "azure_blob": diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index 8c7cf1936..3749b3ba7 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -56,7 +56,9 @@ def main(cfg: DictConfig) -> None: # replace with actual job id slurm_jobid = os.environ.get("SLURM_JOBID", None) if "%jobid" in cfg.checkpoint.local_checkpoints_dir and slurm_jobid is not None: - cfg.checkpoint.local_checkpoints_dir = cfg.checkpoint.local_checkpoints_dir.replace("%jobid", slurm_jobid) + cfg.checkpoint.local_checkpoints_dir = ( + cfg.checkpoint.local_checkpoints_dir.replace("%jobid", slurm_jobid) + ) checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.local_checkpoints_dir) @@ -297,7 +299,8 @@ def train( valid_losses, should_stop = train(i, samples) torch.cuda.synchronize() with open( - os.path.join(cfg.checkpoint.local_checkpoints_dir, "memory_usage.txt"), "a" + os.path.join(cfg.checkpoint.local_checkpoints_dir, "memory_usage.txt"), + "a", ) as sourceFile: print( prof.key_averages(group_by_stack_n=5).table( @@ -306,7 +309,9 @@ def train( file=sourceFile, ) prof.export_chrome_trace( - os.path.join(cfg.checkpoint.local_checkpoints_dir, "profiler_trace.json") + os.path.join( + cfg.checkpoint.local_checkpoints_dir, "profiler_trace.json" + ) ) else: valid_losses, should_stop = train(i, samples) diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index 3cfd725b1..0b6dc9bae 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -442,9 +442,7 @@ class CheckpointConfig(MetaseqDataclass): # ) restore_file: Optional[str] = field( default=None, # Used to be: "checkpoint_last.pt" - metadata={ - "help": "filename from which to load checkpoint" - }, + metadata={"help": "filename from which to load checkpoint"}, ) finetune_from_model: Optional[str] = field( default=None, diff --git a/metaseq/trainer.py b/metaseq/trainer.py index 02c3c46df..faa995367 100644 --- a/metaseq/trainer.py +++ b/metaseq/trainer.py @@ -502,13 +502,16 @@ def load_checkpoint( # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] assert ( - last_optim["criterion_name"] == self.get_criterion().__class__.__name__ + last_optim["criterion_name"] + == self.get_criterion().__class__.__name__ ), ( f"Criterion does not match; please reset the optimizer " f"(--reset-optimizer). {last_optim['criterion_name']} vs " f"{self.get_criterion().__class__.__name__}" ) - assert last_optim["optimizer_name"] == self.optimizer.__class__.__name__, ( + assert ( + last_optim["optimizer_name"] == self.optimizer.__class__.__name__ + ), ( f"Optimizer does not match; please reset the optimizer " f"(--reset-optimizer). {last_optim['optimizer_name']} vs " f"{self.optimizer.__class__.__name__}" From 6fe250c067069856c8bf2d14d9c27e0cb8e0d9bb Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Tue, 29 Nov 2022 05:31:48 -0800 Subject: [PATCH 11/24] add dataclass --- metaseq/checkpoint_utils.py | 84 +++++++++++++++++++++---------------- 1 file changed, 47 insertions(+), 37 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 37460d586..d62e2a6f8 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -11,8 +11,10 @@ import re import traceback import socket -from typing import Any, Dict, List, Optional, Tuple import shutil +from typing import Any, Dict, List, Optional, Tuple +from dataclasses import dataclass, field + import torch from omegaconf import OmegaConf @@ -189,11 +191,11 @@ def get_all_checkpoints_from_directory( # TODO validate this checkpoints.append( - { - "path": os.path.join(directory, candidate, f"checkpoint{suffix}.pt"), - "priority": steps + increased_priority, - "storage_type": storage_type, - } + CheckpointPath( + path=os.path.join(directory, candidate, f"checkpoint{suffix}.pt"), + storage_type=storage_type, + priority=steps + increased_priority, + ) ) return checkpoints @@ -204,11 +206,11 @@ def get_recent_checkpoint_from_azure_blob(blob_url, suffix, increased_priority): return [] steps = get_checkpoint_steps(file_to_load) return [ - { - "path": blob_url + "/" + file_to_load, - "priority": steps + increased_priority, - "storage_type": "azure_blob", - } + CheckpointPath( + path=blob_url + "/" + file_to_load, + storage_type="azure_blob", + priority=steps + increased_priority, + ) ] @@ -223,12 +225,12 @@ def get_checkpoint_to_finetune(finetune_path, suffix, priority): raise ValueError( f"--finetune-from-model {cfg.finetune_from_model} does not exist either as is or sharded" ) - return { - "path": validated_path, - "priority": priority, - "storage_type": get_storage_type(validated_path), - "run_before_loading": reset_for_finetuning, - } + return CheckpointPath( + path=validated_path, + storage_type=get_storage_type(validated_path), + priority=priority, + run_before_loading=[reset_for_finetuning], + ) def reset_for_finetuning(cfg, checkpoint): @@ -241,10 +243,11 @@ def reset_for_finetuning(cfg, checkpoint): ) -def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer): +def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: suffix = trainer.checkpoint_suffix - checkpoints = [] + # collect all possible checkpoint paths + checkpoints = [] if cfg.finetune_from_model: checkpoints.append( get_checkpoint_to_finetune(cfg.finetune_from_model, suffix, 0) @@ -252,11 +255,11 @@ def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer): if cfg.restore_file: checkpoints.append( - { - "path": cfg.restore_file.replace(".pt", suffix + ".pt"), - "priority": get_checkpoint_steps(cfg.restore_file) + 0.1, - "storage_type": get_storage_type(cfg.restore_file), - } + CheckpointPath( + path=cfg.restore_file.replace(".pt", suffix + ".pt"), + priority=get_checkpoint_steps(cfg.restore_file) + 0.1, + storage_type=get_storage_type(cfg.restore_file), + ) ) if cfg.cloud_upload_path: @@ -283,34 +286,33 @@ def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer): ) ) - checkpoints.sort(key=lambda checkpoint: checkpoint["priority"]) + # get the most recent valid checkpoint + checkpoints.sort(key=lambda checkpoint: checkpoint.priority) if len(checkpoints) == 0: return "" logger.info( f"The following checkpoints were found to be ready to load: {str(checkpoints)}" ) + checkpoint = checkpoints[-1] - selected = checkpoints[-1] - if "run_before_loading" in selected: - selected["before_loading"](cfg, selected) + _ = [hook(cfg, checkpoint) for hook in checkpoint.run_before_loading] - if selected["storage_type"] == "local": - return selected["path"] + if checkpoint.storage_type == "local": + return checkpoint.path + # copy cloud checkpoints to a local tmp_dir local_tmp_dir = os.path.join( cfg.local_checkpoints_dir, f"checkpoint_tmp{suffix}.pt" ) - logger.info(f"Copying checkpoint from {selected['path']} -> {local_tmp_dir}") - if selected["storage_type"] == "nfs": - shutil.copyfile(selected["path"], local_tmp_dir) - elif selected["storage_type"] == "azure_blob": - azure_utils.download_specific_ckpt(selected["path"], local_tmp_dir) + logger.info(f"Copying checkpoint from {checkpoint.path} -> {local_tmp_dir}") + if checkpoint.storage_type == "nfs": + shutil.copyfile(checkpoint.path, local_tmp_dir) + elif checkpoint.storage_type == "azure_blob": + azure_utils.download_specific_ckpt(checkpoint.path, local_tmp_dir) return local_tmp_dir - # based on the storage type, process and return final path - def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: # Logic flow: @@ -941,3 +943,11 @@ def _get_pad_info(state_dict: Dict) -> Dict[str, int]: assert full_key not in res, f"collision: {full_key} already in {res}" res[full_key] = v["padding"] return res + + +@dataclass +class CheckpointPath: + path: str + storage_type: str + priority: float = 0 + run_before_loading: list = field(default_factory=list) From 2ddf1fb322e2997bf92c9972d4c4c8e9f265d658 Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Tue, 29 Nov 2022 05:53:28 -0800 Subject: [PATCH 12/24] remove old load function, add type hints, move out internal import --- metaseq/checkpoint_utils.py | 196 ++++-------------------------------- 1 file changed, 21 insertions(+), 175 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index d62e2a6f8..2df89c6c9 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -30,6 +30,14 @@ OPT_KEY = "last_optimizer_state" +try: + from metaseq_internal import azure_utils +except ImportError: + logger.warning( + "Proceeding without metaseq-internal installed! Please check if you need this!" + "It is required for loading from azure blob." + ) + def save_checkpoint( cfg: CheckpointConfig, @@ -149,7 +157,7 @@ def _delete_old_checkpoint_files( os.remove(old_chk) -def get_storage_type(path): +def get_storage_type(path: str) -> str: if path.startswith("nfs:"): return "nfs" elif "windows.net" in path: @@ -158,7 +166,7 @@ def get_storage_type(path): return "local" -def get_checkpoint_steps(path): +def get_checkpoint_steps(path: str) -> int: match = re.search(r"checkpoint_(\d+)", path) if match[1] is None: return 0 @@ -166,8 +174,8 @@ def get_checkpoint_steps(path): def get_all_checkpoints_from_directory( - directory, suffix, increased_priority, storage_type -): + directory: str, suffix: str, increased_priority: float, storage_type: str +) -> List[CheckpointPath]: checkpoints = [] for candidate in os.listdir(directory): steps = get_checkpoint_steps(candidate) @@ -200,7 +208,9 @@ def get_all_checkpoints_from_directory( return checkpoints -def get_recent_checkpoint_from_azure_blob(blob_url, suffix, increased_priority): +def get_recent_checkpoint_from_azure_blob( + blob_url, suffix, increased_priority +) -> List[CheckpointPath]: file_to_load = azure_utils.get_most_recent_ckpt(blob_url, suffix) if file_to_load is None: return [] @@ -214,7 +224,9 @@ def get_recent_checkpoint_from_azure_blob(blob_url, suffix, increased_priority): ] -def get_checkpoint_to_finetune(finetune_path, suffix, priority): +def get_checkpoint_to_finetune( + finetune_path: str, suffix: str, priority: float +) -> CheckpointPath: if PathManager.exists(finetune_path): validated_path = finetune_path else: # check for sharded version @@ -233,7 +245,7 @@ def get_checkpoint_to_finetune(finetune_path, suffix, priority): ) -def reset_for_finetuning(cfg, checkpoint): +def reset_for_finetuning(cfg: CheckpointConfig, checkpoint: CheckpointPath) -> None: cfg.reset_optimizer = True cfg.reset_lr_scheduler = True cfg.reset_meters = True @@ -287,7 +299,7 @@ def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: ) # get the most recent valid checkpoint - checkpoints.sort(key=lambda checkpoint: checkpoint.priority) + checkpoints.sort(key=lambda ckpt: ckpt.priority) if len(checkpoints) == 0: return "" logger.info( @@ -300,7 +312,7 @@ def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: if checkpoint.storage_type == "local": return checkpoint.path - # copy cloud checkpoints to a local tmp_dir + # copy cloud checkpoints to a local temporary file local_tmp_dir = os.path.join( cfg.local_checkpoints_dir, f"checkpoint_tmp{suffix}.pt" ) @@ -314,172 +326,6 @@ def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: return local_tmp_dir -def get_and_prep_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: - # Logic flow: - # - if no restore_file: < try to grab latest from checkpoint / cloud > - # - if cloud upload is defined, pull from cloud upload - # - if no cloud upload, pull from checkpoint - - suffix = trainer.checkpoint_suffix - - if cfg.restore_file is None: - checkpoint_path_to_load = os.path.join( - cfg.local_checkpoints_dir, "checkpoint_last{}.pt".format(suffix) - ) - # move this out - # first_launch = not PathManager.exists(checkpoint_path_to_load) - # if cfg.finetune_from_model is not None and first_launch: - # # if there is no last checkpoint to restore, start the finetune from pretrained model - # # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. - # cfg.reset_optimizer = True - # cfg.reset_lr_scheduler = True - # cfg.reset_meters = True - # cfg.reset_dataloader = True - # checkpoint_path_to_load = None - # logger.warning( - # "Resetting optimizer, lr scheduler, meters, and dataloader for fine-tuning!" - # ) - # if PathManager.exists(cfg.finetune_from_model): - # return cfg.finetune_from_model - # elif suffix is not None: # check for sharded version - # sharded_path = cfg.finetune_from_model.replace(".pt", suffix + ".pt") - # if PathManager.exists(sharded_path): - # return sharded_path - # else: - # raise ValueError( - # f"--finetune-from-model {cfg.finetune_from_model} does not exist either as is or sharded" - # ) - else: # restore_file specified - if suffix is not None: - checkpoint_path_to_load = cfg.restore_file.replace(".pt", suffix + ".pt") - else: - checkpoint_path_to_load = cfg.restore_file - - if cfg.restore_file is not None and cfg.finetune_from_model: - raise ValueError( - "--finetune-from-model and --restore-file (non-default value) " - "can not be specified together: " + str(cfg) - ) - - # Azure logic - try: - from metaseq_internal import azure_utils - - has_metaseq_internal = True - except ImportError: - has_metaseq_internal = False - logger.warning( - "Proceeding without metaseq-internal installed! Please check if you need this!" - ) - - # TODO(susanz): fix all of this spagetti, split out logic by env - # Note that we compare by value since ComputeEnvs may be imported from metaseq_internal - - if cfg.cloud_upload_path: - - # RSC NFS LOGIC - if cfg.cloud_upload_path.startswith("nfs:"): - checkpoint_path_to_load = os.path.join( - cfg.local_checkpoints_dir, "checkpoint_last{}.pt".format(suffix) - ) - nfs_path = cfg.cloud_upload_path[4:] - filename = None - specific_restore_file_provided = cfg.restore_file is not None - slurm_was_restarted = int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 - # --cloud_upload_path, slurm restart or no (finetune and restart file) - restart_from_latest = slurm_was_restarted or ( - cfg.finetune_from_model is None and not specific_restore_file_provided - ) - if restart_from_latest: - checkpoints = [] - expected_file_count = distributed_utils.get_global_world_size() - for candidate in os.listdir(nfs_path): - if candidate == "checkpoint_last": - raise RuntimeError( - "trying to restart a job that already wrote checkpoint_last" - ) - m = re.match(r"checkpoint_(\d+)", candidate) - if m: - checkpoints.append((int(m[1]), candidate)) - for _, candidate in sorted(checkpoints, reverse=True): - present_files = len( - [ - f - for f in os.listdir(os.path.join(nfs_path, candidate)) - if not f.startswith("_") - ] - ) - if present_files == expected_file_count: - filename = os.path.join( - nfs_path, candidate, f"checkpoint{suffix}.pt" - ) - break - logger.info( - f"skipping checkpoint {candidate} because it only has" - f" {present_files} files (expected {expected_file_count})" - ) - else: - # --cloud_upload_path, no slurm restart, or a finetune or restart file) -> use restore_file - filename = ( - cfg.restore_file.replace(".pt", suffix + ".pt") - if cfg.restore_file is not None - else None - ) - if filename is not None: - # rsc nfs copying - logger.info( - f"Copying checkpoint from nfs {filename} -> {checkpoint_path_to_load}" - ) - shutil.copyfile(filename, checkpoint_path_to_load) - else: - logger.info(f"No NFS checkpoints found") - - # AZURE LOGIC - elif cfg.cluster_env == ComputeEnvs.AZURE.value and has_metaseq_internal: - if ( - # --restore-file was not passed, always download latest checkpoint - (cfg.restore_file is None and cfg.finetune_from_model is None) - # --restore-file was passed, but we requeued, so download latest checkpoint - or int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 - ): - # download checkpoint into local save_dir - checkpoint_path_to_load = os.path.join( - cfg.local_checkpoints_dir, "checkpoint_last{}.pt".format(suffix) - ) - azure_utils.download_recent_ckpt( - cfg.cloud_upload_path, checkpoint_path_to_load, suffix + ".pt" - ) - elif ( - # --restore-file was passed and is a blob URL, download that checkpoint - cfg.restore_file is not None - and "windows.net" in cfg.restore_file - ): - blob_url = cfg.restore_file.replace(".pt", suffix + ".pt") - # download checkpoint into local save_dir - checkpoint_path_to_load = os.path.join( - cfg.local_checkpoints_dir, "checkpoint_last{}.pt".format(suffix) - ) - azure_utils.download_specific_ckpt(blob_url, checkpoint_path_to_load) - else: - logger.info( - f"Using checkpoint {checkpoint_path_to_load} even while on Azure" - ) - - # RSC logic: --restore-file was passed, and we requeued - # no cloud upload path specified - elif ( - cfg.restore_file is not None - and int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 - ): - # point checkpoint_path to the current checkpoint directory for loading, if it exists. - save_dir_last = os.path.join( - cfg.local_checkpoints_dir, "checkpoint_last{}.pt".format(suffix) - ) - if PathManager.isfile(save_dir_last): - checkpoint_path_to_load = save_dir_last - return checkpoint_path_to_load - - def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): """ Load a checkpoint and restore the training iterator. From d46404ab02b9670995447d66ae464e530ea3f7a0 Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Tue, 29 Nov 2022 06:16:21 -0800 Subject: [PATCH 13/24] changed naming of prio --- metaseq/checkpoint_utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 2df89c6c9..91effc799 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -174,7 +174,7 @@ def get_checkpoint_steps(path: str) -> int: def get_all_checkpoints_from_directory( - directory: str, suffix: str, increased_priority: float, storage_type: str + directory: str, suffix: str, add_priority: float, storage_type: str ) -> List[CheckpointPath]: checkpoints = [] for candidate in os.listdir(directory): @@ -202,14 +202,14 @@ def get_all_checkpoints_from_directory( CheckpointPath( path=os.path.join(directory, candidate, f"checkpoint{suffix}.pt"), storage_type=storage_type, - priority=steps + increased_priority, + priority=steps + add_priority, ) ) return checkpoints def get_recent_checkpoint_from_azure_blob( - blob_url, suffix, increased_priority + blob_url, suffix, add_priority ) -> List[CheckpointPath]: file_to_load = azure_utils.get_most_recent_ckpt(blob_url, suffix) if file_to_load is None: @@ -219,7 +219,7 @@ def get_recent_checkpoint_from_azure_blob( CheckpointPath( path=blob_url + "/" + file_to_load, storage_type="azure_blob", - priority=steps + increased_priority, + priority=steps + add_priority, ) ] @@ -281,20 +281,20 @@ def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: get_all_checkpoints_from_directory( cfg.cloud_upload_path[4:], suffix, - increased_priority=0.2, + add_priority=0.2, storage_type="nfs", ) ) elif cloud_storage_type == "azure_blob": checkpoints.extend( get_recent_checkpoint_from_azure_blob( - cfg.cloud_upload_path, suffix, increased_priority=0.2 + cfg.cloud_upload_path, suffix, add_priority=0.2 ) ) checkpoints.extend( get_all_checkpoints_from_directory( - cfg.local_checkpoints_dir, increased_priority=0.3, storage_type="local" + cfg.local_checkpoints_dir, add_priority=0.3, storage_type="local" ) ) From 725bfb8e9ad6a513915c5c7716bc6852741757fb Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Tue, 29 Nov 2022 06:25:49 -0800 Subject: [PATCH 14/24] change naming for epoch checkpoints to include num steps --- metaseq/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 91effc799..a960dc21c 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -84,7 +84,7 @@ def is_better(a, b): and updates % cfg.save_interval_updates == 0 ) - checkpoint_conds[f"checkpoint{epoch}{suffix}.pt"] = save_for_epoch + checkpoint_conds[f"checkpoint_{updates}_epoch_{epoch}{suffix}.pt"] = save_for_epoch checkpoint_conds[f"checkpoint_{updates}{suffix}.pt"] = save_for_updates checkpoint_conds[f"checkpoint_last{suffix}.pt"] = ( training_finished and cfg.save_last_checkpoint From 5e6218eaf208071edd3098c4eed1ecc141bc849c Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Tue, 29 Nov 2022 06:46:01 -0800 Subject: [PATCH 15/24] add local caching by including num steps in cache file name --- metaseq/checkpoint_utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index a960dc21c..7e8678031 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -312,18 +312,19 @@ def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: if checkpoint.storage_type == "local": return checkpoint.path - # copy cloud checkpoints to a local temporary file - local_tmp_dir = os.path.join( - cfg.local_checkpoints_dir, f"checkpoint_tmp{suffix}.pt" + # copy cloud checkpoints to a local cache file + local_cache_dir = os.path.join( + cfg.local_checkpoints_dir, + f"cached_checkpoint_{checkpoint.priority}{suffix}.pt", ) - logger.info(f"Copying checkpoint from {checkpoint.path} -> {local_tmp_dir}") + logger.info(f"Copying checkpoint from {checkpoint.path} -> {local_cache_dir}") if checkpoint.storage_type == "nfs": - shutil.copyfile(checkpoint.path, local_tmp_dir) + shutil.copyfile(checkpoint.path, local_cache_dir) elif checkpoint.storage_type == "azure_blob": - azure_utils.download_specific_ckpt(checkpoint.path, local_tmp_dir) + azure_utils.download_specific_ckpt(checkpoint.path, local_cache_dir) - return local_tmp_dir + return local_cache_dir def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): From 24b4ab880065cc03c5b7be620873bdc71a862b4b Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Fri, 6 Jan 2023 07:34:13 -0800 Subject: [PATCH 16/24] move checkpoint path out --- metaseq/checkpoint_utils.py | 12 +----------- metaseq/dataclass/configs.py | 30 +++++++++++++++++++++--------- metaseq/dataclass/utils.py | 9 ++++++++- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 7e8678031..04461178f 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -13,14 +13,12 @@ import socket import shutil from typing import Any, Dict, List, Optional, Tuple -from dataclasses import dataclass, field - import torch from omegaconf import OmegaConf from metaseq.dataclass.configs import CheckpointConfig -from metaseq.dataclass.utils import overwrite_args_by_name +from metaseq.dataclass.utils import overwrite_args_by_name, CheckpointPath from metaseq.distributed import utils as distributed_utils from metaseq.file_io import PathManager, torch_load_cpu from metaseq.launcher.opt_job_constants import ComputeEnvs @@ -790,11 +788,3 @@ def _get_pad_info(state_dict: Dict) -> Dict[str, int]: assert full_key not in res, f"collision: {full_key} already in {res}" res[full_key] = v["padding"] return res - - -@dataclass -class CheckpointPath: - path: str - storage_type: str - priority: float = 0 - run_before_loading: list = field(default_factory=list) diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index 0b6dc9bae..8d8406f4f 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -440,6 +440,26 @@ class CheckpointConfig(MetaseqDataclass): # save_dir: str = field( # default="None", metadata={"help": "path to save checkpoints"} # ) + local_checkpoints_dir: Optional[str] = field( + default=None, + metadata={ + "help": ( + "local directory for checkpoints" + ), + }, + ) + + + cloud_upload_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Upload checkpoints asynchronously in a separate " + "thread to blob store. NOTE: This feature is currently being tested." + ), + }, + ) + restore_file: Optional[str] = field( default=None, # Used to be: "checkpoint_last.pt" metadata={"help": "filename from which to load checkpoint"}, @@ -544,15 +564,7 @@ class CheckpointConfig(MetaseqDataclass): ), }, ) - # cloud_upload_path: Optional[str] = field( - # default=None, - # metadata={ - # "help": ( - # "Upload checkpoints asynchronously in a separate " - # "thread to blob store. NOTE: This feature is currently being tested." - # ), - # }, - # ) + # TODO(susanz): After https://github.com/fairinternal/fairseq-big-internal/issues/22 is tackled, modify this # to use ComputeEnvs constant cluster_env: str = field( diff --git a/metaseq/dataclass/utils.py b/metaseq/dataclass/utils.py index 05f31df95..4b4d21138 100644 --- a/metaseq/dataclass/utils.py +++ b/metaseq/dataclass/utils.py @@ -9,7 +9,7 @@ import os import re from argparse import ArgumentError, ArgumentParser, Namespace -from dataclasses import _MISSING_TYPE, MISSING +from dataclasses import dataclass, field, _MISSING_TYPE, MISSING from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type @@ -473,3 +473,10 @@ def merge_with_parent(dc: MetaseqDataclass, cfg: MetaseqDataclass): merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"] OmegaConf.set_struct(merged_cfg, True) return merged_cfg + +@dataclass +class CheckpointPath: + path: str + storage_type: str + priority: float = 0 + run_before_loading: list = field(default_factory=list) From 243c327983e54fd19a75e4e0abeb6233e516198a Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Fri, 20 Jan 2023 03:54:45 -0800 Subject: [PATCH 17/24] add types and clean --- metaseq/checkpoint_utils.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 1209a98cd..a22804e4d 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -160,7 +160,7 @@ def get_all_checkpoints_from_directory( def get_recent_checkpoint_from_azure_blob( - blob_url, suffix, add_priority + blob_url: str, suffix: str, add_priority: float ) -> List[CheckpointPath]: file_to_load = azure_utils.get_most_recent_ckpt(blob_url, suffix) if file_to_load is None: @@ -196,13 +196,6 @@ def get_checkpoint_to_finetune( ) - has_metaseq_internal = True - except ImportError: - has_metaseq_internal = False - logger.warning( - "Resetting optimizer, lr scheduler, meters, and dataloader for fine-tuning!" - ) - def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: suffix = trainer.checkpoint_suffix From c3396b83f08fc15f737fdb6d31fe9e4339dd1fbe Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Wed, 1 Feb 2023 11:54:22 -0800 Subject: [PATCH 18/24] fixes --- metaseq/checkpoint_utils.py | 25 ++++++++++++++++--------- metaseq/dataclass/configs.py | 6 +++--- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index a22804e4d..cebf7d350 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -119,7 +119,7 @@ def get_storage_type(path: str) -> str: def get_checkpoint_steps(path: str) -> int: match = re.search(r"checkpoint_(\d+)", path) - if match[1] is None: + if match is None: return 0 return int(match[1]) @@ -127,6 +127,8 @@ def get_checkpoint_steps(path: str) -> int: def get_all_checkpoints_from_directory( directory: str, suffix: str, add_priority: float, storage_type: str ) -> List[CheckpointPath]: + # from metaseq.pdb import set_trace_rank0 + # set_trace_rank0() checkpoints = [] for candidate in os.listdir(directory): steps = get_checkpoint_steps(candidate) @@ -141,11 +143,14 @@ def get_all_checkpoints_from_directory( if not f.startswith("_") ] ) + logger.info(f"{present_files} parts found for this checkpoint") if present_files != expected_file_count: logger.info( - f"skipping checkpoint {candidate} because it only has" + f"skipping checkpoint {candidate} in {directory} because it only has" f" {present_files} files (expected {expected_file_count})" ) + present_files = [ f for f in os.listdir(os.path.join(directory, candidate)) if not f.startswith("_")] + logger.info(str(present_files)) continue # TODO validate this @@ -236,7 +241,7 @@ def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: checkpoints.extend( get_all_checkpoints_from_directory( - cfg.local_checkpoints_dir, add_priority=0.3, storage_type="local" + cfg.save_dir, suffix, add_priority=0.3, storage_type="local" ) ) @@ -256,17 +261,19 @@ def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: # copy cloud checkpoints to a local cache file local_cache_dir = os.path.join( - cfg.local_checkpoints_dir, - f"cached_checkpoint_{checkpoint.priority}{suffix}.pt", + cfg.save_dir, + f"cached_checkpoint_{int(checkpoint.priority)}" ) + os.makedirs(local_cache_dir, exist_ok=True) + local_cache_file = os.path.join(local_cache_dir, f"checkpoint{suffix}.pt") - logger.info(f"Copying checkpoint from {checkpoint.path} -> {local_cache_dir}") + logger.info(f"Copying checkpoint from {checkpoint.path} -> {local_cache_file}") if checkpoint.storage_type == "nfs": - shutil.copyfile(checkpoint.path, local_cache_dir) + shutil.copyfile(checkpoint.path, local_cache_file) elif checkpoint.storage_type == "azure_blob": - azure_utils.download_specific_ckpt(checkpoint.path, local_cache_dir) + azure_utils.download_specific_ckpt(checkpoint.path, local_cache_file) - return local_cache_dir + return local_cache_file def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index d5da9bb39..e0db99cc9 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -449,11 +449,11 @@ class CheckpointConfig(MetaseqDataclass): save_dir: str = field( default="checkpoints", metadata={"help": "path to save checkpoints"} ) - restore_file: str = field( - default="checkpoint_last.pt", + restore_file: Optional[str] = field( + default=None, metadata={ "help": "filename from which to load checkpoint " - "(default: /checkpoint_last.pt" + "in the form nfs:path/to/dir/checkpoint.pt" }, ) finetune_from_model: Optional[str] = field( From 58dcfb2d451c8872497f48dad19a37be5430be0d Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Mon, 6 Feb 2023 04:57:19 -0800 Subject: [PATCH 19/24] two paths for local vs nfs checkpoints --- metaseq/checkpoint_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index cebf7d350..4b7a51b6f 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -123,18 +123,19 @@ def get_checkpoint_steps(path: str) -> int: return 0 return int(match[1]) +def get_all_checkpoints_from_directory_with_subdirs(directory: str, suffix: str, add_priority: float, storage_type: str): -def get_all_checkpoints_from_directory( + + +def get_all_checkpoints_from_directory_with_subdirs( directory: str, suffix: str, add_priority: float, storage_type: str ) -> List[CheckpointPath]: - # from metaseq.pdb import set_trace_rank0 - # set_trace_rank0() checkpoints = [] for candidate in os.listdir(directory): steps = get_checkpoint_steps(candidate) if steps == 0: continue - # TODO needs to be adapated for local dir + # TODO needs to be adapated for local dir, which has only num_gpu checkpoints expected_file_count = distributed_utils.get_global_world_size() present_files = len( [ From 65c03bc160290df0a424b28c6d95159c79be0394 Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Fri, 10 Feb 2023 08:36:43 -0800 Subject: [PATCH 20/24] add some debugging and load local checkpoints --- metaseq/checkpoint_utils.py | 41 +++++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 4b7a51b6f..660b12fef 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -123,11 +123,8 @@ def get_checkpoint_steps(path: str) -> int: return 0 return int(match[1]) -def get_all_checkpoints_from_directory_with_subdirs(directory: str, suffix: str, add_priority: float, storage_type: str): - - -def get_all_checkpoints_from_directory_with_subdirs( +def get_all_checkpoints_from_directory( directory: str, suffix: str, add_priority: float, storage_type: str ) -> List[CheckpointPath]: checkpoints = [] @@ -135,7 +132,36 @@ def get_all_checkpoints_from_directory_with_subdirs( steps = get_checkpoint_steps(candidate) if steps == 0: continue - # TODO needs to be adapated for local dir, which has only num_gpu checkpoints + + # in scratch saved files are in this form: checkpoint_180-model_part-0-shard0.pt + if candidate.endswith(".pt"): + logger.info("is .pt file") + if not suffix in candidate: + continue + logger.info(suffix) + logger.info(candidate) + checkpoints.append( + CheckpointPath( + path=os.path.join(directory, candidate), + storage_type=storage_type, + priority=steps + add_priority, + ) + ) + # delete this all + prefix = candidate.split(suffix)[0] + counter = 0 + for other_file in os.listdir(directory): + if prefix in other_file: + logger.info(other_file) + counter += 1 + + logger.info(f"Files found for it: {counter}") + expected_file_count = distributed_utils.get_global_world_size() + logger.info(f"World size: {expected_file_count}") + + continue + + # nfs and cached files look like this: checkpoint_180/checkpoint-model_part-0-shard0.pt expected_file_count = distributed_utils.get_global_world_size() present_files = len( [ @@ -144,7 +170,7 @@ def get_all_checkpoints_from_directory_with_subdirs( if not f.startswith("_") ] ) - logger.info(f"{present_files} parts found for this checkpoint") + logger.info(f"{present_files} parts found for {candidate} checkpoint") if present_files != expected_file_count: logger.info( f"skipping checkpoint {candidate} in {directory} because it only has" @@ -154,7 +180,6 @@ def get_all_checkpoints_from_directory_with_subdirs( logger.info(str(present_files)) continue - # TODO validate this checkpoints.append( CheckpointPath( path=os.path.join(directory, candidate, f"checkpoint{suffix}.pt"), @@ -288,7 +313,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): checkpoint_path_to_load = prepare_local_checkpoint_path(cfg, trainer) logger.info(f"attempting to load checkpoint from: {checkpoint_path_to_load}") - + logger.info("V3") # make sure everyone is done downloading their checkpoints before we load distributed_utils.global_barrier() From e2efa2df31e8f487147aa5d22b3072aa6bb975de Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Fri, 10 Feb 2023 09:57:34 -0800 Subject: [PATCH 21/24] more cleanup, add reset_for_finetuning back in --- metaseq/checkpoint_utils.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 660b12fef..2a43d2dee 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -138,8 +138,6 @@ def get_all_checkpoints_from_directory( logger.info("is .pt file") if not suffix in candidate: continue - logger.info(suffix) - logger.info(candidate) checkpoints.append( CheckpointPath( path=os.path.join(directory, candidate), @@ -147,18 +145,6 @@ def get_all_checkpoints_from_directory( priority=steps + add_priority, ) ) - # delete this all - prefix = candidate.split(suffix)[0] - counter = 0 - for other_file in os.listdir(directory): - if prefix in other_file: - logger.info(other_file) - counter += 1 - - logger.info(f"Files found for it: {counter}") - expected_file_count = distributed_utils.get_global_world_size() - logger.info(f"World size: {expected_file_count}") - continue # nfs and cached files look like this: checkpoint_180/checkpoint-model_part-0-shard0.pt @@ -170,14 +156,11 @@ def get_all_checkpoints_from_directory( if not f.startswith("_") ] ) - logger.info(f"{present_files} parts found for {candidate} checkpoint") if present_files != expected_file_count: logger.info( f"skipping checkpoint {candidate} in {directory} because it only has" f" {present_files} files (expected {expected_file_count})" ) - present_files = [ f for f in os.listdir(os.path.join(directory, candidate)) if not f.startswith("_")] - logger.info(str(present_files)) continue checkpoints.append( @@ -227,6 +210,15 @@ def get_checkpoint_to_finetune( ) +def reset_for_finetuning(cfg, checkpoint): + cfg.reset_optimizer = True + cfg.reset_lr_scheduler = True + cfg.reset_meters = True + cfg.reset_dataloader = True + logger.info( + "Resetting optimizer, lr scheduler, meters, and dataloader for fine-tuning!" + ) + def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: suffix = trainer.checkpoint_suffix @@ -313,7 +305,6 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): checkpoint_path_to_load = prepare_local_checkpoint_path(cfg, trainer) logger.info(f"attempting to load checkpoint from: {checkpoint_path_to_load}") - logger.info("V3") # make sure everyone is done downloading their checkpoints before we load distributed_utils.global_barrier() From cbff07cb3d2c2fe65bd6433220173048f6bd4b26 Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Tue, 14 Feb 2023 06:04:54 -0800 Subject: [PATCH 22/24] run black . --- metaseq/checkpoint_utils.py | 3 +-- metaseq/dataclass/utils.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 2a43d2dee..3e55d7d68 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -279,8 +279,7 @@ def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: # copy cloud checkpoints to a local cache file local_cache_dir = os.path.join( - cfg.save_dir, - f"cached_checkpoint_{int(checkpoint.priority)}" + cfg.save_dir, f"cached_checkpoint_{int(checkpoint.priority)}" ) os.makedirs(local_cache_dir, exist_ok=True) local_cache_file = os.path.join(local_cache_dir, f"checkpoint{suffix}.pt") diff --git a/metaseq/dataclass/utils.py b/metaseq/dataclass/utils.py index 4b4d21138..562dd8b50 100644 --- a/metaseq/dataclass/utils.py +++ b/metaseq/dataclass/utils.py @@ -474,6 +474,7 @@ def merge_with_parent(dc: MetaseqDataclass, cfg: MetaseqDataclass): OmegaConf.set_struct(merged_cfg, True) return merged_cfg + @dataclass class CheckpointPath: path: str From 0bde2fe497c2638a8d1cc95d5ecf184779fc7338 Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Tue, 14 Feb 2023 06:23:41 -0800 Subject: [PATCH 23/24] flake8 --- metaseq/checkpoint_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 3d2c17c91..fa5e765cc 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -19,7 +19,6 @@ from metaseq.dataclass.utils import overwrite_args_by_name, CheckpointPath from metaseq.distributed import utils as distributed_utils from metaseq.file_io import PathManager, torch_load_cpu -from metaseq.launcher.opt_job_constants import ComputeEnvs logger = logging.getLogger(__name__) @@ -205,7 +204,7 @@ def get_checkpoint_to_finetune( validated_path = sharded_path if validated_path is None: raise ValueError( - f"--finetune-from-model {cfg.finetune_from_model} does not exist either as is or sharded" + f"--finetune-from-model {finetune_path} does not exist either as is or sharded" ) return CheckpointPath( path=validated_path, From 7991503527824650c94f919ea3c0d5cd4735df8f Mon Sep 17 00:00:00 2001 From: Peter Albert Date: Tue, 14 Feb 2023 08:36:19 -0800 Subject: [PATCH 24/24] flake 8 --- metaseq/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index fa5e765cc..5b53eec44 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -140,7 +140,7 @@ def get_all_checkpoints_from_directory( # in scratch saved files are in this form: checkpoint_180-model_part-0-shard0.pt if candidate.endswith(".pt"): logger.info("is .pt file") - if not suffix in candidate: + if suffix not in candidate: continue checkpoints.append( CheckpointPath(