From 7414f1f1df9b4cbeed61231660d59d8444280759 Mon Sep 17 00:00:00 2001 From: Dag Brattli Date: Fri, 27 Nov 2020 18:24:46 +0100 Subject: [PATCH] Fix tests for async iteration --- aioreactive/__init__.py | 24 +++++---- aioreactive/combine.py | 6 +-- aioreactive/create.py | 42 ++++++++++++++-- aioreactive/filtering.py | 23 ++++++--- aioreactive/timeshift.py | 6 +-- aioreactive/transform.py | 4 +- test/test_async_iteration.py | 95 ++++++++++++++++-------------------- 7 files changed, 118 insertions(+), 82 deletions(-) diff --git a/aioreactive/__init__.py b/aioreactive/__init__.py index 1901b8e..45feb1f 100644 --- a/aioreactive/__init__.py +++ b/aioreactive/__init__.py @@ -31,10 +31,11 @@ class AsyncRx(AsyncObservable[TSource]): """An AsyncObservable class similar to classic Rx. - This class supports has all operators as methods and supports + This class provides all operators as methods and supports method chaining. - Subscribe is also a method. + Example: + >>> AsyncRx.from_iterable([1,2,3]).map(lambda x: x + 2).filter(lambda x: x < 3) All methods are lazy imported. """ @@ -112,6 +113,10 @@ def empty(cls) -> "AsyncRx[TSource]": def from_iterable(cls, iter: Iterable[TSource]) -> "AsyncRx[TSource]": return AsyncRx(from_iterable(iter)) + @classmethod + def from_async_iterable(cls, iter: AsyncIterable[TSource]) -> "AsyncObservable[TSource]": + return AsyncRx(from_async_iterable(iter)) + @classmethod def single(cls, value: TSource) -> "AsyncRx[TSource]": from .create import single @@ -252,11 +257,12 @@ def flat_map_latest_async( most-recently transformed observable sequence. Args: - mapper (Callable[[TSource]): [description] - Awaitable ([type]): [description] + mapper: Function to transform each item into a new async + observable. Returns: - Stream[TSource, TResult]: [description] + An async observable that only merges values from the latest + async observable produced by the mapper. """ return AsyncRx(pipe(self, flat_map_latest_async(mapper))) @@ -611,12 +617,10 @@ def flat_map_latest_async(mapper: Callable[[TSource], Awaitable[AsyncObservable[ return flat_map_latest_async(mapper) -def from_async_iterable(iter: Iterable[TSource]) -> "AsyncObservable[TSource]": - from aioreactive.operators.from_async_iterable import from_async_iterable - - from .create import from_async_iterable +def from_async_iterable(iter: AsyncIterable[TSource]) -> "AsyncObservable[TSource]": + from .create import of_async_iterable - return AsyncRx(from_async_iterable(iter)) + return AsyncRx(of_async_iterable(iter)) def interval(seconds: float, period: int) -> AsyncObservable[int]: diff --git a/aioreactive/combine.py b/aioreactive/combine.py index 4a4210b..a18291b 100644 --- a/aioreactive/combine.py +++ b/aioreactive/combine.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Generic, Iterable, Tuple, TypeVar, cast from expression.collections import FrozenList, Map, frozenlist, map -from expression.core import MailboxProcessor, Nothing, Option, Result, Some, TailCall, match, pipe, recursive_async +from expression.core import MailboxProcessor, Nothing, Option, Result, Some, TailCall, match, pipe, tailrec_async from expression.system import AsyncDisposable from .create import of_seq @@ -168,7 +168,7 @@ async def subscribe_async(aobv: AsyncObserver[Tuple[TSource, TOther]]) -> AsyncD safe_obv, auto_detach = auto_detach_observer(aobv) async def worker(inbox: MailboxProcessor[Msg]) -> None: - @recursive_async + @tailrec_async async def message_loop( source_value: Option[TSource], other_value: Option[TOther] ) -> Result[AsyncObservable[TSource], Exception]: @@ -245,7 +245,7 @@ async def subscribe_async(aobv: AsyncObserver[Tuple[TSource, TOther]]) -> AsyncD safe_obv, auto_detach = auto_detach_observer(aobv) async def worker(inbox: MailboxProcessor[Msg]) -> None: - @recursive_async + @tailrec_async async def message_loop(latest: Option[TOther]) -> Result[TSource, Exception]: cn = await inbox.receive() diff --git a/aioreactive/create.py b/aioreactive/create.py index 6933880..4de6dc3 100644 --- a/aioreactive/create.py +++ b/aioreactive/create.py @@ -1,8 +1,9 @@ import asyncio import logging -from typing import Awaitable, Callable, Iterable, Tuple, TypeVar +from asyncio import Future +from typing import AsyncIterable, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar -from expression.core import Ok, Result, aiotools, recursive_async +from expression.core import Ok, Result, aiotools, tailrec_async from expression.core.fn import TailCall from expression.system import AsyncDisposable, CancellationToken, CancellationTokenSource @@ -68,6 +69,41 @@ async def worker(obv: AsyncObserver[TSource], _: CancellationToken) -> None: return of_async_worker(worker) +def of_async_iterable(iterable: AsyncIterable[TSource]) -> AsyncObservable[TSource]: + """Convert an async iterable to a source stream. + 2 - xs = from_async_iterable(async_iterable) + Returns the source stream whose elements are pulled from the + given (async) iterable sequence.""" + + async def subscribe_async(observer: AsyncObserver[TSource]) -> AsyncDisposable: + task: Optional[Future[TSource]] = None + + async def cancel(): + if task is not None: + task.cancel() + + sub = AsyncDisposable.create(cancel) + + async def worker() -> None: + async for value in iterable: + try: + await observer.asend(value) + except Exception as ex: + await observer.athrow(ex) + return + + await observer.aclose() + + try: + task = asyncio.ensure_future(worker()) + except Exception as ex: + log.debug("FromIterable:worker(), Exception: %s" % ex) + await observer.athrow(ex) + return sub + + return AsyncAnonymousObservable(subscribe_async) + + def single(value: TSource) -> AsyncObservable[TSource]: """Returns an observable sequence containing the single specified element.""" @@ -155,7 +191,7 @@ def interval(seconds: float, period: float) -> AsyncObservable[int]: async def subscribe_async(aobv: AsyncObserver[int]) -> AsyncDisposable: cancel, token = canceller() - @recursive_async + @tailrec_async async def handler(seconds: float, next: int) -> Result[None, Exception]: await asyncio.sleep(seconds) await aobv.asend(next) diff --git a/aioreactive/filtering.py b/aioreactive/filtering.py index f4deaba..d506e20 100644 --- a/aioreactive/filtering.py +++ b/aioreactive/filtering.py @@ -1,17 +1,24 @@ -from typing import (Any, Awaitable, Callable, List, Optional, Tuple, TypeVar, - overload) +from typing import Any, Awaitable, Callable, List, Optional, Tuple, TypeVar, overload from expression.collections import seq -from expression.core import (MailboxProcessor, Option, Result, TailCall, - aiotools, compose, fst, match, pipe, - recursive_async) +from expression.core import ( + MailboxProcessor, + Option, + Result, + TailCall, + aiotools, + compose, + fst, + match, + pipe, + tailrec_async, +) from expression.system.disposable import AsyncDisposable from .combine import zip_seq from .notification import Notification, OnCompleted, OnError, OnNext from .observables import AsyncAnonymousObservable -from .observers import (AsyncAnonymousObserver, AsyncNotificationObserver, - auto_detach_observer) +from .observers import AsyncAnonymousObserver, AsyncNotificationObserver, auto_detach_observer from .transform import map, transform from .types import AsyncObservable, AsyncObserver, Stream @@ -114,7 +121,7 @@ async def subscribe_async(aobv: AsyncObserver[TSource]) -> AsyncDisposable: safe_obv, auto_detach = auto_detach_observer(aobv) async def worker(inbox: MailboxProcessor[Notification[TSource]]) -> None: - @recursive_async + @tailrec_async async def message_loop(latest: Notification[TSource]) -> Result[Notification[TSource], Exception]: n = await inbox.receive() diff --git a/aioreactive/timeshift.py b/aioreactive/timeshift.py index 9b0c1b5..60c96be 100644 --- a/aioreactive/timeshift.py +++ b/aioreactive/timeshift.py @@ -4,7 +4,7 @@ from typing import Iterable, Tuple, TypeVar from expression.collections import seq -from expression.core import MailboxProcessor, Result, TailCall, aiotools, match, pipe, recursive_async, snd +from expression.core import MailboxProcessor, Result, TailCall, aiotools, match, pipe, tailrec_async, snd from expression.system import CancellationTokenSource from .combine import with_latest_from @@ -37,7 +37,7 @@ def _delay(source: AsyncObservable[TSource]) -> AsyncObservable[TSource]: async def subscribe_async(aobv: AsyncObserver[TSource]) -> AsyncDisposable: async def worker(inbox: MailboxProcessor[Tuple[Notification[TSource], datetime]]) -> None: - @recursive_async + @tailrec_async async def loop() -> Result[None, Exception]: ns, due_time = await inbox.receive() @@ -102,7 +102,7 @@ async def subscribe_async(aobv: AsyncObserver[TSource]) -> AsyncDisposable: infinite: Iterable[int] = seq.infinite() async def worker(inbox: MailboxProcessor[Tuple[Notification[TSource], int]]) -> None: - @recursive_async + @tailrec_async async def message_loop(current_index: int) -> Result[TSource, Exception]: n, index = await inbox.receive() diff --git a/aioreactive/transform.py b/aioreactive/transform.py index b8c576d..4e0b8c2 100644 --- a/aioreactive/transform.py +++ b/aioreactive/transform.py @@ -11,7 +11,7 @@ compose, match, pipe, - recursive_async, + tailrec_async, ) from expression.system import AsyncDisposable @@ -256,7 +256,7 @@ async def aclose() -> None: return AsyncAnonymousObserver(asend, athrow, aclose) async def worker(inbox: MailboxProcessor[Msg]) -> None: - @recursive_async + @tailrec_async async def message_loop( current: Option[AsyncDisposable], is_stopped: bool, current_id: int ) -> Result[None, Exception]: diff --git a/test/test_async_iteration.py b/test/test_async_iteration.py index 48c4ace..45add44 100644 --- a/test/test_async_iteration.py +++ b/test/test_async_iteration.py @@ -1,74 +1,63 @@ -# import pytest -# import asyncio -# import logging +import asyncio +import logging -# from aioreactive.testing import VirtualTimeEventLoop -# from aioreactive.core import AsyncObservable, run, subscribe, AsyncStream, AsyncAnonymousObserver, AsyncIteratorObserver -# from aioreactive.operators.pipe import pipe -# from aioreactive.operators import op, from_async_iterable -# from aioreactive.operators.to_async_iterable import to_async_iterable +import aioreactive as rx +import pytest +from aioreactive.testing import VirtualTimeEventLoop -# log = logging.getLogger(__name__) -# logging.basicConfig(level=logging.DEBUG) +log = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG) -# @pytest.yield_fixture() -# def event_loop() -> None: -# loop = VirtualTimeEventLoop() -# yield loop -# loop.close() +@pytest.yield_fixture() # type: ignore +def event_loop(): + loop = VirtualTimeEventLoop() + yield loop + loop.close() -# @pytest.mark.asyncio -# async def test_async_iteration() -> None: -# xs = AsyncObservable.from_iterable([1, 2, 3]) -# result = [] +@pytest.mark.asyncio +async def test_async_iteration() -> None: + xs = rx.from_iterable([1, 2, 3]) + result = [] -# async for x in to_async_iterable(xs): -# result.append(x) + async for x in rx.to_async_iterable(xs): + result.append(x) -# assert result == [1, 2, 3] + assert result == [1, 2, 3] -# @pytest.mark.asyncio -# async def test_async_comprehension() -> None: -# xs = AsyncObservable.from_iterable([1, 2, 3]) +@pytest.mark.asyncio +async def test_async_comprehension() -> None: + xs = rx.from_iterable([1, 2, 3]) -# result = [x async for x in to_async_iterable(xs)] + result = [x async for x in rx.to_async_iterable(xs)] -# assert result == [1, 2, 3] + assert result == [1, 2, 3] -# @pytest.mark.asyncio -# async def test_async_iteration_aync_with() -> None: -# xs = AsyncObservable.from_iterable([1, 2, 3]) -# result = [] +@pytest.mark.asyncio +async def test_async_iteration_aync_with() -> None: + xs = rx.from_iterable([1, 2, 3]) + result = [] -# obv = AsyncIteratorObserver() -# async with subscribe(xs, obv): -# async for x in obv: -# result.append(x) + obv = rx.AsyncIteratorObserver(xs) + async for x in obv: + result.append(x) -# assert result == [1, 2, 3] + assert result == [1, 2, 3] -# @pytest.mark.asyncio -# async def test_async_iteration_inception() -> None: -# # iterable to async source to async iterator to async source -# obv = AsyncIteratorObserver() +@pytest.mark.asyncio +async def test_async_iteration_inception() -> None: + # iterable to async source to async iterator to async source + xs = rx.from_iterable([1, 2, 3]) + obv = rx.AsyncIteratorObserver(xs) -# xs = AsyncObservable.from_iterable([1, 2, 3]) -# await subscribe(xs, obv) -# ys = from_async_iterable(obv) -# result = [] + ys = rx.from_async_iterable(obv) + result = [] -# async for y in to_async_iterable(ys): -# result.append(y) + async for y in rx.to_async_iterable(ys): + result.append(y) -# assert result == [1, 2, 3] - - -# if __name__ == '__main__': -# loop = asyncio.get_event_loop() -# loop.run_until_complete(test_async_iteration_inception()) -# loop.close() + assert result == [1, 2, 3]