diff --git a/python/client.cc b/python/client.cc index 933b97c7d..fffca4a40 100644 --- a/python/client.cc +++ b/python/client.cc @@ -832,6 +832,10 @@ void bind_client(py::module& mod) { [](RPCClient* self, const std::vector>& remote_blob_builders) -> std::vector { + // Release GIL to avoid blocking the other threads + // See also + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#global-interpreter-lock-gil + py::gil_scoped_release release; std::vector blob_metas; throw_on_error( self->CreateRemoteBlobs(remote_blob_builders, blob_metas)); diff --git a/python/vineyard/contrib/ml/README.md b/python/vineyard/contrib/ml/README.md index 6b790796d..f6f18335d 100644 --- a/python/vineyard/contrib/ml/README.md +++ b/python/vineyard/contrib/ml/README.md @@ -87,6 +87,30 @@ with torch_context(): model.load_state_dict(state_dict, assign=True) ``` +By default, the compression is enabled for the vineyard client. Sometimes, the compression may not be efficient for the torch modules, you can disable it as follows: + +```python +from vineyard.contrib.ml.torch import torch_context +# add the client parameter to the torch_context to disable the compression +with torch_context(client): + object_id = client.put(model) + +# add the client parameter to the torch_context to disable the compression +with torch_context(client): + state_dict = client.get(object_id) +``` + +Besides, if you want to put the torch modules into all vineyard workers spreadly to gather the network bandwidth of all workers, you can enable the spread option as follows: + +```python +from vineyard.contrib.ml.torch import torch_context +with torch_context(client, spread=True): + object_id = client.put(model) + +with torch_context(client): + state_dict = client.get(object_id) +``` + Reference and Implementation ---------------------------- diff --git a/python/vineyard/contrib/ml/torch.py b/python/vineyard/contrib/ml/torch.py index be9f11d0c..9694504fc 100644 --- a/python/vineyard/contrib/ml/torch.py +++ b/python/vineyard/contrib/ml/torch.py @@ -18,9 +18,11 @@ import contextlib import ctypes -import time import warnings from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import as_completed +from math import ceil from typing import Iterable from typing import Iterator from typing import List @@ -33,6 +35,7 @@ import lazy_import import vineyard +from vineyard import envvars from vineyard._C import NotEnoughMemoryException from vineyard._C import ObjectID from vineyard._C import ObjectMeta @@ -268,6 +271,83 @@ def datapipe( return torchdata.datapipes.iter.IterableWrapper(dataset) +def distribute_tensors(client, tensor_values): + cluster_info = client.meta + instance_ids = cluster_info.keys() + chunk_size = len(cluster_info) + + def split_tensors_into_chunks(tensor_values, chunk_size): + average_size = ceil( + sum(t.numel() * t.element_size() for t in tensor_values) / chunk_size + ) + current_size = 0 + tensor_chunks = [] + current_chunk = [] + for t in tensor_values: + if current_size >= average_size and current_chunk: + tensor_chunks.append(current_chunk) + current_size = 0 + current_chunk = [] + current_chunk.append(t) + current_size += t.numel() * t.element_size() + if current_chunk: + tensor_chunks.append(current_chunk) + return tensor_chunks + + tensor_chunks = split_tensors_into_chunks(tensor_values, chunk_size) + + def thread_put_torch_tensors( + cluster_info, instance_id, tensor_chunk, client, output_objects + ): + compression = client.compression + connected_instance_id = ( + client.instance_id if client.is_ipc else client.remote_instance_id + ) + rpc_client = None + if connected_instance_id != instance_id: + instance_status = cluster_info.get(instance_id) + if instance_status is None or instance_status['rpc_endpoint'] is None: + raise RuntimeError( + "The rpc endpoint of the vineyard instance " + f"{instance_id} is not available." + ) + + host, port = cluster_info[instance_id]['rpc_endpoint'].split(':') + try: + with envvars('VINEYARD_RPC_SKIP_RETRY', '1'): + rpc_client = vineyard.connect(host=host, port=int(port)) + rpc_client.compression = compression + except Exception as exec: + raise RuntimeError( + f"Failed to connect to the vineyard instance {instance_id} " + f"at {host}:{port}." + ) from exec + used_client = rpc_client if rpc_client else client + result = put_torch_tensors(used_client, tensor_chunk) + output_objects[instance_id] = result + + tensor_objects_dict = {} + with ThreadPoolExecutor() as executor: + futures = [] + for instance_id, tensor_chunk in zip(instance_ids, tensor_chunks): + future = executor.submit( + thread_put_torch_tensors, + cluster_info, + instance_id, + tensor_chunk, + client, + tensor_objects_dict, + ) + futures.append(future) + for future in as_completed(futures): + future.result() + + tensor_objects = [] + for instance_id in instance_ids: + tensor_objects.extend(tensor_objects_dict[instance_id]) + return tensor_objects + + def put_torch_tensors(client, tensors) -> List[Union[ObjectID, ObjectMeta]]: pointers, sizes = [], [] tensors = [tensor.contiguous() for tensor in tensors] @@ -359,8 +439,11 @@ def assign(state_dict, key_prefix, tensors): go(value, 'tensor', tensors) tensor_keys, tensor_values = list(tensors.keys()), list(tensors.values()) - tensor_objects = put_torch_tensors(client, tensor_values) + if client.spread: + tensor_objects = distribute_tensors(client, tensor_values) + else: + tensor_objects = put_torch_tensors(client, tensor_values) tensors = dict(zip(tensor_keys, tensor_objects)) new_value = assign(value, 'tensor', tensors) @@ -369,7 +452,10 @@ def assign(state_dict, key_prefix, tensors): meta['state_dict'] = to_json(new_value) for key, tensor in tensors.items(): meta.add_member(key, tensor) - return client.create_metadata(meta) + if client.spread: + meta.set_global(True) + o = client.create_metadata(meta) + return o def torch_module_resolver(obj, resolver, **kw): @@ -420,13 +506,14 @@ def register_torch_types(builder_ctx, resolver_ctx): @contextlib.contextmanager -def torch_context(client: Client = None): +def torch_context(client: Client = None, spread=False): if client is not None: with client.with_compression(False): - with context() as (builder_ctx, resolver_ctx): - with contextlib.suppress(ImportError): - register_torch_types(builder_ctx, resolver_ctx) - yield builder_ctx, resolver_ctx + with client.with_spread(spread): + with context() as (builder_ctx, resolver_ctx): + with contextlib.suppress(ImportError): + register_torch_types(builder_ctx, resolver_ctx) + yield builder_ctx, resolver_ctx else: with context() as (builder_ctx, resolver_ctx): with contextlib.suppress(ImportError): diff --git a/python/vineyard/core/client.py b/python/vineyard/core/client.py index 35434fea9..6d83519c1 100644 --- a/python/vineyard/core/client.py +++ b/python/vineyard/core/client.py @@ -274,6 +274,8 @@ def __init__( except VineyardException: continue + self._spread = False + self._compression = True if self._ipc_client is None and self._rpc_client is None: raise ConnectionError( "Failed to connect to vineyard via both IPC and RPC connection. " @@ -287,12 +289,22 @@ def compression(self) -> bool: '''Whether the compression is enabled for underlying RPC client.''' if self._rpc_client: return self._rpc_client.compression - return None + return self._compression @compression.setter def compression(self, value: bool = True): if self._rpc_client: self._rpc_client.compression = value + self._compression = value + + @property + def spread(self) -> bool: + '''Whether the spread is enabled for underlying RPC client.''' + return self._spread + + @spread.setter + def spread(self, value: bool = False): + self._spread = value @property def ipc_client(self) -> IPCClient: @@ -789,5 +801,13 @@ def with_compression(self, enabled: bool = True): yield self.compression = compression + @contextlib.contextmanager + def with_spread(self, enabled: bool = True): + """Enable spread for the following put operations.""" + tmp_spread = self._spread + self.spread = enabled + yield + self.spread = tmp_spread + __all__ = ['Client']