Skip to content

Commit

Permalink
Reduce memory overhead for TransportableObject
Browse files Browse the repository at this point in the history
* Always represent TransportableObject internally as a single array of
bytes. Various properties, such as `header`, or `object_string`, decode
various segments of the byte array.

* Store the serialized object as raw picklebytes without
base64-encoding. As a result, `get_deserialized()` no longer needs to
create a temporary copy of the raw picklebytes. The data segment is
directly unpickled. Base64-encoding is applied to the data segment or
the entire internal buffer whenever a print friendly representation of
the `TransportableObject` is desired.

* Since the properties of `TransportableObject` are simply views into
the underlying buffer, `TransportableObject` may itself be serialized
efficiently by simply writing out the byte array.
`
  • Loading branch information
cjao committed Feb 29, 2024
1 parent d7841c7 commit e890bc5
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 184 deletions.
211 changes: 74 additions & 137 deletions covalent/_workflow/transportable_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,76 +31,9 @@
BYTE_ORDER = "big"


class _TOArchive:

"""Archived transportable object."""

def __init__(self, header: bytes, object_string: bytes, data: bytes):
"""
Initialize TOArchive.
Args:
header: Archived transportable object header.
object_string: Archived transportable object string.
data: Archived transportable object data.
Returns:
None
"""

self.header = header
self.object_string = object_string
self.data = data

def cat(self) -> bytes:
"""
Concatenate TOArchive.
Returns:
Concatenated TOArchive.
"""

header_size = len(self.header)
string_size = len(self.object_string)
data_offset = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES + header_size + string_size
string_offset = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES + header_size

data_offset = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER, signed=False)
string_offset = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER, signed=False)

return string_offset + data_offset + self.header + self.object_string + self.data

@staticmethod
def load(serialized: bytes, header_only: bool, string_only: bool) -> "_TOArchive":
"""
Load TOArchive object from serialized bytes.
Args:
serialized: Serialized transportable object.
header_only: Load header only.
string_only: Load string only.
Returns:
Archived transportable object.
"""

string_offset = TOArchiveUtils.string_offset(serialized)
header = TOArchiveUtils.parse_header(serialized, string_offset)
object_string = b""
data = b""

if not header_only:
data_offset = TOArchiveUtils.data_offset(serialized)
object_string = TOArchiveUtils.parse_string(serialized, string_offset, data_offset)

if not string_only:
data = TOArchiveUtils.parse_data(serialized, data_offset)
return _TOArchive(header, object_string, data)


class TOArchiveUtils:
"""Utilities for reading serialized TransportableObjects"""

@staticmethod
def data_offset(serialized: bytes) -> int:
size64 = serialized[STRING_OFFSET_BYTES : STRING_OFFSET_BYTES + DATA_OFFSET_BYTES]
Expand All @@ -120,24 +53,38 @@ def string_byte_range(serialized: bytes) -> Tuple[int, int]:

@staticmethod
def data_byte_range(serialized: bytes) -> Tuple[int, int]:
"""Return byte range for the b64 picklebytes"""
"""Return byte range for the picklebytes"""
start_byte = TOArchiveUtils.data_offset(serialized)
return start_byte, -1

@staticmethod
def parse_header(serialized: bytes, string_offset: int) -> bytes:
def header(serialized: bytes) -> dict:
string_offset = TOArchiveUtils.string_offset(serialized)
header = serialized[HEADER_OFFSET:string_offset]
return header
return json.loads(header.decode("utf-8"))

@staticmethod
def parse_string(serialized: bytes, string_offset: int, data_offset: int) -> bytes:
def string_segment(serialized: bytes) -> bytes:
string_offset = TOArchiveUtils.string_offset(serialized)
data_offset = TOArchiveUtils.data_offset(serialized)
return serialized[string_offset:data_offset]

@staticmethod
def parse_data(serialized: bytes, data_offset: int) -> bytes:
def data_segment(serialized: bytes) -> bytes:
data_offset = TOArchiveUtils.data_offset(serialized)
return serialized[data_offset:]


class _ByteArrayFile:
"""File-like interface for appending to a bytearray."""

def __init__(self, buf: bytearray):
self._buf = buf

def write(self, data: bytes):
self._buf.extend(data)


class TransportableObject:
"""
A function is converted to a transportable object by serializing it using cloudpickle
Expand All @@ -150,38 +97,65 @@ class TransportableObject:
"""

def __init__(self, obj: Any) -> None:
b64object = base64.b64encode(cloudpickle.dumps(obj))
object_string_u8 = str(obj).encode("utf-8")
self._buffer = bytearray()

self._object = b64object.decode("utf-8")
self._object_string = object_string_u8.decode("utf-8")
# Reserve space for the byte offsets to be written at the end
self._buffer.extend(b"\0" * HEADER_OFFSET)

self._header = {
"py_version": platform.python_version(),
"cloudpickle_version": cloudpickle.__version__,
_header = {
"format": "0.1",
"ver": {
"python": platform.python_version(),
"cloudpickle": cloudpickle.__version__,
},
"attrs": {
"doc": getattr(obj, "__doc__", ""),
"name": getattr(obj, "__name__", ""),
},
}

# Write header and object string
header_u8 = json.dumps(_header).encode("utf-8")
header_len = len(header_u8)

object_string_u8 = str(obj).encode("utf-8")
object_string_len = len(object_string_u8)

self._buffer.extend(header_u8)
self._buffer.extend(object_string_u8)
del object_string_u8

# Append picklebytes
cloudpickle.dump(obj, _ByteArrayFile(self._buffer))

# Write byte offsets
string_offset = HEADER_OFFSET + header_len
data_offset = string_offset + object_string_len

string_offset_bytes = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER)
data_offset_bytes = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER)
self._buffer[:STRING_OFFSET_BYTES] = string_offset_bytes
self._buffer[STRING_OFFSET_BYTES:HEADER_OFFSET] = data_offset_bytes

@property
def python_version(self):
return self._header["py_version"]
return self.header["ver"]["python"]

@property
def header(self):
return self._header
return TOArchiveUtils.header(self._buffer)

@property
def attrs(self):
return self._header["attrs"]
return self.header["attrs"]

Check warning on line 150 in covalent/_workflow/transportable_object.py

View check run for this annotation

Codecov / codecov/patch

covalent/_workflow/transportable_object.py#L150

Added line #L150 was not covered by tests

@property
def object_string(self):
# For compatibility with older Covalent
try:
return self._object_string
return (
TOArchiveUtils.string_segment(memoryview(self._buffer)).tobytes().decode("utf-8")
)
except AttributeError:
return self.__dict__["object_string"]

Expand All @@ -202,11 +176,15 @@ def get_deserialized(self) -> Callable:
"""

return cloudpickle.loads(base64.b64decode(self._object.encode("utf-8")))
return cloudpickle.loads(TOArchiveUtils.data_segment(memoryview(self._buffer)))

def to_dict(self) -> dict:
"""Return a JSON-serializable dictionary representation of self"""
return {"type": "TransportableObject", "attributes": self.__dict__.copy()}
attr_dict = {
"buffer_b64": base64.b64encode(memoryview(self._buffer)).decode("utf-8"),
}

return {"type": "TransportableObject", "attributes": attr_dict}

@staticmethod
def from_dict(object_dict) -> "TransportableObject":
Expand All @@ -220,7 +198,7 @@ def from_dict(object_dict) -> "TransportableObject":
"""

sc = TransportableObject(None)
sc.__dict__ = object_dict["attributes"]
sc._buffer = base64.b64decode(object_dict["attributes"]["buffer_b64"].encode("utf-8"))
return sc

def get_serialized(self) -> str:
Expand All @@ -234,7 +212,9 @@ def get_serialized(self) -> str:
object: The serialized transportable object.
"""

return self._object
# For backward compatibility
data_segment = TOArchiveUtils.data_segment(memoryview(self._buffer))
return base64.b64encode(data_segment).decode("utf-8")

def serialize(self) -> bytes:
"""
Expand All @@ -247,7 +227,7 @@ def serialize(self) -> bytes:
pickled_object: The serialized object alongwith the python version.
"""

return _to_archive(self).cat()
return self._buffer

def serialize_to_json(self) -> str:
"""
Expand Down Expand Up @@ -296,9 +276,7 @@ def make_transportable(obj) -> "TransportableObject":
return TransportableObject(obj)

@staticmethod
def deserialize(
serialized: bytes, *, header_only: bool = False, string_only: bool = False
) -> "TransportableObject":
def deserialize(serialized: bytes) -> "TransportableObject":
"""
Deserialize the transportable object.
Expand All @@ -308,9 +286,9 @@ def deserialize(
Returns:
object: The deserialized transportable object.
"""

ar = _TOArchive.load(serialized, header_only, string_only)
return _from_archive(ar)
to = TransportableObject(None)
to._buffer = serialized
return to

@staticmethod
def deserialize_list(collection: list) -> list:
Expand Down Expand Up @@ -357,44 +335,3 @@ def deserialize_dict(collection: dict) -> dict:
else:
raise TypeError("Couldn't deserialize collection")
return new_dict


def _to_archive(to: TransportableObject) -> _TOArchive:
"""
Convert a TransportableObject to a _TOArchive.
Args:
to: Transportable object to be converted.
Returns:
Archived transportable object.
"""

header = json.dumps(to._header).encode("utf-8")
object_string = to._object_string.encode("utf-8")
data = to._object.encode("utf-8")
return _TOArchive(header=header, object_string=object_string, data=data)


def _from_archive(ar: _TOArchive) -> TransportableObject:
"""
Convert a _TOArchive to a TransportableObject.
Args:
ar: Archived transportable object.
Returns:
Transportable object.
"""

decoded_object_str = ar.object_string.decode("utf-8")
decoded_data = ar.data.decode("utf-8")
decoded_header = json.loads(ar.header.decode("utf-8"))
to = TransportableObject(None)
to._header = decoded_header
to._object_string = decoded_object_str or ""
to._object = decoded_data or ""

return to
3 changes: 2 additions & 1 deletion tests/covalent_dispatcher_tests/_service/assets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

"""Unit tests for the FastAPI asset endpoints"""

import base64
import tempfile
from contextlib import contextmanager
from typing import Generator
Expand Down Expand Up @@ -704,7 +705,7 @@ def test_get_pickle_offsets():

start, end = _get_tobj_pickle_offsets(f"file://{write_file.name}")

assert data[start:].decode("utf-8") == tobj.get_serialized()
assert data[start:] == base64.b64decode(tobj.get_serialized().encode("utf-8"))


def test_generate_partial_file_slice():
Expand Down
Loading

0 comments on commit e890bc5

Please sign in to comment.