Skip to content

Commit

Permalink
Create the tensor from shm with torch.frombuffer. (#868)
Browse files Browse the repository at this point in the history
  • Loading branch information
workingloong authored Dec 3, 2023
1 parent 4f1004d commit 33345aa
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 99 deletions.
4 changes: 2 additions & 2 deletions dlrover/python/common/multi_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _create_socket_client(path):
client.connect(path)
connected = True
break
except FileNotFoundError:
except (FileNotFoundError, ConnectionRefusedError):
time.sleep(0.1)
if not connected:
client.connect(path)
Expand Down Expand Up @@ -460,7 +460,7 @@ def update(self, new_dict):
self._shared_queue.put(1)
self._request(request)
except Exception:
logger.info("The recv processs has breakdown.")
logger.info("The recv process has breakdown.")

def get(self, local=False):
"""
Expand Down
54 changes: 12 additions & 42 deletions dlrover/python/elastic_agent/torch/ckpt_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from datetime import timedelta
from typing import Callable, Dict, List, Mapping, Tuple

import numpy as np
import torch
import torch.distributed as dist

Expand Down Expand Up @@ -77,27 +76,6 @@ def _init_dir(dir):
os.makedirs(dir)


def _convert_torch_dtype_to_numpy(torch_dtype):
"""Conver the torch dtype to numpy dtype."""
dtype_map = {
torch.float32: np.float32,
torch.float: np.float32,
torch.float64: np.float64,
torch.double: np.double,
torch.float16: np.float16,
torch.half: np.half,
torch.uint8: np.uint8,
torch.int8: np.int8,
torch.int16: np.int16,
torch.short: np.short,
torch.int32: np.int32,
torch.int: np.int32,
torch.long: np.int64,
torch.bool: np.dtype("bool"),
}
return dtype_map[torch_dtype]


def _traverse_state_dict(value: object, visitor: Callable[[object], None]):
"""
Invoke ``visitor`` for each value recursively in ``state_dict``.
Expand Down Expand Up @@ -129,13 +107,13 @@ def _read_tensor_from_buf(value, shm_tensor_buffer):
Read a tensor from the buffer of shared memory.
"""
if isinstance(value, TensorMeta):
data_array = np.frombuffer(
shm_tensor = torch.frombuffer(
buffer=shm_tensor_buffer.buf,
dtype=value.dtype,
offset=value.offset,
count=value.numel,
)
value = torch.reshape(torch.tensor(data_array), value.shape)
value = shm_tensor.reshape(value.shape)
return value
else:
return value
Expand Down Expand Up @@ -213,21 +191,14 @@ def _tarverse_copy_to_shm(value, meta, buffer):
meta[i] = v


def _write_shared_memory(value, meta: TensorMeta, buffer):
def _write_shared_memory(value: torch.Tensor, meta: TensorMeta, buffer):
"""
Write a CPU tensor into the shared memory.
"""
data_array = value.cpu().numpy()
write_array = np.ndarray(
data_array.shape,
dtype=data_array.dtype,
buffer=buffer,
offset=meta.offset,
)
if data_array.shape == ():
write_array.fill(data_array)
else:
write_array[:] = data_array[:]
shm_tensor = torch.frombuffer(
buffer, dtype=value.dtype, count=value.numel(), offset=meta.offset
).reshape(value.shape)
shm_tensor.copy_(value)


def _load_from_historic_checkpoint(checkpoint_dir):
Expand Down Expand Up @@ -270,10 +241,6 @@ def __init__(self, checkpoint_dir, num_proc=1):
self.checkpoint_dir = checkpoint_dir
self.num_proc = num_proc

@abstractmethod
def _sync_shm_to_storage(self):
pass

@classmethod
def start_async_saving_ckpt(cls):
"""
Expand All @@ -299,6 +266,10 @@ def _save():
target=_save, name="checkpoint-saver", daemon=True
).start()

@abstractmethod
def _sync_shm_to_storage(self):
pass

