Skip to content

Commit

Permalink
Fix tests for async iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
dbrattli committed Nov 27, 2020
1 parent 0055e5a commit 7414f1f
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 82 deletions.
24 changes: 14 additions & 10 deletions aioreactive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
class AsyncRx(AsyncObservable[TSource]):
"""An AsyncObservable class similar to classic Rx.
This class supports has all operators as methods and supports
This class provides all operators as methods and supports
method chaining.
Subscribe is also a method.
Example:
>>> AsyncRx.from_iterable([1,2,3]).map(lambda x: x + 2).filter(lambda x: x < 3)
All methods are lazy imported.
"""
Expand Down Expand Up @@ -112,6 +113,10 @@ def empty(cls) -> "AsyncRx[TSource]":
def from_iterable(cls, iter: Iterable[TSource]) -> "AsyncRx[TSource]":
return AsyncRx(from_iterable(iter))

@classmethod
def from_async_iterable(cls, iter: AsyncIterable[TSource]) -> "AsyncObservable[TSource]":
return AsyncRx(from_async_iterable(iter))

@classmethod
def single(cls, value: TSource) -> "AsyncRx[TSource]":
from .create import single
Expand Down Expand Up @@ -252,11 +257,12 @@ def flat_map_latest_async(
most-recently transformed observable sequence.
Args:
mapper (Callable[[TSource]): [description]
Awaitable ([type]): [description]
mapper: Function to transform each item into a new async
observable.
Returns:
Stream[TSource, TResult]: [description]
An async observable that only merges values from the latest
async observable produced by the mapper.
"""
return AsyncRx(pipe(self, flat_map_latest_async(mapper)))

Expand Down Expand Up @@ -611,12 +617,10 @@ def flat_map_latest_async(mapper: Callable[[TSource], Awaitable[AsyncObservable[
return flat_map_latest_async(mapper)


def from_async_iterable(iter: Iterable[TSource]) -> "AsyncObservable[TSource]":
from aioreactive.operators.from_async_iterable import from_async_iterable

from .create import from_async_iterable
def from_async_iterable(iter: AsyncIterable[TSource]) -> "AsyncObservable[TSource]":
from .create import of_async_iterable

return AsyncRx(from_async_iterable(iter))
return AsyncRx(of_async_iterable(iter))


def interval(seconds: float, period: int) -> AsyncObservable[int]:
Expand Down
6 changes: 3 additions & 3 deletions aioreactive/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Callable, Generic, Iterable, Tuple, TypeVar, cast

from expression.collections import FrozenList, Map, frozenlist, map
from expression.core import MailboxProcessor, Nothing, Option, Result, Some, TailCall, match, pipe, recursive_async
from expression.core import MailboxProcessor, Nothing, Option, Result, Some, TailCall, match, pipe, tailrec_async
from expression.system import AsyncDisposable

from .create import of_seq
Expand Down Expand Up @@ -168,7 +168,7 @@ async def subscribe_async(aobv: AsyncObserver[Tuple[TSource, TOther]]) -> AsyncD
safe_obv, auto_detach = auto_detach_observer(aobv)

async def worker(inbox: MailboxProcessor[Msg]) -> None:
@recursive_async
@tailrec_async
async def message_loop(
source_value: Option[TSource], other_value: Option[TOther]
) -> Result[AsyncObservable[TSource], Exception]:
Expand Down Expand Up @@ -245,7 +245,7 @@ async def subscribe_async(aobv: AsyncObserver[Tuple[TSource, TOther]]) -> AsyncD
safe_obv, auto_detach = auto_detach_observer(aobv)

async def worker(inbox: MailboxProcessor[Msg]) -> None:
@recursive_async
@tailrec_async
async def message_loop(latest: Option[TOther]) -> Result[TSource, Exception]:
cn = await inbox.receive()

Expand Down
42 changes: 39 additions & 3 deletions aioreactive/create.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import logging
from typing import Awaitable, Callable, Iterable, Tuple, TypeVar
from asyncio import Future
from typing import AsyncIterable, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar

from expression.core import Ok, Result, aiotools, recursive_async
from expression.core import Ok, Result, aiotools, tailrec_async
from expression.core.fn import TailCall
from expression.system import AsyncDisposable, CancellationToken, CancellationTokenSource

Expand Down Expand Up @@ -68,6 +69,41 @@ async def worker(obv: AsyncObserver[TSource], _: CancellationToken) -> None:
return of_async_worker(worker)


def of_async_iterable(iterable: AsyncIterable[TSource]) -> AsyncObservable[TSource]:
"""Convert an async iterable to a source stream.
2 - xs = from_async_iterable(async_iterable)
Returns the source stream whose elements are pulled from the
given (async) iterable sequence."""

async def subscribe_async(observer: AsyncObserver[TSource]) -> AsyncDisposable:
task: Optional[Future[TSource]] = None

async def cancel():
if task is not None:
task.cancel()

sub = AsyncDisposable.create(cancel)

async def worker() -> None:
async for value in iterable:
try:
await observer.asend(value)
except Exception as ex:
await observer.athrow(ex)
return

await observer.aclose()

try:
task = asyncio.ensure_future(worker())
except Exception as ex:
log.debug("FromIterable:worker(), Exception: %s" % ex)
await observer.athrow(ex)
return sub

return AsyncAnonymousObservable(subscribe_async)


def single(value: TSource) -> AsyncObservable[TSource]:
"""Returns an observable sequence containing the single specified element."""

Expand Down Expand Up @@ -155,7 +191,7 @@ def interval(seconds: float, period: float) -> AsyncObservable[int]:
async def subscribe_async(aobv: AsyncObserver[int]) -> AsyncDisposable:
cancel, token = canceller()

@recursive_async
@tailrec_async
async def handler(seconds: float, next: int) -> Result[None, Exception]:
await asyncio.sleep(seconds)
await aobv.asend(next)
Expand Down
23 changes: 15 additions & 8 deletions aioreactive/filtering.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
from typing import (Any, Awaitable, Callable, List, Optional, Tuple, TypeVar,
overload)
from typing import Any, Awaitable, Callable, List, Optional, Tuple, TypeVar, overload

from expression.collections import seq
from expression.core import (MailboxProcessor, Option, Result, TailCall,
aiotools, compose, fst, match, pipe,
recursive_async)
from expression.core import (
MailboxProcessor,
Option,
Result,
TailCall,
aiotools,
compose,
fst,
match,
pipe,
tailrec_async,
)
from expression.system.disposable import AsyncDisposable

from .combine import zip_seq
from .notification import Notification, OnCompleted, OnError, OnNext
from .observables import AsyncAnonymousObservable
from .observers import (AsyncAnonymousObserver, AsyncNotificationObserver,
auto_detach_observer)
from .observers import AsyncAnonymousObserver, AsyncNotificationObserver, auto_detach_observer
from .transform import map, transform
from .types import AsyncObservable, AsyncObserver, Stream

Expand Down Expand Up @@ -114,7 +121,7 @@ async def subscribe_async(aobv: AsyncObserver[TSource]) -> AsyncDisposable:
safe_obv, auto_detach = auto_detach_observer(aobv)

async def worker(inbox: MailboxProcessor[Notification[TSource]]) -> None:
@recursive_async
@tailrec_async
async def message_loop(latest: Notification[TSource]) -> Result[Notification[TSource], Exception]:
n = await inbox.receive()

Expand Down
6 changes: 3 additions & 3 deletions aioreactive/timeshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Iterable, Tuple, TypeVar

from expression.collections import seq
from expression.core import MailboxProcessor, Result, TailCall, aiotools, match, pipe, recursive_async, snd
from expression.core import MailboxProcessor, Result, TailCall, aiotools, match, pipe, tailrec_async, snd
from expression.system import CancellationTokenSource

from .combine import with_latest_from
Expand Down Expand Up @@ -37,7 +37,7 @@ def _delay(source: AsyncObservable[TSource]) -> AsyncObservable[TSource]:

async def subscribe_async(aobv: AsyncObserver[TSource]) -> AsyncDisposable:
async def worker(inbox: MailboxProcessor[Tuple[Notification[TSource], datetime]]) -> None:
@recursive_async
@tailrec_async
async def loop() -> Result[None, Exception]:
ns, due_time = await inbox.receive()

Expand Down Expand Up @@ -102,7 +102,7 @@ async def subscribe_async(aobv: AsyncObserver[TSource]) -> AsyncDisposable:
infinite: Iterable[int] = seq.infinite()

async def worker(inbox: MailboxProcessor[Tuple[Notification[TSource], int]]) -> None:
@recursive_async
@tailrec_async
async def message_loop(current_index: int) -> Result[TSource, Exception]:
n, index = await inbox.receive()

Expand Down
4 changes: 2 additions & 2 deletions aioreactive/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
compose,
match,
pipe,
recursive_async,
tailrec_async,
)
from expression.system import AsyncDisposable

Expand Down Expand Up @@ -256,7 +256,7 @@ async def aclose() -> None:
return AsyncAnonymousObserver(asend, athrow, aclose)

async def worker(inbox: MailboxProcessor[Msg]) -> None:
@recursive_async
@tailrec_async
async def message_loop(
current: Option[AsyncDisposable], is_stopped: bool, current_id: int
) -> Result[None, Exception]:
Expand Down
95 changes: 42 additions & 53 deletions test/test_async_iteration.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,63 @@
# import pytest
# import asyncio
# import logging
import asyncio
import logging

# from aioreactive.testing import VirtualTimeEventLoop
# from aioreactive.core import AsyncObservable, run, subscribe, AsyncStream, AsyncAnonymousObserver, AsyncIteratorObserver
# from aioreactive.operators.pipe import pipe
# from aioreactive.operators import op, from_async_iterable
# from aioreactive.operators.to_async_iterable import to_async_iterable
import aioreactive as rx
import pytest
from aioreactive.testing import VirtualTimeEventLoop

# log = logging.getLogger(__name__)
# logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)


# @pytest.yield_fixture()
# def event_loop() -> None:
# loop = VirtualTimeEventLoop()
# yield loop
# loop.close()
@pytest.yield_fixture() # type: ignore
def event_loop():
loop = VirtualTimeEventLoop()
yield loop
loop.close()


# @pytest.mark.asyncio
# async def test_async_iteration() -> None:
# xs = AsyncObservable.from_iterable([1, 2, 3])
# result = []
@pytest.mark.asyncio
async def test_async_iteration() -> None:
xs = rx.from_iterable([1, 2, 3])
result = []

# async for x in to_async_iterable(xs):
# result.append(x)
async for x in rx.to_async_iterable(xs):
result.append(x)

# assert result == [1, 2, 3]
assert result == [1, 2, 3]


# @pytest.mark.asyncio
# async def test_async_comprehension() -> None:
# xs = AsyncObservable.from_iterable([1, 2, 3])
@pytest.mark.asyncio
async def test_async_comprehension() -> None:
xs = rx.from_iterable([1, 2, 3])

# result = [x async for x in to_async_iterable(xs)]
result = [x async for x in rx.to_async_iterable(xs)]

# assert result == [1, 2, 3]
assert result == [1, 2, 3]


# @pytest.mark.asyncio
# async def test_async_iteration_aync_with() -> None:
# xs = AsyncObservable.from_iterable([1, 2, 3])
# result = []
@pytest.mark.asyncio
async def test_async_iteration_aync_with() -> None:
xs = rx.from_iterable([1, 2, 3])
result = []

# obv = AsyncIteratorObserver()
# async with subscribe(xs, obv):
# async for x in obv:
# result.append(x)
obv = rx.AsyncIteratorObserver(xs)
async for x in obv:
result.append(x)

# assert result == [1, 2, 3]
assert result == [1, 2, 3]


# @pytest.mark.asyncio
# async def test_async_iteration_inception() -> None:
# # iterable to async source to async iterator to async source
# obv = AsyncIteratorObserver()
@pytest.mark.asyncio
async def test_async_iteration_inception() -> None:
# iterable to async source to async iterator to async source
xs = rx.from_iterable([1, 2, 3])
obv = rx.AsyncIteratorObserver(xs)

# xs = AsyncObservable.from_iterable([1, 2, 3])
# await subscribe(xs, obv)
# ys = from_async_iterable(obv)
# result = []
ys = rx.from_async_iterable(obv)
result = []

# async for y in to_async_iterable(ys):
# result.append(y)
async for y in rx.to_async_iterable(ys):
result.append(y)

# assert result == [1, 2, 3]


# if __name__ == '__main__':
# loop = asyncio.get_event_loop()
# loop.run_until_complete(test_async_iteration_inception())
# loop.close()
assert result == [1, 2, 3]

0 comments on commit 7414f1f

Please sign in to comment.