From 84f3e4a9464db7a9165d1862af5aa16294fd7623 Mon Sep 17 00:00:00 2001 From: Tao He Date: Tue, 26 Dec 2023 10:30:56 +0800 Subject: [PATCH] Enable sharing pytorch modules (#1695) Resolves part of #1659. Signed-off-by: Tao He --- python/core.cc | 13 +++ python/vineyard/contrib/ml/README.md | 38 +++++++ .../vineyard/contrib/ml/tests/test_torch.py | 45 ++++++++ python/vineyard/contrib/ml/torch.py | 102 ++++++++++++++---- python/vineyard/core/resolver.py | 3 + python/vineyard/data/utils.py | 2 + 6 files changed, 185 insertions(+), 18 deletions(-) diff --git a/python/core.cc b/python/core.cc index e827b073e..06bb3db67 100644 --- a/python/core.cc +++ b/python/core.cc @@ -224,6 +224,19 @@ void bind_core(py::module& mod) { return py::cast(self->GetMember(key)); }, doc::ObjectMeta_get_member) + .def( + "member" /* alias for get_member() */, + [](ObjectMeta* self, std::string const& key) -> py::object { + auto const& tree = self->MetaData(); + auto iter = tree.find(key); + if (iter == tree.end()) { + return py::none(); + } + VINEYARD_ASSERT(iter->is_object() && !iter->empty(), + "The value is not a member, but a meta"); + return py::cast(self->GetMember(key)); + }, + doc::ObjectMeta_get_member) .def("get_buffer", [](ObjectMeta* self, const ObjectID key) -> py::memoryview { std::shared_ptr buffer; diff --git a/python/vineyard/contrib/ml/README.md b/python/vineyard/contrib/ml/README.md index 88f8beaf4..6b790796d 100644 --- a/python/vineyard/contrib/ml/README.md +++ b/python/vineyard/contrib/ml/README.md @@ -10,6 +10,8 @@ and inference tasks in these frameworks. Examples -------- +### Datasets + The following examples shows how `DataFrame` in vineyard can be used as the input of Dataset for PyTorch: @@ -49,6 +51,42 @@ for data, label in pipe: pass ``` +### Pytorch Modules + +The following example shows how to use vineyard to share pytorch modules between processes: + +```python +import torch +import vineyard + +# connected to vineyard, see also: https://v6d.io/notes/getting-started.html +client = vineyard.connect(os.environ['VINEYARD_IPC_SOCKET']) + +# generate a dummy model in vineyard +class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5) + self.conv2 = nn.Conv2d(20, 20, 5) + + def forward(self, x): + x = F.relu(self.conv1(x)) + return F.relu(self.conv2(x)) + +model = Model() + +# put the model into vineyard +from vineyard.contrib.ml.torch import torch_context +with torch_context(): + object_id = client.put(model) + +# get the module state dict from vineyard and load it into a new model +model = Model() +with torch_context(): + state_dict = client.get(object_id) +model.load_state_dict(state_dict, assign=True) +``` + Reference and Implementation ---------------------------- diff --git a/python/vineyard/contrib/ml/tests/test_torch.py b/python/vineyard/contrib/ml/tests/test_torch.py index fb4fd4a4b..ee69e5b35 100644 --- a/python/vineyard/contrib/ml/tests/test_torch.py +++ b/python/vineyard/contrib/ml/tests/test_torch.py @@ -16,6 +16,11 @@ # limitations under the License. # +import copy +import itertools +from typing import Any +from typing import Dict + import numpy as np import pandas as pd import pyarrow as pa @@ -30,6 +35,8 @@ from vineyard.data.dataframe import NDArrayArray torch = lazy_import.lazy_module("torch") +nn = lazy_import.lazy_module("torch.nn") +F = lazy_import.lazy_module("torch.nn.functional") torchdata = lazy_import.lazy_module("torchdata") @@ -130,3 +137,41 @@ def test_torch_dataset_table(vineyard_client): assert torch.isclose( value.tensors[2], torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float64) ).all() + + +class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5) + self.conv2 = nn.Conv2d(20, 20, 5) + + def forward(self, x): + x = F.relu(self.conv1(x)) + return F.relu(self.conv2(x)) + + +def assert_torch_module_equal(model1, model2): + assert isinstance(model1, nn.Module) + assert isinstance(model2, nn.Module) + assert len(list(model1.parameters())) == len(list(model2.parameters())) + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert torch.allclose(p1, p2), f'{p1} != {p2}' + + +@pytest_cases.parametrize( + "vineyard_client,model", + itertools.product( + [vineyard_client, vineyard_rpc_client], + [nn.Linear(5, 2), nn.Conv2d(1, 20, 5), Model()], + ), +) +def test_torch_module(vineyard_client, model): + object_id = vineyard_client.put(model) + value: Dict[str, Any] = vineyard_client.get(object_id) + + result = copy.deepcopy(model) + result.to(torch.device('meta')) + result.load_state_dict(value, assign=True) + + # check the module's equality + assert_torch_module_equal(model, result) diff --git a/python/vineyard/contrib/ml/torch.py b/python/vineyard/contrib/ml/torch.py index 7a5d9b4b5..96ee6f29b 100644 --- a/python/vineyard/contrib/ml/torch.py +++ b/python/vineyard/contrib/ml/torch.py @@ -30,10 +30,10 @@ from vineyard._C import ObjectMeta from vineyard.core import context +from vineyard.data.utils import from_json from vineyard.data.utils import to_json torch = lazy_import.lazy_module("torch") -torchdata = lazy_import.lazy_module("torchdata") class WholeBatchSampler(torch.utils.data.Sampler[List[int]]): @@ -137,25 +137,9 @@ def torch_global_dataframe_resolver(obj, resolver, **_kw): return torch.utils.data.ConcatDataset(data) -def register_torch_types(builder_ctx, resolver_ctx): - if builder_ctx is not None: - builder_ctx.register(torch.Tensor, torch_tensor_builder) - builder_ctx.register(torch.utils.data.Dataset, torch_dataset_builder) - - if resolver_ctx is not None: - resolver_ctx.register('vineyard::Tensor', torch_tensor_resolver) - resolver_ctx.register('vineyard::DataFrame', torch_dataset_resolver) - resolver_ctx.register('vineyard::RecordBatch', torch_dataset_resolver) - resolver_ctx.register('vineyard::Table', torch_dataset_resolver) - resolver_ctx.register('vineyard::GlobalTensor', torch_global_tensor_resolver) - resolver_ctx.register( - 'vineyard::GlobalDataFrame', torch_global_dataframe_resolver - ) - - def datapipe( dataset: torch.utils.data.Dataset, -) -> torchdata.datapipes.iter.IterableWrapper: +): # -> "torchdata.datapipes.iter.IterableWrapper": '''Convert a torch.utils.data.Dataset to a torchdata.datapipes.iter.IterableWrapper. e.g., @@ -182,9 +166,91 @@ def datapipe( Returns: A torchdata.datapipes.iter.IterableWrapper. ''' + import torchdata + return torchdata.datapipes.iter.IterableWrapper(dataset) +def torch_module_builder(client, value, builder, **kw): + def go(state_dict, key_prefix, tensors): + if isinstance(state_dict, torch.Tensor): + r = builder.run(client, state_dict, **kw) + tensors[key_prefix] = r + if isinstance(r, ObjectMeta): + r = r.id + return r + elif isinstance(state_dict, dict): + keys = list(state_dict.keys()) + for key in keys: + state_dict[key] = go(state_dict[key], f'{key_prefix}.{key}', tensors) + return state_dict + elif isinstance(state_dict, (tuple, list)): + return [ + go(element, f'{key_prefix}.{i}', tensors) + for i, element in enumerate(state_dict) + ] + else: + return state_dict + + if isinstance(value, torch.nn.Module): + value = value.state_dict() + + tensors = dict() + value = go(value, 'tensor', tensors) + + meta = ObjectMeta() + meta['typename'] = 'vineyard::torch::Module' + meta['state_dict'] = to_json(value) + for key, tensor in tensors.items(): + meta.add_member(key, tensor) + return client.create_metadata(meta) + + +def torch_module_resolver(obj, resolver, **kw): + def go(state_dict, key_prefix, tensors): + if key_prefix in tensors: + return tensors[key_prefix] + elif isinstance(state_dict, dict): + keys = list(state_dict.keys()) + for key in keys: + state_dict[key] = go(state_dict[key], f'{key_prefix}.{key}', tensors) + return state_dict + elif isinstance(state_dict, (tuple, list)): + return [ + go(element, f'{key_prefix}.{i}', tensors) + for i, element in enumerate(state_dict) + ] + else: + return state_dict + + meta = obj.meta + state_dict = from_json(meta['state_dict']) + tensors = dict() + for key, value in meta.items(): + if key.startswith('tensor.'): + tensors[key] = resolver.run(value, **kw) + state_dict = go(state_dict, 'tensor', tensors) + return state_dict + + +def register_torch_types(builder_ctx, resolver_ctx): + if builder_ctx is not None: + builder_ctx.register(torch.Tensor, torch_tensor_builder) + builder_ctx.register(torch.utils.data.Dataset, torch_dataset_builder) + builder_ctx.register(torch.nn.Module, torch_module_builder) + + if resolver_ctx is not None: + resolver_ctx.register('vineyard::Tensor', torch_tensor_resolver) + resolver_ctx.register('vineyard::DataFrame', torch_dataset_resolver) + resolver_ctx.register('vineyard::RecordBatch', torch_dataset_resolver) + resolver_ctx.register('vineyard::Table', torch_dataset_resolver) + resolver_ctx.register('vineyard::GlobalTensor', torch_global_tensor_resolver) + resolver_ctx.register( + 'vineyard::GlobalDataFrame', torch_global_dataframe_resolver + ) + resolver_ctx.register('vineyard::torch::Module', torch_module_resolver) + + @contextlib.contextmanager def torch_context(): with context() as (builder_ctx, resolver_ctx): diff --git a/python/vineyard/core/resolver.py b/python/vineyard/core/resolver.py index 5f2d1c6e9..8def7d4f0 100644 --- a/python/vineyard/core/resolver.py +++ b/python/vineyard/core/resolver.py @@ -45,6 +45,9 @@ def __init__(self, parent_context: Optional["ResolverContext"] = None): def __str__(self) -> str: return str(self._factory) + def __repr__(self) -> str: + return repr(self._factory) + @property def parent_context(self) -> "ResolverContext": return self._parent_context diff --git a/python/vineyard/data/utils.py b/python/vineyard/data/utils.py index 2415ceb2d..48e0d82d4 100644 --- a/python/vineyard/data/utils.py +++ b/python/vineyard/data/utils.py @@ -184,6 +184,8 @@ def build_numpy_buffer(client, array): def default_json_encoder(value): if isinstance(value, (np.integer, np.floating)): return value.item() + if isinstance(value, ObjectID): + return int(value) raise TypeError