Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow streaming nodes to timeout #125

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tmp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import asyncio
import time
import tributary.streaming as ts
import requests

def create(interval):
async def foo_():
for _ in range(5):
yield interval
await asyncio.sleep(interval)
return foo_


fast = ts.Foo(create(1))
med = ts.Foo(create(2))
slow = ts.Foo(create(3))

def reducer(fast, med, slow):
return {"fast": fast, "med": med, "slow": slow}

node = ts.Reduce(fast, med, slow, reducer=reducer).print()
ts.run(node, period=1)
3 changes: 2 additions & 1 deletion tributary/streaming/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from .utils import *


def run(node, blocking=True, **kwargs):
def run(node, blocking=True, period=None, **kwargs):
graph = node.constructGraph()
kwargs["blocking"] = blocking
kwargs["period"] = period
return graph.run(**kwargs)
61 changes: 44 additions & 17 deletions tributary/streaming/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,52 @@ def rebuild(self):
def stop(self):
self._stop = True

async def _run(self):
async def _run(self, period=None):
"""this is the main graph runner. it is pretty straightforward, we go through
all the layers of the graph and execute the layer as a batch of coroutines.

If we generate a stop event (e.g. graph is done), we stop.

If a period is set, a layer in the graph will run for at max `period` seconds
before pushing a None.

Args:
period (Optional[int]): max period to wait
"""
value, last, self._stop = None, None, False

# run onstarts
await asyncio.gather(*(asyncio.create_task(s()) for s in self._onstarts))

while True:
for level in self._nodes:
if self._stop:
break
if period is not None:
sets = {}
for i, level in enumerate(self._nodes):
sets[i] = set()
for n in level:
sets[i].add(asyncio.create_task(n()))

await asyncio.gather(*(asyncio.create_task(n()) for n in level))
# TODO
# wrap each individual node in a task
# add tasks to set
# execute all and remove from set on callback
# how loop checking if tasks are done up until `period`
# force push None for remaining (`_output(None)`)
# next level
# on next loop around only re-wrap and re-call those that aren't still in the set
raise NotImplementedError()
else:
for level in self._nodes:
if self._stop:
break

await asyncio.gather(*(asyncio.create_task(n()) for n in level))
await asyncio.gather(*(asyncio.create_task(n()) for n in level))

self.rebuild()
self.rebuild()

if self._stop:
break
if self._stop:
break

value, last = self._starting_node.value(), value

Expand All @@ -78,7 +107,8 @@ async def _run(self):
# return last val
return last

def run(self, blocking=True, newloop=False, start=True):
def run(self, blocking=True, newloop=False, period=None):

if sys.platform == "win32":
# Set to proactor event loop on window
# (default in python 3.8+)
Expand All @@ -94,7 +124,7 @@ def run(self, blocking=True, newloop=False, start=True):

asyncio.set_event_loop(loop)

task = loop.create_task(self._run())
task = loop.create_task(self._run(period=period))

if blocking:
# block until done
Expand All @@ -103,13 +133,10 @@ def run(self, blocking=True, newloop=False, start=True):
except KeyboardInterrupt:
return

if start:
t = Thread(target=loop.run_until_complete, args=(task,))
t.daemon = True
t.start()
return loop

return loop, task
t = Thread(target=loop.run_until_complete, args=(task,))
t.daemon = True
t.start()
return loop

def graph(self):
return self._starting_node.graph()
Expand Down
14 changes: 12 additions & 2 deletions tributary/streaming/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ async def __call__(self):
if isinstance(val, StreamEnd):
return await self._finish()

if isinstance(val, (StreamNone,)):
ready = False

# set as active
self._active[i] = val
else:
Expand All @@ -234,6 +237,7 @@ async def __call__(self):
# Private interface
# ***********************
def __hash__(self):
"""nodes are unique"""
return self._id

def __rshift__(self, other):
Expand Down Expand Up @@ -273,10 +277,12 @@ async def _execute(self):
# else call it
elif isinstance(self._foo, types.FunctionType):
try:
# could be a generator
# could be a kicked generator, so wrap in try
try:
# execute wrapped function
_last = self._foo(*self._active, **self._foo_kwargs)
except ZeroDivisionError:
# catch divide by zero and force inf
_last = float("inf")

except ValueError:
Expand All @@ -285,6 +291,7 @@ async def _execute(self):
continue

else:
# can only wrap function types
raise TributaryException("Cannot use type:{}".format(type(self._foo)))

# calculation was valid
Expand All @@ -294,7 +301,7 @@ async def _execute(self):
self._execution_count += 1

if isinstance(_last, types.AsyncGeneratorType):

# Swap to async generator unroller
async def _foo(g=_last):
return await _agen_to_foo(g)

Expand All @@ -308,6 +315,7 @@ async def _foo(g=_last):
_last = self._foo()

elif asyncio.iscoroutine(_last):
# await coroutine
_last = await _last

if self._repeat:
Expand All @@ -319,8 +327,10 @@ async def _foo(g=_last):
else:
self._last = _last

# push result downstream
await self._output(self._last)

# allow new inputs
for i in range(len(self._active)):
self._active[i] = StreamNone()

Expand Down