Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle non-async functions too #15

Closed
simonw opened this issue Jan 20, 2023 · 8 comments
Closed

Handle non-async functions too #15

simonw opened this issue Jan 20, 2023 · 8 comments
Labels
enhancement New feature or request research

Comments

@simonw
Copy link
Owner

simonw commented Jan 20, 2023

It may be useful if this could also handle regular def ... functions, in addition to async def functions.

I'm imagining using this for Datasette extras where some extras might be able to operate directly on data that has already been fetched by other functions - e.g. an extra which transforms objects in some way.

Refs:

@simonw simonw added enhancement New feature or request research labels Jan 20, 2023
@simonw
Copy link
Owner Author

simonw commented Jan 20, 2023

Relevant code:

async def _execute_sequential(self, results, ts):
for name in ts.static_order():
if name not in self._registry:
continue
results[name] = await self._get_awaitable(name, results)
async def _execute_parallel(self, results, ts):
ts.prepare()
tasks = []
def schedule():
for name in ts.get_ready():
if name not in self._registry:
ts.done(name)
continue
tasks.append(asyncio.create_task(worker(name)))
async def worker(name):
res = await self._get_awaitable(name, results)
results[name] = res
ts.done(name)
schedule()
schedule()
while tasks:
await asyncio.gather(*[tasks.pop() for _ in range(len(tasks))])

I need to figure out how to have that asyncio.gather() piece know how to split out the non-async functions and call them separately from running the async functions in parallel.

@simonw
Copy link
Owner Author

simonw commented Jan 20, 2023

I probably need the await_me_maybe pattern for this: https://simonwillison.net/2020/Sep/2/await-me-maybe/

import asyncio

async def await_me_maybe(value):
    if callable(value):
        value = value()
    if asyncio.iscoroutine(value):
        value = await value
    return value

@simonw
Copy link
Owner Author

simonw commented Jan 20, 2023

This seems to work - it passes the tests:

diff --git a/asyncinject/__init__.py b/asyncinject/__init__.py
index a553660..3a92f41 100644
--- a/asyncinject/__init__.py
+++ b/asyncinject/__init__.py
@@ -27,7 +27,7 @@ class Registry:
     def _make_time_logger(self, awaitable):
         async def inner():
             start = time.perf_counter()
-            result = await awaitable
+            result = await await_me_maybe(awaitable)
             end = time.perf_counter()
             self.timer(awaitable.__name__, start, end)
             return result
@@ -90,8 +90,9 @@ class Registry:
             **{k: v for k, v in results.items() if k in self.graph[name]},
         )
         if self.timer:
-            aw = self._make_time_logger(aw)
-        return aw
+            return self._make_time_logger(aw)
+        else:
+            return await_me_maybe(aw)
 
     async def _execute_sequential(self, results, ts):
         for name in ts.static_order():
@@ -132,3 +133,11 @@ class Registry:
             await self._execute_sequential(results, ts)
 
         return results
+
+
+async def await_me_maybe(value):
+    if callable(value):
+        value = value()
+    if asyncio.iscoroutine(value):
+        value = await value
+    return value

And in the Python shell (with top-level await thanks to python -m asyncio):

>>> import asyncio
>>> from asyncinject import Registry
>>> one = lambda: 1
>>> two = lambda: 2
>>> three = lambda one, two: one + two
>>> 
>>> three
<function <lambda> at 0x107d19870>
>>> three.__name__
'<lambda>'
>>> r = Registry()
>>> r.register(one, name='one')
>>> r.register(two, name='two')
>>> await r.resolve(three)
3

@simonw
Copy link
Owner Author

simonw commented Jan 20, 2023

Need to write tests that show how that operates when mixed with async functions, in particular the planning and timing stuff.

@simonw simonw changed the title Handle non-async functions too? Handle non-async functions too Jan 20, 2023
@simonw
Copy link
Owner Author

simonw commented Jan 20, 2023

Ran into a problem with the intersection of this and the logging mechanism. I changed the diff to this:

diff --git a/asyncinject/__init__.py b/asyncinject/__init__.py
index a553660..1e3bd1f 100644
--- a/asyncinject/__init__.py
+++ b/asyncinject/__init__.py
@@ -90,8 +90,9 @@ class Registry:
             **{k: v for k, v in results.items() if k in self.graph[name]},
         )
         if self.timer:
