Skip to content

Commit

Permalink
more test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 committed Feb 13, 2024
1 parent a78d9d6 commit 9f44dfc
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 50 deletions.
51 changes: 33 additions & 18 deletions src/psygnal/_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

from psygnal._signal import Signal, SignalInstance, _SignalBlocker

__all__ = ["EmissionInfo", "SignalGroup"]


class EmissionInfo(NamedTuple):
"""Tuple containing information about an emission event.
Expand All @@ -42,20 +44,23 @@ class EmissionInfo(NamedTuple):
class SignalRelay(SignalInstance):
"""Special SignalInstance that can be used to connect to all signals in a group."""

def __init__(self, group: SignalGroup, instance: Any = None) -> None:
self._group = group
def __init__(
self, signals: Mapping[str, SignalInstance], instance: Any = None
) -> None:
super().__init__(signature=(EmissionInfo,), instance=instance)
self._signals = signals
self._sig_was_blocked: dict[str, bool] = {}

# silence any warnings about failed weakrefs (will occur in compiled version)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
for sig in group._psygnal_instances.values():
sig.connect(self._slot_relay, check_nargs=False, check_types=False)
for sig in signals.values():
sig.connect(
self._slot_relay, check_nargs=False, check_types=False, unique=True
)

def _slot_relay(self, *args: Any) -> None:
emitter = Signal.current_emitter()
if emitter:
if emitter := Signal.current_emitter():
info = EmissionInfo(emitter, args)
self._run_emit_loop((info,))

Expand Down Expand Up @@ -103,8 +108,8 @@ def connect_direct(
"""

