Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory overhead for TransportableObject #1883

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for Python 3.11
- Removed official support for Python 3.8
- Improved memory overhead for operations involving TransportableObject

## [0.235.1-rc.0] - 2024-06-10

Expand Down
254 changes: 120 additions & 134 deletions covalent/_workflow/transportable_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import base64
import json
import platform
from typing import Any, Callable, Tuple
from typing import Any, Callable, Dict, Tuple

import cloudpickle

Expand All @@ -29,77 +29,12 @@
DATA_OFFSET_BYTES = 8
HEADER_OFFSET = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES
BYTE_ORDER = "big"


class _TOArchive:
"""Archived transportable object."""

def __init__(self, header: bytes, object_string: bytes, data: bytes):
"""
Initialize TOArchive.

Args:
header: Archived transportable object header.
object_string: Archived transportable object string.
data: Archived transportable object data.

Returns:
None
"""

self.header = header
self.object_string = object_string
self.data = data

def cat(self) -> bytes:
"""
Concatenate TOArchive.

Returns:
Concatenated TOArchive.

"""

header_size = len(self.header)
string_size = len(self.object_string)
data_offset = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES + header_size + string_size
string_offset = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES + header_size

data_offset = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER, signed=False)
string_offset = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER, signed=False)

return string_offset + data_offset + self.header + self.object_string + self.data

@staticmethod
def load(serialized: bytes, header_only: bool, string_only: bool) -> "_TOArchive":
"""
Load TOArchive object from serialized bytes.

Args:
serialized: Serialized transportable object.
header_only: Load header only.
string_only: Load string only.

Returns:
Archived transportable object.

"""

string_offset = TOArchiveUtils.string_offset(serialized)
header = TOArchiveUtils.parse_header(serialized, string_offset)
object_string = b""
data = b""

if not header_only:
data_offset = TOArchiveUtils.data_offset(serialized)
object_string = TOArchiveUtils.parse_string(serialized, string_offset, data_offset)

if not string_only:
data = TOArchiveUtils.parse_data(serialized, data_offset)
return _TOArchive(header, object_string, data)
TOBJ_FMT_STR = "0.1"


class TOArchiveUtils:
"""Utilities for reading serialized TransportableObjects"""

@staticmethod
def data_offset(serialized: bytes) -> int:
size64 = serialized[STRING_OFFSET_BYTES : STRING_OFFSET_BYTES + DATA_OFFSET_BYTES]
Expand All @@ -119,24 +54,38 @@ def string_byte_range(serialized: bytes) -> Tuple[int, int]:

@staticmethod
def data_byte_range(serialized: bytes) -> Tuple[int, int]:
"""Return byte range for the b64 picklebytes"""
"""Return byte range for the picklebytes"""
start_byte = TOArchiveUtils.data_offset(serialized)
return start_byte, -1

@staticmethod
def parse_header(serialized: bytes, string_offset: int) -> bytes:
def header(serialized: bytes) -> dict:
string_offset = TOArchiveUtils.string_offset(serialized)
header = serialized[HEADER_OFFSET:string_offset]
return header
return json.loads(header.decode("utf-8"))

@staticmethod
def parse_string(serialized: bytes, string_offset: int, data_offset: int) -> bytes:
def string_segment(serialized: bytes) -> bytes:
string_offset = TOArchiveUtils.string_offset(serialized)
data_offset = TOArchiveUtils.data_offset(serialized)
return serialized[string_offset:data_offset]

@staticmethod
def parse_data(serialized: bytes, data_offset: int) -> bytes:
def data_segment(serialized: bytes) -> bytes:
data_offset = TOArchiveUtils.data_offset(serialized)
return serialized[data_offset:]


class _ByteArrayFile:
"""File-like interface for appending to a bytearray."""

def __init__(self, buf: bytearray):
self._buf = buf

def write(self, data: bytes):
self._buf.extend(data)


