Skip to content

Commit

Permalink
extract_paths, use get_object_state (#209)
Browse files Browse the repository at this point in the history
This fixes the same problems as #207
but now for extract_paths,
by using the same shared code (get_object_state).
  • Loading branch information
albertz authored Oct 21, 2024
1 parent 9e597db commit 311ebfd
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 18 deletions.
43 changes: 39 additions & 4 deletions sisyphus/hash.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Tuple
import enum
import hashlib
import pathlib
Expand Down Expand Up @@ -37,13 +38,43 @@ def short_hash(obj, length=12, chars="0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdef
return "".join(ls)


_BasicTypes: Tuple[type, ...] = (int, float, bool, str, complex)
_BasicSeqTypes: Tuple[type, ...] = (list, tuple)
_BasicSetTypes: Tuple[type, ...] = (set, frozenset)
_BasicDictTypes: Tuple[type, ...] = (dict,)
_BasicTypesCombined: Tuple[type, ...] = _BasicTypes + _BasicSeqTypes + _BasicSetTypes + _BasicDictTypes


def get_object_state(obj):
"""
Export current object status
Comment: Maybe obj.__reduce__() is a better idea? is it stable for hashing?
"""

# Note: sis_hash_helper does not call get_object_state in these cases.
# However, other code might (e.g. extract_paths),
# so we keep consistent to the behavior of sis_hash_helper.
if obj is None:
return None
if isinstance(obj, _BasicTypesCombined):
for type_ in _BasicTypesCombined:
if isinstance(obj, type_):
# Note: For compatibility with old behavior, first only allow if the type is not derived
# (thus type(obj) is type_ check).
# For derived cases, we need to be more careful.
if type(obj) is type_:
return obj
if hasattr(obj, "__getnewargs_ex__") or hasattr(obj, "__getnewargs__"):
# E.g. a namedtuple. This is sometimes used, and we must keep the behavior for compatibility.
break # Use the original behavior.
# For now, let's fail, and extend this logic maybe later, to make sure we don't miss anything.
# E.g. a dict/tuple/list would contain the elements itself (which should be part of the state),
# and maybe *additionally* some other things in __dict__.
raise TypeError(f"derived type {obj!r} {type(obj)!r} not handled yet")
if isfunction(obj) or isclass(obj):
return obj.__module__, obj.__qualname__

if isinstance(obj, pathlib.PurePath):
# pathlib paths have a somewhat technical internal state
# ('_drv', '_root', '_parts', '_str', '_hash', '_pparts', '_cached_cparts'),
Expand Down Expand Up @@ -101,13 +132,17 @@ def sis_hash_helper(obj):
byte_list.append(obj)
elif obj is None:
pass
elif type(obj) in (int, float, bool, str, complex):
# Note: Using type(obj) in _Types instead of isinstance(obj, _Types)
# because of historical reasons (and we cannot change this now).
# For derived types (e.g. namedtuple, np.float), it is then handled by get_object_state.
# That's why the handling of get_object_state for those types is important.
elif type(obj) in _BasicTypes:
byte_list.append(repr(obj).encode())
elif type(obj) in (list, tuple):
elif type(obj) in _BasicSeqTypes:
byte_list += map(sis_hash_helper, obj)
elif type(obj) in (set, frozenset):
elif type(obj) in _BasicSetTypes:
byte_list += sorted(map(sis_hash_helper, obj))
elif isinstance(obj, dict):
elif isinstance(obj, _BasicDictTypes):
# sort items to ensure they are always in the same order
byte_list += sorted(map(sis_hash_helper, obj.items()))
elif isfunction(obj):
Expand Down
22 changes: 8 additions & 14 deletions sisyphus/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import sisyphus.global_settings as gs
from sisyphus.block import Block
from sisyphus.hash import get_object_state


def get_system_informations(file=sys.stdout):
Expand Down Expand Up @@ -84,29 +85,22 @@ def extract_paths(args: Any) -> Set:
if id(obj) in visited_obj_ids:
continue
visited_obj_ids[id(obj)] = obj
if obj is None:
continue
if isinstance(obj, (bool, int, float, complex, str)):
continue
if isinstance(obj, Block) or isinstance(obj, enum.Enum):
continue
if hasattr(obj, "_sis_path") and obj._sis_path is True and not type(obj) is type:
out.add(obj)
elif isinstance(obj, (list, tuple, set)):
elif isinstance(obj, (list, tuple, set, frozenset)):
queue.extend(obj)
elif isinstance(obj, dict):
for k, v in obj.items():
if not type(k) == str or not k.startswith("_sis_"):
queue.append(v)
elif hasattr(obj, "__sis_state__") and not inspect.isclass(obj):
queue.append(obj.__sis_state__())
elif hasattr(obj, "__getstate__") and not inspect.isclass(obj):
queue.append(obj.__getstate__())
elif hasattr(obj, "__dict__"):
for k, v in obj.__dict__.items():
if not type(k) == str or not k.startswith("_sis_"):
queue.append(v)
elif hasattr(obj, "__slots__"):
for k in obj.__slots__:
if hasattr(obj, k) and not k.startswith("_sis_"):
a = getattr(obj, k)
queue.append(a)
else:
queue.append(get_object_state(obj))
return out


Expand Down
16 changes: 16 additions & 0 deletions tests/hash_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,22 @@ def d():
self.assertEqual(sis_hash_helper(b), b"(function, (tuple, (str, '" + __name__.encode() + b"'), (str, 'b')))")
self.assertRaises(AssertionError, sis_hash_helper, c)

def test_get_object_state_cls(self):
# Note: the hash of a class currently does not depend on get_object_state,
# but there is special logic in sis_hash_helper for classes,
# thus it doesn't really matter for the hash what is being returned here.
# However, this is used by extract_paths, so we test it here.
s = get_object_state(str)
self.assertEqual(s, ("builtins", "str"))

def test_get_object_state_function(self):
# Note: the hash of a function currently does not depend on get_object_state,
# but there is special logic in sis_hash_helper for functions,
# thus it doesn't really matter for the hash what is being returned here.
# However, this is used by extract_paths, so we test it here.
s = get_object_state(b)
self.assertEqual(s, (b.__module__, b.__name__))

def test_enum(self):
self.assertEqual(
sis_hash_helper(MyEnum.Entry1),
Expand Down
9 changes: 9 additions & 0 deletions tests/tools_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,14 @@ def test_copy(self):
shutil.rmtree(dst)


def test_extract_paths_functools_partial():
from sisyphus.tools import extract_paths
from functools import partial

path = job_path.Path("foo/bar")
obj = partial(job_path.Path.join_right, path)
assert extract_paths(obj) == {path}


if __name__ == "__main__":
unittest.main()

0 comments on commit 311ebfd

Please sign in to comment.