From a0101bcdf4f943b4256998f220cd4c2581f1359d Mon Sep 17 00:00:00 2001 From: Jeremy Howard Date: Mon, 15 Jul 2024 10:11:55 +1000 Subject: [PATCH] fix async --- fastcore/xtras.py | 18 +++++++++++++----- nbs/03_xtras.ipynb | 41 +++++++++++++++++++++++++++++++++-------- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/fastcore/xtras.py b/fastcore/xtras.py index b48a1177..fcbd0149 100644 --- a/fastcore/xtras.py +++ b/fastcore/xtras.py @@ -689,18 +689,17 @@ def mk_dataclass(cls): # %% ../nbs/03_xtras.ipynb 179 def flexicache(*funcs, maxsize=128): "Like `lru_cache`, but customisable with policy `funcs`" + import asyncio def _f(func): cache,states = {}, [None]*len(funcs) - @wraps(func) - def wrapper(*args, **kwargs): - key = f"{args} // {kwargs}" + def _cache_logic(key, execute_func): if key in cache: result,states = cache[key] if not any(f(state) for f,state in zip(funcs, states)): cache[key] = cache.pop(key) return result del cache[key] - try: newres = func(*args, **kwargs) + try: newres = execute_func() except: if key not in cache: raise cache[key] = cache.pop(key) @@ -708,7 +707,16 @@ def wrapper(*args, **kwargs): cache[key] = (newres, [f(None) for f in funcs]) if len(cache) > maxsize: cache.popitem() return newres - return wrapper + + @wraps(func) + def wrapper(*args, **kwargs): + return _cache_logic(f"{args} // {kwargs}", lambda: func(*args, **kwargs)) + + @wraps(func) + async def async_wrapper(*args, **kwargs): + return await _cache_logic(f"{args} // {kwargs}", lambda: asyncio.ensure_future(func(*args, **kwargs))) + + return async_wrapper if asyncio.iscoroutinefunction(func) else wrapper return _f # %% ../nbs/03_xtras.ipynb 181 diff --git a/nbs/03_xtras.ipynb b/nbs/03_xtras.ipynb index 1f1eb76d..3ce54bda 100644 --- a/nbs/03_xtras.ipynb +++ b/nbs/03_xtras.ipynb @@ -2849,18 +2849,17 @@ "#| export\n", "def flexicache(*funcs, maxsize=128):\n", " \"Like `lru_cache`, but customisable with policy `funcs`\"\n", + " import asyncio\n", " def _f(func):\n", " cache,states = {}, [None]*len(funcs)\n", - " @wraps(func)\n", - " def wrapper(*args, **kwargs):\n", - " key = f\"{args} // {kwargs}\"\n", + " def _cache_logic(key, execute_func):\n", " if key in cache:\n", " result,states = cache[key]\n", " if not any(f(state) for f,state in zip(funcs, states)):\n", " cache[key] = cache.pop(key)\n", " return result\n", " del cache[key]\n", - " try: newres = func(*args, **kwargs)\n", + " try: newres = execute_func()\n", " except:\n", " if key not in cache: raise\n", " cache[key] = cache.pop(key)\n", @@ -2868,7 +2867,16 @@ " cache[key] = (newres, [f(None) for f in funcs])\n", " if len(cache) > maxsize: cache.popitem()\n", " return newres\n", - " return wrapper\n", + "\n", + " @wraps(func)\n", + " def wrapper(*args, **kwargs):\n", + " return _cache_logic(f\"{args} // {kwargs}\", lambda: func(*args, **kwargs))\n", + "\n", + " @wraps(func)\n", + " async def async_wrapper(*args, **kwargs):\n", + " return await _cache_logic(f\"{args} // {kwargs}\", lambda: asyncio.ensure_future(func(*args, **kwargs)))\n", + "\n", + " return async_wrapper if asyncio.iscoroutinefunction(func) else wrapper\n", " return _f" ] }, @@ -2913,10 +2921,23 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "3" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "@flexicache(time_policy(10), mtime_policy('000_tour.ipynb'))\n", - "def cached_func(x, y): return x+y" + "def cached_func(x, y): return x+y\n", + "\n", + "cached_func(1,2)" ] }, { @@ -2936,7 +2957,11 @@ } ], "source": [ - "cached_func(1,2)" + "@flexicache(time_policy(10), mtime_policy('000_tour.ipynb'))\n", + "async def cached_func(x, y): return x+y\n", + "\n", + "await cached_func(1,2)\n", + "await cached_func(1,2)" ] }, {