class TransportableObject:
"""
A function is converted to a transportable object by serializing it using cloudpickle
Expand All @@ -149,13 +98,13 @@ class TransportableObject:
"""

def __init__(self, obj: Any) -> None:
b64object = base64.b64encode(cloudpickle.dumps(obj))
object_string_u8 = str(obj).encode("utf-8")
self._buffer = bytearray()

self._object = b64object.decode("utf-8")
self._object_string = object_string_u8.decode("utf-8")
# Reserve space for the byte offsets to be written at the end
self._buffer.extend(b"\0" * HEADER_OFFSET)

self._header = {
_header = {
"format": TOBJ_FMT_STR,
"py_version": platform.python_version(),
"cloudpickle_version": cloudpickle.__version__,
"attrs": {
Expand All @@ -164,23 +113,48 @@ def __init__(self, obj: Any) -> None:
},
}

# Write header and object string
header_u8 = json.dumps(_header).encode("utf-8")
header_len = len(header_u8)

object_string_u8 = str(obj).encode("utf-8")
object_string_len = len(object_string_u8)

self._buffer.extend(header_u8)
self._buffer.extend(object_string_u8)
del object_string_u8

# Append picklebytes (not base64-encoded)
cloudpickle.dump(obj, _ByteArrayFile(self._buffer))

# Write byte offsets
string_offset = HEADER_OFFSET + header_len
data_offset = string_offset + object_string_len

string_offset_bytes = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER)
data_offset_bytes = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER)
self._buffer[:STRING_OFFSET_BYTES] = string_offset_bytes
self._buffer[STRING_OFFSET_BYTES:HEADER_OFFSET] = data_offset_bytes

@property
def python_version(self):
return self._header["py_version"]
return self.header["py_version"]

@property
def header(self):
return self._header
return TOArchiveUtils.header(self._buffer)

@property
def attrs(self):
return self._header["attrs"]
return self.header["attrs"]

@property
def object_string(self):
# For compatibility with older Covalent
try:
return self._object_string
return (
TOArchiveUtils.string_segment(memoryview(self._buffer)).tobytes().decode("utf-8")
)
except AttributeError:
return self.__dict__["object_string"]

Expand All @@ -201,11 +175,15 @@ def get_deserialized(self) -> Callable:

"""

return cloudpickle.loads(base64.b64decode(self._object.encode("utf-8")))
return cloudpickle.loads(TOArchiveUtils.data_segment(memoryview(self._buffer)))

def to_dict(self) -> dict:
"""Return a JSON-serializable dictionary representation of self"""
return {"type": "TransportableObject", "attributes": self.__dict__.copy()}
attr_dict = {
"buffer_b64": base64.b64encode(memoryview(self._buffer)).decode("utf-8"),
}

return {"type": "TransportableObject", "attributes": attr_dict}

@staticmethod
def from_dict(object_dict) -> "TransportableObject":
Expand All @@ -219,7 +197,7 @@ def from_dict(object_dict) -> "TransportableObject":
"""

sc = TransportableObject(None)
sc.__dict__ = object_dict["attributes"]
sc._buffer = base64.b64decode(object_dict["attributes"]["buffer_b64"].encode("utf-8"))
return sc

def get_serialized(self) -> str:
Expand All @@ -233,7 +211,9 @@ def get_serialized(self) -> str:
object: The serialized transportable object.
"""

return self._object
# For backward compatibility
data_segment = TOArchiveUtils.data_segment(memoryview(self._buffer))
return base64.b64encode(data_segment).decode("utf-8")

def serialize(self) -> bytes:
"""
Expand All @@ -246,7 +226,7 @@ def serialize(self) -> bytes:
pickled_object: The serialized object alongwith the python version.
"""

return _to_archive(self).cat()
return self._buffer

def serialize_to_json(self) -> str:
"""
Expand Down Expand Up @@ -295,9 +275,7 @@ def make_transportable(obj) -> "TransportableObject":
return TransportableObject(obj)

