Skip to content

Commit

Permalink
Enable sharing pytorch modules (#1695)
Browse files Browse the repository at this point in the history
Resolves part of #1659.

Signed-off-by: Tao He <[email protected]>
  • Loading branch information
sighingnow authored Dec 26, 2023
1 parent 5b0ef1a commit 84f3e4a
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 18 deletions.
13 changes: 13 additions & 0 deletions python/core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,19 @@ void bind_core(py::module& mod) {
return py::cast(self->GetMember(key));
},
doc::ObjectMeta_get_member)
.def(
"member" /* alias for get_member() */,
[](ObjectMeta* self, std::string const& key) -> py::object {
auto const& tree = self->MetaData();
auto iter = tree.find(key);
if (iter == tree.end()) {
return py::none();
}
VINEYARD_ASSERT(iter->is_object() && !iter->empty(),
"The value is not a member, but a meta");
return py::cast(self->GetMember(key));
},
doc::ObjectMeta_get_member)
.def("get_buffer",
[](ObjectMeta* self, const ObjectID key) -> py::memoryview {
std::shared_ptr<Buffer> buffer;
Expand Down
38 changes: 38 additions & 0 deletions python/vineyard/contrib/ml/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and inference tasks in these frameworks.
Examples
--------

### Datasets

The following examples shows how `DataFrame` in vineyard can be used as the input
of Dataset for PyTorch:

Expand Down Expand Up @@ -49,6 +51,42 @@ for data, label in pipe:
pass
```

### Pytorch Modules

The following example shows how to use vineyard to share pytorch modules between processes:

```python
import torch
import vineyard

# connected to vineyard, see also: https://v6d.io/notes/getting-started.html
client = vineyard.connect(os.environ['VINEYARD_IPC_SOCKET'])

# generate a dummy model in vineyard
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)

def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))

model = Model()

# put the model into vineyard
from vineyard.contrib.ml.torch import torch_context
with torch_context():
object_id = client.put(model)

