Skip to content

Commit

Permalink
Handle type guards properly in Receiver.filter()
Browse files Browse the repository at this point in the history
Now the `Receiver` type returned by `Receiver.filter()` will have the
narrowed type when a `TypeGuard` is used.

Signed-off-by: Leandro Lucarella <[email protected]>
  • Loading branch information
llucax committed Nov 15, 2024
1 parent 486d1be commit be7f0a9
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 1 deletion.
2 changes: 2 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
print('Received from recv2:', selected.message)
```

* `Receiver.filter()` can now properly handle `TypeGuard`s. The resulting receiver will now have the narrowed type when a `TypeGuard` is used.

## Bug Fixes

<!-- Here goes notable bug fixes that are worth a special mention or explanation -->
60 changes: 59 additions & 1 deletion src/frequenz/channels/_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,17 @@

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard
from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard, TypeVar, overload

from ._exceptions import Error
from ._generic import MappedMessageT_co, ReceiverMessageT_co

if TYPE_CHECKING:
from ._select import Selected

FilteredMessageT_co = TypeVar("FilteredMessageT_co", covariant=True)
"""Type variable for the filtered message type."""


class Receiver(ABC, Generic[ReceiverMessageT_co]):
"""An endpoint to receive messages."""
Expand Down Expand Up @@ -267,11 +270,66 @@ def map(
"""
return _Mapper(receiver=self, mapping_function=mapping_function)

@overload
def filter(
self,
filter_function: Callable[
[ReceiverMessageT_co], TypeGuard[FilteredMessageT_co]
],
/,
) -> Receiver[FilteredMessageT_co]:
"""Apply a type guard on the messages on a receiver.
Tip:
The returned receiver type won't have all the methods of the original
receiver. If you need to access methods of the original receiver that are
not part of the `Receiver` interface you should save a reference to the
original receiver and use that instead.
Args:
filter_function: The function to be applied on incoming messages to
determine if they should be received.
Returns:
A new receiver that only receives messages that pass the filter.
"""
... # pylint: disable=unnecessary-ellipsis

@overload
def filter(
self, filter_function: Callable[[ReceiverMessageT_co], bool], /
) -> Receiver[ReceiverMessageT_co]:
"""Apply a filter function on the messages on a receiver.
Tip:
The returned receiver type won't have all the methods of the original
receiver. If you need to access methods of the original receiver that are
not part of the `Receiver` interface you should save a reference to the
original receiver and use that instead.
Args:
filter_function: The function to be applied on incoming messages to
determine if they should be received.
Returns:
A new receiver that only receives messages that pass the filter.
"""
... # pylint: disable=unnecessary-ellipsis

def filter(
self,
filter_function: (
Callable[[ReceiverMessageT_co], bool]
| Callable[[ReceiverMessageT_co], TypeGuard[FilteredMessageT_co]]
),
/,
) -> Receiver[ReceiverMessageT_co] | Receiver[FilteredMessageT_co]:
"""Apply a filter function on the messages on a receiver.
Note:
You can pass a [type guard][typing.TypeGuard] as the filter function to
narrow the type of the messages that pass the filter.
Tip:
The returned receiver type won't have all the methods of the original
receiver. If you need to access methods of the original receiver that are
Expand Down
26 changes: 26 additions & 0 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import asyncio
from dataclasses import dataclass
from typing import TypeGuard, assert_never

import pytest

Expand Down Expand Up @@ -248,6 +249,31 @@ async def test_broadcast_filter() -> None:
assert (await receiver.receive()) == 15


async def test_broadcast_filter_type_guard() -> None:
"""Ensure filter type guard works."""
chan = Broadcast[int | str](name="input-chan")
sender = chan.new_sender()

def _is_int(num: int | str) -> TypeGuard[int]:
return isinstance(num, int)

# filter out all numbers less than 10.
receiver = chan.new_receiver().filter(_is_int)

await sender.send("hello")
await sender.send(8)

message = await receiver.receive()
assert message == 8
is_int = False
match message:
case int():
is_int = True
case unexpected:
assert_never(unexpected)
assert is_int


async def test_broadcast_receiver_drop() -> None:
"""Ensure deleted receivers get cleaned up."""
chan = Broadcast[int](name="input-chan")
Expand Down

0 comments on commit be7f0a9

Please sign in to comment.