From c3d4a270792ced604715bf7da77bdb659de02bbd Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Tue, 2 Jan 2024 14:36:40 +0800 Subject: [PATCH 1/9] Refactor the codes to save/load checkpoint of DDP. --- .../python/elastic_agent/torch/ckpt_saver.py | 44 +- dlrover/python/master/scaler/pod_scaler.py | 12 +- .../trainer/tests/torch/checkpoint_test.py | 153 ------ dlrover/trainer/tests/torch/elastic_test.py | 32 +- dlrover/trainer/torch/elastic/checkpoint.py | 439 ------------------ dlrover/trainer/torch/elastic/trainer.py | 115 +---- dlrover/trainer/torch/flash_checkpoint/ddp.py | 9 +- .../torch/flash_checkpoint/ddp_engine.py | 12 +- .../trainer/torch/flash_checkpoint/engine.py | 8 +- examples/pytorch/mnist/cnn_train.py | 79 ++-- 10 files changed, 94 insertions(+), 809 deletions(-) delete mode 100644 dlrover/trainer/tests/torch/checkpoint_test.py delete mode 100644 dlrover/trainer/torch/elastic/checkpoint.py diff --git a/dlrover/python/elastic_agent/torch/ckpt_saver.py b/dlrover/python/elastic_agent/torch/ckpt_saver.py index 3f0c83894..bfc38cebb 100644 --- a/dlrover/python/elastic_agent/torch/ckpt_saver.py +++ b/dlrover/python/elastic_agent/torch/ckpt_saver.py @@ -746,18 +746,6 @@ def save_step_checkpoint(self, step: int): self._writing_storage = False - def persist_to_storage( - self, - local_shard_id: int, - ckpt_config: SingleFileCheckpointConfig, - ): - """Persist the checkpoint from CPU memory buffer into the storage.""" - state_dict = self._shm_handlers[local_shard_id].load_state_dict() - state_dict.pop(DLROVER_CKPT_CONFIG_KEY, None) - checkpoint_dir = os.path.dirname(ckpt_config.path) - os.makedirs(checkpoint_dir, exist_ok=True) - torch.save(state_dict, ckpt_config.path) - def commit_checkpoint(self, step: int, step_done_dir: str, timeout=600): """ The node rank 0 will update the tracker file with the step @@ -950,6 +938,24 @@ def commit_checkpoint( # type: ignore time.sleep(2) +class DdpCheckpointSaver(CommonDirCheckpointSaver): + """Persist the checkpoint from CPU memory buffer into the storage.""" + + def persist_to_storage( + self, + local_shard_id: int, + ckpt_config: SingleFileCheckpointConfig, + ): + if self._node_rank != 0: + logger.info("Skip saving checkpoint because only rank 0 saves.") + return + state_dict = self._shm_handlers[local_shard_id].load_state_dict() + state_dict.pop(DLROVER_CKPT_CONFIG_KEY, None) + checkpoint_dir = os.path.dirname(ckpt_config.path) + os.makedirs(checkpoint_dir, exist_ok=True) + torch.save(state_dict, ckpt_config.path) + + class MegatronCheckpointSaver(CommonDirCheckpointSaver): TRACER_FILE = "latest_checkpointed_iteration.txt" @@ -971,6 +977,20 @@ def update_tracker_file(self, step): with open(ds_tracker_filename, "w") as f: f.write(str(step)) + def persist_to_storage( + self, + local_shard_id: int, + ckpt_config: SingleFileCheckpointConfig, + ): + if self._node_rank != 0: + logger.info("Skip saving checkpoint because only rank 0 saves.") + return + state_dict = self._shm_handlers[local_shard_id].load_state_dict() + state_dict.pop(DLROVER_CKPT_CONFIG_KEY, None) + checkpoint_dir = os.path.dirname(ckpt_config.path) + os.makedirs(checkpoint_dir, exist_ok=True) + torch.save(state_dict, ckpt_config.path) + class DeepSpeedCheckpointSaver(CommonDirCheckpointSaver): TRACER_FILE = "latest" diff --git a/dlrover/python/master/scaler/pod_scaler.py b/dlrover/python/master/scaler/pod_scaler.py index ba4b8c310..090bae45c 100644 --- a/dlrover/python/master/scaler/pod_scaler.py +++ b/dlrover/python/master/scaler/pod_scaler.py @@ -142,6 +142,8 @@ def _retry_to_get_master_pod(self): def scale(self, plan: ScalePlan): """Scale in/out Pods by a ScalePlan.""" + + self._remove_nodes(plan) while True: waited = False with self._lock: @@ -179,12 +181,14 @@ def scale(self, plan: ScalePlan): self._scale_down_pods(type, plan, cur_pods) for node in plan.launch_nodes: self._create_node_queue.append(node) - for node in plan.remove_nodes: - removed = self._remove_not_create_pod(node.name) - if not removed: - self._k8s_client.delete_pod(node.name) self._update_job_pods(job_pods) + def _remove_nodes(self, plan: ScalePlan): + for node in plan.remove_nodes: + removed = self._remove_not_create_pod(node.name) + if not removed: + self._k8s_client.delete_pod(node.name) + def _update_job_pods(self, job_pods: Dict[str, List[Node]]): for type in [ NodeType.CHIEF, diff --git a/dlrover/trainer/tests/torch/checkpoint_test.py b/dlrover/trainer/tests/torch/checkpoint_test.py deleted file mode 100644 index 5788dfb57..000000000 --- a/dlrover/trainer/tests/torch/checkpoint_test.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2023 The DLRover Authors. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import time -import unittest - -import numpy as np -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.data import DataLoader, Dataset - -from dlrover.python.common import grpc -from dlrover.python.common.constants import CheckpointConstant -from dlrover.python.elastic_agent.torch.ckpt_saver import AsyncCheckpointSaver -from dlrover.trainer.torch.elastic.checkpoint import CheckpointManger -from dlrover.trainer.torch.elastic.sampler import ElasticDistributedSampler - - -def set_torch_dist_env(port): - os.environ["WORLD_SIZE"] = "1" - os.environ["RANK"] = "0" - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(port) - - -class SimpleDataset(Dataset): - def __init__(self): - self.data = np.arange(0, 60001) - - def __len__(self): - return len(self.data) - - def __getitem__(self, index): - return self.data[index] - - -class SimpleNet(nn.Module): - def __init__(self): - super(SimpleNet, self).__init__() - self.fc1 = nn.Linear(64, 32) - self.fc2 = nn.Linear(32, 10) - self.dropout = nn.Dropout(0.5) - - def forward(self, x): - x = self.fc1(x) - x = F.relu(x) - x = self.dropout(x) - x = self.fc2(x) - output = F.log_softmax(x, dim=1) - return output - - -def create_torch_modules(): - model = SimpleNet() - optimizer = optim.SGD( - model.parameters(), - lr=0.01, - momentum=0.001, - ) - dataset = SimpleDataset() - sampler = ElasticDistributedSampler( - dataset=dataset, - num_replicas=2, - rank=0, - shuffle=False, - ) - dataloader = DataLoader( - dataset, - batch_size=4, - sampler=sampler, - ) - return model, optimizer, dataloader - - -def _wait_async_saving_finished(dir_name, step): - ckpt_path = os.path.join(dir_name, f"checkpoint-{step}.pt") - while True: - if os.path.exists(ckpt_path): - return - time.sleep(0.2) - - -class CheckpointManagerTest(unittest.TestCase): - def setUp(self): - AsyncCheckpointSaver._saver_instance = None - AsyncCheckpointSaver.start_async_saving_ckpt() - - def tearDown(self) -> None: - if AsyncCheckpointSaver._saver_instance: - AsyncCheckpointSaver._saver_instance.close() - - def test_ddp_save_load(self): - os.environ["LOCAL_RANK"] = "0" - port = grpc.find_free_port() - set_torch_dist_env(port) - dist.init_process_group(backend="gloo") - try: - model, optimizer, dataloader = create_torch_modules() - model = DDP(model) - msd = model.state_dict() - with tempfile.TemporaryDirectory() as tmpdirname: - ckpt_manager = CheckpointManger.init_checkpoint_manager( - model, - optimizer, - dataloader, - tmpdirname, - max_to_keep=2, - ) - for step in [10, 20, 30]: - ckpt_manager.save(epoch=0, step=step) - _wait_async_saving_finished(tmpdirname, step) - ckpt_dirs = os.listdir(tmpdirname) - ckpt_num = 0 - for d in ckpt_dirs: - if d.endswith(".pt"): - ckpt_num += 1 - self.assertEqual(ckpt_num, 2) - - tracer_file = os.path.join( - tmpdirname, CheckpointConstant.TRACER_FILE_NAME - ) - with open(tracer_file, "r") as f: - restored_step = int(f.read()) - self.assertEqual(step, restored_step) - - ckpt_manager.load() - self.assertEqual(dataloader.sampler.total_size, 60002) - resume_msd = ckpt_manager.model.state_dict() - self.assertTrue( - torch.equal( - msd["module.fc1.weight"], - resume_msd["module.fc1.weight"], - ) - ) - ckpt_manager._ckpt_engine.close() - finally: - dist.destroy_process_group() diff --git a/dlrover/trainer/tests/torch/elastic_test.py b/dlrover/trainer/tests/torch/elastic_test.py index 4d4b370c5..060731792 100644 --- a/dlrover/trainer/tests/torch/elastic_test.py +++ b/dlrover/trainer/tests/torch/elastic_test.py @@ -15,7 +15,7 @@ import os import tempfile import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import numpy as np import torch @@ -25,7 +25,6 @@ from dlrover.python.common.grpc import ParallelConfig from dlrover.trainer.torch.elastic.dataloader import ElasticDataLoader from dlrover.trainer.torch.elastic.trainer import ( - CheckpointInterval, ElasticTrainer, _ElasticLRScheduler, _ElasticOptimizer, @@ -43,22 +42,6 @@ def __getitem__(self, index): return self.data[index] -class CheckpointIntervalTest(unittest.TestCase): - def test_steps(self): - ci = CheckpointInterval(steps=10) - self.assertTrue(ci.should_save(current_step=10)) - self.assertFalse(ci.should_save(current_step=5)) - - def test_epochs(self): - ci = CheckpointInterval(epochs=3) - self.assertTrue(ci.should_save(current_epoch=3)) - self.assertFalse(ci.should_save(current_epoch=1)) - - def test_invalid_input(self): - with self.assertRaises(ValueError): - CheckpointInterval(epochs=3, steps=10) - - class ElasticTrainerTest(unittest.TestCase): def setUp(self): self.model_mock = MagicMock() @@ -68,23 +51,12 @@ def test_epoch_context(self): with self.elastic_trainer.epoch(1): self.assertEqual(self.elastic_trainer.gradient_state.num_steps, 0) - @patch( - "dlrover.trainer.torch.elastic.trainer.ElasticTrainer._save_fsdp_ckpt" - ) - @patch( - "dlrover.trainer.torch.elastic.trainer.CheckpointInterval.should_save" - ) def test_step_context( self, mock_should_save: MagicMock, mock_save: MagicMock ): mock_should_save.return_value = False model = torch.nn.Linear(10, 10) - fsdp_trainer = ElasticTrainer( - model, - use_fsdp=True, - shared_storage_path="fake://", - ckpt_interval=CheckpointInterval(steps=100), - ) + fsdp_trainer = ElasticTrainer(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.001) optimizer = self.elastic_trainer.prepare(optimizer) diff --git a/dlrover/trainer/torch/elastic/checkpoint.py b/dlrover/trainer/torch/elastic/checkpoint.py deleted file mode 100644 index 7e41f2144..000000000 --- a/dlrover/trainer/torch/elastic/checkpoint.py +++ /dev/null @@ -1,439 +0,0 @@ -# Copyright 2023 The DLRover Authors. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import shutil -from abc import ABCMeta, abstractmethod -from typing import Dict - -import torch.distributed as dist -import torch.distributed.checkpoint as dist_cp -from torch.distributed.checkpoint.optimizer import ( - load_sharded_optimizer_state_dict, -) -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import StateDictType -from torch.nn.parallel import DistributedDataParallel as DDP - -from dlrover.python.common.constants import CheckpointConstant -from dlrover.python.common.log import default_logger as logger -from dlrover.trainer.torch.elastic.sampler import ElasticDistributedSampler -from dlrover.trainer.torch.flash_checkpoint.ddp_engine import ( - DdpCheckpointEngine, -) -from dlrover.trainer.torch.flash_checkpoint.engine import CheckpointEngine - -CKPT_DIR_PREFIX = "checkpoint-" - - -def _sync(): - if dist.is_initialized(): - dist.barrier() - - -def _keep_topk_checkpoint(checkpoint_dir, max_to_keep): - """Keep top k checkpoints and remove other checkpoints. - - Arguments: - checkpoint_dir: the directory to save checkpoint files. - max_to_keep: the number of checkpoint files to keep. - """ - step_names: Dict[int, str] = {} - if not os.path.exists(checkpoint_dir): - return - for ckpt_name in os.listdir(checkpoint_dir): - if not ckpt_name.startswith( - CheckpointConstant.CKPT_NAME_PREFIX - ) or not ckpt_name.endswith(".pt"): - continue - name = ckpt_name.split("-")[-1] - if name.endswith(".pt"): - step = int(name.split(".")[0]) - else: - step = int(name) - step_names[step] = ckpt_name - - steps = sorted(list(step_names.keys())) - if len(steps) <= max_to_keep: - return - if max_to_keep == 0: - remove_steps = steps - else: - remove_steps = steps[: -1 * max_to_keep] - for step in remove_steps: - ckpt_name = os.path.join(checkpoint_dir, step_names[step]) - logger.info(f"Remove the checkpoint {ckpt_name}") - if os.path.isfile(ckpt_name): - os.remove(ckpt_name) - else: - shutil.rmtree(ckpt_name) - - -class CheckpointManger(metaclass=ABCMeta): - """CheckpontManager can save and load checkpoint states. - - Args: - model (nn.Module): an instance of `torch.nn.Module`. - optimizer (Optimizer): an instance of `torch.optim.Optimizer`. - dataloader (DataLader): an instance of `torch.utils.data.DataLoader`. - The sampler of DataLoader should be an instance of - `dlrover.trainer.torch.elastic.ElasticDistribuedSampler`. - checkpoint_dir (str): the directory to save the checkpoint states. - save_storage_interval (int, optinal): The step inverval to save the - checkoint state dict into the storage. Default: ``1``. - max_to_keep (int, optinal): the max number of checkpoint to keep. The - oldest checkpoint files will be removed if the number of - checkpoints is bigger than max_to_kep. Default: ``1``. - - Example:: - >>> ckpt_manager = LocalCheckpointManger( - >>> model=model, - >>> optimizer=optimizer, - >>> dataloader=train_dataloader, - >>> checkpoint_dir="/tmp/checkpoint/", - >>> save_storage_interval=5, - >>> ) - >>> ckpt_manager.save(0, 10) - >>> ckpt_manger.load() - """ - - def __init__( - self, - model, - optimizer, - dataloader, - checkpoint_dir, - save_storage_interval=1, - max_to_keep=1, - ): - self.model = model - self.optimizer = optimizer - self.dataloader = dataloader - self.checkpoint_dir = checkpoint_dir - self.save_storage_interval = save_storage_interval - self.max_to_keep = max_to_keep - if dist.is_initialized(): - self._rank = dist.get_rank() - self._local_rank = int(os.environ["LOCAL_RANK"]) - else: - self._rank = 0 - self._local_rank = int(os.getenv("LOCAL_RANK", 0)) - - def _log_rank0(self, log): - if self._rank == 0: - logger.info(log) - - def _engine_save(self, engine: CheckpointEngine, step, state_dict): - """ - The each rank has the complete state dict without sharding. Only - the locak rank 0 on each node saves the state dict into the shared - memory and only the rank 0 saves the state dict into the storage. - """ - ckpt_path = os.path.join( - self.checkpoint_dir, - f"{CheckpointConstant.CKPT_NAME_PREFIX}{step}.pt", - ) - engine.save_to_memory(step, state_dict, ckpt_path) - if step % self.save_storage_interval == 0: - if self._rank == 0: - _keep_topk_checkpoint( - self.checkpoint_dir, self.max_to_keep - 1 - ) - engine.save_to_storage(step, state_dict, ckpt_path) - - @abstractmethod - def save(self, epoch, step): - """ - Save the checkpoint of model, optimizer and sampler. - - Args: - epoch (int): the epoch index. - step (int): the iteration step in the epoch. - """ - pass - - @abstractmethod - def load(self, resuming_path=None): - """ - The manager loads the states from the files in the - checkpoint direcotry to the model, optimizer and sampler. - - Args: - resuming_path (str, optinoal): The manager will load checkpoint - from the path. If the path is None, the manager will load - the state checkpoint from the file with the maximum step. - - Return: - step (int): the iteration step. - A dict: a state dict. - """ - pass - - @classmethod - def init_checkpoint_manager( - cls, - model, - optimizer, - dataloader, - directory, - max_to_keep=1, - save_storage_interval=1, - ): - """A factory method to initialize a checkpoint manager by the model - class. - """ - if not dist.is_initialized(): - return LocalCheckpointManger( - model, - optimizer, - dataloader, - directory, - save_storage_interval, - max_to_keep, - ) - elif isinstance(model, DDP): - return DDPCheckpointManger( - model, - optimizer, - dataloader, - directory, - save_storage_interval, - max_to_keep, - ) - elif isinstance(model, FSDP): - return FSDPCheckpointManger( - model, - optimizer, - dataloader, - directory, - save_storage_interval, - max_to_keep, - ) - else: - raise NotImplementedError(f"Not support model class {model}") - - -class LocalCheckpointManger(CheckpointManger): - """ - The manager saves and loads checkpoint states of the local - model and optimizer without distributed execution. - """ - - def __init__( - self, - model, - optimizer, - dataloader, - checkpoint_dir, - save_storage_interval=1, - max_to_keep=1, - ): - super().__init__( - model, - optimizer, - dataloader, - checkpoint_dir, - save_storage_interval, - max_to_keep, - ) - self._ckpt_engine = DdpCheckpointEngine( - checkpoint_dir, - ) - - def save(self, epoch, step): - """ - Save the checkpoint of model, optimizer, dataloader into the directory - `{self.directory}/checkpoint-{step}/checkpoint.pt`. - """ - logger.info(f"Save checkpoint of step={step} of epoch={epoch}.") - step = step + epoch * len(self.dataloader) - msd = self.model.state_dict() - osd = self.optimizer.state_dict() - ssd = {} - if isinstance(self.dataloader.sampler, ElasticDistributedSampler): - ssd = self.dataloader.sampler.state_dict( - step, self.dataloader.batch_size - ) - checkpoint = { - "model": msd, - "optimizer": osd, - "sampler": ssd, - "epoch": epoch, - "step": step, - } - self._engine_save(self._ckpt_engine, step, checkpoint) - - def load(self, resuming_path=None): - """ - Load teh state dict from checkpointing data to the model and optimizer. - """ - checkpoint = self._ckpt_engine.load(resuming_path) - if not checkpoint: - return {} - sampler = self.dataloader.sampler - if isinstance(sampler, ElasticDistributedSampler): - sampler.load_state_dict(checkpoint.get("sampler", {})) - model_state_dict = checkpoint.get("model", {}) - optim_state_dict = checkpoint.get("optimizer", {}) - self.model.load_state_dict(model_state_dict) - self.optimizer.load_state_dict(optim_state_dict) - return checkpoint - - -class DDPCheckpointManger(LocalCheckpointManger): - """ - DDPCheckpontManager saves and loads checkpoint states of a DDP model. - """ - - def __init__( - self, - model, - optimizer, - dataloader, - checkpoint_dir, - save_storage_interval=1, - max_to_keep=1, - ): - super().__init__( - model, - optimizer, - dataloader, - checkpoint_dir, - save_storage_interval, - max_to_keep, - ) - - def load(self, resuming_path=None): - """ - Load teh state dict from checkpointing data to the model and optimizer. - """ - checkpoint = super().load(resuming_path=resuming_path) - _sync() - return checkpoint - - -class FSDPCheckpointManger(CheckpointManger): - """ - DDPCheckpontManager saves and loads checkpoint states of a DDP model. - """ - - def __init__( - self, - model, - optimizer, - dataloader, - checkpoint_dir, - save_storage_interval=1, - max_to_keep=1, - ): - super().__init__( - model, - optimizer, - dataloader, - checkpoint_dir, - save_storage_interval, - max_to_keep, - ) - self._ckpt_engine = DdpCheckpointEngine(checkpoint_dir) - - def save(self, epoch, step): - """ - Save the checkpoint of model, optimizer, dataloader into the directory - `{self.directory}/checkpoint-{step}/`. All ranks will save - the part of the model and optimizer states into the file - `checkpoint-{step}/part-{rank}.pt`. - """ - self._log_rank0(f"Save checkpoint of step={step} of epoch={epoch}.") - if self.dataloader: - step = step + epoch * len(self.dataloader) - - with FSDP.state_dict_type( - self.model, - StateDictType.SHARDED_STATE_DICT, - ): - state_dict = { - "model": self.model.state_dict(), - "optim": FSDP.optim_state_dict(self.model, self.optimizer), - } - - ssd = {} - if self.dataloader and isinstance( - self.dataloader.sampler, ElasticDistributedSampler - ): - ssd = self.dataloader.sampler.state_dict( - step, self.dataloader.batch_size - ) - state_dict["sampler"] = ssd - state_dict["epoch"] = epoch - state_dict["step"] = step - subdir_name = CheckpointConstant.CKPT_NAME_PREFIX + str(step) - checkpoint_dir = os.path.join(self.checkpoint_dir, subdir_name) - dist_cp.save_state_dict( - state_dict=state_dict, - storage_writer=dist_cp.FileSystemWriter(checkpoint_dir), - ) - tracer_file = os.path.join( - self.checkpoint_dir, CheckpointConstant.TRACER_FILE_NAME - ) - with open(tracer_file, "w") as f: - f.write(str(step)) - - def load(self, resuming_path=None): - """ - Load teh state dict from checkpointing data to the model and optimizer. - """ - - if resuming_path is None: - tracer_file = os.path.join( - self.checkpoint_dir, CheckpointConstant.TRACER_FILE_NAME - ) - if not os.path.exists(tracer_file): - return {} - with open(tracer_file, "r") as f: - step = f.read() - subdir_name = CheckpointConstant.CKPT_NAME_PREFIX + step - resuming_path = os.path.join(self.checkpoint_dir, subdir_name) - with FSDP.state_dict_type( - self.model, StateDictType.SHARDED_STATE_DICT - ): - # cannot load the optimizer state_dict together - # with the model state_dict. - state_dict = { - "model": self.model.state_dict(), - "step": 0, - "epoch": 0, - "sampler": {}, - } - - dist_cp.load_state_dict( - state_dict=state_dict, - storage_reader=dist_cp.FileSystemReader(resuming_path), - ) - self.model.load_state_dict(state_dict["model"]) - - optim_state = load_sharded_optimizer_state_dict( - model_state_dict=state_dict["model"], - optimizer_key="optim", - storage_reader=dist_cp.FileSystemReader(resuming_path), - ) - - flattened_osd = FSDP.optim_state_dict_to_load( - self.model, self.optimizer, optim_state["optim"] - ) - self.optimizer.load_state_dict(flattened_osd) - - if self.dataloader: - sampler = self.dataloader.sampler - if isinstance(sampler, ElasticDistributedSampler): - sampler.load_state_dict(state_dict.get("sampler", {})) - return state_dict diff --git a/dlrover/trainer/torch/elastic/trainer.py b/dlrover/trainer/torch/elastic/trainer.py index b88ae0111..9f38c2702 100644 --- a/dlrover/trainer/torch/elastic/trainer.py +++ b/dlrover/trainer/torch/elastic/trainer.py @@ -17,7 +17,7 @@ import time from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any, Dict import torch import torch.distributed as dist @@ -178,50 +178,6 @@ def print_lr(self, *args, **kwargs): return self.scheduler.print_lr(*args, **kwargs) -class CheckpointInterval: - def __init__( - self, steps: Optional[int] = None, epochs: Optional[int] = None - ): - """Initializes the CheckpointInterval class to determine intervals - for saving checkpoints. - - Args: - steps (int, optional): Number of steps for checkpoint intervals. - epochs (int, optional): Number of epochs for checkpoint intervals. - - Raises: - ValueError: If both 'steps' and 'epochs' are set simultaneously. - - **Note:** - Only one of 'steps' or 'epochs' should be set. - """ - if steps and epochs: - raise ValueError("Only one of 'steps' or 'epochs' should be set.") - self.steps = steps - self.epochs = epochs - - def should_save( - self, - current_step: Optional[int] = None, - current_epoch: Optional[int] = None, - ) -> bool: - """Determines if a checkpoint should be saved based on - provided parameters. - - Args: - current_step (int, optional): The current training step. - current_epoch (int, optional): The current training epoch. - - Returns: - bool: True if the checkpoint should be saved, otherwise False. - """ - if self.steps and current_step and current_step % self.steps == 0: - return True - if self.epochs and current_epoch and current_epoch % self.epochs == 0: - return True - return False - - class ElasticTrainer(object): """Creates an instance of an elastic trainer for elastic distributed training on multi-nodes. The elastic trainer will do: @@ -250,22 +206,11 @@ def __init__( self, model, dataloader: ElasticDataLoader = None, - use_fsdp: bool = False, - ckpt_interval: CheckpointInterval = None, - shared_storage_path: str = None, ): - if use_fsdp and (ckpt_interval is None or shared_storage_path is None): - raise ValueError( - "When 'use_fsdp' is True, both 'ckpt_interval' and \ - 'shared_storage_path' must be provided and not None." - ) self.model = model self.dataloader = dataloader self.gradient_state = GradientState() self.gradient_accumulation_steps = 1 - self.use_fsdp = use_fsdp - self.ckpt_interval = ckpt_interval - self.shared_storage_path = shared_storage_path self._report_step_interval = 15 # 15s self._last_report_time = 0 @@ -292,60 +237,6 @@ def prepare(self, optimizer, lr_scheduler=None): else: return optimizer - def _save_fsdp_ckpt( - self, - epoch_num: Optional[int] = None, - step_num: Optional[int] = None, - ): - """Intended for saving Fully Sharded Data Parallel (FSDP) checkpoints. - This method's implementation is pending in a subsequent PR. - - Args: - epoch_num (Optional[int], optional): - The current epoch number. Defaults to None. - step_num (Optional[int], optional): - The current step number. Defaults to None. - - Raises: - NotImplementedError: This method has not been implemented yet. - - **TODO:** - - Implement the detailed logic for saving FSDP checkpoints to - the file system in a subsequent PR. - """ - raise NotImplementedError( - "The save_checkpoint method will be implemented \ - in a subsequent PR." - ) - - def _before_epoch(self): - """Prepares necessary setups before starting a new epoch.""" - self.reset() - - def _after_epoch(self, num_epochs: int): - """Handles post-epoch operations based on the completed number - of epochs. - - Args: - num_epochs (int): The total number of completed epochs. - """ - if self.ckpt_interval and self.ckpt_interval.should_save( - current_epoch=num_epochs - ): - self._save_fsdp_ckpt(num_epochs) - - @contextmanager - def epoch(self, num_epochs: int): - """Context manager for pre-epoch setup and post-epoch cleanup. - - Args: - num_epochs (int): - The number of epochs to consider within this context. - """ - self._before_epoch() - yield - self._after_epoch(num_epochs) - @contextmanager def step(self, fix_total_batch_size=False): """ @@ -404,10 +295,6 @@ def _before_step(self, fix_total_batch_size): ) def _after_step(self): - if self.ckpt_interval and self.ckpt_interval.should_save( - current_step=self.num_steps - ): - self._save_fsdp_ckpt() if self.gradient_state.sync_gradients: self.gradient_state.num_steps += 1 now = time.time() diff --git a/dlrover/trainer/torch/flash_checkpoint/ddp.py b/dlrover/trainer/torch/flash_checkpoint/ddp.py index c83298152..fe77102e5 100644 --- a/dlrover/trainer/torch/flash_checkpoint/ddp.py +++ b/dlrover/trainer/torch/flash_checkpoint/ddp.py @@ -11,6 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + +from dlrover.python.common.constants import CheckpointConstant from .checkpointer import Checkpointer, StorageType from .ddp_engine import DdpCheckpointEngine @@ -43,11 +46,15 @@ class DdpCheckpointer(Checkpointer): """ def __init__(self, checkpoint_dir: str): + self.checkpoint_dir = checkpoint_dir self._engine = DdpCheckpointEngine(checkpoint_dir) def save_checkpoint( - self, step, state_dict, path, storage_type=StorageType.DISK + self, step, state_dict, path="", storage_type=StorageType.DISK ): + if path == "": + ckpt_name = f"{CheckpointConstant.CKPT_NAME_PREFIX}{step}.pt" + path = os.path.join(self.checkpoint_dir, ckpt_name) if storage_type == StorageType.MEMORY: self._engine.save_to_memory(step, state_dict, path) elif storage_type == StorageType.DISK: diff --git a/dlrover/trainer/torch/flash_checkpoint/ddp_engine.py b/dlrover/trainer/torch/flash_checkpoint/ddp_engine.py index d1c4836dd..4e8cfa89e 100644 --- a/dlrover/trainer/torch/flash_checkpoint/ddp_engine.py +++ b/dlrover/trainer/torch/flash_checkpoint/ddp_engine.py @@ -24,7 +24,7 @@ DLROVER_CKPT_CONFIG_KEY, CheckpointEvent, CheckpointEventType, - CommonDirCheckpointSaver, + DdpCheckpointSaver, ) from .engine import CheckpointEngine, timer @@ -74,6 +74,7 @@ def _get_saver_ranks(self): for i in range(group_size): saver_rank = i * local_world_size save_ranks.append(saver_rank) + logger.info(f"The ranks to save checkpoint are {save_ranks}.") return save_ranks def get_local_shard_num(self): @@ -83,7 +84,7 @@ def get_global_shard_num(self): return 1 def get_saver_class(self): - return CommonDirCheckpointSaver + return DdpCheckpointSaver @timer def save_to_storage(self, step, state_dict, path=""): @@ -101,7 +102,7 @@ def save_to_storage(self, step, state_dict, path=""): Note, the ckpt_name is used to save the state dict to storage only if the training process fails. """ - if self._rank != 0: + if self._local_rank != 0: return if not path: name = f"{CheckpointConstant.CKPT_NAME_PREFIX}{step}.pt" @@ -109,8 +110,9 @@ def save_to_storage(self, step, state_dict, path=""): if step > self._cached_step: self.save_to_memory(step, state_dict, path) event = CheckpointEvent(type=CheckpointEventType.SAVE, step=step) - if self._local_rank == 0: - self._event_queue.put(event) + + # Only rank 0 persist the checkpoint to the storage. + self._event_queue.put(event) def load(self, resume_path=""): """ diff --git a/dlrover/trainer/torch/flash_checkpoint/engine.py b/dlrover/trainer/torch/flash_checkpoint/engine.py index 35714e222..9a3326a9b 100644 --- a/dlrover/trainer/torch/flash_checkpoint/engine.py +++ b/dlrover/trainer/torch/flash_checkpoint/engine.py @@ -101,8 +101,8 @@ def __init__(self, checkpoint_dir: str): self._saver_group = None self._cached_step = 0 self._restart_count = env_utils.get_torch_restart_count() - # queue for agent to save to storage, only rank 0 - if self._rank == 0: + # queue for agent to save to storage, only lock rank 0 needs the queue. + if self._local_rank == 0: self._event_queue = SharedQueue( name=CheckpointSharedObjPrefix.SAVE_STEP_QNAME + str(0), create=False, @@ -165,6 +165,10 @@ def _update_saver_config(self): type=CheckpointEventType.UPDATE_SHARD, global_shard_num=global_shard_num, ) + if self._event_queue is None: + raise ValueError( + "The event queue cannot be None on local rank 0." + ) self._event_queue.put(event) @timer diff --git a/examples/pytorch/mnist/cnn_train.py b/examples/pytorch/mnist/cnn_train.py index 8ce9a2683..53ccb8e8a 100644 --- a/examples/pytorch/mnist/cnn_train.py +++ b/examples/pytorch/mnist/cnn_train.py @@ -12,7 +12,6 @@ # limitations under the License. import argparse -import functools import os from datetime import datetime, timedelta @@ -23,19 +22,15 @@ import torch.optim as optim import torchvision from torch.distributed.elastic.multiprocessing.errors import record -from torch.distributed.fsdp import CPUOffload, FullStateDictConfig -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import StateDictType -from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim.lr_scheduler import StepLR from torch.utils.data import DataLoader from torchvision import transforms -from dlrover.trainer.torch.elastic.checkpoint import CheckpointManger from dlrover.trainer.torch.elastic.dataloader import ElasticDataLoader from dlrover.trainer.torch.elastic.sampler import ElasticDistributedSampler from dlrover.trainer.torch.elastic.trainer import ElasticTrainer +from dlrover.trainer.torch.flash_checkpoint.ddp import DdpCheckpointer # Note, we need to set the path of a shared file # system like nas, cpfs or hdfs. @@ -128,40 +123,24 @@ def train(args): # create model and move it to GPU with id rank model = model.to(local_rank) - if args.use_fsdp: - print(f"Running basic FSDP example on local rank {local_rank}.") - my_auto_wrap_policy = functools.partial( - size_based_auto_wrap_policy, min_num_params=1000 - ) - cpu_offload = ( - CPUOffload(offload_params=True) if args.cpu_offload else None - ) - model = FSDP( - model, - device_id=local_rank, - auto_wrap_policy=my_auto_wrap_policy, - cpu_offload=cpu_offload, - ) - else: - print(f"Running basic DDP example on local rank {local_rank}.") - model = DDP(model, device_ids=[local_rank]) - print(f"Model device {model.device}") + print(f"Running basic DDP example on local rank {local_rank}.") + model = DDP(model, device_ids=[local_rank]) + print(f"Model device {model.device}") else: - if args.use_fsdp: - raise ValueError("fsdp requires cuda devices") model = DDP(model) optimizer = optim.SGD( model.parameters(), lr=args.learning_rate, momentum=args.momentum ) scheduler = StepLR(optimizer, step_size=1, gamma=0.5) - ckpt_manager = CheckpointManger.init_checkpoint_manager( - model, - optimizer, - train_loader, - CHEKPOINT_DIR, - ) - ckpt_manager.load() + checkpointer = DdpCheckpointer(CHEKPOINT_DIR) + state_dict = checkpointer.load_checkpoint() + if "model" in state_dict: + model.load_state_dict(state_dict["model"]) + if "optimizer" in state_dict: + optimizer.load_state_dict(state_dict["optimizer"]) + if "sampler" in state_dict: + train_loader.sampler.load_state_dict(state_dict["sampler"]) elastic_trainer = ElasticTrainer(model, dataloader=train_loader) optimizer, scheduler = elastic_trainer.prepare(optimizer, scheduler) @@ -180,14 +159,14 @@ def train(args): optimizer, train_loader, device, - ckpt_manager, + checkpointer, args.fixed_batch_size, ) log_rank0("Test model after epoch {}".format(epoch)) test(model, device, test_loader) if args.save_model: rank = int(os.environ.get("RANK", "0")) - save_model(model, args.num_epochs, rank, args.use_fsdp) + save_model(model, args.num_epochs, rank) dist.barrier() @@ -198,7 +177,7 @@ def train_epoch( optimizer, train_loader, device, - ckpt_manager: CheckpointManger, + checkpointer: DdpCheckpointer, fixed_batch_size=False, ): """ @@ -207,6 +186,9 @@ def train_epoch( # Note: Set epoch into the sampler. train_loader.sampler.set_epoch(epoch) for _, (data, target) in enumerate(train_loader): + + # Automatically adjust the accumulated step to keep the global batch + # size fixed even if the number of workers changes. with elastic_trainer.step(fixed_batch_size): optimizer.zero_grad() target = target.type(torch.LongTensor) @@ -220,24 +202,25 @@ def train_epoch( log_rank0("loss = {}, step = {}".format(loss, train_step)) if train_step > 0 and train_step % 200 == 0: - ckpt_manager.save(epoch, train_step) + sd = { + "step": train_step, + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + if isinstance(train_loader.sampler, ElasticDistributedSampler): + sd["sampler"] = train_loader.sampler.state_dict( + train_step, train_loader.batch_size + ) + checkpointer.save_checkpoint(train_step, sd) print("Finish save checkpoint.") -def save_model(model, epoch, rank, use_fsdp=False): +def save_model(model, epoch, rank): # save if rank == 0: print("--> entering save model state") - if use_fsdp: - save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with FSDP.state_dict_type( - model, StateDictType.FULL_STATE_DICT, save_policy - ): - cpu_state = model.state_dict() - else: - cpu_state = model.state_dict() - + cpu_state = model.state_dict() if rank == 0: print("--> saving model ...") currEpoch = "-" + str(epoch) + ".pt" @@ -283,8 +266,6 @@ def arg_parser(): parser.add_argument("--batch_size", type=int, default=32, required=False) parser.add_argument("--num_epochs", type=int, default=1, required=False) parser.add_argument("--shuffle", type=bool, default=True, required=False) - parser.add_argument("--use_fsdp", action="store_true", required=False) - parser.add_argument("--cpu_offload", action="store_true", required=False) parser.add_argument( "--fixed_batch_size", type=bool, default=True, required=False ) From 642bb9472e2b2dfe011c18f5f64781ba0107ab66 Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Tue, 2 Jan 2024 14:37:34 +0800 Subject: [PATCH 2/9] Set the default value of checkpoint path. --- examples/pytorch/nanogpt/train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/pytorch/nanogpt/train.py b/examples/pytorch/nanogpt/train.py index 73b4dffa6..c215cf1ff 100644 --- a/examples/pytorch/nanogpt/train.py +++ b/examples/pytorch/nanogpt/train.py @@ -306,15 +306,14 @@ def flash_save_checkpoint( iter_num, train_loader.batch_size ) state_dict["ds_sampler"] = sampler_sd - ckpt_path = os.path.join(checkpoint_dir, f"checkpoint-{iter_num}.pt") if iter_num % save_memory_interval == 0: checkpointer.save_checkpoint( - iter_num, state_dict, ckpt_path, storage_type=StorageType.MEMORY + iter_num, state_dict, storage_type=StorageType.MEMORY ) saved = True if iter_num % save_storage_interval == 0: checkpointer.save_checkpoint( - iter_num, state_dict, ckpt_path, storage_type=StorageType.DISK + iter_num, state_dict, storage_type=StorageType.DISK ) saved = True return saved From dc1fbb52dffaf5e1b821a9203c4ad6cbaaebbacd Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Tue, 2 Jan 2024 14:53:13 +0800 Subject: [PATCH 3/9] Fix test cases. --- dlrover/python/tests/test_ckpt_saver.py | 14 +++++++------- .../python/tests/test_elastic_training_agent.py | 4 ++-- dlrover/trainer/tests/torch/elastic_test.py | 12 +----------- 3 files changed, 10 insertions(+), 20 deletions(-) diff --git a/dlrover/python/tests/test_ckpt_saver.py b/dlrover/python/tests/test_ckpt_saver.py index 0784a8aae..c81b1bb67 100644 --- a/dlrover/python/tests/test_ckpt_saver.py +++ b/dlrover/python/tests/test_ckpt_saver.py @@ -36,7 +36,7 @@ CheckpointEventType, CheckpointShardConfig, CheckpointSharedObjPrefix, - CommonDirCheckpointSaver, + DdpCheckpointSaver, FsdpDcpSaver, SaverClassMeta, SharedMemoryHandler, @@ -110,8 +110,8 @@ def tearDown(self) -> None: def test_create_checkpoint_saver(self): sq = SharedQueue(name="factory", create=False) class_meta = SaverClassMeta( - module_path=CommonDirCheckpointSaver.__module__, - class_name=CommonDirCheckpointSaver.__name__, + module_path=DdpCheckpointSaver.__module__, + class_name=DdpCheckpointSaver.__name__, init_args={"checkpoint_dir": "test_ckpt"}, ) sq.put(class_meta) @@ -123,7 +123,7 @@ def test_create_checkpoint_saver(self): self.assertIsNotNone(AsyncCheckpointSaver._saver_instance) def test_close_saver(self): - saver = CommonDirCheckpointSaver("test_ckpt") + saver = DdpCheckpointSaver("test_ckpt") try: SharedMemory(name="test").unlink() except Exception: @@ -168,7 +168,7 @@ def test_save_to_storage(self): step=step, ) with tempfile.TemporaryDirectory() as tmpdir: - saver = CommonDirCheckpointSaver(tmpdir) + saver = DdpCheckpointSaver(tmpdir) path = Path(tmpdir) / "checkpoint.pt" ckpt_config = SingleFileCheckpointConfig(step=100, path=path) saver._shm_handlers[0].save_state_dict(state_dict, ckpt_config) @@ -194,7 +194,7 @@ def test_save_to_storage(self): def test_shard_num_changes(self): with tempfile.TemporaryDirectory() as tmpdir: - saver = CommonDirCheckpointSaver(tmpdir) + saver = DdpCheckpointSaver(tmpdir) saver.global_shard_num = 1 threading.Thread( target=saver._sync_shm_to_storage, daemon=True @@ -216,7 +216,7 @@ def test_commit_checkpoint(self): with tempfile.TemporaryDirectory() as tmpdir: step_done_dir = os.path.join(tmpdir, ".done/10/") os.makedirs(step_done_dir, exist_ok=True) - saver = CommonDirCheckpointSaver(tmpdir) + saver = DdpCheckpointSaver(tmpdir) saver.global_shard_num = 1 saver.commit_checkpoint(100, step_done_dir, 2) diff --git a/dlrover/python/tests/test_elastic_training_agent.py b/dlrover/python/tests/test_elastic_training_agent.py index 3ef3301e4..6758b602b 100644 --- a/dlrover/python/tests/test_elastic_training_agent.py +++ b/dlrover/python/tests/test_elastic_training_agent.py @@ -32,7 +32,7 @@ from dlrover.python.elastic_agent.monitor.training import TorchTrainingMonitor from dlrover.python.elastic_agent.torch.ckpt_saver import ( AsyncCheckpointSaver, - CommonDirCheckpointSaver, + DdpCheckpointSaver, ) from dlrover.python.elastic_agent.torch.training import ( ElasticLaunchConfig, @@ -284,7 +284,7 @@ def test_restart_training(self): start_method=self.config.start_method, log_dir=self.config.log_dir, ) - saver = CommonDirCheckpointSaver("/tmp/test") + saver = DdpCheckpointSaver("/tmp/test") AsyncCheckpointSaver._saver_instance = saver agent._save_ckpt_to_storage() agent._stop_workers_to_restart() diff --git a/dlrover/trainer/tests/torch/elastic_test.py b/dlrover/trainer/tests/torch/elastic_test.py index 060731792..2db720d90 100644 --- a/dlrover/trainer/tests/torch/elastic_test.py +++ b/dlrover/trainer/tests/torch/elastic_test.py @@ -47,14 +47,7 @@ def setUp(self): self.model_mock = MagicMock() self.elastic_trainer = ElasticTrainer(self.model_mock) - def test_epoch_context(self): - with self.elastic_trainer.epoch(1): - self.assertEqual(self.elastic_trainer.gradient_state.num_steps, 0) - - def test_step_context( - self, mock_should_save: MagicMock, mock_save: MagicMock - ): - mock_should_save.return_value = False + def test_step_context(self): model = torch.nn.Linear(10, 10) fsdp_trainer = ElasticTrainer(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.001) @@ -70,9 +63,7 @@ def test_step_context( optimizer.step() optimizer.zero_grad() self.assertTrue(self.elastic_trainer.gradient_state.sync_gradients) - self.assertEqual(mock_save.call_count, 0) - mock_should_save.return_value = True with fsdp_trainer.step(): output = model(data) loss = torch.sum(output) @@ -80,7 +71,6 @@ def test_step_context( optimizer.step() optimizer.zero_grad() self.assertTrue(self.elastic_trainer.gradient_state.sync_gradients) - self.assertEqual(mock_save.call_count, 1) def test_prepare_without_lr_scheduler(self): optimizer_mock = MagicMock() From b7e6c77d749386a17bf36525fc6b13e9e4be25ae Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Tue, 2 Jan 2024 15:11:21 +0800 Subject: [PATCH 4/9] Add test cases. --- dlrover/python/elastic_agent/torch/ckpt_saver.py | 5 +---- dlrover/python/tests/test_ckpt_saver.py | 3 +++ 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dlrover/python/elastic_agent/torch/ckpt_saver.py b/dlrover/python/elastic_agent/torch/ckpt_saver.py index bfc38cebb..6dfe86598 100644 --- a/dlrover/python/elastic_agent/torch/ckpt_saver.py +++ b/dlrover/python/elastic_agent/torch/ckpt_saver.py @@ -947,7 +947,7 @@ def persist_to_storage( ckpt_config: SingleFileCheckpointConfig, ): if self._node_rank != 0: - logger.info("Skip saving checkpoint because only rank 0 saves.") + logger.info("Skip and only rank 0 saves checkpoint in a DDP job.") return state_dict = self._shm_handlers[local_shard_id].load_state_dict() state_dict.pop(DLROVER_CKPT_CONFIG_KEY, None) @@ -982,9 +982,6 @@ def persist_to_storage( local_shard_id: int, ckpt_config: SingleFileCheckpointConfig, ): - if self._node_rank != 0: - logger.info("Skip saving checkpoint because only rank 0 saves.") - return state_dict = self._shm_handlers[local_shard_id].load_state_dict() state_dict.pop(DLROVER_CKPT_CONFIG_KEY, None) checkpoint_dir = os.path.dirname(ckpt_config.path) diff --git a/dlrover/python/tests/test_ckpt_saver.py b/dlrover/python/tests/test_ckpt_saver.py index c81b1bb67..8cc30ee29 100644 --- a/dlrover/python/tests/test_ckpt_saver.py +++ b/dlrover/python/tests/test_ckpt_saver.py @@ -192,6 +192,9 @@ def test_save_to_storage(self): self.assertEqual(len(ckpt_files), 3) saver.close() + saver._node_rank = 1 + saver.persist_to_storage(0, None) + def test_shard_num_changes(self): with tempfile.TemporaryDirectory() as tmpdir: saver = DdpCheckpointSaver(tmpdir) From fc987bc49c92733c5208ed26708ef26d7215642d Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Tue, 2 Jan 2024 15:16:25 +0800 Subject: [PATCH 5/9] Add test cases. --- dlrover/python/tests/test_pod_scaler.py | 3 +++ dlrover/python/tests/test_utils.py | 1 + 2 files changed, 4 insertions(+) diff --git a/dlrover/python/tests/test_pod_scaler.py b/dlrover/python/tests/test_pod_scaler.py index bfedfdf66..9d8c90acf 100644 --- a/dlrover/python/tests/test_pod_scaler.py +++ b/dlrover/python/tests/test_pod_scaler.py @@ -176,6 +176,9 @@ def test_scale(self): scale_plan.launch_nodes.append( Node(NodeType.WORKER, 1, NodeResource(0, 0)) ) + scale_plan.remove_nodes.append( + Node(NodeType.WORKER, 3, NodeResource(0, 0)) + ) scaler.scale(scale_plan) self.assertFalse(scale_plan.empty()) self.assertEqual(len(scaler._create_node_queue), 2) diff --git a/dlrover/python/tests/test_utils.py b/dlrover/python/tests/test_utils.py index 78f81ed0b..f2767978f 100644 --- a/dlrover/python/tests/test_utils.py +++ b/dlrover/python/tests/test_utils.py @@ -253,6 +253,7 @@ def mock_k8s_client(): return_value=True ) k8s_client.create_pod = mock.MagicMock(return_value=True) # type: ignore + k8s_client.delete_pod = mock.MagicMock(return_value=True) # type: ignore k8s_client.create_service = mock.MagicMock( # type: ignore return_value=True ) From ca73ff4a0b35e090f2313259472bc9804b94ec83 Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Tue, 2 Jan 2024 16:18:21 +0800 Subject: [PATCH 6/9] Format codes. --- .../tests/torch/ddp_checkpointer_test.py | 61 +++++++++++++++++++ examples/pytorch/nanogpt/ds_train.py | 2 +- examples/pytorch/nanogpt/fsdp_train.py | 9 +++ examples/pytorch/nanogpt/train.py | 8 +++ 4 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 dlrover/trainer/tests/torch/ddp_checkpointer_test.py diff --git a/dlrover/trainer/tests/torch/ddp_checkpointer_test.py b/dlrover/trainer/tests/torch/ddp_checkpointer_test.py new file mode 100644 index 000000000..760053fe2 --- /dev/null +++ b/dlrover/trainer/tests/torch/ddp_checkpointer_test.py @@ -0,0 +1,61 @@ +# Copyright 2023 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import torch.nn as nn +import torch.nn.functional as F + +from dlrover.python.elastic_agent.torch.ckpt_saver import DdpCheckpointSaver +from dlrover.trainer.torch.flash_checkpoint.ddp import ( + DdpCheckpointer, + StorageType, +) + + +class SimpleNet(nn.Module): + def __init__(self): + super(SimpleNet, self).__init__() + self.fc1 = nn.Linear(64, 32) + self.fc2 = nn.Linear(32, 10) + self.dropout = nn.Dropout(0.5) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = self.dropout(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +class DdpCheckpoinerTest(unittest.TestCase): + def setUp(self) -> None: + DdpCheckpointSaver.start_async_saving_ckpt() + + def tearDown(self) -> None: + if DdpCheckpointSaver._saver_instance: + DdpCheckpointSaver._saver_instance.close() + + def test_ddp_checkppinter(self): + model = SimpleNet() + with tempfile.TemporaryDirectory() as tmpdir: + checkpointer = DdpCheckpointer(tmpdir) + step = 100 + sd = {"model": model.state_dict()} + checkpointer.save_checkpoint( + step, sd, storage_type=StorageType.MEMORY + ) + sd = checkpointer.load_checkpoint() + self.assertTrue("model" in sd) diff --git a/examples/pytorch/nanogpt/ds_train.py b/examples/pytorch/nanogpt/ds_train.py index 1f0c3703c..d820f684e 100644 --- a/examples/pytorch/nanogpt/ds_train.py +++ b/examples/pytorch/nanogpt/ds_train.py @@ -18,7 +18,7 @@ dlrover-run --nnodes=1 --max_restarts=2 --nproc_per_node=2 \ ds_train.py --n_layer 36 --n_head 20 --n_embd 1280 \ --data_dir './' --ds_config ./ds_config.json \ - --epochs 50 --checkpoint_step 50 + --epochs 50 --save_memory_interval 50 --save_storage_interval 500 """ import argparse diff --git a/examples/pytorch/nanogpt/fsdp_train.py b/examples/pytorch/nanogpt/fsdp_train.py index 9ab662ad6..e82a318da 100644 --- a/examples/pytorch/nanogpt/fsdp_train.py +++ b/examples/pytorch/nanogpt/fsdp_train.py @@ -12,6 +12,15 @@ # limitations under the License. +""" +The start command on a local ndoe: + +dlrover-run --nproc_per_node=2 fsdp_train.py \ + --n_layer 48 --n_head 16 --n_embd 1600 --data_dir './' \ + --epochs 50 --save_memory_interval 50 --save_storage_interval 500 +""" + + import argparse import contextlib import functools diff --git a/examples/pytorch/nanogpt/train.py b/examples/pytorch/nanogpt/train.py index c215cf1ff..3d5f01255 100644 --- a/examples/pytorch/nanogpt/train.py +++ b/examples/pytorch/nanogpt/train.py @@ -11,6 +11,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +The start command on a local ndoe: + +dlrover-run --nproc_per_node=2 train.py \ + --n_layer 48 --n_head 16 --n_embd 1600 --data_dir './' \ + --epochs 50 --save_memory_interval 50 --save_storage_interval 500 +""" + import argparse import contextlib From e833e69fb5a619f516786b9252238a7948097e62 Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Tue, 2 Jan 2024 16:47:22 +0800 Subject: [PATCH 7/9] Fix test cases. --- dlrover/trainer/tests/torch/ddp_checkpointer_test.py | 3 ++- examples/pytorch/nanogpt/fsdp_train.py | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/dlrover/trainer/tests/torch/ddp_checkpointer_test.py b/dlrover/trainer/tests/torch/ddp_checkpointer_test.py index 760053fe2..cb6c19df0 100644 --- a/dlrover/trainer/tests/torch/ddp_checkpointer_test.py +++ b/dlrover/trainer/tests/torch/ddp_checkpointer_test.py @@ -42,13 +42,14 @@ def forward(self, x): class DdpCheckpoinerTest(unittest.TestCase): def setUp(self) -> None: + DdpCheckpointSaver._saver_instance = None DdpCheckpointSaver.start_async_saving_ckpt() def tearDown(self) -> None: if DdpCheckpointSaver._saver_instance: DdpCheckpointSaver._saver_instance.close() - def test_ddp_checkppinter(self): + def test_ddp_checkpointer(self): model = SimpleNet() with tempfile.TemporaryDirectory() as tmpdir: checkpointer = DdpCheckpointer(tmpdir) diff --git a/examples/pytorch/nanogpt/fsdp_train.py b/examples/pytorch/nanogpt/fsdp_train.py index e82a318da..0582e9b3f 100644 --- a/examples/pytorch/nanogpt/fsdp_train.py +++ b/examples/pytorch/nanogpt/fsdp_train.py @@ -298,7 +298,6 @@ def native_save_checkpoint(step, model, optimizer, save_storage_interval): "optim": FSDP.optim_state_dict(model, optimizer), "step": step, } - print(f"save checkpoint to {ckpt_dir}") if step % save_storage_interval == 0: dist_cp.save_state_dict( state_dict=state_dict, @@ -356,7 +355,6 @@ def flash_save_checkpoint( "step": step, } ckpt_dir = os.path.join(checkpoint_dir, str(step)) - print(f"save checkpoint to {ckpt_dir}") if step % save_memory_interval == 0: checkpointer.save_checkpoint( step, state_dict, ckpt_dir, storage_type=StorageType.MEMORY From faa5b4f182b221dd80486005db505f76ebd26827 Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Tue, 2 Jan 2024 17:22:23 +0800 Subject: [PATCH 8/9] Add an argument to set checkpoint dir. --- .../torch/flash_checkpoint/fsdp_engine.py | 3 +++ examples/pytorch/example.dockerfile | 3 ++- examples/pytorch/nanogpt/ds_train.py | 6 +++--- examples/pytorch/nanogpt/fsdp_train.py | 20 ++++++++++++------- examples/pytorch/nanogpt/train.py | 12 +++++++---- examples/pytorch/nanogpt/train_utils.py | 3 +++ 6 files changed, 32 insertions(+), 15 deletions(-) diff --git a/dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py b/dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py index 4afb4be47..65a873736 100644 --- a/dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py +++ b/dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py @@ -503,6 +503,9 @@ def save_to_storage(self, step, state_dict, path): if self._local_rank != 0: return if path: + logger.info( + "Put a save event to notify the agent persists checkpoint." + ) event = CheckpointEvent(type=CheckpointEventType.SAVE, step=step) self._event_queue.put(event) diff --git a/examples/pytorch/example.dockerfile b/examples/pytorch/example.dockerfile index 44994179a..a62f44cf1 100644 --- a/examples/pytorch/example.dockerfile +++ b/examples/pytorch/example.dockerfile @@ -13,7 +13,8 @@ FROM python:3.8.14 as base WORKDIR /dlrover RUN apt-get update && apt-get install -y sudo vim libgl1-mesa-glx libglib2.0-dev -RUN pip install deprecated pyparsing torch==2.0.1 opencv-python==4.7.0.72 torchvision==0.15.2 transformers +RUN pip install deprecated pyparsing torch==2.0.1 opencv-python==4.7.0.72 \ +torchvision==0.15.2 transformers deepspeed COPY ./data /data COPY ./examples ./examples diff --git a/examples/pytorch/nanogpt/ds_train.py b/examples/pytorch/nanogpt/ds_train.py index d820f684e..e54d6e4fe 100644 --- a/examples/pytorch/nanogpt/ds_train.py +++ b/examples/pytorch/nanogpt/ds_train.py @@ -47,12 +47,10 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -# We should use a shared storage to persist the checkpiont. -checkpoint_dir = "/nas/nanogpt-ckpt-ds/" - def train(): args = arg_parser() + checkpoint_dir = args.save_dir setup() os.makedirs(checkpoint_dir, exist_ok=True) world_size = int(os.getenv("WORLD_SIZE", 1)) @@ -212,6 +210,7 @@ def train(): iter_num, args.save_memory_interval, args.save_storage_interval, + checkpoint_dir, ) if saved: save_time = round(time.time() - start_save_t, 2) @@ -243,6 +242,7 @@ def flash_save_checkpoint( iter_num, save_memory_interval, save_storage_interval, + checkpoint_dir, ): saved = False if iter_num % save_memory_interval == 0: diff --git a/examples/pytorch/nanogpt/fsdp_train.py b/examples/pytorch/nanogpt/fsdp_train.py index 0582e9b3f..7323518ff 100644 --- a/examples/pytorch/nanogpt/fsdp_train.py +++ b/examples/pytorch/nanogpt/fsdp_train.py @@ -55,12 +55,10 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -# We should use a shared storage to persist the checkpiont. -checkpoint_dir = "/nas/nanogpt-ckpt-fsdp/" - def train(): args = arg_parser() + checkpoint_dir = args.save_dir setup() os.makedirs(checkpoint_dir, exist_ok=True) world_size = int(os.getenv("WORLD_SIZE", 1)) @@ -162,7 +160,7 @@ def train(): start_load_t = time.time() if args.use_native_ckpt: - iter_num = native_load_checkpoint(0, model, optimizer) + iter_num = native_load_checkpoint(0, model, optimizer, checkpoint_dir) else: checkpointer = FsdpCheckpointer(checkpoint_dir) iter_num = flash_load_checkpoint(checkpointer, model, optimizer) @@ -231,7 +229,11 @@ def train(): start_save_t = time.time() if args.use_native_ckpt: saved = native_save_checkpoint( - iter_num, model, optimizer, args.save_storage_interval + iter_num, + model, + optimizer, + args.save_storage_interval, + checkpoint_dir, ) else: saved = flash_save_checkpoint( @@ -241,6 +243,7 @@ def train(): optimizer, args.save_memory_interval, args.save_storage_interval, + checkpoint_dir, ) if saved: save_time = round(time.time() - start_save_t, 2) @@ -255,7 +258,7 @@ def train(): break -def native_load_checkpoint(step, model, optimizer): +def native_load_checkpoint(step, model, optimizer, checkpoint_dir): with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), @@ -286,7 +289,9 @@ def native_load_checkpoint(step, model, optimizer): return state_dict["step"] -def native_save_checkpoint(step, model, optimizer, save_storage_interval): +def native_save_checkpoint( + step, model, optimizer, save_storage_interval, checkpoint_dir +): saved = False if step % save_storage_interval != 0: return saved @@ -344,6 +349,7 @@ def flash_save_checkpoint( optimizer, save_memory_interval, save_storage_interval, + checkpoint_dir, ): saved = False if step % save_memory_interval != 0 and step % save_storage_interval != 0: diff --git a/examples/pytorch/nanogpt/train.py b/examples/pytorch/nanogpt/train.py index 3d5f01255..2598ccda3 100644 --- a/examples/pytorch/nanogpt/train.py +++ b/examples/pytorch/nanogpt/train.py @@ -48,12 +48,10 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -# We should use a shared storage to persist the checkpiont. -checkpoint_dir = "/nas/nanogpt-ckpt/" - def train(): args = arg_parser() + checkpoint_dir = args.save_dir setup() os.makedirs(checkpoint_dir, exist_ok=True) world_size = int(os.getenv("WORLD_SIZE", 1)) @@ -244,6 +242,7 @@ def train(): optimizer, train_loader, args.save_storage_interval, + checkpoint_dir, ) else: saved = flash_save_checkpoint( @@ -269,7 +268,12 @@ def train(): def native_save_checkpoint( - iter_num, model, optimizer, train_loader, save_storage_interval + iter_num, + model, + optimizer, + train_loader, + save_storage_interval, + checkpoint_dir, ): saved = False if iter_num % save_storage_interval != 0: diff --git a/examples/pytorch/nanogpt/train_utils.py b/examples/pytorch/nanogpt/train_utils.py index 0ab917f40..2e504c710 100644 --- a/examples/pytorch/nanogpt/train_utils.py +++ b/examples/pytorch/nanogpt/train_utils.py @@ -234,3 +234,6 @@ def add_train_args(parser: argparse.ArgumentParser): parser.add_argument( "--use_native_ckpt", action="store_true", required=False ) + parser.add_argument( + "--save_dir", type=str, default="/tmp/checkpoint/", required=False + ) From caf96e4fdd65fcadf7bb9774209734c4f8867b63 Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Tue, 2 Jan 2024 17:54:20 +0800 Subject: [PATCH 9/9] Fix train_batch size --- examples/pytorch/nanogpt/ds_config.json | 36 ++++++++++++------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/examples/pytorch/nanogpt/ds_config.json b/examples/pytorch/nanogpt/ds_config.json index 09df4686a..56afc6fcf 100644 --- a/examples/pytorch/nanogpt/ds_config.json +++ b/examples/pytorch/nanogpt/ds_config.json @@ -1,21 +1,21 @@ { "zero_optimization": { - "stage": 1, - "overlap_comm": true, - "contiguous_gradients": true, - "sub_group_size": 1e9, - "reduce_bucket_size": "auto", - "stage3_prefetch_bucket_size": "auto", - "stage3_param_persistence_threshold": "auto", - "stage3_max_live_parameters": 1e9, - "stage3_max_reuse_distance": 1e9, - "stage3_gather_fp16_weights_on_model_save": true - }, + "stage": 1, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_fp16_weights_on_model_save": true + }, - "gradient_accumulation_steps": 1, - "gradient_clipping": 0.1, - "steps_per_print": 100, - "train_batch_size": 32, - "train_micro_batch_size_per_gpu": 16, - "wall_clock_breakdown": false - } \ No newline at end of file + "gradient_accumulation_steps": 1, + "gradient_clipping": 0.1, + "steps_per_print": 100, + "train_batch_size": 256, + "train_micro_batch_size_per_gpu": 16, + "wall_clock_breakdown": false +} \ No newline at end of file