diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index c06bebdb..95a6aae6 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -124,7 +124,7 @@ def save(self, ckpt_path: str) -> None: time_start = time.perf_counter() - ckpt_path = self._get_ckpt_folder_name(ckpt_path, self.training_progress.step) + ckpt_path = os.path.join(ckpt_path, f"step_{self.training_progress.step}") catch_warning = self._logger.getEffectiveLevel() <= logging.INFO # pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907 # we can ignore it if we are not logging in DEBUG mode @@ -166,10 +166,3 @@ def load(self, resume_ckpt_path: str) -> None: self.dataloader.load_state_dict(rank_state_dict["data_loader"]) self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds") - - @staticmethod - def _get_ckpt_folder_name(ckpt_path: str, step: int) -> str: - """ - The ckpt folder can contains multiple ckpt with different step name. This function return the sub directory name for the ckpt with the given step. - """ - return os.path.join(ckpt_path, f"step_{step}")