-            aw = self._make_time_logger(aw)
-        return aw
+            return self._make_time_logger(await_me_maybe(aw))
+        else:
+            return await_me_maybe(aw)
 
     async def _execute_sequential(self, results, ts):
         for name in ts.static_order():
@@ -132,3 +133,11 @@ class Registry:
             await self._execute_sequential(results, ts)
 
         return results
+
+
+async def await_me_maybe(value):
+    if callable(value):
+        value = value()
+    if asyncio.iscoroutine(value):
+        value = await value
+    return value
diff --git a/tests/test_asyncinject.py b/tests/test_asyncinject.py
index fdd3318..ea11603 100644
--- a/tests/test_asyncinject.py
+++ b/tests/test_asyncinject.py
@@ -187,7 +187,8 @@ async def test_resolve_unregistered_function(use_async):
 async def test_register():
     registry = Registry()
 
-    async def one():
+    # Mix in a non-async function too:
+    def one():
         return "one"
 
     async def two_():
@@ -207,3 +208,26 @@ async def test_register():
     result = await registry.resolve(three)
 
     assert result == "onetwo"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("parallel", (True, False))
+async def test_just_sync_functions(parallel):
+    def one():
+        return 1
+
+    def two():
+        return 2
+
+    def three(one, two):
+        return one + two
+
+    timed = []
+
+    registry = Registry(
+        one, two, three, parallel=parallel, timer=lambda *args: timed.append(args)
+    )
+    result = await registry.resolve(three)
+    assert result == 3
+
+    assert False

And got this test failure:

    @pytest.mark.asyncio
    async def test_timer(complex_registry):
        collected = []
        complex_registry.timer = lambda name, start, end: collected.append(
            (name, start, end)
        )
        await complex_registry.resolve("go")
        assert len(collected) == 6
        names = [c[0] for c in collected]
        starts = [c[1] for c in collected]
        ends = [c[2] for c in collected]
        assert all(isinstance(n, float) for n in starts)
        assert all(isinstance(n, float) for n in ends)
>       assert names[0] == "log"
E       AssertionError: assert 'await_me_maybe' == 'log'
E         - log
E         + await_me_maybe

@simonw
Copy link
Owner Author

simonw commented Jan 20, 2023

Here's the problem:

def _make_time_logger(self, awaitable):
async def inner():
start = time.perf_counter()
result = await awaitable
end = time.perf_counter()
self.timer(awaitable.__name__, start, end)
return result
return inner()

I'm trying to introspect the name of the function as part of the logging mechanism.

@simonw
Copy link
Owner Author

simonw commented Jan 20, 2023

Actually I think the core problem is here:

def _get_awaitable(self, name, results):
aw = self._registry[name](
**{k: v for k, v in results.items() if k in self.graph[name]},
)
if self.timer:
aw = self._make_time_logger(aw)
return aw

If self._registry[name] is a function (not an async def function) it will be executed directly on that line 89.

@simonw
Copy link
Owner Author

simonw commented Jan 20, 2023

OK, a better solution: the _get_awaitable() method now spots if fn is a non-awaitable function and upgrades it to be awaitable:

diff --git a/asyncinject/__init__.py b/asyncinject/__init__.py
index a553660..4a4cac9 100644
--- a/asyncinject/__init__.py
+++ b/asyncinject/__init__.py
@@ -86,9 +86,20 @@ class Registry:
         return ts
 
     def _get_awaitable(self, name, results):
-        aw = self._registry[name](
-            **{k: v for k, v in results.items() if k in self.graph[name]},
-        )
+        fn = self._registry[name]
+        kwargs = {k: v for k, v in results.items() if k in self.graph[name]}
+
+        awaitable_fn = fn
+
+        if not asyncio.iscoroutinefunction(fn):
+
+            async def _awaitable(*args, **kwargs):
+                return fn(*args, **kwargs)
+
+            _awaitable.__name__ = fn.__name__
+            awaitable_fn = _awaitable
+
+        aw = awaitable_fn(**kwargs)
         if self.timer:
             aw = self._make_time_logger(aw)
         return aw

@simonw simonw closed this as completed in c6b8245 Apr 14, 2023
simonw added a commit that referenced this issue Apr 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request research
Projects
None yet
Development

No branches or pull requests

1 participant