Skip to content

Commit

Permalink
save into dioco scpeific folder
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 29, 2024
1 parent ec499ce commit 744e69f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
29 changes: 24 additions & 5 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def __init__(
self.diloco_offloaded_param_list = diloco_offloaded_param_list
# even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state.
# main reason is that we actually don't a cpu model but just a list of cpu parameters.
self.states["diloco_offloaded_optimizer"] = self.diloco_offloaded_optimizer
self.diloco_states = {"optimizer": self.diloco_offloaded_optimizer}
else:
self.diloco_states = {}

self.process_group = process_group
self._logger = get_logger()
Expand All @@ -145,16 +147,22 @@ def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None:
# pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907
# we can ignore it if we are not logging in DEBUG mode

rank = get_world_info().local_rank
world_info = get_world_info()

with warnings.catch_warnings():
if catch_warning:
warnings.simplefilter("ignore")

dcp.save(self.states, process_group=self.process_group, checkpoint_id=ckpt_path)

if self.diloco_states:
diloco_ckpt_path = os.path.join(ckpt_path, f"diloco_{world_info.diloco_rank}")
dcp.save(self.diloco_states, process_group=self.process_group, checkpoint_id=diloco_ckpt_path)
## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk
with open(os.path.join(ckpt_path, f"__{rank}_0.pt"), "wb") as f:

# dataloader is different for each diloco worker. If diloco is enable we use diloco_ckpt_path
dataloader_path = ckpt_path if self.diloco_states is None else diloco_ckpt_path
with open(os.path.join(dataloader_path, f"__{world_info.local_rank}_0.pt"), "wb") as f:
torch.save({"data_loader": self.dataloader.state_dict()}, f)

self._logger.info(f"Saved checkpoint to {ckpt_path} in {time.perf_counter() - time_start} seconds")
Expand Down Expand Up @@ -204,10 +212,21 @@ def load(self, resume_ckpt_path: str) -> None:

self.states = dcp.load(self.states, process_group=self.process_group, checkpoint_id=resume_ckpt_path)

rank = get_world_info().local_rank # todo check after on/off ramping pr which rank is good here
world_info = get_world_info()

self._logger.debug(msg=f"diloco_states {self.diloco_states}")
if self.diloco_states:
resume_ckpt_path_diloco = os.path.join(resume_ckpt_path, f"diloco_{world_info.diloco_rank}")
dcp.load(self.diloco_states, process_group=self.process_group, checkpoint_id=resume_ckpt_path_diloco)

self._logger.debug(msg=f"postdiloco_states {self.diloco_states}")

## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk
with open(os.path.join(resume_ckpt_path, f"__{rank}_0.pt"), "rb") as f:

# dataloader is different for each diloco worker. If diloco is enable we use diloco_ckpt_path
dataloader_path = resume_ckpt_path if self.diloco_states is None else resume_ckpt_path_diloco
self._logger.debug(f"loading dataloader from {dataloader_path}")
with open(os.path.join(dataloader_path, f"__{world_info.local_rank}_0.pt"), "rb") as f:
rank_state_dict = torch.load(f)
self.dataloader.load_state_dict(rank_state_dict["data_loader"])

Expand Down
4 changes: 4 additions & 0 deletions src/zeroband/utils/world_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def __init__(self):
def __repr__(self):
return f"WorldInfo(world_size={self.world_size}, rank={self.rank}, local_rank={self.local_rank}, local_world_size={self.local_world_size}, nnodes={self.nnodes}, global_unique_id={self.global_unique_id}, global_addr={self.global_addr}, global_port={self.global_port}, global_world_size={self.global_world_size}, global_rank={self.global_rank})"

@property
def diloco_rank(self):
return self.rank // self.local_world_size


def get_world_info() -> WorldInfo:
"""
Expand Down

0 comments on commit 744e69f

Please sign in to comment.