diff --git a/README.md b/README.md index 8be8091..ba2d613 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,14 @@ print(obj.value == obj2.value) # True ``` +### Context managing + +EzSerialization supports two context managers: +- `with no_serialization(): ...` - disables injecting class type metadata into the result of `to_dict()` method. + Leaves the result dict unfit to be deserialized automatically via `deserialize()`; +- `with use_serialization(): ...` - opposite of `no_serialization()`, enables class type metadata injection. + Useful when using inside the disabled serialization scope. + ## Configuration Currently only a single option is available for customizing `ezserialization`: diff --git a/pyproject.toml b/pyproject.toml index c81aa9c..f2b95b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ezserialization" -version = "0.2.10" +version = "0.3.0" description = "Simple, easy to use & transparent python objects serialization & deserialization." authors = ["Matas Gumbinas "] repository = "https://github.com/gMatas/ezserialization" diff --git a/src/ezserialization/_serialization.py b/src/ezserialization/_serialization.py index fdc3529..319ceb2 100644 --- a/src/ezserialization/_serialization.py +++ b/src/ezserialization/_serialization.py @@ -43,45 +43,44 @@ def _is_serializable_subclass(cls: Type) -> bool: :param cls: Type to check. """ - return hasattr(cls, "from_dict") and hasattr(cls, "to_dict") + return isinstance(cls, type) and hasattr(cls, "from_dict") and hasattr(cls, "to_dict") -_T = TypeVar("_T", bound=Serializable) -""" -Serializable object type. -""" - _thread_local = threading.local() -_thread_local.enabled = (_SERIALIZATION_ENABLED_DEFAULT := True) -""" -Thread-safe serialization enabling/disabling flag. -""" -def using_serialization() -> bool: +def _get_serialization_enabled() -> bool: if not hasattr(_thread_local, "enabled"): - _thread_local.enabled = _SERIALIZATION_ENABLED_DEFAULT + _thread_local.enabled = True return cast(bool, _thread_local.enabled) +def _set_serialization_enabled(enabled: bool) -> None: + _thread_local.enabled = enabled + + +def using_serialization() -> bool: + return _get_serialization_enabled() + + @contextlib.contextmanager def use_serialization() -> Iterator[None]: - prev = using_serialization() + prev = _get_serialization_enabled() try: - _thread_local.enabled = _SERIALIZATION_ENABLED_DEFAULT + _set_serialization_enabled(True) yield finally: - _thread_local.enabled = prev + _set_serialization_enabled(prev) @contextlib.contextmanager def no_serialization() -> Iterator[None]: - prev = using_serialization() + prev = _get_serialization_enabled() try: - _thread_local.enabled = not _SERIALIZATION_ENABLED_DEFAULT + _set_serialization_enabled(False) yield finally: - _thread_local.enabled = prev + _set_serialization_enabled(prev) _types_: Dict[str, Type[Serializable]] = {} @@ -125,6 +124,12 @@ def _is_same_type_by_qualname(a: Type, b: Type) -> bool: return _abs_qualname(a) == _abs_qualname(b) +_T = TypeVar("_T", bound=Serializable) +""" +Serializable object type. +""" + + def serializable(cls: Optional[Type[_T]] = None, *, name: Optional[str] = None): def wrapper(cls_: Type[_T]) -> Type[_T]: nonlocal name @@ -142,15 +147,15 @@ def wrapper(cls_: Type[_T]) -> Type[_T]: def wrap_to_dict(method: Callable[..., Mapping]): @functools.wraps(method) - def to_dict_wrapper(obj: Serializable) -> Mapping: - data = method(obj) + def to_dict_wrapper(__ctx, *__args, **__kwargs) -> Mapping: + data = method(__ctx, *__args, **__kwargs) # Wrap object with serialization metadata. if TYPE_FIELD_NAME in data: raise KeyError(f"Key '{TYPE_FIELD_NAME}' already exist in the serialized data mapping!") - if using_serialization(): - typename = _typenames_[type(obj)] - return {TYPE_FIELD_NAME: typename, **data} - return copy(data) + if _get_serialization_enabled(): + typename = _typenames_[__ctx if isinstance(__ctx, type) else type(__ctx)] + return {TYPE_FIELD_NAME: typename, **data} # TODO: avoid copying data if possible + return copy(data) # TODO: avoid copying data if possible return to_dict_wrapper @@ -158,15 +163,30 @@ def to_dict_wrapper(obj: Serializable) -> Mapping: def wrap_from_dict(method: Callable[..., Serializable]): @functools.wraps(method) - def from_dict_wrapper(*args) -> Serializable: - # See if `from_dict` method is staticmethod-like or classmethod-like (or normal method-like), - # i.e. `Serializable.from_dict(data)` or `Serializable().from_dict(data)`. - src = args[1] if len(args) == 2 else args[0] - # Remove deserialization metadata. - src = dict(src) + def from_dict_wrapper(*__args, **__kwargs) -> Serializable: + # Differentiate between different ways this method was called. + first_arg_type = val if isinstance(val := __args[0], type) else type(val) + if _is_same_type_by_qualname(first_arg_type, cls_): + # When this method was called as instance-method i.e. Serializable().from_dict(...) + __cls = first_arg_type + src = __args[1] + __args = __args[2:] + else: + # When this method was called as class-method i.e. Serializable.from_dict(...) + __cls = cls_ + src = __args[0] + __args = __args[1:] + + # Drop deserialization metadata. + src = dict(src) # TODO: avoid copying data src.pop(TYPE_FIELD_NAME, None) - # Deserialize as-is. - return method(src) + + # Deserialize. + if hasattr(method, "__self__"): + # As bounded method (class or instance method) + return method(src, *__args, **__kwargs) + # As staticmethod (simple function) + return method(__cls, src, *__args, **__kwargs) return from_dict_wrapper diff --git a/tests/ezserialization_tests/test_serializable_decorator.py b/tests/ezserialization_tests/test_serializable_decorator.py new file mode 100644 index 0000000..ac4270c --- /dev/null +++ b/tests/ezserialization_tests/test_serializable_decorator.py @@ -0,0 +1,66 @@ +import json +from typing import Mapping, cast + +from ezserialization import ( + Serializable, + deserialize, + serializable, +) + + +@serializable # <- valid for serialization +@serializable(name="A") +@serializable(name="XXX") +class _CaseAUsingAutoName(Serializable): + def __init__(self, value: str): + self.value = value + + def to_raw_dict(self) -> dict: + return {"value": self.value} + + def to_dict(self) -> Mapping: + return self.to_raw_dict() + + @classmethod + def from_dict(cls, src: Mapping): + return cls(value=src["value"]) + + @classmethod + def abs_qualname(cls) -> str: + return f"{cls.__module__}.{cls.__qualname__}" + + +@serializable(name="B") # <- valid for serialization +@serializable(name="YYY") +@serializable +@serializable(name="ZZZ") +class _CaseBUsingNameAlias(Serializable): + def __init__(self, value: str): + self.value = value + + def to_dict(self) -> Mapping: + return {"value": self.value} + + @classmethod + def from_dict(cls, src: Mapping): + return cls(value=src["value"]) + + +def test_serialization_typenames_order(): + """ + Expected behaviour: Only the top typename is used to serialize instances. + On the other hand, for deserialization all typenames are valid. + """ + + a = _CaseAUsingAutoName("a") + data = a.to_dict() + + a.from_dict(data) + + assert data["_type_"] == _CaseAUsingAutoName.abs_qualname() + assert a.value == cast(_CaseAUsingAutoName, deserialize(json.loads(json.dumps(data)))).value + + b = _CaseBUsingNameAlias("b") + data = b.to_dict() + assert data["_type_"] == "B" + assert b.value == cast(_CaseBUsingNameAlias, deserialize(json.loads(json.dumps(data)))).value diff --git a/tests/ezserialization_tests/test_serialization.py b/tests/ezserialization_tests/test_threadsafe_usage.py similarity index 61% rename from tests/ezserialization_tests/test_serialization.py rename to tests/ezserialization_tests/test_threadsafe_usage.py index e15a5db..3714dc5 100644 --- a/tests/ezserialization_tests/test_serialization.py +++ b/tests/ezserialization_tests/test_threadsafe_usage.py @@ -1,10 +1,9 @@ import threading import time -from typing import Mapping, cast +from typing import Mapping from ezserialization import ( Serializable, - deserialize, no_serialization, serializable, use_serialization, @@ -12,10 +11,8 @@ ) -@serializable # <- valid for serialization -@serializable(name="A") -@serializable(name="XXX") -class _CaseAUsingAutoName(Serializable): +@serializable +class _TestSerializable(Serializable): def __init__(self, value: str): self.value = value @@ -29,46 +26,33 @@ def to_dict(self) -> Mapping: def from_dict(cls, src: Mapping): return cls(value=src["value"]) - @classmethod - def abs_qualname(cls) -> str: - return f"{cls.__module__}.{cls.__qualname__}" - - -@serializable(name="B") # <- valid for serialization -@serializable(name="YYY") -@serializable -@serializable(name="ZZZ") -class _CaseBUsingNameAlias(Serializable): - def __init__(self, value: str): - self.value = value - - def to_dict(self) -> Mapping: - return {"value": self.value} - - @classmethod - def from_dict(cls, src: Mapping): - return cls(value=src["value"]) - - -def test_serialization_typenames_order(): - """ - Expected behaviour: Only the top typename is used to serialize instances. - On the other hand, for deserialization all typenames are valid. - """ - a = _CaseAUsingAutoName("a") - data = a.to_dict() - assert data["_type_"] == _CaseAUsingAutoName.abs_qualname() - assert a.value == cast(_CaseAUsingAutoName, deserialize(data)).value +class _TestThread(threading.Thread): + def __init__(self): + self.exception = None + self.finished = False + self.should_stop = False + self.serialization_explicitly_enabled = False + super().__init__(target=self._fun, daemon=True) - b = _CaseBUsingNameAlias("b") - data = b.to_dict() - assert data["_type_"] == "B" - assert b.value == cast(_CaseBUsingNameAlias, deserialize(data)).value + def _fun(self): + try: + assert using_serialization() + with use_serialization(): + assert using_serialization() + self.serialization_explicitly_enabled = True + while not self.should_stop: + time.sleep(0.1) + except Exception as e: + self.exception = e + finally: + self.finished = True + self.should_stop = True + self.serialization_explicitly_enabled = True def test_threadsafe_serialization_enabling_and_disabling(): - a = _CaseAUsingAutoName("foo") + a = _TestSerializable("foo") assert using_serialization(), "By default, serialization must be enabled!" @@ -97,27 +81,3 @@ def test_threadsafe_serialization_enabling_and_disabling(): raise thread.exception assert using_serialization() - - -class _TestThread(threading.Thread): - def __init__(self): - self.exception = None - self.finished = False - self.should_stop = False - self.serialization_explicitly_enabled = False - super().__init__(target=self._fun, daemon=True) - - def _fun(self): - try: - assert using_serialization() - with use_serialization(): - assert using_serialization() - self.serialization_explicitly_enabled = True - while not self.should_stop: - time.sleep(0.1) - except Exception as e: - self.exception = e - finally: - self.finished = True - self.should_stop = True - self.serialization_explicitly_enabled = True diff --git a/tests/ezserialization_tests/test_to_and_from_dict_methods.py b/tests/ezserialization_tests/test_to_and_from_dict_methods.py new file mode 100644 index 0000000..8ac98b6 --- /dev/null +++ b/tests/ezserialization_tests/test_to_and_from_dict_methods.py @@ -0,0 +1,63 @@ +import json +from abc import ABC +from typing import Mapping, Type, cast, overload + +from ezserialization import ( + Serializable, + deserialize, + serializable, +) + + +class _BaseTestCase(Serializable, ABC): + def __init__(self, value: str): + self.value = value + + def to_dict(self) -> Mapping: + return self.to_raw_dict() + + def to_raw_dict(self) -> dict: + return {"value": self.value} + + +@serializable +class _TestFromDictWithClassmethod(_BaseTestCase): + @classmethod + def from_dict(cls, src: Mapping) -> "_TestFromDictWithClassmethod": + return cls(value=src["value"]) + + +@serializable +class _TestFromDictWithStaticmethod(_BaseTestCase): + @staticmethod + @overload + def from_dict(cls: Type["_TestFromDictWithStaticmethod"], src: Mapping) -> "_TestFromDictWithStaticmethod": ... + + @staticmethod + @overload + def from_dict(*args) -> "_TestFromDictWithStaticmethod": ... + + @staticmethod + def from_dict(*args, **kwargs) -> "_TestFromDictWithStaticmethod": + obj = args[0](value=args[1]["value"]) + assert isinstance(obj, _TestFromDictWithStaticmethod) + return obj + + +def test_from_dict_as_classmethod(): + obj = _TestFromDictWithClassmethod("wow") + obj_dict = obj.to_dict() + assert obj.value == obj.from_dict(obj_dict).value + assert obj.value == _TestFromDictWithClassmethod.from_dict(obj_dict).value + + assert obj.value == cast(_TestFromDictWithClassmethod, deserialize(json.loads(json.dumps(obj_dict)))).value + + +def test_from_dict_as_staticmethod(): + obj = _TestFromDictWithStaticmethod("wow") + obj_dict = obj.to_dict() + assert obj.value == obj.from_dict(obj_dict).value + assert obj.value == _TestFromDictWithStaticmethod.from_dict(obj, obj_dict).value + assert obj.value == _TestFromDictWithStaticmethod.from_dict(_TestFromDictWithStaticmethod, obj_dict).value + + assert obj.value == cast(_TestFromDictWithStaticmethod, deserialize(json.loads(json.dumps(obj_dict)))).value