Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monkeypatch protocol.loads ala dask/distributed#8216 #1247

Merged
merged 1 commit into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dask_cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

__version__ = "23.10.00"

from . import compat

# Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
dask.dataframe.shuffle.rearrange_by_column = get_rearrange_by_column_wrapper(
Expand Down
118 changes: 118 additions & 0 deletions dask_cuda/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import pickle

import msgpack
from packaging.version import Version

import dask
import distributed
import distributed.comm.utils
import distributed.protocol
from distributed.comm.utils import OFFLOAD_THRESHOLD, nbytes, offload
from distributed.protocol.core import (
Serialized,
decompress,
logger,
merge_and_deserialize,
msgpack_decode_default,
msgpack_opts,
)

if Version(distributed.__version__) >= Version("2023.8.1"):
# Monkey-patch protocol.core.loads (and its users)
async def from_frames(
frames, deserialize=True, deserializers=None, allow_offload=True
):
"""
Unserialize a list of Distributed protocol frames.
"""
size = False

def _from_frames():
try:
# Patched code
return loads(
frames, deserialize=deserialize, deserializers=deserializers
)
# end patched code
except EOFError:
if size > 1000:
datastr = "[too large to display]"
else:
datastr = frames
# Aid diagnosing
logger.error("truncated data stream (%d bytes): %s", size, datastr)
raise

if allow_offload and deserialize and OFFLOAD_THRESHOLD:
size = sum(map(nbytes, frames))
if (
allow_offload
and deserialize
and OFFLOAD_THRESHOLD
and size > OFFLOAD_THRESHOLD
):
res = await offload(_from_frames)
else:
res = _from_frames()

return res

def loads(frames, deserialize=True, deserializers=None):
"""Transform bytestream back into Python value"""

allow_pickle = dask.config.get("distributed.scheduler.pickle")

try:

def _decode_default(obj):
offset = obj.get("__Serialized__", 0)
if offset > 0:
sub_header = msgpack.loads(
frames[offset],
object_hook=msgpack_decode_default,
use_list=False,
**msgpack_opts,
)
offset += 1
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
if deserialize:
if "compression" in sub_header:
sub_frames = decompress(sub_header, sub_frames)
return merge_and_deserialize(
sub_header, sub_frames, deserializers=deserializers
)
else:
return Serialized(sub_header, sub_frames)

offset = obj.get("__Pickled__", 0)
if offset > 0:
sub_header = msgpack.loads(frames[offset])
offset += 1
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
# Patched code
if "compression" in sub_header:
sub_frames = decompress(sub_header, sub_frames)
# end patched code
if allow_pickle:
return pickle.loads(
sub_header["pickled-obj"], buffers=sub_frames
)
else:
raise ValueError(
"Unpickle on the Scheduler isn't allowed, "
"set `distributed.scheduler.pickle=true`"
)

return msgpack_decode_default(obj)

return msgpack.loads(
frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts
)

except Exception:
logger.critical("Failed to deserialize", exc_info=True)
raise

distributed.protocol.loads = loads
distributed.protocol.core.loads = loads
distributed.comm.utils.from_frames = from_frames
18 changes: 18 additions & 0 deletions dask_cuda/tests/test_from_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest

import dask.array as da
from distributed import Client

from dask_cuda import LocalCUDACluster

pytest.importorskip("ucp")
cupy = pytest.importorskip("cupy")


@pytest.mark.parametrize("protocol", ["ucx", "tcp"])
def test_ucx_from_array(protocol):
N = 10_000
with LocalCUDACluster(protocol=protocol) as cluster:
with Client(cluster):
val = da.from_array(cupy.arange(N), chunks=(N // 10,)).sum().compute()
assert val == (N * (N - 1)) // 2