Skip to content

Commit

Permalink
Merge pull request #260 from Krukov/fix-iterator-on-error
Browse files Browse the repository at this point in the history
Fix iterator on error
  • Loading branch information
Krukov authored Aug 13, 2024
2 parents 418c7b7 + 3e2ec5d commit ad91db0
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 10 deletions.
5 changes: 4 additions & 1 deletion cashews/contrib/_starlette.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from starlette.responses import StreamingResponse

from cashews.backends.interface import Backend
from cashews.serialize import register_type
from cashews.serialize import DecodeError, register_type


async def encode_streaming_response(
Expand All @@ -16,6 +16,8 @@ async def encode_streaming_response(


async def decode_streaming_response(value: bytes, backend: Backend, key: str, **kwargs) -> StreamingResponse:
if not await backend.get(f"{key}:done"):
raise DecodeError()
status_code, headers = value.split(b":")
raw_headers = []
for header in headers.split(b";"):
Expand All @@ -36,6 +38,7 @@ async def set_iterator(backend: Backend, key: str, iterator, expire: int):
await backend.set(f"{key}:chunk:{chunk_number}", chunk, expire=expire)
yield chunk
chunk_number += 1
await backend.set(f"{key}:done", True, expire=expire) # mark as finished


async def get_iterator(backend: Backend, key: str):
Expand Down
28 changes: 24 additions & 4 deletions cashews/decorators/cache/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from cashews.key import get_cache_key, get_cache_key_template
from cashews.ttl import ttl_to_seconds

from ._exception import RaiseException, return_or_raise
from .defaults import context_cache_detect

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -16,6 +17,12 @@
__all__ = ("iterator",)


if "anext" not in globals():

async def anext(ait):
return await ait.__anext__()


def iterator(
backend: _BackendInterface,
ttl: TTL,
Expand Down Expand Up @@ -49,14 +56,27 @@ async def _wrap(*args, **kwargs):
chunk = await backend.get(_cache_key + f":{chunk_number}")
if not chunk:
return
yield chunk
yield return_or_raise(chunk)
chunk_number += 1

_to_cache = condition(None, args, kwargs, key=_cache_key)
start = time.monotonic()
async for chunk in async_iterator(*args, **kwargs):
_to_cache = False
_async_iterator = async_iterator(*args, **kwargs)
while True:
try:
chunk = await anext(_async_iterator)
except StopAsyncIteration:
break
except Exception as exc:
cond_res = condition(exc, args, kwargs, key=_cache_key)
if cond_res and isinstance(cond_res, Exception):
_to_cache = True
await backend.set(_cache_key + f":{chunk_number}", RaiseException(exc), expire=_ttl)
await backend.set(_cache_key, True, expire=_ttl - time.monotonic() + start)
raise exc
yield chunk
if _to_cache:
if condition(chunk, args, kwargs, key=_cache_key):
_to_cache = True
await backend.set(_cache_key + f":{chunk_number}", chunk, expire=_ttl)
chunk_number += 1
if _to_cache:
Expand Down
10 changes: 8 additions & 2 deletions cashews/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,14 @@ async def _custom_decode(self, backend: Backend, key: Key, value: bytes, default
if value_type not in self._type_mapping:
return default
_, decoder = self._type_mapping[value_type]
decode_value = await decoder(value, backend, key)
return decode_value
try:
return await decoder(value, backend, key)
except DecodeError:
return default


class DecodeError(Exception):
pass


register_type = Serializer.register_type
Expand Down
17 changes: 17 additions & 0 deletions tests/test_cache_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@ async def func(raise_exception=None):
assert mock.call_count == 3


async def test_cache_exc_not_cached_by_default(cache: Cache):
mock = Mock()

@cache(ttl=EXPIRE)
async def func():
mock()
raise CustomError()

with pytest.raises(CustomError):
await func()

with pytest.raises(CustomError):
await func()

assert mock.call_count == 2


async def test_cache_only_exceptions(cache: Cache):
mock = Mock()

Expand Down
25 changes: 23 additions & 2 deletions tests/test_intergations/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
pytest.param("diskcache", marks=pytest.mark.integration),
],
)
def _cache(request, redis_dsn):
async def _cache(request, redis_dsn):
dsn = "mem://"
if request.param == "diskcache":
dsn = "disk://"
Expand All @@ -27,7 +27,8 @@ def _cache(request, redis_dsn):
dsn = redis_dsn + "&client_side=t"
cache = Cache()
cache.setup(dsn, suppress=False)
return cache
yield cache
await cache.clear()


@pytest.fixture(name="app")
Expand Down Expand Up @@ -114,6 +115,26 @@ async def stream():
assert response.content == b"0123456789"


def test_cache_stream_on_error(client, app, cache):
from starlette.responses import StreamingResponse

def iterator():
for i in range(10):
if i == 5:
raise Exception
yield f"{i}"

@app.get("/stream")
@cache(ttl="10s", key="stream")
async def stream():
return StreamingResponse(iterator(), status_code=201, headers={"X-Test": "TRUE"})

with pytest.raises(Exception):
client.get("/stream")
with pytest.raises(Exception):
client.get("/stream")


def test_cache_delete_middleware(client_with_middleware, app, cache):
from cashews.contrib.fastapi import CacheDeleteMiddleware

Expand Down
68 changes: 67 additions & 1 deletion tests/test_iterators_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
from unittest.mock import Mock

from cashews import Cache
import pytest

from cashews import Cache, with_exceptions


async def test_iterator(cache: Cache):
Expand Down Expand Up @@ -29,3 +31,67 @@ async def func():
i += 1

assert call.call_count == 0


class MyException(Exception):
pass


async def test_iterator_error_with_cond(cache: Cache):
call = Mock(side_effect=["a", "b", MyException(), "c", "d"])

@cache.iterator(ttl=10, key="iterator", condition=with_exceptions(MyException))
async def func():
while True:
try:
yield call()
except StopIteration:
return
await asyncio.sleep(0)

full = ""
with pytest.raises(MyException):
async for chunk in func():
assert chunk
full += chunk

assert full == "ab"
assert call.call_count == 3
call.reset_mock()

with pytest.raises(MyException):
async for chunk in func():
assert chunk
full += chunk

assert full == "abab"
assert call.call_count == 0


async def test_iterator_error(cache: Cache):
call = Mock(side_effect=["a", MyException(), "c"])

@cache.iterator(ttl=10, key="iterator")
async def func():
while True:
try:
yield call()
except StopIteration:
return
await asyncio.sleep(0)

with pytest.raises(MyException):
async for chunk in func():
assert chunk == "a"

assert call.call_count == 2 # return a and raise error

async for chunk in func():
assert chunk == "c"

assert call.call_count == 4 # return a, raise error + return c and stopIteration

async for chunk in func():
assert chunk == "c"

assert call.call_count == 4

0 comments on commit ad91db0

Please sign in to comment.