Skip to content

Commit

Permalink
chore: Merge branch 'delay_comparision_in_evented_model' into unify-p…
Browse files Browse the repository at this point in the history
…ydantic
  • Loading branch information
tlambert03 committed Feb 24, 2024
2 parents 7ed0bc7 + 79e2b61 commit 80186dc
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 28 deletions.
125 changes: 97 additions & 28 deletions src/psygnal/_evented_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Dict,
Iterator,
NamedTuple,
Optional,
Set,
Type,
Union,
Expand All @@ -32,6 +33,7 @@

if TYPE_CHECKING:
from inspect import Signature
from types import TracebackType

from pydantic import ConfigDict
from pydantic._internal import _model_construction as pydantic_main
Expand Down Expand Up @@ -169,6 +171,23 @@ def _model_dump(obj: pydantic.BaseModel) -> dict:
return obj.dict()


class ComparisonDelayer:
def __init__(self, target: "EventedModel") -> None:
self._target = target

def __enter__(self) -> None:
self._target._delay_check_semaphore += 1

def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional["TracebackType"],
) -> None:
self._target._delay_check_semaphore -= 1
self._target._check_if_values_changed_and_emit_if_needed()


class EventedMetaclass(pydantic_main.ModelMetaclass):
"""pydantic ModelMetaclass that preps "equality checking" operations.
Expand Down Expand Up @@ -246,6 +265,10 @@ def __new__(
cls, model_config, model_fields
)
cls.__signal_group__ = type(f"{name}SignalGroup", (SignalGroup,), signals)
if not cls.__field_dependents__ and hasattr(cls, "_setattr_no_dependants"):
cls._setattr_default = cls._setattr_no_dependants
elif hasattr(cls, "_setattr_with_dependents"):
cls._setattr_default = cls._setattr_with_dependents
return cls


Expand Down Expand Up @@ -391,12 +414,12 @@ def c(self, val: Sequence[int]) -> None:
class Config:
allow_property_setters = True
property_dependencies = {"c": ["a", "b"]}
field_dependencies = {"c": ["a", "b"]}
m = MyModel()
assert m.c == [1, 1]
m.events.c.connect(lambda v: print(f"c updated to {v}"))
m.a = 2 # prints 'c updated to [2, 1]'
m.a = 2 # prints 'c updated to [2, 1]'
```
"""
Expand All @@ -413,6 +436,10 @@ class Config:
__field_dependents__: ClassVar[Dict[str, Set[str]]]
__eq_operators__: ClassVar[Dict[str, "EqOperator"]]

_changes_queue: Dict[str, Any] = PrivateAttr(default_factory=dict)
_primary_changes: Set[str] = PrivateAttr(default_factory=set)
_delay_check_semaphore: int = PrivateAttr(0)

if PYDANTIC_V1:

class Config:
Expand Down Expand Up @@ -510,15 +537,44 @@ def reset(self) -> None:
elif not model_config.get("frozen") and not model_fields[name].frozen:
setattr(self, name, value)

def __setattr__(self, name: str, value: Any) -> None:
if (
name == "_events"
or not hasattr(self, "_events") # can happen on init
or name not in self._events
):
# fallback to default behavior
return self._super_setattr_(name, value)
def _check_if_values_changed_and_emit_if_needed(self) -> None:
"""
Check if field values changed and emit events if needed.
The advantage of moving this to the end of all the modifications is
that comparisons will be performed only once for every potential change.
"""
if self._delay_check_semaphore > 0 or len(self._changes_queue) == 0:
# do not run whole machinery if there is no need
return
to_emit = []
for name in self._primary_changes:
# primary changes should contains only fields
# that are changed directly by assigment
old_value = self._changes_queue[name]
new_value = getattr(self, name)
if not _check_field_equality(type(self), name, new_value, old_value):
to_emit.append((name, new_value))
self._changes_queue.pop(name)
if not to_emit:
# If no direct changes was made then we can skip whole machinery
self._changes_queue.clear()
self._primary_changes.clear()
return
for name, old_value in self._changes_queue.items():
# check if any of dependent properties changed
new_value = getattr(self, name)
if not _check_field_equality(type(self), name, new_value, old_value):
to_emit.append((name, new_value))
self._changes_queue.clear()
self._primary_changes.clear()

