Skip to content

Commit

Permalink
perf: Delay SignalRelay connection to when a callback is connected (#277
Browse files Browse the repository at this point in the history
)

* fix: delay relay connection

* fix: evented model

* test: add test

* test: extend test

* Update src/psygnal/_group.py

Co-authored-by: Grzegorz Bokota <[email protected]>

* Update src/psygnal/_group_descriptor.py

Co-authored-by: Grzegorz Bokota <[email protected]>

* test: disconnect relay

---------

Co-authored-by: Grzegorz Bokota <[email protected]>
  • Loading branch information
tlambert03 and Czaki authored Feb 21, 2024
1 parent e9ee496 commit ab5dfc5
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 30 deletions.
4 changes: 2 additions & 2 deletions src/psygnal/_evented_model_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,10 @@ def __setattr__(self, name: str, value: Any) -> None:
deps_with_callbacks = {
dep_name
for dep_name in self.__field_dependents__.get(name, ())
if len(group[dep_name]) > 1
if len(group[dep_name])
}
if (
len(signal_instance) < 2 # the signal itself has no listeners
len(signal_instance) < 1 # the signal itself has no listeners
and not deps_with_callbacks # no dependent properties with listeners
and not len(group._psygnal_relay) # no listeners on the SignalGroup
):
Expand Down
4 changes: 2 additions & 2 deletions src/psygnal/_evented_model_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,10 @@ def __setattr__(self, name: str, value: Any) -> None:
deps_with_callbacks = {
dep_name
for dep_name in self.__field_dependents__.get(name, ())
if len(group[dep_name]) > 1
if len(group[dep_name])
}
if (
len(signal_instance) < 2 # the signal itself has no listeners
len(signal_instance) < 1 # the signal itself has no listeners
and not deps_with_callbacks # no dependent properties with listeners
and not len(group._psygnal_relay) # no listeners on the SignalGroup
):
Expand Down
22 changes: 21 additions & 1 deletion src/psygnal/_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@

import warnings
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
ContextManager,
Iterable,
Iterator,
Literal,
Mapping,
NamedTuple,
)
Expand All @@ -26,6 +28,9 @@

from ._mypyc import mypyc_attr

if TYPE_CHECKING:
from psygnal._weak_callback import WeakCallback

__all__ = ["EmissionInfo", "SignalGroup"]