def _inner(slot: Callable) -> Callable:
for sig in self._group:
self._group[sig].connect(
for sig in self._signals:
self._signals[sig].connect(
slot,
check_nargs=check_nargs,
check_types=check_types,
Expand All @@ -118,8 +123,8 @@ def _inner(slot: Callable) -> Callable:
def block(self, exclude: Iterable[str | SignalInstance] = ()) -> None:
"""Block this signal and all emitters from emitting."""
super().block()
for k in self._group:
v = self._group[k]
for k in self._signals:
v = self._signals[k]
if exclude and v in exclude or k in exclude:
continue
self._sig_was_blocked[k] = v._is_blocked
Expand All @@ -128,9 +133,9 @@ def block(self, exclude: Iterable[str | SignalInstance] = ()) -> None:
def unblock(self) -> None:
"""Unblock this signal and all emitters, allowing them to emit."""
super().unblock()
for k in self._group:
for k in self._signals:
if not self._sig_was_blocked.pop(k, False):
self._group[k].unblock()
self._signals[k].unblock()

def blocked(
self, exclude: Iterable[str | SignalInstance] = ()
Expand Down Expand Up @@ -162,8 +167,8 @@ def disconnect(self, slot: Callable | None = None, missing_ok: bool = True) -> N
ValueError
If `slot` is not connected and `missing_ok` is False.
"""
for signal in self._group:
self._group[signal].disconnect(slot, missing_ok)
for signal in self._signals:
self._signals[signal].disconnect(slot, missing_ok)
super().disconnect(slot, missing_ok)


Expand All @@ -186,7 +191,7 @@ def __init__(self, instance: Any = None) -> None:
self._psygnal_instances: dict[str, SignalInstance] = {
name: signal.__get__(self, cls) for name, signal in cls._signals_.items()
}
self._psygnal_relay = SignalRelay(self, instance)
self._psygnal_relay = SignalRelay(self._psygnal_instances, instance)

# determine the public name of the signal relay.
# by default, this is "all", but it can be overridden by the user by creating
Expand Down Expand Up @@ -222,7 +227,14 @@ def __getattr__(self, name: str) -> Any:
if name in self._psygnal_instances:
return self._psygnal_instances[name]

Check warning on line 228 in src/psygnal/_group.py

View check run for this annotation

Codecov / codecov/patch

src/psygnal/_group.py#L228

Added line #L228 was not covered by tests
if name != "_psygnal_relay" and hasattr(self._psygnal_relay, name):
# TODO: add deprecation warning and redirect to `self.all`
warnings.warn(
f"Accessing SignalInstance attribute {name!r} on a SignalGroup is "
f"deprecated. Access it on the {self._psygnal_relay_name!r} "
f"attribute instead. e.g. `group.{self._psygnal_relay_name}.{name}`"
". This will be an error in a future version.",
FutureWarning,
stacklevel=2,
)
return getattr(self._psygnal_relay, name)
raise AttributeError(f"{type(self).__name__!r} has no attribute {name!r}")

Expand Down Expand Up @@ -251,11 +263,14 @@ def __repr__(self) -> str:
@classmethod
def is_uniform(cls) -> bool:
"""Return true if all signals in the group have the same signature."""
# TODO: Deprecate this method
# TODO: Deprecate this method?
return cls._uniform

def __deepcopy__(self, memo: dict[int, Any]) -> SignalGroup:
# TODO: should we also copy connections?
# TODO:
# This really isn't a deep copy. Should we also copy connections?
# a working deepcopy is important for pydantic support, but in most cases
# it will be a group without any signals connected
return type(self)(instance=self._psygnal_relay.instance)


Expand Down
6 changes: 3 additions & 3 deletions tests/containers/test_evented_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ class E:
e_obj = E()
root: EventedList[E] = EventedList(child_events=True)
mock = Mock()
root.events.connect(mock)
root.events.all.connect(mock)
root.append(e_obj)
assert len(e_obj.test) == 1
assert root == [e_obj]
Expand Down Expand Up @@ -347,7 +347,7 @@ def __init__(self):
e_obj = E()
root: EventedList[E] = EventedList(child_events=True)
mock = Mock()
root.events.connect(mock)
root.events.all.connect(mock)
root.append(e_obj)
assert root == [e_obj]
e_obj.events.test2.emit("hi")
Expand All @@ -372,7 +372,7 @@ def __init__(self):

# note that we can get back to the actual object in the list using the .instance
# attribute on signal instances.
assert e_obj.events.test2.instance.instance == e_obj
assert e_obj.events.test2.instance.all.instance == e_obj
mock.assert_has_calls(expected)


Expand Down
16 changes: 13 additions & 3 deletions tests/test_evented_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class Config:
assert model1 == model2


def test_values_updated():
def test_values_updated() -> None:
class User(EventedModel):
"""Demo evented model.
Expand All @@ -201,7 +201,17 @@ class User(EventedModel):
user1_events = Mock()
u1_id_events = Mock()
u2_id_events = Mock()
user1.events.connect(user1_events)

with pytest.warns(
FutureWarning,
match="Accessing SignalInstance attribute 'connect' on a SignalGroup "
"is deprecated",
):
user1.events.connect(user1_events)
user1.events.connect(user1_events)

user1.events.id.connect(u1_id_events)
user2.events.id.connect(u2_id_events)
user1.events.id.connect(u1_id_events)
user2.events.id.connect(u2_id_events)

Expand Down Expand Up @@ -837,7 +847,7 @@ class Model(EventedModel):
check_mock.assert_not_called()
mock1.assert_not_called()

m.events.connect(mock1)
m.events.all.connect(mock1)
with patch.object(
model_module,
"_check_field_equality",
Expand Down
28 changes: 14 additions & 14 deletions tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def test_signal_group_connect(direct: bool):
group = MyGroup()
if direct:
# the callback wants the emitted arguments directly
group.connect_direct(mock)
group.all.connect_direct(mock)
else:
# the callback will receive an EmissionInfo tuple
# (SignalInstance, arg_tuple)
group.connect(mock)
group.all.connect(mock)
group.sig1.emit(1)
group.sig2.emit("hi")

Expand All @@ -89,14 +89,14 @@ def test_signal_group_connect(direct: bool):


def test_signal_group_connect_no_args():
"""Test that group.connect can take a callback that wants no args"""
"""Test that group.all.connect can take a callback that wants no args"""
group = MyGroup()
count = []

def my_slot() -> None:
count.append(1)

group.connect(my_slot)
group.all.connect(my_slot)
group.sig1.emit(1)
group.sig2.emit("hi")
assert len(count) == 2
Expand All @@ -108,7 +108,7 @@ def test_group_blocked():
mock1 = Mock()
mock2 = Mock()

group.connect(mock1)
group.all.connect(mock1)
group.sig1.connect(mock2)
group.sig1.emit(1)

Expand All @@ -121,7 +121,7 @@ def test_group_blocked():
group.sig2.block()
assert group.sig2._is_blocked

with group.blocked():
with group.all.blocked():
group.sig1.emit(1)
assert group.sig1._is_blocked

Expand All @@ -143,7 +143,7 @@ def test_group_blocked_exclude():
group.sig1.connect(mock1)
group.sig2.connect(mock2)

with group.blocked(exclude=("sig2",)):
with group.all.blocked(exclude=("sig2",)):
group.sig1.emit(1)
group.sig2.emit("hi")
mock1.assert_not_called()
Expand All @@ -160,7 +160,7 @@ def test_group_disconnect_single_slot():
group.sig1.connect(mock1)
group.sig2.connect(mock2)

group.disconnect(mock1)
group.all.disconnect(mock1)
group.sig1.emit()
mock1.assert_not_called()

Expand All @@ -178,7 +178,7 @@ def test_group_disconnect_all_slots():
group.sig1.connect(mock1)
group.sig2.connect(mock2)

group.disconnect()
group.all.disconnect()
group.sig1.emit()
group.sig2.emit()

Expand All @@ -195,10 +195,10 @@ class T:

obj = T()
group = MyGroup(obj)
assert group.instance is obj
assert group.all.instance is obj
del obj
gc.collect()
assert group.instance is None
assert group.all.instance is None


def test_group_deepcopy() -> None:
Expand All @@ -210,16 +210,16 @@ def method(self):
group = MyGroup(obj)
assert deepcopy(group) is not group # but no warning

group.connect(obj.method)
group.all.connect(obj.method)

# with pytest.warns(UserWarning, match="does not copy connected weakly"):
group2 = deepcopy(group)

assert not len(group2._psygnal_relay)
mock = Mock()
mock2 = Mock()
group.connect(mock)
group2.connect(mock2)
group.all.connect(mock)
group2.all.connect(mock2)

group2.sig1.emit(1)
mock.assert_not_called()
Expand Down
22 changes: 11 additions & 11 deletions tests/test_psygnal.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,28 +391,28 @@ def test_group_weakref(slot):
class MyGroup(SignalGroup):
sig1 = Signal(int)

emitter = MyGroup()
group = MyGroup()
obj = MyObj()

# simply by nature of being in a group, sig1 will have a callback
assert len(emitter.sig1) == 1
assert len(group.sig1) == 1
# but the group itself doesn't have any
assert len(emitter._psygnal_relay) == 0
assert len(group._psygnal_relay) == 0

# connecting something to the group adds to the group connections
emitter.connect(
group.all.connect(
partial(obj.f_int_int, 1) if slot == "partial" else getattr(obj, slot)
)
assert len(emitter.sig1) == 1
assert len(emitter._psygnal_relay) == 1
assert len(group.sig1) == 1
assert len(group._psygnal_relay) == 1

emitter.sig1.emit(1)
assert len(emitter.sig1) == 1
group.sig1.emit(1)
assert len(group.sig1) == 1
del obj
gc.collect()
emitter.sig1.emit(1) # this should trigger deletion, so would emitter.emit()
assert len(emitter.sig1) == 1
assert len(emitter._psygnal_relay) == 0 # it's been cleaned up
group.sig1.emit(1) # this should trigger deletion, so would emitter.emit()
assert len(group.sig1) == 1
assert len(group._psygnal_relay) == 0 # it's been cleaned up


# def test_norm_slot():
Expand Down
3 changes: 2 additions & 1 deletion typesafety/test_group.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
t = T()
reveal_type(T.e) # N: Revealed type is "psygnal._group_descriptor.SignalGroupDescriptor"
reveal_type(t.e) # N: Revealed type is "psygnal._group.SignalGroup"
reveal_type(t.e.x) # N: Revealed type is "psygnal._signal.SignalInstance"
reveal_type(t.e['x']) # N: Revealed type is "psygnal._signal.SignalInstance"
reveal_type(t.e.x) # N: Revealed type is "Any"
@t.e.x.connect
def func(x: int) -> None:
Expand Down

0 comments on commit 9f44dfc

Please sign in to comment.