Skip to content

Commit

Permalink
Torch tensor from vineyard (numpy.ndarray): in the zero-copy way (#1726)
Browse files Browse the repository at this point in the history
Improve the performance of vineyard.get() for torch tensors.

Signed-off-by: Tao He <[email protected]>
  • Loading branch information
sighingnow authored Jan 22, 2024
1 parent 50f60cf commit 020cfa7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
4 changes: 2 additions & 2 deletions docs/notes/integration/ml.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ Using Dataframe
.. code:: python
>>> df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'c': [1.0, 2.0, 3.0, 4.0]})
>>> label = torch.tensor(df['c'].values.astype(np.float32))
>>> data = torch.tensor(df.drop('c', axis=1).values.astype(np.float32))
>>> label = torch.from_numpy(df['c'].values.astype(np.float32))
>>> data = torch.from_numpy(df.drop('c', axis=1).values.astype(np.float32))
>>> dataset = torch.utils.data.TensorDataset(data, label)
>>> data_id = vineyard_client.put(dataset, typename='Dataframe', cols=['a', 'b', 'c'], label='c')
>>> vin_data = vineyard_client.get(data_id, label='c)
Expand Down
23 changes: 13 additions & 10 deletions python/vineyard/contrib/ml/tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ def test_torch_tensor(vineyard_client):
@pytest_cases.parametrize("vineyard_client", [vineyard_client, vineyard_rpc_client])
def test_torch_dataset(vineyard_client):
dataset = torch.utils.data.TensorDataset(
*[torch.tensor(np.random.rand(2, 3)), torch.tensor(np.random.rand(2, 3))],
*[
torch.from_numpy(np.random.rand(2, 3)),
torch.from_numpy(np.random.rand(2, 3)),
],
)
object_id = vineyard_client.put(dataset)
value = vineyard_client.get(object_id)
Expand All @@ -83,10 +86,10 @@ def test_torch_dataset_dataframe(vineyard_client):
assert isinstance(value, torch.utils.data.TensorDataset)
assert len(df.columns) == len(value.tensors)

assert torch.isclose(value.tensors[0], torch.tensor([1, 2, 3, 4])).all()
assert torch.isclose(value.tensors[1], torch.tensor([5, 6, 7, 8])).all()
assert torch.isclose(value.tensors[0], torch.from_numpy([1, 2, 3, 4])).all()
assert torch.isclose(value.tensors[1], torch.from_numpy([5, 6, 7, 8])).all()
assert torch.isclose(
value.tensors[2], torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float64)
value.tensors[2], torch.from_numpy([1.0, 2.0, 3.0, 4.0], dtype=torch.float64)
).all()


Expand Down Expand Up @@ -115,10 +118,10 @@ def test_torch_dataset_recordbatch(vineyard_client):
assert isinstance(value, torch.utils.data.TensorDataset)
assert len(df.columns) == len(value.tensors)

assert torch.isclose(value.tensors[0], torch.tensor([1, 2, 3, 4])).all()
assert torch.isclose(value.tensors[1], torch.tensor([5, 6, 7, 8])).all()
assert torch.isclose(value.tensors[0], torch.from_numpy([1, 2, 3, 4])).all()
assert torch.isclose(value.tensors[1], torch.from_numpy([5, 6, 7, 8])).all()
assert torch.isclose(
value.tensors[2], torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float64)
value.tensors[2], torch.from_numpy([1.0, 2.0, 3.0, 4.0], dtype=torch.float64)
).all()


Expand All @@ -132,10 +135,10 @@ def test_torch_dataset_table(vineyard_client):
assert isinstance(value, torch.utils.data.TensorDataset)
assert len(df.columns) == len(value.tensors)

assert torch.isclose(value.tensors[0], torch.tensor([1, 2, 3, 4])).all()
assert torch.isclose(value.tensors[1], torch.tensor([5, 6, 7, 8])).all()
assert torch.isclose(value.tensors[0], torch.from_numpy([1, 2, 3, 4])).all()
assert torch.isclose(value.tensors[1], torch.from_numpy([5, 6, 7, 8])).all()
assert torch.isclose(
value.tensors[2], torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float64)
value.tensors[2], torch.from_numpy([1.0, 2.0, 3.0, 4.0], dtype=torch.float64)
).all()


Expand Down
12 changes: 9 additions & 3 deletions python/vineyard/contrib/ml/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#

import contextlib
from collections import OrderedDict
from typing import Iterable
from typing import Iterator
from typing import List
Expand Down Expand Up @@ -96,18 +97,21 @@ def torch_dataset_builder(client, value, builder, **kw):

def torch_tensor_resolver(obj, resolver, **kw):
value = resolver.parent_context.run(obj, **kw)
return torch.tensor(value)
return torch.from_numpy(value)


def torch_dataset_resolver(obj, resolver, **kw):
value = resolver.parent_context.run(obj, **kw)
if isinstance(value, pd.DataFrame):
return torch.utils.data.TensorDataset(
*[torch.tensor(np.array(value[column].values)) for column in value.columns]
*[
torch.from_numpy(np.array(value[column].values))
for column in value.columns
]
)
elif isinstance(value, (pa.Table, pa.RecordBatch)):
return torch.utils.data.TensorDataset(
*[torch.tensor(column.to_numpy()) for column in value.columns]
*[torch.from_numpy(column.to_numpy()) for column in value.columns]
)
else:
raise TypeError(f'torch dataset: unsupported type {type(value)}')
Expand Down Expand Up @@ -238,6 +242,8 @@ def register_torch_types(builder_ctx, resolver_ctx):
builder_ctx.register(torch.Tensor, torch_tensor_builder)
builder_ctx.register(torch.utils.data.Dataset, torch_dataset_builder)
builder_ctx.register(torch.nn.Module, torch_module_builder)
builder_ctx.register(dict, torch_module_builder)
builder_ctx.register(OrderedDict, torch_module_builder)

if resolver_ctx is not None:
resolver_ctx.register('vineyard::Tensor', torch_tensor_resolver)
Expand Down

0 comments on commit 020cfa7

Please sign in to comment.