@staticmethod
def deserialize(
serialized: bytes, *, header_only: bool = False, string_only: bool = False
) -> "TransportableObject":
def deserialize(serialized: bytes) -> "TransportableObject":
"""
Deserialize the transportable object.

Expand All @@ -307,9 +285,58 @@ def deserialize(
Returns:
object: The deserialized transportable object.
"""
to = TransportableObject(None)
header = TOArchiveUtils.header(serialized)

# For backward compatibility
if header.get("format") is None:
# Re-encode TObj serialized using older versions of the SDK,
# characterized by the lack of a "format" field in the
# header. TObj was previously serialized as
# [offsets][header][string][b64-encoded picklebytes],
# whereas starting from format 0.1 we store them as
# [offsets][header][string][picklebytes].
to._buffer = TransportableObject._upgrade_tobj_format(serialized, header)
else:
to._buffer = serialized
return to

@staticmethod
def _upgrade_tobj_format(serialized: bytes, header: Dict) -> bytes:
"""Re-encode a serialized TObj in the newer format.

This involves adding a format version in the header and
base64-decoding the data segment. Because the header at the
beginning of the byte array, the string and data offsets need
to be recomputed.
"""
buf = bytearray()

# Upgrade header and recompute byte offsets
header["format"] = TOBJ_FMT_STR
serialized_header = json.dumps(header).encode("utf-8")
string_offset = HEADER_OFFSET + len(serialized_header)

# This is just a view into the bytearray and consumes
# negligible space on its own.
string_segment = TOArchiveUtils.string_segment(serialized)

data_offset = string_offset + len(string_segment)
string_offset_bytes = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER)
data_offset_bytes = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER)

# Write the new byte offsets
buf.extend(b"\0" * HEADER_OFFSET)
buf[:STRING_OFFSET_BYTES] = string_offset_bytes
buf[STRING_OFFSET_BYTES:HEADER_OFFSET] = data_offset_bytes

ar = _TOArchive.load(serialized, header_only, string_only)
return _from_archive(ar)
buf.extend(serialized_header)
buf.extend(string_segment)

# base64-decode the data segment into raw picklebytes
buf.extend(base64.b64decode(TOArchiveUtils.data_segment(serialized)))

return buf

@staticmethod
def deserialize_list(collection: list) -> list:
Expand Down Expand Up @@ -356,44 +383,3 @@ def deserialize_dict(collection: dict) -> dict:
else:
raise TypeError("Couldn't deserialize collection")
return new_dict


def _to_archive(to: TransportableObject) -> _TOArchive:
"""
Convert a TransportableObject to a _TOArchive.

Args:
to: Transportable object to be converted.

Returns:
Archived transportable object.

"""

header = json.dumps(to._header).encode("utf-8")
object_string = to._object_string.encode("utf-8")
data = to._object.encode("utf-8")
return _TOArchive(header=header, object_string=object_string, data=data)


def _from_archive(ar: _TOArchive) -> TransportableObject:
"""
Convert a _TOArchive to a TransportableObject.

Args:
ar: Archived transportable object.

Returns:
Transportable object.

"""

decoded_object_str = ar.object_string.decode("utf-8")
decoded_data = ar.data.decode("utf-8")
decoded_header = json.loads(ar.header.decode("utf-8"))
to = TransportableObject(None)
to._header = decoded_header
to._object_string = decoded_object_str or ""
to._object = decoded_data or ""

return to
3 changes: 2 additions & 1 deletion tests/covalent_dispatcher_tests/_service/assets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

"""Unit tests for the FastAPI asset endpoints"""

import base64
import tempfile
from contextlib import contextmanager
from typing import Generator
Expand Down Expand Up @@ -704,7 +705,7 @@ def test_get_pickle_offsets():

start, end = _get_tobj_pickle_offsets(f"file://{write_file.name}")

assert data[start:].decode("utf-8") == tobj.get_serialized()
assert data[start:] == base64.b64decode(tobj.get_serialized().encode("utf-8"))


def test_generate_partial_file_slice():
Expand Down
Loading
Loading