Skip to content

Commit

Permalink
Merge pull request #924 from workingloong/refactor-mnist-example
Browse files Browse the repository at this point in the history
Refactor the codes to save/load checkpoint of DDP.
  • Loading branch information
samplise authored Jan 2, 2024
2 parents 08f84d2 + caf96e4 commit 325c2ab
Show file tree
Hide file tree
Showing 22 changed files with 240 additions and 868 deletions.
41 changes: 29 additions & 12 deletions dlrover/python/elastic_agent/torch/ckpt_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,18 +746,6 @@ def save_step_checkpoint(self, step: int):

self._writing_storage = False

def persist_to_storage(
self,
local_shard_id: int,
ckpt_config: SingleFileCheckpointConfig,
):
"""Persist the checkpoint from CPU memory buffer into the storage."""
state_dict = self._shm_handlers[local_shard_id].load_state_dict()
state_dict.pop(DLROVER_CKPT_CONFIG_KEY, None)
checkpoint_dir = os.path.dirname(ckpt_config.path)
os.makedirs(checkpoint_dir, exist_ok=True)
torch.save(state_dict, ckpt_config.path)

def commit_checkpoint(self, step: int, step_done_dir: str, timeout=600):
"""
The node rank 0 will update the tracker file with the step
Expand Down Expand Up @@ -950,6 +938,24 @@ def commit_checkpoint( # type: ignore
time.sleep(2)


class DdpCheckpointSaver(CommonDirCheckpointSaver):
"""Persist the checkpoint from CPU memory buffer into the storage."""

def persist_to_storage(
self,
local_shard_id: int,
ckpt_config: SingleFileCheckpointConfig,
):
if self._node_rank != 0:
logger.info("Skip and only rank 0 saves checkpoint in a DDP job.")
return
state_dict = self._shm_handlers[local_shard_id].load_state_dict()
state_dict.pop(DLROVER_CKPT_CONFIG_KEY, None)
checkpoint_dir = os.path.dirname(ckpt_config.path)
os.makedirs(checkpoint_dir, exist_ok=True)
torch.save(state_dict, ckpt_config.path)


class MegatronCheckpointSaver(CommonDirCheckpointSaver):
TRACER_FILE = "latest_checkpointed_iteration.txt"

Expand All @@ -971,6 +977,17 @@ def update_tracker_file(self, step):
with open(ds_tracker_filename, "w") as f:
f.write(str(step))

def persist_to_storage(
self,
local_shard_id: int,
ckpt_config: SingleFileCheckpointConfig,
):
state_dict = self._shm_handlers[local_shard_id].load_state_dict()
state_dict.pop(DLROVER_CKPT_CONFIG_KEY, None)
checkpoint_dir = os.path.dirname(ckpt_config.path)
os.makedirs(checkpoint_dir, exist_ok=True)
torch.save(state_dict, ckpt_config.path)


class DeepSpeedCheckpointSaver(CommonDirCheckpointSaver):
TRACER_FILE = "latest"
Expand Down
12 changes: 8 additions & 4 deletions dlrover/python/master/scaler/pod_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def _retry_to_get_master_pod(self):

def scale(self, plan: ScalePlan):
"""Scale in/out Pods by a ScalePlan."""

self._remove_nodes(plan)
while True:
waited = False
with self._lock:
Expand Down Expand Up @@ -179,12 +181,14 @@ def scale(self, plan: ScalePlan):
self._scale_down_pods(type, plan, cur_pods)
for node in plan.launch_nodes:
self._create_node_queue.append(node)
for node in plan.remove_nodes:
removed = self._remove_not_create_pod(node.name)
if not removed:
self._k8s_client.delete_pod(node.name)
self._update_job_pods(job_pods)

def _remove_nodes(self, plan: ScalePlan):
for node in plan.remove_nodes:
removed = self._remove_not_create_pod(node.name)
if not removed:
self._k8s_client.delete_pod(node.name)

def _update_job_pods(self, job_pods: Dict[str, List[Node]]):
for type in [
NodeType.CHIEF,
Expand Down
17 changes: 10 additions & 7 deletions dlrover/python/tests/test_ckpt_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
CheckpointEventType,
CheckpointShardConfig,
CheckpointSharedObjPrefix,
CommonDirCheckpointSaver,
DdpCheckpointSaver,
FsdpDcpSaver,
SaverClassMeta,
SharedMemoryHandler,
Expand Down Expand Up @@ -110,8 +110,8 @@ def tearDown(self) -> None:
def test_create_checkpoint_saver(self):
sq = SharedQueue(name="factory", create=False)
class_meta = SaverClassMeta(
module_path=CommonDirCheckpointSaver.__module__,
class_name=CommonDirCheckpointSaver.__name__,
module_path=DdpCheckpointSaver.__module__,
class_name=DdpCheckpointSaver.__name__,
init_args={"checkpoint_dir": "test_ckpt"},
)
sq.put(class_meta)
Expand All @@ -123,7 +123,7 @@ def test_create_checkpoint_saver(self):
self.assertIsNotNone(AsyncCheckpointSaver._saver_instance)

def test_close_saver(self):
saver = CommonDirCheckpointSaver("test_ckpt")
saver = DdpCheckpointSaver("test_ckpt")
try:
SharedMemory(name="test").unlink()
except Exception:
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_save_to_storage(self):
step=step,
)
with tempfile.TemporaryDirectory() as tmpdir:
saver = CommonDirCheckpointSaver(tmpdir)
saver = DdpCheckpointSaver(tmpdir)
path = Path(tmpdir) / "checkpoint.pt"
ckpt_config = SingleFileCheckpointConfig(step=100, path=path)
saver._shm_handlers[0].save_state_dict(state_dict, ckpt_config)
Expand All @@ -192,9 +192,12 @@ def test_save_to_storage(self):
self.assertEqual(len(ckpt_files), 3)
saver.close()

saver._node_rank = 1
saver.persist_to_storage(0, None)

def test_shard_num_changes(self):
with tempfile.TemporaryDirectory() as tmpdir:
saver = CommonDirCheckpointSaver(tmpdir)
saver = DdpCheckpointSaver(tmpdir)
saver.global_shard_num = 1
threading.Thread(
target=saver._sync_shm_to_storage, daemon=True
Expand All @@ -216,7 +219,7 @@ def test_commit_checkpoint(self):
with tempfile.TemporaryDirectory() as tmpdir:
step_done_dir = os.path.join(tmpdir, ".done/10/")
os.makedirs(step_done_dir, exist_ok=True)
saver = CommonDirCheckpointSaver(tmpdir)
saver = DdpCheckpointSaver(tmpdir)
saver.global_shard_num = 1
saver.commit_checkpoint(100, step_done_dir, 2)

Expand Down
4 changes: 2 additions & 2 deletions dlrover/python/tests/test_elastic_training_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from dlrover.python.elastic_agent.monitor.training import TorchTrainingMonitor
from dlrover.python.elastic_agent.torch.ckpt_saver import (
AsyncCheckpointSaver,
CommonDirCheckpointSaver,
DdpCheckpointSaver,
)
from dlrover.python.elastic_agent.torch.training import (
ElasticLaunchConfig,
Expand Down Expand Up @@ -284,7 +284,7 @@ def test_restart_training(self):
start_method=self.config.start_method,
log_dir=self.config.log_dir,
)
saver = CommonDirCheckpointSaver("/tmp/test")
saver = DdpCheckpointSaver("/tmp/test")
AsyncCheckpointSaver._saver_instance = saver
agent._save_ckpt_to_storage()
agent._stop_workers_to_restart()
Expand Down
3 changes: 3 additions & 0 deletions dlrover/python/tests/test_pod_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ def test_scale(self):
scale_plan.launch_nodes.append(
Node(NodeType.WORKER, 1, NodeResource(0, 0))
)
scale_plan.remove_nodes.append(
Node(NodeType.WORKER, 3, NodeResource(0, 0))
)
scaler.scale(scale_plan)
self.assertFalse(scale_plan.empty())
self.assertEqual(len(scaler._create_node_queue), 2)
Expand Down
1 change: 1 addition & 0 deletions dlrover/python/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def mock_k8s_client():
return_value=True
)
k8s_client.create_pod = mock.MagicMock(return_value=True) # type: ignore
k8s_client.delete_pod = mock.MagicMock(return_value=True) # type: ignore
k8s_client.create_service = mock.MagicMock( # type: ignore
return_value=True
)
Expand Down
153 changes: 0 additions & 153 deletions dlrover/trainer/tests/torch/checkpoint_test.py

This file was deleted.

Loading

0 comments on commit 325c2ab

Please sign in to comment.