Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(id_generation): recursive hashing of class methods and custom attribs #593

Merged
merged 1 commit into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions tests/unit/mazepa/test_id_generation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

from functools import partial
from typing import Any, Callable, Mapping

import attrs

from zetta_utils import mazepa
from zetta_utils.mazepa import taskable_operation_cls
from zetta_utils.mazepa.id_generation import generate_invocation_id as gen_id

Expand Down Expand Up @@ -80,6 +82,20 @@ def __call__(self, b):
return self.a * b


@mazepa.flow_schema_cls
@attrs.mutable
class FlowSchema:
fn_kwargs: Mapping[Any, Any]
callable_fn: Callable[..., Any]

def __init__(self, fn_kwargs: Mapping[Any, Any], callable_fn: Callable[..., Any]):
self.fn_kwargs = fn_kwargs
self.callable_fn = callable_fn

def flow(self, *args, **kwargs):
return self.callable_fn(*args, **kwargs)


def test_generate_invocation_id_method() -> None:
assert gen_id(ClassA().method, [], {}) != gen_id(ClassB().method, [], {})
assert gen_id(ClassB().method, [], {}) != gen_id(ClassC().method, [], {})
Expand Down Expand Up @@ -130,3 +146,33 @@ def test_generate_invocation_id_taskable_op() -> None:

assert gen_id(TaskableD(1), [], {}) == gen_id(TaskableD(1), [], {})
assert gen_id(TaskableD(1), [], {}) != gen_id(TaskableD(2), [], {})


def test_generate_invocation_id_flow_schema() -> None:
assert gen_id(FlowSchema({}, ClassA().method).flow, [], {}) != gen_id(
FlowSchema({}, ClassB().method).flow, [], {}
)
assert gen_id(FlowSchema({}, ClassB().method).flow, [], {}) != gen_id(
FlowSchema({}, ClassC().method).flow, [], {}
)

assert gen_id(FlowSchema({}, ClassA().method).flow, [4, 2], {}) == gen_id(
FlowSchema({}, ClassA().method).flow, [4, 2], {}
)
assert gen_id(FlowSchema({}, ClassA().method).flow, [], {"a": 1}) == gen_id(
FlowSchema({}, ClassA().method).flow, [], {"a": 1}
)

assert gen_id(FlowSchema({}, ClassA().method).flow, [4, 2], {}) != gen_id(
FlowSchema({}, ClassA().method).flow, [6, 3], {}
)
assert gen_id(FlowSchema({}, ClassA().method).flow, [], {"a": 1}) != gen_id(
FlowSchema({}, ClassA().method).flow, [], {"a": 2}
)

assert gen_id(FlowSchema({}, ClassD1().method).flow, [], {}) != gen_id(
FlowSchema({}, ClassD2().method).flow, [], {}
)
assert gen_id(FlowSchema({}, ClassE(1).method).flow, [], {}) != gen_id(
FlowSchema({}, ClassE(2).method).flow, [], {}
)
59 changes: 30 additions & 29 deletions zetta_utils/mazepa/id_generation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# pylint: disable=unused-argument
from __future__ import annotations

import functools
import uuid
from typing import Callable, Optional

import xxhash
from coolname import generate_slug

import zetta_utils.mazepa.tasks
from zetta_utils import log

logger = log.get_logger("mazepa")
Expand Down Expand Up @@ -37,46 +35,49 @@ def get_unique_id(


def _get_code_hash(
fn: Callable, _hash: Optional[xxhash.xxh128] = None, _prefix=""
fn: Callable, _hash: Optional[xxhash.xxh128] = None, _visited: Optional[set[int]] = None
) -> xxhash.xxh128:
if _hash is None:
_hash = xxhash.xxh128()
if _visited is None:
_visited = set()

# Check to prevent infinite recursion
# This is a bit silly, as the entire custom code hashing endeavor is done to avoid
# issues with Python's code hash in the first place...
# However, PYTHONHASHSEED is not an issue for tracking methods within the same session.
# Generating recursive loops with the same code hash requires some effort by the user
if id(fn) in _visited:
return _hash

try:
_hash.update(fn.__qualname__)
_visited.add(id(fn))

for attribute_name in {x for x in dir(fn) if not x.startswith("__")}:
attrib = getattr(fn, attribute_name)
if callable(attrib):
_get_code_hash(attrib, _hash, _visited)
else:
_hash.update(f"{attribute_name}: {attrib}".encode())

if hasattr(fn, "__self__") and fn.__self__ is not None:
_get_code_hash(fn.__self__, _hash, _visited)

try:
# Mypy wants to see (BuiltinFunctionType, MethodType, MethodWrapperType),
# but not all have __self__.__dict__ that is not a mappingproxy
method_kwargs = fn.__self__.__dict__ # type: ignore
if isinstance(method_kwargs, dict):
_hash.update(method_kwargs.__repr__())
_get_code_hash(fn.__self__.__call__.__func__, _hash, _visited)
except AttributeError:
pass

_hash.update(fn.__code__.co_code)

return _hash
try:
_hash.update(fn.__qualname__)
except AttributeError:
pass

if isinstance(fn, functools.partial):
_hash.update(fn.args.__repr__().encode())

_hash.update(fn.keywords.__repr__().encode())

_hash = _get_code_hash(fn.func, _hash=_hash, _prefix=_prefix + " ")
return _hash

if isinstance(
fn, (zetta_utils.mazepa.tasks.TaskableOperation, zetta_utils.builder.BuilderPartial)
):
_hash.update(fn.__repr__())

_hash = _get_code_hash(fn.__call__, _hash=_hash, _prefix=_prefix + " ")
return _hash
try:
_hash.update(fn.__code__.co_code)
except AttributeError:
pass

raise TypeError(f"Can't hash code for fn of type {type(fn)}")
return _hash


def generate_invocation_id(
Expand Down
Loading