diff --git a/tests/unit/mazepa/test_id_generation.py b/tests/unit/mazepa/test_id_generation.py index 03ae321af..b5e5b23e8 100644 --- a/tests/unit/mazepa/test_id_generation.py +++ b/tests/unit/mazepa/test_id_generation.py @@ -1,9 +1,11 @@ from __future__ import annotations from functools import partial +from typing import Any, Callable, Mapping import attrs +from zetta_utils import mazepa from zetta_utils.mazepa import taskable_operation_cls from zetta_utils.mazepa.id_generation import generate_invocation_id as gen_id @@ -80,6 +82,20 @@ def __call__(self, b): return self.a * b +@mazepa.flow_schema_cls +@attrs.mutable +class FlowSchema: + fn_kwargs: Mapping[Any, Any] + callable_fn: Callable[..., Any] + + def __init__(self, fn_kwargs: Mapping[Any, Any], callable_fn: Callable[..., Any]): + self.fn_kwargs = fn_kwargs + self.callable_fn = callable_fn + + def flow(self, *args, **kwargs): + return self.callable_fn(*args, **kwargs) + + def test_generate_invocation_id_method() -> None: assert gen_id(ClassA().method, [], {}) != gen_id(ClassB().method, [], {}) assert gen_id(ClassB().method, [], {}) != gen_id(ClassC().method, [], {}) @@ -130,3 +146,33 @@ def test_generate_invocation_id_taskable_op() -> None: assert gen_id(TaskableD(1), [], {}) == gen_id(TaskableD(1), [], {}) assert gen_id(TaskableD(1), [], {}) != gen_id(TaskableD(2), [], {}) + + +def test_generate_invocation_id_flow_schema() -> None: + assert gen_id(FlowSchema({}, ClassA().method).flow, [], {}) != gen_id( + FlowSchema({}, ClassB().method).flow, [], {} + ) + assert gen_id(FlowSchema({}, ClassB().method).flow, [], {}) != gen_id( + FlowSchema({}, ClassC().method).flow, [], {} + ) + + assert gen_id(FlowSchema({}, ClassA().method).flow, [4, 2], {}) == gen_id( + FlowSchema({}, ClassA().method).flow, [4, 2], {} + ) + assert gen_id(FlowSchema({}, ClassA().method).flow, [], {"a": 1}) == gen_id( + FlowSchema({}, ClassA().method).flow, [], {"a": 1} + ) + + assert gen_id(FlowSchema({}, ClassA().method).flow, [4, 2], {}) != gen_id( + FlowSchema({}, ClassA().method).flow, [6, 3], {} + ) + assert gen_id(FlowSchema({}, ClassA().method).flow, [], {"a": 1}) != gen_id( + FlowSchema({}, ClassA().method).flow, [], {"a": 2} + ) + + assert gen_id(FlowSchema({}, ClassD1().method).flow, [], {}) != gen_id( + FlowSchema({}, ClassD2().method).flow, [], {} + ) + assert gen_id(FlowSchema({}, ClassE(1).method).flow, [], {}) != gen_id( + FlowSchema({}, ClassE(2).method).flow, [], {} + ) diff --git a/zetta_utils/mazepa/id_generation.py b/zetta_utils/mazepa/id_generation.py index dab8990b2..12f202cf3 100644 --- a/zetta_utils/mazepa/id_generation.py +++ b/zetta_utils/mazepa/id_generation.py @@ -1,14 +1,12 @@ # pylint: disable=unused-argument from __future__ import annotations -import functools import uuid from typing import Callable, Optional import xxhash from coolname import generate_slug -import zetta_utils.mazepa.tasks from zetta_utils import log logger = log.get_logger("mazepa") @@ -37,46 +35,49 @@ def get_unique_id( def _get_code_hash( - fn: Callable, _hash: Optional[xxhash.xxh128] = None, _prefix="" + fn: Callable, _hash: Optional[xxhash.xxh128] = None, _visited: Optional[set[int]] = None ) -> xxhash.xxh128: if _hash is None: _hash = xxhash.xxh128() + if _visited is None: + _visited = set() + + # Check to prevent infinite recursion + # This is a bit silly, as the entire custom code hashing endeavor is done to avoid + # issues with Python's code hash in the first place... + # However, PYTHONHASHSEED is not an issue for tracking methods within the same session. + # Generating recursive loops with the same code hash requires some effort by the user + if id(fn) in _visited: + return _hash - try: - _hash.update(fn.__qualname__) + _visited.add(id(fn)) + + for attribute_name in {x for x in dir(fn) if not x.startswith("__")}: + attrib = getattr(fn, attribute_name) + if callable(attrib): + _get_code_hash(attrib, _hash, _visited) + else: + _hash.update(f"{attribute_name}: {attrib}".encode()) + + if hasattr(fn, "__self__") and fn.__self__ is not None: + _get_code_hash(fn.__self__, _hash, _visited) try: - # Mypy wants to see (BuiltinFunctionType, MethodType, MethodWrapperType), - # but not all have __self__.__dict__ that is not a mappingproxy - method_kwargs = fn.__self__.__dict__ # type: ignore - if isinstance(method_kwargs, dict): - _hash.update(method_kwargs.__repr__()) + _get_code_hash(fn.__self__.__call__.__func__, _hash, _visited) except AttributeError: pass - _hash.update(fn.__code__.co_code) - - return _hash + try: + _hash.update(fn.__qualname__) except AttributeError: pass - if isinstance(fn, functools.partial): - _hash.update(fn.args.__repr__().encode()) - - _hash.update(fn.keywords.__repr__().encode()) - - _hash = _get_code_hash(fn.func, _hash=_hash, _prefix=_prefix + " ") - return _hash - - if isinstance( - fn, (zetta_utils.mazepa.tasks.TaskableOperation, zetta_utils.builder.BuilderPartial) - ): - _hash.update(fn.__repr__()) - - _hash = _get_code_hash(fn.__call__, _hash=_hash, _prefix=_prefix + " ") - return _hash + try: + _hash.update(fn.__code__.co_code) + except AttributeError: + pass - raise TypeError(f"Can't hash code for fn of type {type(fn)}") + return _hash def generate_invocation_id(