# get the module state dict from vineyard and load it into a new model
model = Model()
with torch_context():
state_dict = client.get(object_id)
model.load_state_dict(state_dict, assign=True)
```

Reference and Implementation
----------------------------

Expand Down
45 changes: 45 additions & 0 deletions python/vineyard/contrib/ml/tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
# limitations under the License.
#

import copy
import itertools
from typing import Any
from typing import Dict

import numpy as np
import pandas as pd
import pyarrow as pa
Expand All @@ -30,6 +35,8 @@
from vineyard.data.dataframe import NDArrayArray

torch = lazy_import.lazy_module("torch")
nn = lazy_import.lazy_module("torch.nn")
F = lazy_import.lazy_module("torch.nn.functional")
torchdata = lazy_import.lazy_module("torchdata")


Expand Down Expand Up @@ -130,3 +137,41 @@ def test_torch_dataset_table(vineyard_client):
assert torch.isclose(
value.tensors[2], torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float64)
).all()


class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)

def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))


def assert_torch_module_equal(model1, model2):
assert isinstance(model1, nn.Module)
assert isinstance(model2, nn.Module)
assert len(list(model1.parameters())) == len(list(model2.parameters()))
for p1, p2 in zip(model1.parameters(), model2.parameters()):
assert torch.allclose(p1, p2), f'{p1} != {p2}'


@pytest_cases.parametrize(
"vineyard_client,model",
itertools.product(
[vineyard_client, vineyard_rpc_client],
[nn.Linear(5, 2), nn.Conv2d(1, 20, 5), Model()],
),
)
def test_torch_module(vineyard_client, model):
object_id = vineyard_client.put(model)
value: Dict[str, Any] = vineyard_client.get(object_id)

result = copy.deepcopy(model)
result.to(torch.device('meta'))
result.load_state_dict(value, assign=True)

# check the module's equality
assert_torch_module_equal(model, result)
102 changes: 84 additions & 18 deletions python/vineyard/contrib/ml/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@

from vineyard._C import ObjectMeta
from vineyard.core import context
from vineyard.data.utils import from_json
from vineyard.data.utils import to_json

torch = lazy_import.lazy_module("torch")
torchdata = lazy_import.lazy_module("torchdata")


class WholeBatchSampler(torch.utils.data.Sampler[List[int]]):
Expand Down Expand Up @@ -137,25 +137,9 @@ def torch_global_dataframe_resolver(obj, resolver, **_kw):
return torch.utils.data.ConcatDataset(data)


def register_torch_types(builder_ctx, resolver_ctx):
if builder_ctx is not None:
builder_ctx.register(torch.Tensor, torch_tensor_builder)
builder_ctx.register(torch.utils.data.Dataset, torch_dataset_builder)

if resolver_ctx is not None:
resolver_ctx.register('vineyard::Tensor', torch_tensor_resolver)
resolver_ctx.register('vineyard::DataFrame', torch_dataset_resolver)
resolver_ctx.register('vineyard::RecordBatch', torch_dataset_resolver)
resolver_ctx.register('vineyard::Table', torch_dataset_resolver)
resolver_ctx.register('vineyard::GlobalTensor', torch_global_tensor_resolver)
resolver_ctx.register(
'vineyard::GlobalDataFrame', torch_global_dataframe_resolver
)


def datapipe(
dataset: torch.utils.data.Dataset,
) -> torchdata.datapipes.iter.IterableWrapper:
): # -> "torchdata.datapipes.iter.IterableWrapper":
'''Convert a torch.utils.data.Dataset to a torchdata.datapipes.iter.IterableWrapper.
e.g.,
Expand All @@ -182,9 +166,91 @@ def datapipe(
Returns:
A torchdata.datapipes.iter.IterableWrapper.
'''
import torchdata

return torchdata.datapipes.iter.IterableWrapper(dataset)


def torch_module_builder(client, value, builder, **kw):
def go(state_dict, key_prefix, tensors):
if isinstance(state_dict, torch.Tensor):
r = builder.run(client, state_dict, **kw)
tensors[key_prefix] = r
if isinstance(r, ObjectMeta):
r = r.id
return r
elif isinstance(state_dict, dict):
keys = list(state_dict.keys())
for key in keys:
state_dict[key] = go(state_dict[key], f'{key_prefix}.{key}', tensors)
return state_dict
elif isinstance(state_dict, (tuple, list)):
return [
go(element, f'{key_prefix}.{i}', tensors)
for i, element in enumerate(state_dict)
]
else:
return state_dict

if isinstance(value, torch.nn.Module):
value = value.state_dict()

tensors = dict()
value = go(value, 'tensor', tensors)

meta = ObjectMeta()
meta['typename'] = 'vineyard::torch::Module'
meta['state_dict'] = to_json(value)
for key, tensor in tensors.items():
meta.add_member(key, tensor)
return client.create_metadata(meta)


def torch_module_resolver(obj, resolver, **kw):
def go(state_dict, key_prefix, tensors):
if key_prefix in tensors:
return tensors[key_prefix]
elif isinstance(state_dict, dict):
keys = list(state_dict.keys())
for key in keys:
state_dict[key] = go(state_dict[key], f'{key_prefix}.{key}', tensors)
return state_dict
elif isinstance(state_dict, (tuple, list)):
return [
go(element, f'{key_prefix}.{i}', tensors)
for i, element in enumerate(state_dict)
]
else:
return state_dict

meta = obj.meta
state_dict = from_json(meta['state_dict'])
tensors = dict()
for key, value in meta.items():
if key.startswith('tensor.'):
tensors[key] = resolver.run(value, **kw)
state_dict = go(state_dict, 'tensor', tensors)
return state_dict


def register_torch_types(builder_ctx, resolver_ctx):
if builder_ctx is not None:
builder_ctx.register(torch.Tensor, torch_tensor_builder)
builder_ctx.register(torch.utils.data.Dataset, torch_dataset_builder)
builder_ctx.register(torch.nn.Module, torch_module_builder)

if resolver_ctx is not None:
resolver_ctx.register('vineyard::Tensor', torch_tensor_resolver)
resolver_ctx.register('vineyard::DataFrame', torch_dataset_resolver)
resolver_ctx.register('vineyard::RecordBatch', torch_dataset_resolver)
resolver_ctx.register('vineyard::Table', torch_dataset_resolver)
resolver_ctx.register('vineyard::GlobalTensor', torch_global_tensor_resolver)
resolver_ctx.register(
'vineyard::GlobalDataFrame', torch_global_dataframe_resolver
)
resolver_ctx.register('vineyard::torch::Module', torch_module_resolver)


@contextlib.contextmanager
def torch_context():
with context() as (builder_ctx, resolver_ctx):
Expand Down
3 changes: 3 additions & 0 deletions python/vineyard/core/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def __init__(self, parent_context: Optional["ResolverContext"] = None):
def __str__(self) -> str:
return str(self._factory)

def __repr__(self) -> str:
return repr(self._factory)

@property
def parent_context(self) -> "ResolverContext":
return self._parent_context
Expand Down
2 changes: 2 additions & 0 deletions python/vineyard/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def build_numpy_buffer(client, array):
def default_json_encoder(value):
if isinstance(value, (np.integer, np.floating)):
return value.item()
if isinstance(value, ObjectID):
return int(value)
raise TypeError


Expand Down

0 comments on commit 84f3e4a

Please sign in to comment.