Expand Down Expand Up @@ -63,14 +68,29 @@ def __init__(
self._signals = signals
self._sig_was_blocked: dict[str, bool] = {}

def _append_slot(self, slot: WeakCallback) -> None:
super()._append_slot(slot)
if len(self._slots) == 1:
self._connect_relay()

def _connect_relay(self) -> None:
# silence any warnings about failed weakrefs (will occur in compiled version)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
for sig in signals.values():
for sig in self._signals.values():
sig.connect(
self._slot_relay, check_nargs=False, check_types=False, unique=True
)

def _remove_slot(self, slot: int | WeakCallback | Literal["all"]) -> None:
super()._remove_slot(slot)
if not self._slots:
self._disconnect_relay()

def _disconnect_relay(self) -> None:
for sig in self._signals.values():
sig.disconnect(self._slot_relay)

def _slot_relay(self, *args: Any) -> None:
if emitter := Signal.current_emitter():
info = EmissionInfo(emitter, args)
Expand Down
2 changes: 1 addition & 1 deletion src/psygnal/_group_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def _setattr_and_emit_(self: object, name: str, value: Any) -> None:

# don't emit if the signal doesn't exist or has no listeners
signal: SignalInstance = group[name]
if len(signal) < 2 and not len(group._psygnal_relay):
if len(signal) < 1:
return super_setattr(self, name, value)

with _changes_emitted(self, name, signal):
Expand Down
32 changes: 23 additions & 9 deletions src/psygnal/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,14 +521,28 @@ def _wrapper(
finalize=self._try_discard,
on_ref_error=_on_ref_err,
)
if thread is None:
self._slots.append(cb)
else:
self._slots.append(QueuedCallback(cb, thread=thread))
if thread is not None:
cb = QueuedCallback(cb, thread=thread)
self._append_slot(cb)
return slot

return _wrapper if slot is None else _wrapper(slot)

def _append_slot(self, slot: WeakCallback) -> None:
"""Append a slot to the list of slots."""
# implementing this as a method allows us to override/extend it in subclasses
self._slots.append(slot)

def _remove_slot(self, slot: Literal["all"] | int | WeakCallback) -> None:
"""Remove a slot from the list of slots."""
# implementing this as a method allows us to override/extend it in subclasses
if slot == "all":
self._slots.clear()
elif isinstance(slot, int):
self._slots.pop(slot)
else:
self._slots.remove(cast("WeakCallback", slot))

def _try_discard(self, callback: WeakCallback, missing_ok: bool = True) -> None:
"""Try to discard a callback from the list of slots.
Expand All @@ -540,7 +554,7 @@ def _try_discard(self, callback: WeakCallback, missing_ok: bool = True) -> None:
If `True`, do not raise an error if the callback is not found in the list.
"""
try:
self._slots.remove(callback)
self._remove_slot(callback)
except ValueError:
if not missing_ok:
raise
Expand Down Expand Up @@ -633,7 +647,7 @@ def connect_setattr(
finalize=self._try_discard,
on_ref_error=on_ref_error,
)
self._slots.append(caller)
self._append_slot(caller)
return caller

def disconnect_setattr(
Expand Down Expand Up @@ -747,7 +761,7 @@ def connect_setitem(
finalize=self._try_discard,
on_ref_error=on_ref_error,
)
self._slots.append(caller)
self._append_slot(caller)

return caller

Expand Down Expand Up @@ -859,12 +873,12 @@ def disconnect(self, slot: Callable | None = None, missing_ok: bool = True) -> N
with self._lock:
if slot is None:
# NOTE: clearing an empty list is actually a RuntimeError in Qt
self._slots.clear()
self._remove_slot("all")
return

idx = self._slot_index(slot)
if idx != -1:
self._slots.pop(idx)
self._remove_slot(idx)
elif not missing_ok:
raise ValueError(f"slot is not connected: {slot}")

Expand Down
56 changes: 45 additions & 11 deletions tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class MyGroup(SignalGroup):
sig2 = Signal(str)


def test_signal_group():
def test_signal_group() -> None:
assert not MyGroup.psygnals_uniform()
with pytest.warns(
FutureWarning, match="The `is_uniform` method on SignalGroup is deprecated"
Expand All @@ -33,7 +33,7 @@ def test_signal_group():
group.sig3 # noqa: B018


def test_uniform_group():
def test_uniform_group() -> None:
"""In a uniform group, all signals must have the same signature."""

class MyStrictGroup(SignalGroup, strict=True):
Expand All @@ -55,7 +55,7 @@ class BadGroup(SignalGroup, strict=True):


@pytest.mark.skipif(Annotated is None, reason="requires typing.Annotated")
def test_nonhashable_args():
def test_nonhashable_args() -> None:
"""Test that non-hashable annotations are allowed in a SignalGroup"""

class MyGroup(SignalGroup):
Expand All @@ -72,7 +72,7 @@ class MyGroup2(SignalGroup, strict=True):


@pytest.mark.parametrize("direct", [True, False])
def test_signal_group_connect(direct: bool):
def test_signal_group_connect(direct: bool) -> None:
mock = Mock()
group = MyGroup()
if direct:
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_signal_group_connect(direct: bool):
mock.assert_has_calls(expected_calls)


def test_signal_group_connect_no_args():
def test_signal_group_connect_no_args() -> None:
"""Test that group.all.connect can take a callback that wants no args"""
group = MyGroup()
count = []
Expand All @@ -118,7 +118,7 @@ def my_slot() -> None:
assert len(count) == 2


def test_group_blocked():
def test_group_blocked() -> None:
group = MyGroup()

mock1 = Mock()
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_group_blocked():
mock2.assert_not_called()


def test_group_blocked_exclude():
def test_group_blocked_exclude() -> None:
"""Test that we can exempt certain signals from being blocked."""
group = MyGroup()

Expand All @@ -166,7 +166,7 @@ def test_group_blocked_exclude():
mock2.assert_called_once_with("hi")


def test_group_disconnect_single_slot():
def test_group_disconnect_single_slot() -> None:
"""Test that we can disconnect single slots from groups."""
group = MyGroup()

Expand All @@ -184,7 +184,7 @@ def test_group_disconnect_single_slot():
mock2.assert_called_once()


def test_group_disconnect_all_slots():
def test_group_disconnect_all_slots() -> None:
"""Test that we can disconnect all slots from groups."""
group = MyGroup()

Expand All @@ -202,7 +202,7 @@ def test_group_disconnect_all_slots():
mock2.assert_not_called()


def test_weakref():
def test_weakref() -> None:
"""Make sure that the group doesn't keep a strong reference to the instance."""
import gc

Expand All @@ -218,7 +218,7 @@ class T: ...

def test_group_deepcopy() -> None:
class T:
def method(self): ...
def method(self) -> None: ...

obj = T()
group = MyGroup(obj)
Expand Down Expand Up @@ -254,3 +254,37 @@ class MyGroup(SignalGroup):

assert "_psygnal_thing" not in MyGroup._psygnal_signals
assert "other_signal" in MyGroup._psygnal_signals


def test_delayed_relay_connect() -> None:
group = MyGroup()
mock = Mock()
gmock = Mock()
assert len(group.sig1) == 0

group.sig1.connect(mock)
# group relay hasn't been connected to sig1 or sig2 yet
assert len(group.sig1) == 1
assert len(group.sig2) == 0

group.all.connect(gmock)
# NOW the relay is connected
assert len(group.sig1) == 2
assert len(group.sig2) == 1
method = group.sig1._slots[-1].dereference()
assert method
assert method.__name__ == "_slot_relay"

group.sig1.emit(1)
mock.assert_called_once_with(1)
gmock.assert_called_once_with(EmissionInfo(group.sig1, (1,)))

group.all.disconnect(gmock)
assert len(group.sig1) == 1
assert len(group.all) == 0

mock.reset_mock()
gmock.reset_mock()
group.sig1.emit(1)
mock.assert_called_once_with(1)
gmock.assert_not_called()
7 changes: 3 additions & 4 deletions tests/test_psygnal.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def test_weakref(slot):
"partial",
],
)
def test_group_weakref(slot):
def test_group_weakref(slot) -> None:
"""Test that a connected method doesn't hold strong ref."""
from psygnal import SignalGroup

Expand All @@ -395,8 +395,7 @@ class MyGroup(SignalGroup):
group = MyGroup()
obj = MyObj()

# simply by nature of being in a group, sig1 will have a callback
assert len(group.sig1) == 1
assert len(group.sig1) == 0
# but the group itself doesn't have any
assert len(group._psygnal_relay) == 0

Expand All @@ -412,7 +411,7 @@ class MyGroup(SignalGroup):
del obj
gc.collect()
group.sig1.emit(1) # this should trigger deletion, so would emitter.emit()
assert len(group.sig1) == 1
assert len(group.sig1) == 0 # NOTE! this is 0 not 1, because the relay is also gone
assert len(group._psygnal_relay) == 0 # it's been cleaned up


Expand Down

0 comments on commit ab5dfc5

Please sign in to comment.