Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use cloudpickle for id_generation instead of dill #859

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"rich >= 12.6.0",
"python-logging-loki >= 0.3.1",
"neuroglancer >= 2.32",
"cloudpickle >= 3.1.1",
"dill >= 0.3.6",
"pyyaml ~= 6.0.1",
"requests==2.31.0", # version conflicts otherwise
Expand Down
77 changes: 41 additions & 36 deletions tests/unit/mazepa/test_id_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,49 +226,54 @@ def test_generate_invocation_id_subchunkable_flow() -> None:

def _gen_id_calls(_) -> dict[str, str]:
gen_ids = {
'gen_id(ClassA().method, [], {"a": 1})': gen_id(ClassA().method, [], {"a": 1}),
"gen_id(ClassD1().method, [], {})": gen_id(ClassD1().method, [], {}),
"gen_id(ClassE(1).method, [], {})": gen_id(ClassE(1).method, [], {}),
"gen_id(partial(ClassA().method, 42), [], {})": gen_id(
partial(ClassA().method, 42), [], {}
),
"gen_id(partial(ClassD1().method, 42), [], {})": gen_id(
partial(ClassD1().method, 42), [], {}
),
"gen_id(partial(ClassE(1).method, 42), [], {})": gen_id(
partial(ClassE(1).method, 42), [], {}
),
"gen_id(TaskableA(), [], {})": gen_id(TaskableA(), [], {}),
"gen_id(TaskableD(1), [], {})": gen_id(TaskableD(1), [], {}),
"gen_id(FlowSchema({}, ClassA().method).flow, [], {})": gen_id(
FlowSchema({}, ClassA().method).flow, [], {}
),
"gen_id(FlowSchema({}, ClassD1().method).flow, [], {})": gen_id(
FlowSchema({}, ClassD1().method).flow, [], {}
),
"gen_id(FlowSchema({}, ClassE(1).method).flow, [], {})": gen_id(
FlowSchema({}, ClassE(1).method).flow, [], {}
),
"gen_id(subchunkable_flow(), [], {})": gen_id(
subchunkable_flow().fn, subchunkable_flow().args, subchunkable_flow().kwargs
),
# 'gen_id(ClassA().method, [], {"a": 1})': gen_id(ClassA().method, [], {"a": 1}),
"gen_id(ClassD1().method, [], {})": gen_id(ClassD1().method, [], {}, None, True),
# "gen_id(ClassE(1).method, [], {})": gen_id(ClassE(1).method, [], {}),
# "gen_id(partial(ClassA().method, 42), [], {})": gen_id(
# partial(ClassA().method, 42), [], {}
# ),
# "gen_id(partial(ClassD1().method, 42), [], {})": gen_id(
# partial(ClassD1().method, 42), [], {}
# ),
# "gen_id(partial(ClassE(1).method, 42), [], {})": gen_id(
# partial(ClassE(1).method, 42), [], {}
# ),
# "gen_id(TaskableA(), [], {})": gen_id(TaskableA(), [], {}),
# "gen_id(TaskableD(1), [], {})": gen_id(TaskableD(1), [], {}),
# "gen_id(FlowSchema({}, ClassA().method).flow, [], {})": gen_id(
# FlowSchema({}, ClassA().method).flow, [], {}
# ),
# "gen_id(FlowSchema({}, ClassD1().method).flow, [], {})": gen_id(
# FlowSchema({}, ClassD1().method).flow, [], {}
# ),
# "gen_id(FlowSchema({}, ClassE(1).method).flow, [], {})": gen_id(
# FlowSchema({}, ClassE(1).method).flow, [], {}
# ),
# "gen_id(subchunkable_flow(), [], {})": gen_id(
# subchunkable_flow().fn, subchunkable_flow().args, subchunkable_flow().kwargs
# ),
}
return gen_ids


def test_persistence_across_sessions() -> None:
# Create two separate processes - spawn ensures a new PYTHONHASHSEED is used
ctx = multiprocessing.get_context("spawn")
with ctx.Pool(processes=2) as pool:
result = pool.map(_gen_id_calls, range(2))

assert result[0] == result[1]
#ctx = multiprocessing.get_context("spawn")
ctx = multiprocessing.get_context("fork")
for _ in range(1):
with ctx.Pool(processes=2) as pool:
result = pool.map(_gen_id_calls, range(2))