with ComparisonDelayer(self):
# Again delay comparison to avoid having events caused by callback functions
for name, new_value in to_emit:
getattr(self._events, name)(new_value)

def _setattr_impl(self, name: str, value: Any) -> None:
# if there are no listeners, we can just set the value without emitting
# so first check if there are any listeners for this field or any of its
# dependent properties.
Expand All @@ -537,28 +593,41 @@ def __setattr__(self, name: str, value: Any) -> None:
and not len(group._psygnal_relay) # no listeners on the SignalGroup
):
return self._super_setattr_(name, value)
self._primary_changes.add(name)
if name not in self._changes_queue:
self._changes_queue[name] = getattr(self, name, object())

# grab the current value and those of any dependent properties
# so that we can check if they have changed after setting the value
before = getattr(self, name, object())
deps_before: Dict[str, Any] = {
dep: getattr(self, dep) for dep in deps_with_callbacks
}
for dep in deps_with_callbacks:
if dep not in self._changes_queue:
self._changes_queue[dep] = getattr(self, dep, object())
self._super_setattr_(name, value)

# set value using original setter
with signal_instance.blocked():
self._super_setattr_(name, value)
def _setattr_default(self, name: str, value: Any) -> None:
"""Will be overwritten by metaclass __new__."""

# if the value has changed we emit the event with new value
after = getattr(self, name)
if not _check_field_equality(type(self), name, after, before):
signal_instance.emit(after) # emit event
def _setattr_with_dependents(self, name: str, value: Any) -> None:
with ComparisonDelayer(self):
self._setattr_impl(name, value)

# also emit events for any dependent attributes that have changed as well
for dep, before_val in deps_before.items():
after_val = getattr(self, dep)
if not _check_field_equality(type(self), dep, after_val, before_val):
getattr(self._events, dep).emit(after_val)
def _setattr_no_dependants(self, name: str, value: Any) -> None:
group = self._events
signal_instance: SignalInstance = group[name]
if len(signal_instance) < 1:
return self._super_setattr_(name, value)
old_value = getattr(self, name, object())
self._super_setattr_(name, value)
if not _check_field_equality(type(self), name, value, old_value):
getattr(self._events, name)(value)

def __setattr__(self, name: str, value: Any) -> None:
if (
name == "_events"
or not hasattr(self, "_events") # can happen on init
or name not in self._events
):
# fallback to default behavior
return self._super_setattr_(name, value)
self._setattr_default(name, value)

if PYDANTIC_V1:

Expand Down
52 changes: 52 additions & 0 deletions tests/test_evented_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@ def test_evented_model_with_property_setters_events():
mock_c.assert_called_with([5, 20])
mock_b.assert_not_called()
assert t.c == [5, 20]
mock_a.reset_mock()
t.a = 5 # no change, no events
mock_a.assert_not_called()


def test_non_setter_with_dependencies() -> None:
Expand Down Expand Up @@ -845,3 +848,52 @@ class Model(EventedModel):
m.a = 3
check_mock.assert_has_calls([call(Model, "a", 3, 1)])
mock1.assert_called_once()


def test_if_event_is_emited_only_once():
"""Check if, for complex property setters, the event is emitted only once."""

class SampleClass(EventedModel):
a: int = 1
b: int = 2

if PYDANTIC_V2:
model_config = {
"allow_property_setters": True,
"guess_property_dependencies": True,
}
else:

class Config:
allow_property_setters = True
guess_property_dependencies = True

@property
def c(self):
return self.a + self.b

@c.setter
def c(self, value):
self.a = value - self.b

@property
def d(self):
return self.a + self.b

@d.setter
def d(self, value):
self.a = value // 2
self.b = value - self.a

s = SampleClass()
a_m = Mock()
c_m = Mock()
d_m = Mock()
s.events.a.connect(a_m)
s.events.c.connect(c_m)
s.events.d.connect(d_m)

s.d = 5
a_m.assert_called_once()
c_m.assert_called_once()
d_m.assert_called_once()

0 comments on commit 80186dc

Please sign in to comment.