Skip to content

Commit

Permalink
refactor easier step
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 28, 2024
1 parent 410bac4 commit f49635c
Showing 1 changed file with 1 addition and 8 deletions.
9 changes: 1 addition & 8 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def save(self, ckpt_path: str) -> None:

time_start = time.perf_counter()

ckpt_path = self._get_ckpt_folder_name(ckpt_path, self.training_progress.step)
ckpt_path = os.path.join(ckpt_path, f"step_{self.training_progress.step}")
catch_warning = self._logger.getEffectiveLevel() <= logging.INFO
# pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907
# we can ignore it if we are not logging in DEBUG mode
Expand Down Expand Up @@ -166,10 +166,3 @@ def load(self, resume_ckpt_path: str) -> None:
self.dataloader.load_state_dict(rank_state_dict["data_loader"])

self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds")

@staticmethod
def _get_ckpt_folder_name(ckpt_path: str, step: int) -> str:
"""
The ckpt folder can contains multiple ckpt with different step name. This function return the sub directory name for the ckpt with the given step.
"""
return os.path.join(ckpt_path, f"step_{step}")

0 comments on commit f49635c

Please sign in to comment.