From be7f0a93dd5cc3796120c409e73c70b1e7037d75 Mon Sep 17 00:00:00 2001 From: Leandro Lucarella Date: Thu, 14 Nov 2024 14:06:02 +0100 Subject: [PATCH] Handle type guards properly in `Receiver.filter()` Now the `Receiver` type returned by `Receiver.filter()` will have the narrowed type when a `TypeGuard` is used. Signed-off-by: Leandro Lucarella --- RELEASE_NOTES.md | 2 + src/frequenz/channels/_receiver.py | 60 +++++++++++++++++++++++++++++- tests/test_broadcast.py | 26 +++++++++++++ 3 files changed, 87 insertions(+), 1 deletion(-) diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 5fb36f8c..54e1e20f 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -20,6 +20,8 @@ print('Received from recv2:', selected.message) ``` +* `Receiver.filter()` can now properly handle `TypeGuard`s. The resulting receiver will now have the narrowed type when a `TypeGuard` is used. + ## Bug Fixes diff --git a/src/frequenz/channels/_receiver.py b/src/frequenz/channels/_receiver.py index 7b57a631..53862a45 100644 --- a/src/frequenz/channels/_receiver.py +++ b/src/frequenz/channels/_receiver.py @@ -155,7 +155,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard +from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard, TypeVar, overload from ._exceptions import Error from ._generic import MappedMessageT_co, ReceiverMessageT_co @@ -163,6 +163,9 @@ if TYPE_CHECKING: from ._select import Selected +FilteredMessageT_co = TypeVar("FilteredMessageT_co", covariant=True) +"""Type variable for the filtered message type.""" + class Receiver(ABC, Generic[ReceiverMessageT_co]): """An endpoint to receive messages.""" @@ -267,11 +270,66 @@ def map( """ return _Mapper(receiver=self, mapping_function=mapping_function) + @overload + def filter( + self, + filter_function: Callable[ + [ReceiverMessageT_co], TypeGuard[FilteredMessageT_co] + ], + /, + ) -> Receiver[FilteredMessageT_co]: + """Apply a type guard on the messages on a receiver. + + Tip: + The returned receiver type won't have all the methods of the original + receiver. If you need to access methods of the original receiver that are + not part of the `Receiver` interface you should save a reference to the + original receiver and use that instead. + + Args: + filter_function: The function to be applied on incoming messages to + determine if they should be received. + + Returns: + A new receiver that only receives messages that pass the filter. + """ + ... # pylint: disable=unnecessary-ellipsis + + @overload def filter( self, filter_function: Callable[[ReceiverMessageT_co], bool], / ) -> Receiver[ReceiverMessageT_co]: """Apply a filter function on the messages on a receiver. + Tip: + The returned receiver type won't have all the methods of the original + receiver. If you need to access methods of the original receiver that are + not part of the `Receiver` interface you should save a reference to the + original receiver and use that instead. + + Args: + filter_function: The function to be applied on incoming messages to + determine if they should be received. + + Returns: + A new receiver that only receives messages that pass the filter. + """ + ... # pylint: disable=unnecessary-ellipsis + + def filter( + self, + filter_function: ( + Callable[[ReceiverMessageT_co], bool] + | Callable[[ReceiverMessageT_co], TypeGuard[FilteredMessageT_co]] + ), + /, + ) -> Receiver[ReceiverMessageT_co] | Receiver[FilteredMessageT_co]: + """Apply a filter function on the messages on a receiver. + + Note: + You can pass a [type guard][typing.TypeGuard] as the filter function to + narrow the type of the messages that pass the filter. + Tip: The returned receiver type won't have all the methods of the original receiver. If you need to access methods of the original receiver that are diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index f480d194..565ee0fb 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -6,6 +6,7 @@ import asyncio from dataclasses import dataclass +from typing import TypeGuard, assert_never import pytest @@ -248,6 +249,31 @@ async def test_broadcast_filter() -> None: assert (await receiver.receive()) == 15 +async def test_broadcast_filter_type_guard() -> None: + """Ensure filter type guard works.""" + chan = Broadcast[int | str](name="input-chan") + sender = chan.new_sender() + + def _is_int(num: int | str) -> TypeGuard[int]: + return isinstance(num, int) + + # filter out all numbers less than 10. + receiver = chan.new_receiver().filter(_is_int) + + await sender.send("hello") + await sender.send(8) + + message = await receiver.receive() + assert message == 8 + is_int = False + match message: + case int(): + is_int = True + case unexpected: + assert_never(unexpected) + assert is_int + + async def test_broadcast_receiver_drop() -> None: """Ensure deleted receivers get cleaned up.""" chan = Broadcast[int](name="input-chan")