From 744e69f8becb2f10c8b3debeb771dd70b13834ba Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 29 Sep 2024 22:25:41 +0000 Subject: [PATCH] save into dioco scpeific folder --- src/zeroband/checkpoint.py | 29 ++++++++++++++++++++++++----- src/zeroband/utils/world_info.py | 4 ++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index 59158cbd..56eb8bdb 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -124,7 +124,9 @@ def __init__( self.diloco_offloaded_param_list = diloco_offloaded_param_list # even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state. # main reason is that we actually don't a cpu model but just a list of cpu parameters. - self.states["diloco_offloaded_optimizer"] = self.diloco_offloaded_optimizer + self.diloco_states = {"optimizer": self.diloco_offloaded_optimizer} + else: + self.diloco_states = {} self.process_group = process_group self._logger = get_logger() @@ -145,7 +147,7 @@ def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None: # pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907 # we can ignore it if we are not logging in DEBUG mode - rank = get_world_info().local_rank + world_info = get_world_info() with warnings.catch_warnings(): if catch_warning: @@ -153,8 +155,14 @@ def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None: dcp.save(self.states, process_group=self.process_group, checkpoint_id=ckpt_path) + if self.diloco_states: + diloco_ckpt_path = os.path.join(ckpt_path, f"diloco_{world_info.diloco_rank}") + dcp.save(self.diloco_states, process_group=self.process_group, checkpoint_id=diloco_ckpt_path) ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk - with open(os.path.join(ckpt_path, f"__{rank}_0.pt"), "wb") as f: + + # dataloader is different for each diloco worker. If diloco is enable we use diloco_ckpt_path + dataloader_path = ckpt_path if self.diloco_states is None else diloco_ckpt_path + with open(os.path.join(dataloader_path, f"__{world_info.local_rank}_0.pt"), "wb") as f: torch.save({"data_loader": self.dataloader.state_dict()}, f) self._logger.info(f"Saved checkpoint to {ckpt_path} in {time.perf_counter() - time_start} seconds") @@ -204,10 +212,21 @@ def load(self, resume_ckpt_path: str) -> None: self.states = dcp.load(self.states, process_group=self.process_group, checkpoint_id=resume_ckpt_path) - rank = get_world_info().local_rank # todo check after on/off ramping pr which rank is good here + world_info = get_world_info() + + self._logger.debug(msg=f"diloco_states {self.diloco_states}") + if self.diloco_states: + resume_ckpt_path_diloco = os.path.join(resume_ckpt_path, f"diloco_{world_info.diloco_rank}") + dcp.load(self.diloco_states, process_group=self.process_group, checkpoint_id=resume_ckpt_path_diloco) + + self._logger.debug(msg=f"postdiloco_states {self.diloco_states}") ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk - with open(os.path.join(resume_ckpt_path, f"__{rank}_0.pt"), "rb") as f: + + # dataloader is different for each diloco worker. If diloco is enable we use diloco_ckpt_path + dataloader_path = resume_ckpt_path if self.diloco_states is None else resume_ckpt_path_diloco + self._logger.debug(f"loading dataloader from {dataloader_path}") + with open(os.path.join(dataloader_path, f"__{world_info.local_rank}_0.pt"), "rb") as f: rank_state_dict = torch.load(f) self.dataloader.load_state_dict(rank_state_dict["data_loader"]) diff --git a/src/zeroband/utils/world_info.py b/src/zeroband/utils/world_info.py index 9b73f328..f7c6548b 100644 --- a/src/zeroband/utils/world_info.py +++ b/src/zeroband/utils/world_info.py @@ -27,6 +27,10 @@ def __init__(self): def __repr__(self): return f"WorldInfo(world_size={self.world_size}, rank={self.rank}, local_rank={self.local_rank}, local_world_size={self.local_world_size}, nnodes={self.nnodes}, global_unique_id={self.global_unique_id}, global_addr={self.global_addr}, global_port={self.global_port}, global_world_size={self.global_world_size}, global_rank={self.global_rank})" + @property + def diloco_rank(self): + return self.rank // self.local_world_size + def get_world_info() -> WorldInfo: """