@classmethod
def get_ckpt_saver(cls):
return cls._saver_instance
Expand Down Expand Up @@ -618,10 +589,9 @@ def _create_tensor_meta(self, value: torch.Tensor):
"""
if not torch.is_tensor(value):
return value
dtype = _convert_torch_dtype_to_numpy(value.dtype)
meta = TensorMeta(
shape=tuple(value.shape), # type: ignore
dtype=dtype,
dtype=value.dtype,
element_size=value.element_size(),
numel=value.numel(),
offset=self._buffer_size,
Expand Down
60 changes: 48 additions & 12 deletions dlrover/python/tests/test_ckpt_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import time
import unittest

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -29,7 +28,8 @@
NoShardingCheckpointEngine,
NoShardingSaver,
SaverClassMeta,
_convert_torch_dtype_to_numpy,
_create_shared_memory,
_load_from_historic_checkpoint,
_traverse_state_dict,
)

Expand Down Expand Up @@ -96,16 +96,6 @@ def visitor(value):
new_dict = _traverse_state_dict(state_dict, visitor)
self.assertEqual(new_dict, state_dict)

def test_convert_torch_dtype_to_numpy(self):
np_dtype = _convert_torch_dtype_to_numpy(torch.float32)
self.assertEqual(np_dtype, np.float32)

np_dtype = _convert_torch_dtype_to_numpy(torch.float)
self.assertEqual(np_dtype, np.float32)

np_dtype = _convert_torch_dtype_to_numpy(torch.int32)
self.assertEqual(np_dtype, np.int32)

def test_save_to_storage(self):
model = SimpleNet()
step = 100
Expand All @@ -131,3 +121,49 @@ def test_save_to_storage(self):
ckpt_files = os.listdir(tmpdir)
self.assertEqual(len(ckpt_files), 1)
sq.close()


class CheckpointEngineTest(unittest.TestCase):
def setUp(self):
CheckpointSaver._saver_instance = None
CheckpointSaver.start_async_saving_ckpt()

def test_create_shared_memory(self):
shm = _create_shared_memory("test", False)
self.assertIsNone(shm)

def test_create_tensor_meta(self):
engine = NoShardingCheckpointEngine("test-ckpt")
value = torch.rand((10, 10), dtype=torch.float32)
meta = engine._create_tensor_meta(value)
self.assertEqual(meta.numel, 100)
self.assertEqual(meta.element_size, 4)
self.assertEqual(meta.offset, 0)
self.assertEqual(meta.shape, (10, 10))
self.assertEqual(meta.dtype, torch.float32)
engine.close()

def test_load_no_sharding(self):
model = SimpleNet()
step = 100
state_dict = dict(
model=model.state_dict(),
step=step,
)

with tempfile.TemporaryDirectory() as tmpdirname:
engine = NoShardingCheckpointEngine(tmpdirname)
path = os.path.join(tmpdirname, "checkpoint-10/checkpoint.pt")
os.makedirs(os.path.dirname(path))
torch.save(state_dict, path)
path = os.path.join(tmpdirname, "checkpoint-20/checkpoint.pt")
os.makedirs(os.path.dirname(path))
with open(path, "w") as f:
f.write("A error checkpoint\n")
loaded_state_dict = _load_from_historic_checkpoint(
engine.checkpoint_dir
)
for key, value in state_dict["model"].items():
loaded_value = loaded_state_dict["model"][key]
self.assertTrue(torch.equal(value, loaded_value))
engine.close()
50 changes: 7 additions & 43 deletions dlrover/trainer/tests/torch/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@
from dlrover.python.common import grpc
from dlrover.python.elastic_agent.torch.ckpt_saver import (
CheckpointSaver,
NoShardingCheckpointEngine,
_create_shared_memory,
_get_latest_checkpoint,
_load_from_historic_checkpoint,
)
from dlrover.trainer.torch.elastic.checkpoint import CheckpointManger
from dlrover.trainer.torch.elastic.sampler import ElasticDistributedSampler
Expand Down Expand Up @@ -106,17 +103,14 @@ def setUp(self):
CheckpointSaver._saver_instance = None
CheckpointSaver.start_async_saving_ckpt()

def test_create_shared_memory(self):
shm = _create_shared_memory("test", False)
self.assertIsNone(shm)

def test_ddp_save_load(self):
os.environ["LOCAL_RANK"] = "0"
port = grpc.find_free_port()
set_torch_dist_env(port)
dist.init_process_group(backend="gloo")
model, optimizer, dataloader = create_torch_modules()
model = DDP(model)
msd = model.state_dict()
with tempfile.TemporaryDirectory() as tmpdirname:
ckpt_manager = CheckpointManger.init_checkpoint_manager(
model,
Expand All @@ -137,41 +131,11 @@ def test_ddp_save_load(self):

ckpt_manager.load()
self.assertEqual(dataloader.sampler.total_size, 60002)
resume_msd = ckpt_manager.model.state_dict()
self.assertTrue(
torch.equal(
msd["module.fc1.weight"], resume_msd["module.fc1.weight"]
)
)
ckpt_manager._ckpt_engine.close()
dist.destroy_process_group()

def test_create_tensor_meta(self):
engine = NoShardingCheckpointEngine("test-ckpt")
value = torch.rand((10, 10), dtype=torch.float32)
meta = engine._create_tensor_meta(value)
self.assertEqual(meta.numel, 100)
self.assertEqual(meta.element_size, 4)
self.assertEqual(meta.offset, 0)
self.assertEqual(meta.shape, (10, 10))
self.assertEqual(meta.dtype, np.float32)
engine.close()

def test_load_no_sharding(self):
model = SimpleNet()
step = 100
state_dict = dict(
model=model.state_dict(),
step=step,
)

with tempfile.TemporaryDirectory() as tmpdirname:
engine = NoShardingCheckpointEngine(tmpdirname)
path = os.path.join(tmpdirname, "checkpoint-10/checkpoint.pt")
os.makedirs(os.path.dirname(path))
torch.save(state_dict, path)
path = os.path.join(tmpdirname, "checkpoint-20/checkpoint.pt")
os.makedirs(os.path.dirname(path))
with open(path, "w") as f:
f.write("A error checkpoint\n")
loaded_state_dict = _load_from_historic_checkpoint(
engine.checkpoint_dir
)
for key, value in state_dict["model"].items():
loaded_value = loaded_state_dict["model"][key]
self.assertTrue(torch.equal(value, loaded_value))
engine.close()

0 comments on commit 33345aa

Please sign in to comment.