From 020cfa76445e3699a91638e8693957542c656cb7 Mon Sep 17 00:00:00 2001 From: Tao He Date: Mon, 22 Jan 2024 15:36:40 +0800 Subject: [PATCH] Torch tensor from vineyard (numpy.ndarray): in the zero-copy way (#1726) Improve the performance of vineyard.get() for torch tensors. Signed-off-by: Tao He --- docs/notes/integration/ml.rst | 4 ++-- .../vineyard/contrib/ml/tests/test_torch.py | 23 +++++++++++-------- python/vineyard/contrib/ml/torch.py | 12 +++++++--- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/docs/notes/integration/ml.rst b/docs/notes/integration/ml.rst index db9df0607..16c383743 100644 --- a/docs/notes/integration/ml.rst +++ b/docs/notes/integration/ml.rst @@ -93,8 +93,8 @@ Using Dataframe .. code:: python >>> df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'c': [1.0, 2.0, 3.0, 4.0]}) - >>> label = torch.tensor(df['c'].values.astype(np.float32)) - >>> data = torch.tensor(df.drop('c', axis=1).values.astype(np.float32)) + >>> label = torch.from_numpy(df['c'].values.astype(np.float32)) + >>> data = torch.from_numpy(df.drop('c', axis=1).values.astype(np.float32)) >>> dataset = torch.utils.data.TensorDataset(data, label) >>> data_id = vineyard_client.put(dataset, typename='Dataframe', cols=['a', 'b', 'c'], label='c') >>> vin_data = vineyard_client.get(data_id, label='c) diff --git a/python/vineyard/contrib/ml/tests/test_torch.py b/python/vineyard/contrib/ml/tests/test_torch.py index ee69e5b35..10662a61c 100644 --- a/python/vineyard/contrib/ml/tests/test_torch.py +++ b/python/vineyard/contrib/ml/tests/test_torch.py @@ -61,7 +61,10 @@ def test_torch_tensor(vineyard_client): @pytest_cases.parametrize("vineyard_client", [vineyard_client, vineyard_rpc_client]) def test_torch_dataset(vineyard_client): dataset = torch.utils.data.TensorDataset( - *[torch.tensor(np.random.rand(2, 3)), torch.tensor(np.random.rand(2, 3))], + *[ + torch.from_numpy(np.random.rand(2, 3)), + torch.from_numpy(np.random.rand(2, 3)), + ], ) object_id = vineyard_client.put(dataset) value = vineyard_client.get(object_id) @@ -83,10 +86,10 @@ def test_torch_dataset_dataframe(vineyard_client): assert isinstance(value, torch.utils.data.TensorDataset) assert len(df.columns) == len(value.tensors) - assert torch.isclose(value.tensors[0], torch.tensor([1, 2, 3, 4])).all() - assert torch.isclose(value.tensors[1], torch.tensor([5, 6, 7, 8])).all() + assert torch.isclose(value.tensors[0], torch.from_numpy([1, 2, 3, 4])).all() + assert torch.isclose(value.tensors[1], torch.from_numpy([5, 6, 7, 8])).all() assert torch.isclose( - value.tensors[2], torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float64) + value.tensors[2], torch.from_numpy([1.0, 2.0, 3.0, 4.0], dtype=torch.float64) ).all() @@ -115,10 +118,10 @@ def test_torch_dataset_recordbatch(vineyard_client): assert isinstance(value, torch.utils.data.TensorDataset) assert len(df.columns) == len(value.tensors) - assert torch.isclose(value.tensors[0], torch.tensor([1, 2, 3, 4])).all() - assert torch.isclose(value.tensors[1], torch.tensor([5, 6, 7, 8])).all() + assert torch.isclose(value.tensors[0], torch.from_numpy([1, 2, 3, 4])).all() + assert torch.isclose(value.tensors[1], torch.from_numpy([5, 6, 7, 8])).all() assert torch.isclose( - value.tensors[2], torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float64) + value.tensors[2], torch.from_numpy([1.0, 2.0, 3.0, 4.0], dtype=torch.float64) ).all() @@ -132,10 +135,10 @@ def test_torch_dataset_table(vineyard_client): assert isinstance(value, torch.utils.data.TensorDataset) assert len(df.columns) == len(value.tensors) - assert torch.isclose(value.tensors[0], torch.tensor([1, 2, 3, 4])).all() - assert torch.isclose(value.tensors[1], torch.tensor([5, 6, 7, 8])).all() + assert torch.isclose(value.tensors[0], torch.from_numpy([1, 2, 3, 4])).all() + assert torch.isclose(value.tensors[1], torch.from_numpy([5, 6, 7, 8])).all() assert torch.isclose( - value.tensors[2], torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float64) + value.tensors[2], torch.from_numpy([1.0, 2.0, 3.0, 4.0], dtype=torch.float64) ).all() diff --git a/python/vineyard/contrib/ml/torch.py b/python/vineyard/contrib/ml/torch.py index 96ee6f29b..deee97101 100644 --- a/python/vineyard/contrib/ml/torch.py +++ b/python/vineyard/contrib/ml/torch.py @@ -17,6 +17,7 @@ # import contextlib +from collections import OrderedDict from typing import Iterable from typing import Iterator from typing import List @@ -96,18 +97,21 @@ def torch_dataset_builder(client, value, builder, **kw): def torch_tensor_resolver(obj, resolver, **kw): value = resolver.parent_context.run(obj, **kw) - return torch.tensor(value) + return torch.from_numpy(value) def torch_dataset_resolver(obj, resolver, **kw): value = resolver.parent_context.run(obj, **kw) if isinstance(value, pd.DataFrame): return torch.utils.data.TensorDataset( - *[torch.tensor(np.array(value[column].values)) for column in value.columns] + *[ + torch.from_numpy(np.array(value[column].values)) + for column in value.columns + ] ) elif isinstance(value, (pa.Table, pa.RecordBatch)): return torch.utils.data.TensorDataset( - *[torch.tensor(column.to_numpy()) for column in value.columns] + *[torch.from_numpy(column.to_numpy()) for column in value.columns] ) else: raise TypeError(f'torch dataset: unsupported type {type(value)}') @@ -238,6 +242,8 @@ def register_torch_types(builder_ctx, resolver_ctx): builder_ctx.register(torch.Tensor, torch_tensor_builder) builder_ctx.register(torch.utils.data.Dataset, torch_dataset_builder) builder_ctx.register(torch.nn.Module, torch_module_builder) + builder_ctx.register(dict, torch_module_builder) + builder_ctx.register(OrderedDict, torch_module_builder) if resolver_ctx is not None: resolver_ctx.register('vineyard::Tensor', torch_tensor_resolver)