diff --git a/odmantic/model.py b/odmantic/model.py index f387f937..49c7e3e7 100644 --- a/odmantic/model.py +++ b/odmantic/model.py @@ -1,6 +1,7 @@ import datetime import decimal import enum +import functools import pathlib import uuid import warnings @@ -16,6 +17,7 @@ FrozenSet, Iterable, List, + NamedTuple, Optional, Set, Tuple, @@ -499,6 +501,29 @@ def __new__( BaseT = TypeVar("BaseT", bound="_BaseODMModel") +TraversalStateT = NamedTuple( + "TraversalStateT", [("output", List[Any]), ("staging", List[Any])] +) + + +def flat_tree(o: BaseT) -> List[BaseT]: + state = TraversalStateT(output=[], staging=[o]) + + def obj_fields(obj): + return [getattr(obj, name) for name in set(obj.__odm_fields__)] + + def unpack(acc: TraversalStateT, obj: Any) -> TraversalStateT: + output, (_, *staging_tail) = acc + if isinstance(obj, _BaseODMModel): + return TraversalStateT(output + [obj], staging_tail + obj_fields(obj)) + elif isinstance(obj, Iterable) and not isinstance(obj, (str, dict)): + return TraversalStateT(output, staging_tail + [*obj]) + else: + return TraversalStateT(output, staging_tail) + + while state.staging: + state = functools.reduce(unpack, state.staging, state) + return state.output class _BaseODMModel(pydantic.BaseModel, metaclass=ABCMeta): @@ -583,11 +608,8 @@ def _post_copy_update(self: BaseT) -> None: """Recursively update internal fields of the copied model after it has been copied. """ - object.__setattr__(self, "__fields_modified__", set(self.__fields__)) - for field_name, field in self.__odm_fields__.items(): - if isinstance(field, ODMEmbedded): - value = getattr(self, field_name) - value._post_copy_update() + for model in flat_tree(self): + object.__setattr__(model, "__fields_modified__", set(model.__fields__)) def update( self, diff --git a/tests/unit/test_model_logic.py b/tests/unit/test_model_logic.py index 7b12ed27..f2287452 100644 --- a/tests/unit/test_model_logic.py +++ b/tests/unit/test_model_logic.py @@ -389,6 +389,27 @@ class M(Model): assert instance.e.f.g != copied.e.f.g +@pytest.mark.parametrize( + "hint, ctor", + [ + pytest.param(List, list), + pytest.param(Tuple, tuple), + ], +) +def test_model_copy_deep_embedded_model_collection(hint, ctor): + class E(EmbeddedModel): + f: int + + class M(Model): + e: hint[E] + + instance = M(e=ctor([E(f=1)])) + copied = instance.copy(deep=True) + copied.e[0].f = 2 + + assert copied.e[0].f != instance.e[0].f + + def test_model_copy_not_deep_embedded(): class E(EmbeddedModel): f: int