assert result[0] == result[1]
print(result[0])
print(result[1])
#assert False

def test_unpickleable_fn(mocker) -> None:
# See https://github.com/uqfoundation/dill/issues/147 and possibly
# https://github.com/uqfoundation/dill/issues/56

unpickleable_fn = mocker.MagicMock()
"""
def test_unpickleable_invocation(mocker) -> None:
# gen_id will return a random UUID in case of pickle errors
assert gen_id(unpickleable_fn, [], {}) != gen_id(unpickleable_fn, [], {})
some_fn = lambda x: x
unpicklable_arg = [1]
assert gen_id(some_fn, unpicklable_arg, {}) != gen_id(some_fn, unpicklable_arg, {})
"""
4 changes: 2 additions & 2 deletions zetta_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

logger = get_logger("zetta_utils")

builder.registry.MUTLIPROCESSING_INCOMPATIBLE_CLASSES.add("mazepa")
builder.registry.MUTLIPROCESSING_INCOMPATIBLE_CLASSES.add("lightning")
builder.registry.MULTIPROCESSING_INCOMPATIBLE_CLASSES.add("mazepa")
builder.registry.MULTIPROCESSING_INCOMPATIBLE_CLASSES.add("lightning")
log.add_supress_traceback_module(builder)


Expand Down
4 changes: 2 additions & 2 deletions zetta_utils/builder/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
T = TypeVar("T", bound=Callable)

REGISTRY: dict[str, list[RegistryEntry]] = defaultdict(list)
MUTLIPROCESSING_INCOMPATIBLE_CLASSES: set[str] = set()
MULTIPROCESSING_INCOMPATIBLE_CLASSES: set[str] = set()


@attrs.frozen
Expand Down Expand Up @@ -70,7 +70,7 @@ def register(

def decorator(fn: T) -> T:
nonlocal allow_parallel
for k in MUTLIPROCESSING_INCOMPATIBLE_CLASSES:
for k in MULTIPROCESSING_INCOMPATIBLE_CLASSES:
if fn.__module__ is not None and k.lower() in fn.__module__.lower():
allow_parallel = False
break
Expand Down
34 changes: 22 additions & 12 deletions zetta_utils/mazepa/id_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import uuid
from typing import Callable, Optional

import dill
from sympy import im

import cloudpickle
import xxhash
from coolname import generate_slug

Expand Down Expand Up @@ -40,9 +42,10 @@
args: Optional[list] = None,
kwargs: Optional[dict] = None,
prefix: Optional[str] = None,
debug: Optional[bool] = False,
) -> str:
"""Generate a unique and deterministic ID for a function invocation.
The ID is generated using xxhash and dill to hash the function and its arguments.
The ID is generated using xxhash and cloudpickle to hash the function and its arguments.

:param fn: the function, or really any Callable, defaults to None
:param args: the function arguments, or any list, defaults to None
Expand All @@ -51,18 +54,25 @@
:return: A unique, yet deterministic string that identifies (fn, args, kwargs) in
the current Python environment.
"""
# import dill
import pickletools
#return cloudpickle.dumps((fn, args, kwargs), protocol=dill.DEFAULT_PROTOCOL)s
if debug:
pickletools.dis(pickletools.optimize(cloudpickle.dumps((fn, args, kwargs))))

Check warning on line 61 in zetta_utils/mazepa/id_generation.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/mazepa/id_generation.py#L61

Added line #L61 was not covered by tests

return str(cloudpickle.dumps((fn, args, kwargs)))
#return cloudpickle.dumps((fn, args, kwargs), protocol=dill.DEFAULT_PROTOCOL)s
x = xxhash.xxh128()
try:
x.update(
dill.dumps(
(fn, args, kwargs),
protocol=dill.DEFAULT_PROTOCOL,
byref=False,
recurse=True,
fmode=dill.FILE_FMODE,
)
)
except dill.PicklingError as e:
x.update(cloudpickle.dumps((fn, args, kwargs)))
#x.update(dill.dumps(
#(fn, args, kwargs),
#protocol=dill.DEFAULT_PROTOCOL,
#byref=False,
#recurse=True,
#fmode=dill.FILE_FMODE,
#))
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning(f"Failed to pickle {fn} with args {args} and kwargs {kwargs}: {e}")
x.update(str(uuid.uuid4()))

Expand Down
Loading