Skip to content

Commit

Permalink
Implement the torch builder and resolver to support bfloat16 (#1889)
Browse files Browse the repository at this point in the history
Fixes #1885

Signed-off-by: Ye Cao <[email protected]>
  • Loading branch information
dashanji authored May 21, 2024
1 parent 176f071 commit a403b12
Showing 1 changed file with 97 additions and 25 deletions.
122 changes: 97 additions & 25 deletions python/vineyard/contrib/ml/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#

import contextlib
import ctypes
import time
import warnings
from collections import OrderedDict
Expand All @@ -38,6 +39,7 @@
from vineyard._C import RemoteBlobBuilder
from vineyard.core import Client
from vineyard.core import context
from vineyard.data.utils import build_buffer
from vineyard.data.utils import from_json
from vineyard.data.utils import normalize_cpptype
from vineyard.data.utils import to_json
Expand Down Expand Up @@ -74,8 +76,96 @@ def __len__(self) -> int:
return 1


def torch_tensor_builder(client, value, builder, **kw):
return builder.run(client, value.numpy(), **kw)
def build_torch_buffer(client, tensor):
if not tensor.is_contiguous():
tensor = tensor.contiguous()
address = tensor.data_ptr()
return build_buffer(client, address, tensor.nbytes)


def normalize_tensor_dtype(dtype): # pylint: disable=too-many-return-statements
if isinstance(dtype, torch.dtype):
return dtype
if dtype in ['torch.float32', 'torch.float']:
return torch.float32
if dtype in ['torch.float64', 'torch.double']:
return torch.float64
if dtype in ['torch.float16', 'torch.half']:
return torch.float16
if dtype == 'torch.bfloat16':
return torch.bfloat16
if dtype in ['torch.complex32', 'torch.chalf']:
return torch.complex32
if dtype in ['torch.complex64', 'torch.cfloat']:
return torch.complex64
if dtype in ['torch.complex128', 'torch.cdouble']:
return torch.complex128
if dtype == 'torch.uint8':
return torch.uint8
if dtype == 'torch.uint16':
return torch.uint16
if dtype == 'torch.uint32':
return torch.uint32
if dtype == 'torch.uint64':
return torch.uint64
if dtype == 'torch.int8':
return torch.int8
if dtype in ['torch.int16', 'torch.short']:
return torch.int16
if dtype in ['torch.int32', 'torch.int']:
return torch.int32
if dtype in ['torch.int64', 'torch.long']:
return torch.int64
if dtype == 'torch.bool':
return torch.bool
if dtype == 'torch.quint8':
return torch.quint8
if dtype == 'torch.qint8':
return torch.qint8
if dtype == 'torch.qint32':
return torch.qint32
if dtype == 'torch.quint4x2':
return torch.quint4x2
if dtype == 'torch.float8_e4m3fn':
return torch.float8_e4m3fn
if dtype == 'torch.float8_e5m2':
return torch.float8_e5m2
return dtype


def torch_tensor_builder(client, value, **kw):
meta = ObjectMeta()
meta['shape_'] = to_json(value.shape)
meta['partition_index_'] = to_json(kw.get('partition_index', []))
meta['nbytes'] = value.nbytes
meta['order_'] = to_json(('C' if value.is_contiguous() else 'F'))

meta['typename'] = 'vineyard::Tensor<%s>' % str(value.dtype)
meta['value_type_'] = str(value.dtype)
meta.add_member('buffer_', build_torch_buffer(client, value))

return client.create_metadata(meta)


def torch_tensor_resolver(obj):
meta = obj.meta
value_type_name = meta['value_type_']

value_type = normalize_tensor_dtype(value_type_name)
shape = from_json(meta['shape_'])
order = from_json(meta.get('order_', 'C'))

if np.prod(shape) == 0:
return torch.zeros(shape, dtype=value_type)

buffer = (ctypes.c_char * int(np.prod(shape)) * value_type.itemsize).from_address(
(obj.member('buffer_').address)
)

c_tensor = torch.frombuffer(buffer, dtype=value_type).reshape(shape)
tensor = c_tensor if order == 'C' else c_tensor.contiguous()

return tensor


def torch_dataset_builder(client, value, builder, **kw):
Expand Down Expand Up @@ -103,13 +193,6 @@ def torch_dataset_builder(client, value, builder, **kw):
return client.create_metadata(meta)


def torch_tensor_resolver(obj, resolver, **kw):
value = resolver.parent_context.run(obj, **kw)
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
return torch.from_numpy(value)


def torch_dataset_resolver(obj, resolver, **kw):
value = resolver.parent_context.run(obj, **kw)
if isinstance(value, pd.DataFrame):
Expand Down Expand Up @@ -215,26 +298,15 @@ def put_torch_tensors(client, tensors) -> List[Union[ObjectID, ObjectMeta]]:
blobs = client.create_remote_blob(blob_writers)

metadatas = []
found_bfloat16 = False
for tensor, size, blob in zip(tensors, sizes, blobs):
if tensor.dtype == torch.bfloat16:
if not found_bfloat16:
warnings.warn(
"Important, bfloat16 is not supported by vineyard, "
"converting to float16 instead, which may cause precision loss."
)
found_bfloat16 = True
tensor = tensor.to(torch.float16)

value = tensor.numpy()
for tensor, size, blob in zip(tensors, sizes, blobs):
meta = ObjectMeta()
meta['typename'] = 'vineyard::Tensor<%s>' % normalize_cpptype(value.dtype)
meta['value_type_'] = value.dtype.name
meta['value_type_meta_'] = value.dtype.str
meta['shape_'] = to_json(value.shape)
meta['typename'] = 'vineyard::Tensor<%s>' % str(tensor.dtype)
meta['value_type_'] = str(tensor.dtype)
meta['shape_'] = to_json(tensor.shape)
meta['partition_index_'] = to_json([])
meta['nbytes'] = size
meta['order_'] = to_json(('C' if value.flags['C_CONTIGUOUS'] else 'F'))
meta['order_'] = to_json(('C' if tensor.is_contiguous() else 'F'))
meta.add_member('buffer_', blob)
metadatas.append(meta)

Expand Down

0 comments on commit a403b12

Please sign in to comment.